Library for the training and evaluation of object-centric models (ICML 2022)
Code accompanying our paper:
Generalization and Robustness Implications in Object-Centric Learning
Andrea Dittadi, Samuele Papa, Michele De Vita, Bernhard Schölkopf, Ole Winther, Francesco Locatello
ICML 2022
Summary of out-of-the-box functionalities (see Using the library):
The image below showcases the datasets (top row) and the distribution shifts on CLEVR (bottom row) that were used in the experimental study in our paper.
Visualizations of a few object-centric models trained in our study on the datasets shown above:
Example full visualization of a single trained model, including separate slot reconstructions:
Visualizations of a few object-centric models on the distribution shifts on CLEVR:
The library can be extended with more models, datasets, distribution shifts, evaluation metrics, and downstream tasks.
Compared to the original library used in our paper, the current version includes the ClevrTex dataset.
Install requirements from requirements.txt
.
Example installation with conda:
conda create --name object_centric_lib python=3.8
conda activate object_centric_lib
# Optionally install PyTorch with a custom CUDA version. Example:
# pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/cu113
pip install -r requirements.txt
Note: PyTorch might have to be installed separately before installing the requirements, depending on the required CUDA version (see the PyTorch installation instructions).
Python 3.8 recommended (≥3.8 required).
Set the environment variable OBJECT_CENTRIC_LIB_DATA
to the folder where the datasets should be stored.
For example, on Linux or MacOS, add the following line to ~/.bashrc
(or ~/.zshrc
,
depending on your shell):
export OBJECT_CENTRIC_LIB_DATA=/path/to/datasets
Then, restart the shell or run . ~/.bashrc
(or . ~/.zshrc
).
Download the datasets with download_data.py
.
# Download all datasets
python download_data.py -d all
# Download all datasets, including style transfer versions
python download_data.py -d all --include-style
# Download only some datasets, without style transfer
python download_data.py -d multidsprites clevr
Each dataset is a .hdf5
file and its metadata is in a corresponding
${DATASET_NAME}_metadata.npy
file.
Custom datasets may override these defaults.
Check the integrity of the dataset files by running python check_data.py
.
Train a model with default parameters:
python train_object_discovery.py model=monet dataset=multidsprites
This saves the model and the logs by default in outputs/runs/${MODEL}-${DATASET}-${DATETIME}
.
Resume training of a run, given the path to the root folder ${RUN_ROOT}
of the run:
python train_object_discovery.py model=monet dataset=multidsprites hydra.run.dir=${RUN_ROOT} allow_resume=true
Evaluate reconstruction and segmentation metrics, given ${RUN_ROOT}
(the path to the
root folder of the run):
python eval_metrics.py checkpoint_path=outputs/runs/monet-multidsprites-YYYY-MM-DD_HH-MM-SS
Run the downstream object property prediction task (training + evaluation):
python eval_downstream_prediction.py downstream_model=linear checkpoint_path=outputs/runs/monet-multidsprites-YYYY-MM-DD_HH-MM-SS
Save visualizations (reconstructions, masks, slot reconstructions):
python eval_qualitative.py checkpoint_path=outputs/runs/monet-multidsprites-YYYY-MM-DD_HH-MM-SS
All evaluation results are saved in ${RUN_ROOT}/evaluation
, e.g., outputs/runs/monet-multidsprites-YYYY-MM-DD_HH-MM-SS/evaluation
.
Currently, the library includes the following models:
genesis
monet
slot-attention
space
baseline_vae_mlp
baseline_vae_broadcast
and the following datasets:
clevr
(the original dataset has 10 objects: to train on CLEVR6, add +dataset.variant=6
to the command line)multidsprites
objects_room
shapestacks
tetrominoes
clevrtex
Read the following sections for further details.
python train_object_discovery.py model=${MODEL} dataset=${DATASET}
This command trains the specified model on the specified dataset, with default
parameters defined by the hydra configuration files in config/
.
The base config file for this script is config/train_object_discovery.yaml
.
The run folder is handled by hydra, and by default it is outputs/runs/${MODEL}-${DATASET}-${DATETIME}
.
This can be customized using hydra by adding, e.g., hydra.run.dir=outputs/runs/${model.name}-${dataset.name}
to the command line.
The model and dataset correspond to config files -- e.g., model=slot-attention
reads the model
config from config/model/slot-attention.yaml
and dataset=multidsprites
reads the
dataset config from config/dataset/multidsprites.yaml
.
In some cases we define custom parameters for specific combinations of dataset and model: these are
defined in the folder config/special_cases
.
Dataset variants can define dataset filters or transforms to test robustness to distribution shifts.
A variant is picked by adding +dataset.variant=${VARIANT}
to the command line: e.g.
CLEVR6 is dataset=clevr +dataset.variant=6
, and Tetrominoes with occlusions is
dataset=tetrominoes +dataset.variant=occlusion
.
For more information on dataset variants, see config/dataset/variants/readme.md
.
All models are configured through hydra, including the training setup. The default parameters are defined in the model's YAML file, and these can be overridden from the command line. E.g., we can change the foreground sigma, the MLP hidden size, and the learning rate schedule of MONet as follows:
python train_object_discovery.py model=monet dataset=shapestacks model.fg_sigma=0.15 model.encoder_params.mlp_hidden_size=128 trainer.exp_decay_rate=0.8
There are some common flags that can be used with every model and dataset:
batch_size
(default given by the model config).trainer.steps
: number of training steps (default given by the model config).data_sizes
: size of the train, validation, and test sets (defaults given by the dataset config).trainer.optimizer_config
: by default, the class, learning rate, and other parameters can be provided here
(see e.g. config/model/monet.yaml
). We can also implement a custom _make_optimizers()
method
that handles more complex settings, e.g., where we need multiple optimizers: see for example
config/model/space.yaml
and models/space/trainer.py
.trainer.clip_grad_norm
: float value for gradient norm clipping, or None
for no clipping.trainer.logweights_steps
, trainer.logimages_steps
, trainer.logloss_steps
,
trainer.checkpoint_steps
, trainer.logvalid_steps
.allow_resume
: if the directory of the run exists, this flag controls whether the script
loads an existing checkpoint and resumes training, or it throws an exception.num_workers
: for PyTorch data loaders.dataset.skip_loading
: dummy data is loaded instead of the specified dataset (for debugging).seed
: random seed.debug
: if true, it launches a minimal run.device
: cpu or cuda (default: cuda).python eval_metrics.py checkpoint_path=/path/to/run/folder
This command evaluates the reconstruction error (MSE) and 3 segmentation metrics (ARI, SC, mSC).
Typically no customization is necessary, but see config/eval_metrics.yaml
.
The variant_types
flag allows to evaluate the metrics on different variants of the original
training dataset: this is used by default to evaluate generalization
(see the list of default variants in config/eval_metrics.yaml
).
The overwrite
flag allows overwriting the result folder for this evaluation,
and is False
by default.
The seed
, debug
, and device
flags are also available here, with the same
behavior as in train_object_discovery.py
.
python eval_downstream_prediction.py checkpoint_path=/path/to/run/folder downstream_model=linear
This command trains and evaluates a downstream linear model to predict (from the representations
of the upstream model) the properties of the objects in a scene.
This is configured by config/eval_downstream_prediction.yaml
.
See the comments on the file for more information.
Note that a results subfolder is created specifically for each combination of
matching, downstream model, and dataset variant.
Typically useful flags (see the config file for more):
downstream_model
: the type of downstream model, such as linear
or MLP3
.matching
: method for matching objects with model slots.variant_types
: for each of the specified variant types, train a downstream model
and then test it on all variant types (including the one it was trained on).steps
batch_size
learning_rate
train_size
validation_size
test_size
The seed
, debug
, overwrite
, and device
flags are also available here,
with the same behavior as in eval_metrics.py
.
python eval_qualitative.py checkpoint_path=/path/to/run/folder
This command saves model visualizations, and typically does not require customization.
The seed
, debug
, overwrite
, and device
flags are also available here.
To run many experiments in a structured sweep over parameters and/or settings, the library has a "sweep" functionality.
For example, to train all object-centric models in the study in
our paper,
we defined a sweep in sweeps/configs/sweep_object_centric.py
.
This creates a sweep called "object_centric"
, which maps a model number to a specific
configuration of command line arguments.
The first model in the sweep is trained as follows:
python sweep_train.py --sweep-name object_centric --model-num 0
Since in this case we have 10 seeds, 4 models, and 5 datasets, any model number up to 199 would be valid.
This script internally calls train_object_discovery.py
with the appropriate arguments as prescribed by the sweep,
and uses outputs/sweeps/sweep_${SWEEP_NAME}/${MODEL_NUMBER}/
as output folder.
Use python -m sweeps.sweep_progress SWEEP_NAME
to get an overview of the overall progress of the sweep.
The library easily allows adding models, datasets, dataset variants, evaluation metrics, and downstream tasks. Feel free to reach out for questions at:
andrea [đöt] dittadi [åt] gmail [đöt] com
If you use this library in your own work, please consider citing our paper as follows:
@inproceedings{dittadi2022generalization,
title={Generalization and Robustness Implications in Object-Centric Learning},
author={Dittadi, Andrea and Papa, Samuele and De Vita, Michele and Sch{\"o}lkopf, Bernhard and Winther, Ole and Locatello, Francesco},
booktitle={International Conference on Machine Learning},
year={2022},
}
In a follow-up paper, we use this library to investigate inductive biases in unsupervised object-centric learning when the objects in the training set have complex textures:
Inductive Biases for Object-Centric Representations in the Presence of Complex Textures Samuele Papa, Ole Winther, Andrea Dittadi UAI workshop on Causal Representation Learning, 2022