Invert and perturb GAN images for test-time ensembling
Project Page | Paper | Bibtex
Ensembling with Deep Generative Views.
Lucy Chai, Jun-Yan Zhu, Eli Shechtman, Phillip Isola, Richard Zhang
CVPR 2021
Table of Contents:
This Colab Notebook demonstrates the basic latent code perturbation and classification procedure in a simplified setting on the aligned cat dataset.
git clone https://github.com/chail/gan-ensembling.git
cd gan-ensembling
Install dependencies:
environment.yml
file listing the dependencies. You can create the Conda environment using:onda env create -f environment.yml
Download resources:
ash resources/download_resources.sh
styleganinv_ffhq256_encoder.pth
and styleganinv_ffhq256_encoder.pth
, and place them in models/pretrain
Download external datasets:
dataset/celebahq/images/images
.dataset/cars/images/images
and the devkit in dataset/cars/devkit
.An example of the directory organization is below:
dataset/celebahq/
images/images/
000004.png
000009.png
000014.png
...
latents/
latents_idinvert/
dataset/cars/
devkit/
cars_meta.mat
cars_test_annos.mat
cars_train_annos.mat
...
images/images/
00001.jpg
00002.jpg
00003.jpg
...
latents/
dataset/catface/
images/
latents/
dataset/cifar10/
cifar-10-batches-py/
latents/
Once the datasets and precomputed resources are downloaded, the following code snippet demonstrates how to perturb GAN images. Additional examples are contained in notebooks/demo.ipynb
.
import data
from networks import domain_generator
dataset_name = 'celebahq'
generator_name = 'stylegan2'
attribute_name = 'Smiling'
val_transform = data.get_transform(dataset_name, 'imval')
dset = data.get_dataset(dataset_name, 'val', attribute_name, load_w=True, transform=val_transform)
generator = domain_generator.define_generator(generator_name, dataset_name)
index = 100
original_image = dset[index][0][None].cuda()
latent = dset[index][1][None].cuda()
gan_reconstruction = generator.decode(latent)
mix_latent = generator.seed2w(n=4, seed=0)
perturbed_im = generator.perturb_stylemix(latent, 'fine', mix_latent, n=4)
Important: First, set up symlinks required for notebooks: bash notebooks/setup_notebooks.sh
, and add the conda environment to jupyter kernels: python -m ipykernel install --user --name gan-ensembling
.
The provided notebooks are:
notebooks/demo.ipynb
: basic usage examplenotebooks/evaluate_ensemble.ipynb
: plot classification test accuracy as a function of ensemble weightnotebooks/plot_precomputed_evaluations.ipynb
: notebook to generate figures in paperThe full pipeline contains three main parts:
Examples for each step of the pipeline are contained in the following scripts:
bash scripts/optimize_latent/examples.sh
bash scripts/train_classifier/examples.sh
bash scripts/eval_ensemble/examples.sh
To add to the pipeline:
data/
directory, add the dataset in data/__init__.py
and create the dataset class and transformation functions. See data/data_*.py
for examples.networks/domain_generators.py
to add the generator in domain_generators.define_generator
. The perturbation ranges for each dataset and generator are specified in networks/perturb_settings.py
.networks/domain_classifiers.py
to add the classifier in domain_classifiers.define_classifier
We thank the authors of these repositories:
If you use this code for your research, please cite our paper:
@inproceedings{chai2021ensembling,
title={Ensembling with Deep Generative Views.},
author={Chai, Lucy and Zhu, Jun-Yan and Shechtman, Eli and Isola, Phillip and Zhang, Richard},
booktitle={CVPR},
year={2021}
}