Implementation of Temporal Ensembling for Semi-Supervised Learning by Laine et al. with tensorflow eager execution
This repository includes a implementation of Temportal Ensembling for Semi-Supervised Learning by Laine et al. with Tensorflow eager execution.
When I was reading "Realistic Evaluation of Deep Semi-Supervised Learning Algorithms" by Avital Oliver (2018), I realized I had never played enough with Semi-Supervised Learning, so I came across this paper and thought it was interesting for me to play with. (I highly recommend reading the paper by Avital et al., one of my favorite recent papers). Additionally eager will be the default execution method when the 2.0 Tensorflow version comes out, so I though I should use it in this repository.
Semi-Supervised Learning algorithms try improving traditional supervised learning ones by using unlabeled samples. This is very interesting because in real-world there are a big amount of problems where we have a lot of data that is unlabeled. There is no valid reason why this data cannot be used to learn the general structure of the dataset by supporting the learning process of the supervised training.
The paper propose two implementations of self-ensembling, i.e. forming different ensemble predictions in the training process under different conditions of regularization (dropout) and augmentations.
The two different methods are -Model and temporal ensembling. Let's dive a little bit into each one of them.
In the -Model the training inputs are evaluated twice, under different conditions of dropout regularization and augmentations, resulting in tow outputs and .
The loss function here includes two components:
This two components are combined by summing both components and scaling the unsupervised one using a time-dependent weighting function . According to the paper, asking and to be close is a much more string requirement than traditional supervised cross-entropy loss.
The augmentations and dropouts are always randomly different for each input resulting in different output predictions. Additionally to the augmentations, the inputs are also combined with random gaussian noise to increase the variability.
The weighting function will be described later since is also used by temporal ensembling, but it ramps up from zero and increases the contribution from the unsupervised component reaching its maximum by 80 epochs. This means that initially the loss and the gradients are mainly dominated by the supervised component (the authors found that this slow ramp-up is important to keep the gradients stable).
One big difficulty of the -Model is that it relies on the output predictions that can be quite unstable during the training process. To combat this instability the authors propose the temporal ensembling.
Trying to resolve the problem of the noisy predictions during train, temporal ensembling aggregates the predictions of past predictions into an ensemble prediction.
Instead of evaluating each input twice, the predictions are forced to be close the a ensemble prediction , that is based on previous predictions of the network. This algorithm stores each prediction vector and in the end of each epoch, these are accumulated in the ensemble vector by using the formula:
where $\alpha$ is a term that controls how far past predictions influence the temporal ensemble. This vector contains a weighted average of previous predictions for all instances, with recent ones having a higher weight. The ensemble training targets , to be comparable to need to be scaled by dividing them by .
In this algorithm, and are zero on the first epoch, since no past predictions exist.
The advantages of this algorithm when compared to the -Model is:
The disadvantages are the following:
To test the two types of ensembling, the authors used a CNN with the following architecture:
The datasets tested were CIFAR-10, CIFAR-100 and SVHN. I will focus on the latter since currently I only tested on this dataset.
SVHN Dataset
The Street View House Numbers (SVHN) dataset includes images with digits and numbers in natural scene images. The authors used the MNIST-like 32x32 images centered around a single character, trying to classify the center digit. It has 73257 digits for training, 26032 digits for testing, and 531131 extra digits (not used currently).
Unsupervised Weight Function
As described before, both algorithms use a time-dependent weighting function . I the paper the authors use a a Gaussian ramp-up curve that grows from 0 to 1 in 80 epochs, and remains constant for all training:
The function that describes this rampup is: ![rampup]
Notice that the final weight in each epoch corresponds to , where in the number of labeled samples used, N is the number of total training samples and is a constant that varies across the problem and dataset.
The train used Adam, and the learning rate also suffers a rampup in the first 80 epochs and a rampdown in the last 50 epochs (the rampdown is similar to the rampup Gaussian function, but has a scaling constant of 12.5 instead of 5:
The learning rate in each epoch corresponds to the multiplication of this temporal weight function by a max learning rate (hyperparameter) Adam's ![beta1] also was annealed using this function but instead of tending to 0 it converges to 0.5 on the last 50 epochs.
Training parameters
The training hyperparameters described in the paper are different in both algorithms and for the different datasets:
There are some differences in this repository regarding the paper:
Code Details
train_pi_model.py includes the main function for training the -Model on SVHN dataset. You can edit some variables described before in the beginning of the main function. The default parameters are the described in the paper:
# Editable variables
num_labeled_samples = 1000
num_validation_samples = 200
batch_size = 50
epochs = 300
max_learning_rate = 0.001
initial_beta1 = 0.9
final_beta1 = 0.5
checkpoint_directory = './checkpoints/PiModel'
tensorboard_logs_directory = './logs/PiModel'
Similarly, train_temporal_ensembling_model.py includes the main function for training the temporal ensembling model on SVHN:
# Editable variables
num_labeled_samples = 3000
num_validation_samples = 1000
num_train_unlabeled_samples = NUM_TRAIN_SAMPLES - num_labeled_samples - num_validation_samples
batch_size = 150
epochs = 300
max_learning_rate = 0.0002 # 0.001 as recomended in the paper leads to unstable training.
initial_beta1 = 0.9
final_beta1 = 0.5
alpha = 0.6
max_unsupervised_weight = 30 * num_labeled_samples / (NUM_TRAIN_SAMPLES - num_validation_samples)
checkpoint_directory = './checkpoints/TemporalEnsemblingModel'
tensorboard_logs_directory = './logs/TemporalEnsemblingModel'
svnh_loader.py and tfrecord_loader.py have helper classes for downloading the dataset and save them in tfrecords in order to be loaded as tf.data.TFRecordDataset.
pi_model.py is where the model is defined as tf.keras.Model and where some training functions are defined like rampup and rampdown functions, the loss and gradients functions.
In the folder weight_norm_layers there are some edited tensorflow.layers wrappers for allowing weight normalization and mean-only batch normalization in Conv2D and Dense layers as used in the paper.
The code also saves tensorboard logs, plotting loss curves, mean accuracies and the evolution of the unsupervised learning weight and learning rates. In the case of the temporal ensembling the histograms of the temporal ensembling predictions and the normalized training targets are also saved in tensorboard.
Important Notes
X_labeled_train, y_labeled_train, labeled_indexes = train_labeled_iterator.get_next()
X_unlabeled_train, _, unlabeled_indexes = train_unlabeled_iterator.get_next()
this is only relevant to the temporal ensembling case.
rampup_value = ramp_up_function(epoch, 40)
If you find any bug or have a suggestion feel free to send me an email or create an issue in the repository!
I would like to give credit to some repositories that I found while reading the paper that helped me in my implementation.
[rampup]: http://chart.apis.google.com/chart?cht=tx&chl=w(t)=exp(-5(1-(\frac{epoch}{80})^2) [beta1]: http://chart.apis.google.com/chart?cht=tx&chl=\beta_1 [beta2]: http://chart.apis.google.com/chart?cht=tx&chl=\beta_2