DCGAN and WGAN implementation on Keras for Bird Generation
This repository was created for me to familiarize with DCGANs and its peculiarities. The code uses Keras library and the Caltech-UCSD Birds-200-2011 dataset.
This dataset has 11788 images from 200 different birds. It includes also some additional information such as segmentations, attributes and bounding boxes that will not be used in this project. Here are some of the different images in this dataset:
I was looking for a different dataset for my personal experiments, and when I found it, it seemed like a good idea to try generate birds from this dataset.
I tried to implement a DCGAN (Deep Convolutional GAN) since most of GANs are at least loosely based on the DCGAN architecture. The original DCGAN architecture proposed by Radford et al. (2015), can be shown in the next image:
Like traditional GANs, DCGANs consist in a discriminator which tries to classify images as real or fake, and a generator that tries to produce samples that will by fed to the discriminator trying to mislead it. DCGANs usually refer to this specific style of architecture shown in the previous figure, since previously proposed GANs were also deep and convolutional prior to this work.
From the original paper we can find some tips to make DCGAN work:
Although I tried different architectures based on other papers and repositories, the best result was achieved using a traditional DCGAN architecture (see the code for more datails)
In each training iteration:
Each epoch, 3 bathes of generated images are saved in the disk as .png files and each 5 epochs the models are saved in the disk. Also the loss plot is saved in the end of each epoch.
It is important to have the folders generated images and models in the same path of the trainDCGAN.py file.
After training the DCGAN, I wanted to compare the results with the WGAN proposed by Arjovsky et al. (2017), therefore I also implemented code for WGAN training in Keras.
In this paper, it is proposed a meaningful loss function, which helps debugging GANs and tuning its parameters and architectures. This loss function, the Wasserstein loss (or Earth-Moving loss), is shown to have a relationship with image quality. The use of this new loss function has proven to avoid mode collapses (generating similar images) even when fully connected layer GANS are used or when batch normalization is not used.
The word discriminator in this paper is replaced by critic since its purpose is not to classify directly if the image is true or fake – we have no longer a binary classification problem. The critic no longer needs to output a probability. When training the critic, the labels provided for the real images are 1’s and for the fake images are negative 1’s.
Fake and real images are interpreted as coming from a probability distribution, the WGAN tries to measure the difference between these distributions and minimize it (by minimizing the loss). Since the computation of Wasserstein distance is intractable, the paper uses an approximation:
K.mean(y_true * y_predicted)
This only works if the network used is K-Lipschitz. In order this to be true, weights are clamped in the discriminator to very small values near zero. The authors admit this is not the ideal solution, but in practice works (notice that the improved version of the WGAN resolves this problem in a more interesting way).
In the traditional DCGANs to get good images we need to balance the train between the generator and the discriminator. With the new Wasserstein loss, this problem is solved: since the Wasserstein distance is almost every time differentiable, the critic can be train to convergence before training the generator, since the gradients become more accurate in this process.
Regarding the training some topics need to be referred:
The loss function is expected is similar to the shown by the paper:
Additionally there are some interesting facts that I discovered that I think might help someone:
Evolution of generated images along 190 epochs:
Here are some of the Images Generated in the End of the Training:
And here is the generator and discriminator loss evolution during training:
The training was done in a laptop with Ubuntu 16.04 with a NVIDIA GTX-960M and it took a dozen of hours. I am quite satisfied with the result. The images are a little blurred but it seems like the generator learnt how to make artistic bird images.
The results could be better but I'm quite satisfied with the mutant birds it generated. You can say what you want but these are my mutant birds... Despite that I'll try to make improvements in the future to the DCGAN when I have some time.
I followed the same architecture used for DCGAN in order to have a proper comparision of the results. The weights were initialized in a different manner in order not to be all clamped at the beggining of the training.
Notice that in train_wgan.py the results are not saved in a png file. I decided to use tensorboard on this one, so the plots will be saved in a log file that you can open with tensorboard (assuming you're using tensorflow as keras backend).
The training was performed for almost 20000 generator iterations and took almost 3 days in the same setup described earlier. The loss curves follow the convergence referred in the original paper:
The peaks observed in the loss curves, are probably one of the problems described by the author and referred earlier: the critic probably wasn't trained until optimally. I tried to follow the author's sugestions and change the learning rates and increase the number of critic iterations, but the peaks continued to occur. Probably this was one of the reasons why I did not see imporvements in the generated images. Here is one example after 20k iterations:
Despite that, I think the new loss function is quite usefull to know when a specified GAN is converging or not, despite of being a lot slower than traditional DGANs. With a little more investigation and time (let's not forget it took 3 days to train the WGAN) I think the problem of the peaks could be removed by training more the discriminator.
After taking a more deeper look into the pytorch original code I found out that the bias of the layers were set to False. I could not find why this is done, but this may also contibuted to the results in my WGAN, since I set the biases by default.
If I have some time later I will probably try some of these:
Of course I have based my code on other projects and repositories. I would like to give credit to some of them: