ETNet: Error Transition Network for Arbitrary Style Transfer (NeurIPS 2019)
This repository contains the code (in TensorFlow) for the paper:
ETNet: Error Transition Network for Arbitrary Style Transfer
Chunjin song*, Zhijie Wu*, Yang Zhou, Minglun Gong, Hui Huang (* equal contribution, in alphabetic order)
NeurIPS 2019
This repository contains an official implementation for ETNet: Error Transition Network for Arbitrary Style Transfer. To improve the stylization results, we introduce an iterative error-correction mechanism to break the stylization process into multiple refinements with the Laplacian pyramid strategy. Given an insufficiently stylized image in a refinement, we compute what is wrong with the current estimate and then transit the error information to the whole image. The simplicity and motivation lie in the following aspect: the detected errors evaluate the residuals to the ground truth, which thus can guide the refinement effectively. Based on this motivation, ETNet can achieve stylization by presenting more adaptive style patterns and preserving high-level content structure better.
If you have any questions, please feel free to contact Zhijie Wu ([email protected]).
In order to train a model to transfer styles, we should set some essential information in the train.yaml
.
GPU_ID: [0,1,2,3] # GPU ids used for training
...
Incremental: 2 # Count of layers in current network
layers_num: 3
...
dir_checkpoint: [[0, models/layer1/checkpoints],[1, models/layer2/checkpoints]] # Directory of pretrained models: [[layer_0, model_0],[layer_1, model_1],...]
vgg_weights: models/vgg_normalised.t7
data: # Directory for input training images
dir_content: image/input/content
dir_style: image/input/style
dir_out: image/output/train # Root directory for results
output: # Subdirectories for results
dir_log: logs
dir_config: configs
dir_sample: samples
dir_checkpoint: checkpoints
Generally, the train.yaml
is written in json
format and contain the following options.
GPU_ID
: a list for GPU ID. We use four GPUs (e.g. [0, 1, 2, 3]) for training the network at third level by default.Incremental
: the layer index for training, starting from 0.layers_num
: the network count in current model.weight
: the weight to balance the effect of content
and styles
of perceptual loss. When we train the networks at different levels, we assign different values to style
. For a model with three different levels, we set styles
as 2
, 5
, 8
at the training stages for the networks at first, second and third level respectively.dir_checkpoint
: a path list for pretrained networks. Each element in the list is also another list, which aims to indicate the layer index and corresponding model path.vgg_weights
: the directory of pretrained VGG models.data
: the directory to place all the training data. Specifically, dir_content
and dir_style
are the two directories to store content and style images correspondingly. They can be image/input/content
for multiple content images and image/input/style
for style images.dir_out
: the root directory for output results. During training, the model would output the initial configuration information and other intermediate information about logs, stylized samples, checkpoints into dir_config
, dir_log
, dir_sample
and dir_checkpoint
.Please note that, when training networks at different levels, we should update Incremental
, layers_num
, weight
and dir_checkpoint
correspondingly to adjust the balance between content and style, preserve the same training batch size and indicate which network to be trained. For the network at first level, these keys should be set up as:
GPU_ID: [0]
Incremental: 0
layers_num: 1
weight:
content: 1.0
style: 2.0
dir_checkpoint: None
But for finetuning this network, the dir_checkpoint
should be set as:
dir_checkpoint: [[0, models/layer1/checkpoints]] # models/layer1/checkpoints is path of a pretrained model
For the network at second level, we update the configuration as:
GPU_ID: [0,1]
Incremental: 1
layers_num: 2
weight:
content: 1.0
style: 5.0
dir_checkpoint: [[0, models/layer1/checkpoints]]
Similarly, we can add another list to the dir_checkpoint
to include an extra pretrained model for further finetuning.
As for the network at third level, the file can be configured as:
GPU_ID: [0,1,2,3]
Incremental: 2
layers_num: 3
weight:
content: 1.0
style: 8.0
dir_checkpoint: [[0, models/layer1/checkpoints],[1, models/layer2/checkpoints]]
After the train.yaml
has been set well, then we directly run train.py
to start model training as:
python train.py
Before transfering styles into content structures with the trained model, we should configurate test.yaml
as:
GPU_ID: 0 # the gpu for testing
layers_num: 3
checkpoint_dir: [[0, models/layer1/checkpoints],[1, models/layer2/checkpoints],[2, models/layer3/checkpoints]]
vgg_weights: models/vgg_normalised.t7
data: # Directory for testing images
dir_content: image/input/content
dir_style: image/input/style
dir_out: image/output/test # Root directory for output results
output: # Subdirectories for outputs
dir_result: results
This file (test.yaml
) has the following options:
GPU_ID
: indicate which gpu should be used for testing.layers_num
: the level number within a model.checkpoint_dir
: the path list to place the pretrained model, which will be used for generation. More information can be refered in the training
part.vgg_weights
: the path for pretrained VGG models.data
: the directory of the input images. Specifically, dir_content
and dir_style
are used to indicate the pathes of content and style images respectively.dir_out
: the output directory for evaluation results. It can be image/output/test
for multiple synthesis images.For the model with three different levels, after setting test.yaml
as shown above, we can start the testing by running stylize.py
, such as
python stylize.py
If you use our code/model/data, please cite our paper:
@misc{song2019etnet,
title={ETNet: Error Transition Network for Arbitrary Style Transfer},
author={Chunjin Song and Zhijie Wu and Yang Zhou and Minglun Gong and Hui Huang},
year={2019},
eprint={1910.12056},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
MIT License