Winner solution of mobile AI (CVPRW 2021).
A winner solution for MAI2021 Competition(CVPR2021 Workshop). Our model outperforms other participtants by a large margin in terms of both inference speed and reconstruction performance.
Challenge report: Mobile AI 2021 Real-Time Image Super-Resolution Challenge.
Our paper: Anchor-based Plain Net for Mobile Image Super-Resolution.
We conduct an experiment about meta-node latency by decomposing lightweight SR architectures, which determines the portable operations we can utilize. This step is curcial important if you want to deploy your model across mobile device.
For full-integer quantization which means all the weights and activations are int8, it's obvious a better choice to learn residual(always close to zero) rather than directly mapping low-resolution image to high-resolution image. In existing methods, residual learning can be divided into two categories: (1). Image space residual learning means passing the interpolated-input(bilinear, bicubic) to network output. (2).Feature space residual learning means passing the output of shallow convolutional layer to network output. For float32 quantized model, feature space residual learning is slightly better(+0.08dB). For int8 quantized model, image space residual learning is always better(+0.3dB) because it forces the whole network to learn subtle change, thus a set of continuous real-valued numbers can be represented more accurately using a fixed discrete set of numbers. However, bilinear resize and nearest neighbor resize is really slow on mobile device due to pixel-wise multiplication when doing coordinate mapping. Our anchor-based residual learning can enjoy the good property of image space residual learning while being as fast as feature space residual learning. The core operation is repeating input nine times(for x3 scale) and add it to the feature before depth-to-space. See our architecture in model.
After deep feature extraction, existing methods use one convolution to map features to origin image space, followed by a depth-to-space(PixelShuffle in Pytorch) layer. We find that in image space, one more convolution can significantly improve the performance compared with adding one convolution in deep feature extraction stage(+0.11dB).
It should be noted that tensorflow version matters a lot because old versions don't include some layers such as depth-to-space, so you should make sure tf version is larger than 2.4.0. Another important thing is that only tf-nightly larger than 2.5.0 can perform arbitrary input shape quantization. I provide two conda environments, tf.yaml for training and tfnightly.yaml for Post-Training Quantization(PTQ) and Quantization-Aware Training(QAT). You can use the following scripts to create two separate conda environments.
conda env create -f tf.yaml
conda env create -f tfnightly.yaml
Download DIV2K and put DIV2K in data folder. Then the structure should look like:
data
DIV2K
DIV2K_train_HR
0001.png
...
0900.png
DIV2K_train_LR_bicubic
X2
0001x2.png
...
0900x2.png
python train.py --opt options/train/base7.yaml --name base7_D4C28_bs16ps64_lr1e-3 --scale 3 --bs 16 --ps 64 --lr 1e-3 --gpu_ids 0
Note:
The argument --name
specifies the following save path:
log/{name}.log
experiment/{name}/best_status/
Tensorboard/{name}/
You can use tensorboard to monitor the training and validating process by:
tensorboard --logdir Tensorboard
If you haven't worked with Tensorflow Lite and network quantization before, please refer to official guideline. This technology inserts fake quantization nodes to make the weights aware that themselves will be quantized. For this model, you can simply use the following script to perform QAT:
python train.py --opt options/train/base7_qat.yaml --name base7_D4C28_bs16ps64_lr1e-3_qat --scale 3 --bs 16 --ps 64 --lr 1e-3 --gpu_ids 0 --qat --qat_path experiment/base7_D4C28_bs16ps64_lr1e-3/best_status
python generate_tflite.py
Then the converted tflite model will be saved in TFMODEL/
. TFMODEL/{name}.tflite
is used for predicting high-resolution image(arbitary low-resolution input shape is allowed), while TFMODEL/{name}_time.tflite
fixes model input shape to [1, 360, 640, 3]
for getting inference time.
:) If you have any questions, feel free to contact [email protected]