[ICLR'23 Spotlight🔥] The first successful BERT/MAE-style pretraining on any convolutional network; Pytorch impl. of "Designing BERT for Convolutional Networks: Sparse and Hierarchical Masked Modeling"
This is the official implementation of ICLR paper Designing BERT for Convolutional Networks: Sparse and Hierarchical Masked Modeling, which can pretrain any CNN (e.g., ResNet) in a BERT-style self-supervised manner. We've tried our best to make the codebase clean, short, easy to read, state-of-the-art, and only rely on minimal dependencies.
📹Recorded Video, Poster, and Slides
].📹Recorded Video
]📹Recorded Video
]📹Recorded Video
]📹Recorded Video
]Synced
]
[DeepAI
]
[TheGradient
]
[Bytedance
]
[CVers
[QbitAI(量子位)
]
[BAAI(智源)
]
[机器之心机动组
]
[极市平台
]
[ReadPaper笔记
]Check pretrain/viz_reconstruction.ipynb for visualizing the reconstruction of SparK pretrained models, like:
We also provide pretrain/viz_spconv.ipynb that shows the "mask pattern vanishing" issue of dense conv layers.
huggingface
timm
Note: for network definitions, we directly use timm.models.ResNet
and official ConvNeXt.
reso.
: the image resolution; acc@1
: ImageNet-1K finetuned acc (top-1)
arch. | reso. | acc@1 | #params | flops | weights (self-supervised, without SparK's decoder) |
---|---|---|---|---|---|
ResNet50 | 224 | 80.6 | 26M | 4.1G | resnet50_1kpretrained_timm_style.pth |
ResNet101 | 224 | 82.2 | 45M | 7.9G | resnet101_1kpretrained_timm_style.pth |
ResNet152 | 224 | 82.7 | 60M | 11.6G | resnet152_1kpretrained_timm_style.pth |
ResNet200 | 224 | 83.1 | 65M | 15.1G | resnet200_1kpretrained_timm_style.pth |
ConvNeXt-S | 224 | 84.1 | 50M | 8.7G | convnextS_1kpretrained_official_style.pth |
ConvNeXt-B | 224 | 84.8 | 89M | 15.4G | convnextB_1kpretrained_official_style.pth |
ConvNeXt-L | 224 | 85.4 | 198M | 34.4G | convnextL_1kpretrained_official_style.pth |
ConvNeXt-L | 384 | 86.0 | 198M | 101.0G | convnextL_384_1kpretrained_official_style.pth |
arch. | reso. | acc@1 | #params | flops | weights (self-supervised, with SparK's decoder) |
---|---|---|---|---|---|
ResNet50 | 224 | 80.6 | 26M | 4.1G | res50_withdecoder_1kpretrained_spark_style.pth |
ResNet101 | 224 | 82.2 | 45M | 7.9G | res101_withdecoder_1kpretrained_spark_style.pth |
ResNet152 | 224 | 82.7 | 60M | 11.6G | res152_withdecoder_1kpretrained_spark_style.pth |
ResNet200 | 224 | 83.1 | 65M | 15.1G | res200_withdecoder_1kpretrained_spark_style.pth |
ConvNeXt-S | 224 | 84.1 | 50M | 8.7G | cnxS224_withdecoder_1kpretrained_spark_style.pth |
ConvNeXt-L | 384 | 86.0 | 198M | 101.0G | cnxL384_withdecoder_1kpretrained_spark_style.pth |
We highly recommended you to use torch==1.10.0
, torchvision==0.11.1
, and timm==0.5.4
for reproduction.
Check INSTALL.md to install all pip dependencies.
# download our weights `resnet50_1kpretrained_timm_style.pth` first
import torch, timm
res50, state = timm.create_model('resnet50'), torch.load('resnet50_1kpretrained_timm_style.pth', 'cpu')
res50.load_state_dict(state.get('module', state), strict=False) # just in case the model weights are actually saved in state['module']
Pretraining
Finetuning
We referred to these useful codebases:
This project is under the MIT license. See LICENSE for more details.
If you found this project useful, you can kindly give us a star ⭐, or cite us in your work 📖:
@Article{tian2023designing,
author = {Keyu Tian and Yi Jiang and Qishuai Diao and Chen Lin and Liwei Wang and Zehuan Yuan},
title = {Designing BERT for Convolutional Networks: Sparse and Hierarchical Masked Modeling},
journal = {arXiv:2301.03580},
year = {2023},
}