🖼 Training StyleGAN2 on TPUs in JAX
This implementation is adapted from the stylegan2 codebase by Matthias Wright.
Specifically, the features we've added allow for better scaling of StyleGAN2 training on TPUs:
This research is part of the technology underlying our AI-generated photography platform Nyx.gallery
This food does not exist! Click to see more samples 🍪🍰🍣🍹🍔
--data_dir
or in a subdirectory tfrecords
). Still requires dataset_info.json
in --data_dir
location (containing width
, heigh
, num_examples
, and list of classes
if class-conditional).--load_from_pkl
=> --load_from_ckpt
--num_steps
argument to specify a fixed number of steps to run--early_stopping_after_steps
argument to stop after n steps of no FID improvement--bf16
flag and consolidation with --mixed_precision
.--freeze_g
and --freeze_d
arguments--fmap_max
argument, in order to have better control over feature map dimensions--save_every
and --keep_n_checkpoints
)--metric_cache_location
in order to cache dataset statistics (currently for FID only)git clone https://github.com/nyx-ai/stylegan2-flax-tpu.git
cd stylegan2-flax-tpu
pip install -r requirements.txt
We released four 256x256 as well as 512x512 models. Download them from the latest release.
python generate_images.py \
--checkpoint checkpoints/cookie-256.pkl \
--seeds 0 42 420 666 \
--truncation_psi 0.7 \
--out_path generated_images
Check the Colab notebook for more examples:
Add your images into a folder /path/to/image_dir
:
/path/to/image_dir/
0.jpg
1.jpg
2.jpg
4.jpg
...
and create a TFRecord dataset:
python dataset_utils/images_to_tfrecords.py --image_dir /path/to/image_dir/ --data_dir /path/to/tfrecord
For more detailed instructions please refer to this README.
The following command trains with 128 resolution and batch size of 8.
python main.py --data_dir /path/to/tfrecord
Read more about suitable training parameters here.
Our experiments have been run and tested on TPU VMs (generation v2 to v4). At the time of writing Colab is offering an older generation of TPUs. Therefore training (and especially compilation) may be significantly slower. If you still wish to train on Colab, the following may get you started: