We got a stew going!
The trend in normalizing flow (NF) literature has been to devise deeper, more complex transformations to achieve greater flexibility.
We propose an alternative: Gradient Boosted Normalizing Flows (GBNF) model a density by successively adding new NF components with gradient boosting. Under the boosting framework, each new NF component optimizes a sample weighted likelihood objective, resulting in new components that are fit to the residuals of the previously trained components.
The GBNF formulation results in a mixture model structure, whose flexibility increases as more components are added. Moreover, GBNFs offer a wider, as opposed to strictly deeper, approach that improves existing NFs at the cost of additional training---not more complex transformations.
Link to paper:
Gradient Boosted Normalizing Flows by Robert Giaquinto and Arindam Banerjee. In Advances in Neural Information Processing Systems (NeurIPS), 2020.
The code is compatible with:
pytorch 1.1.0
python 3.6+
(should work fine with python 2.7 though if you include print_function)It is recommended that you create a virtual environment with the correct python version and dependencies. After cloning the repository, change directories and run the following codes to create a virtual environment:
python -m venv ./venv
source ./venv/bin/activate
pip install --upgrade pip
pip install -r requirements.txt
(code assumes python
refers to python 3.6+, if not use python3
)
The experiments can be run on the following images datasets:
Additionally, density estimation experiments can be run on datasets from the UCI repository, which can be downloaded by:
./download_datasets.sh
The scripts folder includes examples for running the GBNF model on the Caltech 101 Silhouettes dataset and a density estimation experiment.
Toy problem: match 2-moons energy function with Boosted Real-NVPs
./scripts/getting_started_toy_matching_gbnf.sh &
Toy problem: density estimation on the 8-Gaussians with Boosted Real-NVPs
./scripts/getting_started_toy_estimation_gbnf.sh &
Density estimation of MINIBOONE dataset with Boosted Glow
./scripts/getting_started_density_estimation_gbnf.sh &
Generative modeling of Caltech 101 Silhouettes images with Boosted Real-NVPs
./scripts/getting_started_vae_gbnf.sh &