The official implementation of [CVPR2022] Decoupled Knowledge Distillation https://arxiv.org/abs/2203.08679 and [ICCV2023] DOT: A Distillation-Oriented Trainer https://openaccess.thecvf.com/content/ICCV2023/papers/Zhao_DOT_A_Distillation-Oriented_Trainer_ICCV_2023_paper.pdf
This repo is
(1) a PyTorch library that provides classical knowledge distillation algorithms on mainstream CV benchmarks,
(2) the official implementation of the CVPR-2022 paper: Decoupled Knowledge Distillation.
(3) the official implementation of the ICCV-2023 paper: DOT: A Distillation-Oriented Trainer.
On CIFAR-100:
Teacher Student |
ResNet32x4 ResNet8x4 |
VGG13 VGG8 |
ResNet32x4 ShuffleNet-V2 |
---|---|---|---|
KD | 73.33 | 72.98 | 74.45 |
KD+DOT | 75.12 | 73.77 | 75.55 |
On Tiny-ImageNet:
Teacher Student |
ResNet18 MobileNet-V2 |
ResNet18 ShuffleNet-V2 |
---|---|---|
KD | 58.35 | 62.26 |
KD+DOT | 64.01 | 65.75 |
On ImageNet:
Teacher Student |
ResNet34 ResNet18 |
ResNet50 MobileNet-V1 |
---|---|---|
KD | 71.03 | 70.50 |
KD+DOT | 71.72 | 73.09 |
On CIFAR-100:
Teacher Student |
ResNet56 ResNet20 |
ResNet110 ResNet32 |
ResNet32x4 ResNet8x4 |
WRN-40-2 WRN-16-2 |
WRN-40-2 WRN-40-1 |
VGG13 VGG8 |
---|---|---|---|---|---|---|
KD | 70.66 | 73.08 | 73.33 | 74.92 | 73.54 | 72.98 |
DKD | 71.97 | 74.11 | 76.32 | 76.23 | 74.81 | 74.68 |
Teacher Student |
ResNet32x4 ShuffleNet-V1 |
WRN-40-2 ShuffleNet-V1 |
VGG13 MobileNet-V2 |
ResNet50 MobileNet-V2 |
ResNet32x4 MobileNet-V2 |
---|---|---|---|---|---|
KD | 74.07 | 74.83 | 67.37 | 67.35 | 74.45 |
DKD | 76.45 | 76.70 | 69.71 | 70.35 | 77.07 |
On ImageNet:
Teacher Student |
ResNet34 ResNet18 |
ResNet50 MobileNet-V1 |
---|---|---|
KD | 71.03 | 70.50 |
DKD | 71.70 | 72.05 |
MDistiller supports the following distillation methods on CIFAR-100, ImageNet and MS-COCO:
Method | Paper Link | CIFAR-100 | ImageNet | MS-COCO |
---|---|---|---|---|
KD | https://arxiv.org/abs/1503.02531 | ✓ | ✓ | |
FitNet | https://arxiv.org/abs/1412.6550 | ✓ | ||
AT | https://arxiv.org/abs/1612.03928 | ✓ | ✓ | |
NST | https://arxiv.org/abs/1707.01219 | ✓ | ||
PKT | https://arxiv.org/abs/1803.10837 | ✓ | ||
KDSVD | https://arxiv.org/abs/1807.06819 | ✓ | ||
OFD | https://arxiv.org/abs/1904.01866 | ✓ | ✓ | |
RKD | https://arxiv.org/abs/1904.05068 | ✓ | ||
VID | https://arxiv.org/abs/1904.05835 | ✓ | ||
SP | https://arxiv.org/abs/1907.09682 | ✓ | ||
CRD | https://arxiv.org/abs/1910.10699 | ✓ | ✓ | |
ReviewKD | https://arxiv.org/abs/2104.09044 | ✓ | ✓ | ✓ |
DKD | https://arxiv.org/abs/2203.08679 | ✓ | ✓ | ✓ |
Environments:
Install the package:
sudo pip3 install -r requirements.txt
sudo python3 setup.py develop
CFG.LOG.WANDB
as False
at mdistiller/engine/cfg.py
.You can evaluate the performance of our models or models trained by yourself.
Our models are at https://github.com/megvii-research/mdistiller/releases/tag/checkpoints, please download the checkpoints to ./download_ckpts
If test the models on ImageNet, please download the dataset at https://image-net.org/ and put them to ./data/imagenet
# evaluate teachers
python3 tools/eval.py -m resnet32x4 # resnet32x4 on cifar100
python3 tools/eval.py -m ResNet34 -d imagenet # ResNet34 on imagenet
# evaluate students
python3 tools/eval.p -m resnet8x4 -c download_ckpts/dkd_resnet8x4 # dkd-resnet8x4 on cifar100
python3 tools/eval.p -m MobileNetV1 -c download_ckpts/imgnet_dkd_mv1 -d imagenet # dkd-mv1 on imagenet
python3 tools/eval.p -m model_name -c output/your_exp/student_best # your checkpoints
Download the cifar_teachers.tar
at https://github.com/megvii-research/mdistiller/releases/tag/checkpoints and untar it to ./download_ckpts
via tar xvf cifar_teachers.tar
.
# for instance, our DKD method.
python3 tools/train.py --cfg configs/cifar100/dkd/res32x4_res8x4.yaml
# you can also change settings at command line
python3 tools/train.py --cfg configs/cifar100/dkd/res32x4_res8x4.yaml SOLVER.BATCH_SIZE 128 SOLVER.LR 0.1
Download the dataset at https://image-net.org/ and put them to ./data/imagenet
# for instance, our DKD method.
python3 tools/train.py --cfg configs/imagenet/r34_r18/dkd.yaml
mdistiller/distillers/
and define the distillerfrom ._base import Distiller
class MyDistiller(Distiller):
def __init__(self, student, teacher, cfg):
super(MyDistiller, self).__init__(student, teacher)
self.hyper1 = cfg.MyDistiller.hyper1
...
def forward_train(self, image, target, **kwargs):
# return the output logits and a Dict of losses
...
# rewrite the get_learnable_parameters function if there are more nn modules for distillation.
# rewrite the get_extra_parameters if you want to obtain the extra cost.
...
regist the distiller in distiller_dict
at mdistiller/distillers/__init__.py
regist the corresponding hyper-parameters at mdistiller/engines/cfg.py
create a new config file and test it.
If this repo is helpful for your research, please consider citing the paper:
@article{zhao2022dkd,
title={Decoupled Knowledge Distillation},
author={Zhao, Borui and Cui, Quan and Song, Renjie and Qiu, Yiyu and Liang, Jiajun},
journal={arXiv preprint arXiv:2203.08679},
year={2022}
}
@article{zhao2023dot,
title={DOT: A Distillation-Oriented Trainer},
author={Zhao, Borui and Cui, Quan and Song, Renjie and Liang, Jiajun},
journal={arXiv preprint arXiv:2307.08436},
year={2023}
}
MDistiller is released under the MIT license. See LICENSE for details.
Thanks for CRD and ReviewKD. We build this library based on the CRD's codebase and the ReviewKD's codebase.
Thanks Yiyu Qiu and Yi Shi for the code contribution during their internship in MEGVII Technology.
Thanks Xin Jin for the discussion about DKD.