Code for CRATE (Coding RAte reduction TransformEr).
This repository is the official PyTorch implementation of the papers:
Also, we have released a larger journal-length overview paper of this line of research, which contains a superset of all the results presented above, and also more results in NLP and vision SSL.
CRATE (Coding RAte reduction TransformEr) is a white-box (mathematically interpretable) transformer architecture, where each layer performs a single step of an alternating minimization algorithm to optimize the sparse rate reduction objective
where $R$ and $R^{c}$ are different coding rates for the input representations w.r.t.~different codebooks, and the $\ell^{0}$-norm promotes the sparsity of the final token representations $\boldsymbol{Z} = f(\boldsymbol{X})$. The function $f$ is defined as $$f=f^{L} \circ f^{L-1} \circ \cdots \circ f^{1} \circ f^{\mathrm{pre}},$$ where $f^{\mathrm{pre}}$ is the pre-processing mapping, and $f^{\ell}$ is the $\ell$-th layer forward mapping that transforms the token distribution to optimize the above sparse rate reduction objective incrementally. More specifically, $f^{\ell}$ transforms the $\ell$-th layer token representations $\boldsymbol{Z}^{\ell}$ to $\boldsymbol{Z}^{\ell+1}$ via the $\texttt{MSSA}$ (Multi-Head Subspace Self-Attention) block and the $\texttt{ISTA}$ (Iterative Shrinkage-Thresholding Algorithms) block, i.e., $$\boldsymbol{Z}^{\ell+1} = f^{\ell}(\boldsymbol{Z}^{\ell}) = \texttt{ISTA}(\boldsymbol{Z}^{\ell} + \texttt{MSSA}(\boldsymbol{Z}^{\ell})).$$
The following figure presents an overview of the pipeline for our proposed CRATE architecture:
The following figure shows the overall architecture of one layer of CRATE as the composition of $\texttt{MSSA}$ and $\texttt{ISTA}$ blocks.
In the following figure, we measure the compression term [ $R^{c}$ ($\boldsymbol{Z}^{\ell+1/2}$) ] and the sparsity term [ $||\boldsymbol{Z}^{\ell+1}||_0$ ] defined in the sparse rate reduction objective, and we find that each layer of CRATE indeed optimizes the targeted objectives, showing that our white-box theoretical design is predictive of practice.
In the following figure, we visualize self-attention maps from a supervised CRATE model with 8x8 patches (similar to the ones shown in DINO :t-rex:).
We also discover a surprising empirical phenomenon where each attention head in CRATE retains its own semantics.
We can also use our theory to build a principled autoencoder, which has the following architecture.
It has many of the same empirical properties as the base CRATE model, such as segmented attention maps and amenability to layer-wise analysis. We train it on the masked autoencoding task (calling this model CRATE-MAE), and it achieves comparable performance in linear probing and reconstruction quality as the base ViT-MAE.
A CRATE model can be defined using the following code, (the below parameters are specified for CRATE-Tiny)
from model.crate import CRATE
dim = 384
n_heads = 6
depth = 12
model = CRATE(image_size=224,
patch_size=16,
num_classes=1000,
dim=dim,
depth=depth,
heads=n_heads,
dim_head=dim // n_heads)
model | dim |
n_heads |
depth |
pre-trained checkpoint |
---|---|---|---|---|
CRATE-T(iny) | 384 | 6 | 12 | TODO |
CRATE-S(mall) | 576 | 12 | 12 | download link |
CRATE-B(ase) | 768 | 12 | 12 | TODO |
CRATE-L(arge) | 1024 | 16 | 24 | TODO |
To train a CRATE model on ImageNet-1K, run the following script (training CRATE-tiny)
As an example, we use the following command for training CRATE-tiny on ImageNet-1K:
python main.py
--arch CRATE_tiny
--batch-size 512
--epochs 200
--optimizer Lion
--lr 0.0002
--weight-decay 0.05
--print-freq 25
--data DATA_DIR
and replace DATA_DIR
with [imagenet-folder with train and val folders]
.
python finetune.py
--bs 256
--net CRATE_tiny
--opt adamW
--lr 5e-5
--n_epochs 200
--randomaug 1
--data cifar10
--ckpt_dir CKPT_DIR
--data_dir DATA_DIR
Replace CKPT_DIR
with the path for the pretrained CRATE weight, and replace DATA_DIR
with the path for the CIFAR10
dataset. If CKPT_DIR
is None
, then this script is for training CRATE from random initialization on CIFAR10.
CRATE models exhibit emergent segmentation in their self-attention maps solely through supervised training. We provide a Colab Jupyter notebook to visualize the emerged segmentations from a supervised CRATE model. The demo provides visualizations which match the segmentation figures above.
Link: crate-emergence.ipynb (in colab)
A CRATE-autoencoding model (specifically CRATE-MAE-Base) can be defined using the following code:
from model.crate_ae.crate_ae import mae_crate_base
model = mae_crate_base()
The other sizes in the paper are also importable in that way. Modifying the model/crate_ae/crate_ae.py
file will let you initialize and serve your own config.
model | dim |
n_heads |
depth |
pre-trained checkpoint |
---|---|---|---|---|
CRATE-MAE-S(mall) | 576 | 12 | 12 | TODO |
CRATE-MAE-B(ase) | 768 | 12 | 12 | link |
To train or fine-tune a CRATE-MAE model on ImageNet-1K, please refer to the codebase on MAE training from Meta FAIR. The models_mae.py
file in that codebase can be replaced with the contents of model/crate_ae/crate_ae.py
, and the rest of the code should go through with minimal alterations.
CRATE-MAE models also exhibit emergent segmentation in their self-attention maps. We provide a Colab Jupyter notebook to visualize the emerged segmentations from a CRATE-MAE model. The demo provides visualizations which match the segmentation figures above.
Link: crate-mae.ipynb (in colab)
For technical details and full experimental results, please check the CRATE paper, CRATE segmentation paper, CRATE autoencoding paper, or the long-form overview paper. Please consider citing our work if you find it helpful to yours:
@article{yu2024white,
title={White-Box Transformers via Sparse Rate Reduction},
author={Yu, Yaodong and Buchanan, Sam and Pai, Druv and Chu, Tianzhe and Wu, Ziyang and Tong, Shengbang and Haeffele, Benjamin and Ma, Yi},
journal={Advances in Neural Information Processing Systems},
volume={36},
year={2024}
}
@inproceedings{yu2024emergence,
title={Emergence of Segmentation with Minimalistic White-Box Transformers},
author={Yu, Yaodong and Chu, Tianzhe and Tong, Shengbang and Wu, Ziyang and Pai, Druv and Buchanan, Sam and Ma, Yi},
booktitle={Conference on Parsimony and Learning},
pages={72--93},
year={2024},
organization={PMLR}
}
@inproceedings{pai2024masked,
title={Masked Completion via Structured Diffusion with White-Box Transformers},
author={Pai, Druv and Buchanan, Sam and Wu, Ziyang and Yu, Yaodong and Ma, Yi},
booktitle={The Twelfth International Conference on Learning Representations},
year={2024}
}
@article{yu2023white,
title={White-Box Transformers via Sparse Rate Reduction: Compression Is All There Is?},
author={Yu, Yaodong and Buchanan, Sam and Pai, Druv and Chu, Tianzhe and Wu, Ziyang and Tong, Shengbang and Bai, Hao and Zhai, Yuexiang and Haeffele, Benjamin D and Ma, Yi},
journal={arXiv preprint arXiv:2311.13110},
year={2023}
}