PyTorch image models, scripts, pretrained weights -- ResNet, ResNeXT, EfficientNet, NFNet, Vision Transformer (ViT), MobileNet-V3/V2, RegNet, DPN, CSPNet, Swin Transformer, MaxViT, CoAtNet, ConvNeXt, and more
❗Updates after Oct 10, 2022 are available in version >= 0.9❗
timm.models.layers
moved to timm.layers
:
from timm.models.layers import name
will still work via deprecation mapping (but please transition to timm.layers
).import timm.models.layers.module
or from timm.models.layers.module import name
needs to be changed now.timm.models
have a _
prefix added, ie timm.models.helpers
-> timm.models._helpers
, there are temporary deprecation mapping files but those will be removed.architecture.pretrained_tag
naming (ex resnet50.rsb_a1
).
architecture
defaults to the first weights in the default_cfgs for that model architecture.vit_base_patch16_224_in21k
-> vit_base_patch16_224.augreg_in21k
). There are deprecation mappings for these.features_only=True
, there are checkpoint_filter_fn
methods in any model module that was remapped. These can be passed to timm.models.load_checkpoint(..., filter_fn=timm.models.swin_transformer_v2.checkpoint_filter_fn)
to remap your existing checkpoint.timm
weights. Model cards include link to papers, original source, license.features_only=True
support for ViT models with flat hidden states or non-std module layouts (so far covering 'vit_*', 'twins_*', 'deit*', 'beit*', 'mvitv2*', 'eva*', 'samvit_*', 'flexivit*'
)forward_intermediates()
API that can be used with a feature wrapping module or direclty.model = timm.create_model('vit_base_patch16_224')
final_feat, intermediates = model.forward_intermediates(input)
output = model.forward_head(final_feat) # pooling + classifier head
print(final_feat.shape)
torch.Size([2, 197, 768])
for f in intermediates:
print(f.shape)
torch.Size([2, 768, 14, 14])
torch.Size([2, 768, 14, 14])
torch.Size([2, 768, 14, 14])
torch.Size([2, 768, 14, 14])
torch.Size([2, 768, 14, 14])
torch.Size([2, 768, 14, 14])
torch.Size([2, 768, 14, 14])
torch.Size([2, 768, 14, 14])
torch.Size([2, 768, 14, 14])
torch.Size([2, 768, 14, 14])
torch.Size([2, 768, 14, 14])
torch.Size([2, 768, 14, 14])
print(output.shape)
torch.Size([2, 1000])
model = timm.create_model('eva02_base_patch16_clip_224', pretrained=True, img_size=512, features_only=True, out_indices=(-3, -2,))
output = model(torch.randn(2, 3, 512, 512))
for o in output:
print(o.shape)
torch.Size([2, 768, 32, 32])
torch.Size([2, 768, 32, 32])
Datasets & transform refactoring
--dataset hfids:org/dataset
)datasets
and webdataset wrapper streaming from HF hub with recent timm
ImageNet uploads to https://huggingface.co/timm
--input-size 1 224 224
or --in-chans 1
, sets PIL image conversion appropriately in dataset--val-split ''
) in train script--bce-sum
(sum over class dim) and --bce-pos-weight
(positive weighting) args for training as they're common BCE loss tweaks I was often hard codingmodel_args
config entry. model_args
will be passed as kwargs through to models on creation.
vision_transformer.py
typing and doc cleanup by Laureηt
quickgelu
ViT variants for OpenAI, DFN, MetaCLIP weights that use it (less efficient)convnext_xxlarge
vision_transformer.py
.
vision_transformer.py
, vision_transformer_hybrid.py
, deit.py
, and eva.py
w/o breaking backward compat.
dynamic_img_size=True
to args at model creation time to allow changing the grid size (interpolate abs and/or ROPE pos embed each forward pass).dynamic_img_pad=True
to allow image sizes that aren't divisible by patch size (pad bottom right to patch size each forward pass).img_size
(interpolate pretrained embed weights once) on creation still works.patch_size
(resize pretrained patch_embed weights once) on creation still works.python validate.py /imagenet --model vit_base_patch16_224 --amp --amp-dtype bfloat16 --img-size 255 --crop-pct 1.0 --model-kwargs dynamic_img_size=True dyamic_img_pad=True
--reparam
arg to benchmark.py
, onnx_export.py
, and validate.py
to trigger layer reparameterization / fusion for models with any one of reparameterize()
, switch_to_deploy()
or fuse()
python validate.py /imagenet --model swin_base_patch4_window7_224.ms_in22k_ft_in1k --amp --amp-dtype bfloat16 --input-size 3 256 320 --model-kwargs window_size=8,10 img_size=256,320
selecsls*
model naming regressionseresnextaa201d_32x8d.sw_in12k_ft_in1k_384
weights (and .sw_in12k
pretrain) with 87.3% top-1 on ImageNet-1k, best ImageNet ResNet family model I'm aware of.timm
0.9 released, transition from 0.8.xdev releasestimm
get_intermediate_layers
function on vit/deit models for grabbing hidden states (inspired by DINO impl). This is WIP and may change significantly... feedback welcome.pretrained=True
and no weights exist (instead of continuing with random initialization)bnb
prefix, ie bnbadam8bit
timm
out of pre-release statetimm
models uploaded to HF Hub and almost all updated to support multi-weight pretrained configs--grad-accum-steps
), thanks Taeksang Kim
--head-init-scale
and --head-init-bias
to train.py to scale classiifer head and set fixed bias for fine-tuneinplace_abn
) use, replaced use in tresnet with standard BatchNorm (modified weights accordingly).drop_rate
(classifier dropout), proj_drop_rate
(block mlp / out projections), pos_drop_rate
(position embedding drop), attn_drop_rate
(attention dropout). Also add patch dropout (FLIP) to vit and eva models.timm
trained weights added with recipe based tags to differentiateresnetaa50d.sw_in12k_ft_in1k
- 81.7 @ 224, 82.6 @ 288resnetaa101d.sw_in12k_ft_in1k
- 83.5 @ 224, 84.1 @ 288seresnextaa101d_32x8d.sw_in12k_ft_in1k
- 86.0 @ 224, 86.5 @ 288seresnextaa101d_32x8d.sw_in12k_ft_in1k_288
- 86.5 @ 288, 86.7 @ 320model | top1 | top5 | img_size | param_count | gmacs | macts |
---|---|---|---|---|---|---|
convnext_xxlarge.clip_laion2b_soup_ft_in1k | 88.612 | 98.704 | 256 | 846.47 | 198.09 | 124.45 |
convnext_large_mlp.clip_laion2b_soup_ft_in12k_in1k_384 | 88.312 | 98.578 | 384 | 200.13 | 101.11 | 126.74 |
convnext_large_mlp.clip_laion2b_soup_ft_in12k_in1k_320 | 87.968 | 98.47 | 320 | 200.13 | 70.21 | 88.02 |
convnext_base.clip_laion2b_augreg_ft_in12k_in1k_384 | 87.138 | 98.212 | 384 | 88.59 | 45.21 | 84.49 |
convnext_base.clip_laion2b_augreg_ft_in12k_in1k | 86.344 | 97.97 | 256 | 88.59 | 20.09 | 37.55 |
model | top1 | top5 | param_count | img_size |
---|---|---|---|---|
eva02_large_patch14_448.mim_m38m_ft_in22k_in1k | 90.054 | 99.042 | 305.08 | 448 |
eva02_large_patch14_448.mim_in22k_ft_in22k_in1k | 89.946 | 99.01 | 305.08 | 448 |
eva_giant_patch14_560.m30m_ft_in22k_in1k | 89.792 | 98.992 | 1014.45 | 560 |
eva02_large_patch14_448.mim_in22k_ft_in1k | 89.626 | 98.954 | 305.08 | 448 |
eva02_large_patch14_448.mim_m38m_ft_in1k | 89.57 | 98.918 | 305.08 | 448 |
eva_giant_patch14_336.m30m_ft_in22k_in1k | 89.56 | 98.956 | 1013.01 | 336 |
eva_giant_patch14_336.clip_ft_in1k | 89.466 | 98.82 | 1013.01 | 336 |
eva_large_patch14_336.in22k_ft_in22k_in1k | 89.214 | 98.854 | 304.53 | 336 |
eva_giant_patch14_224.clip_ft_in1k | 88.882 | 98.678 | 1012.56 | 224 |
eva02_base_patch14_448.mim_in22k_ft_in22k_in1k | 88.692 | 98.722 | 87.12 | 448 |
eva_large_patch14_336.in22k_ft_in1k | 88.652 | 98.722 | 304.53 | 336 |
eva_large_patch14_196.in22k_ft_in22k_in1k | 88.592 | 98.656 | 304.14 | 196 |
eva02_base_patch14_448.mim_in22k_ft_in1k | 88.23 | 98.564 | 87.12 | 448 |
eva_large_patch14_196.in22k_ft_in1k | 87.934 | 98.504 | 304.14 | 196 |
eva02_small_patch14_336.mim_in22k_ft_in1k | 85.74 | 97.614 | 22.13 | 336 |
eva02_tiny_patch14_336.mim_in22k_ft_in1k | 80.658 | 95.524 | 5.76 | 336 |
regnet.py
, rexnet.py
, byobnet.py
, resnetv2.py
, swin_transformer.py
, swin_transformer_v2.py
, swin_transformer_v2_cr.py
swinv2_cr_*
, and NHWC for all others) and spatial embedding outputs.timm
weights:
rexnetr_200.sw_in12k_ft_in1k
- 82.6 @ 224, 83.2 @ 288rexnetr_300.sw_in12k_ft_in1k
- 84.0 @ 224, 84.5 @ 288regnety_120.sw_in12k_ft_in1k
- 85.0 @ 224, 85.4 @ 288regnety_160.lion_in12k_ft_in1k
- 85.6 @ 224, 86.0 @ 288regnety_160.sw_in12k_ft_in1k
- 85.6 @ 224, 86.0 @ 288 (compare to SWAG PT + 1k FT this is same BUT much lower res, blows SEER FT away)convnext_xxlarge
default LayerNorm eps to 1e-5 (for CLIP weights, improved stability)convnext_large_mlp.clip_laion2b_ft_320
and convnext_lage_mlp.clip_laion2b_ft_soup_320
CLIP image tower weights for features & fine-tunesafetensor
checkpoint support addedvit_*
, vit_relpos*
, coatnet
/ maxxvit
(to start)features_only=True
PyTorch Image Models (timm
) is a collection of image models, layers, utilities, optimizers, schedulers, data-loaders / augmentations, and reference training / validation scripts that aim to pull together a wide variety of SOTA models with ability to reproduce ImageNet training results.
The work of many others is present here. I've tried to make sure all source material is acknowledged via links to github, arxiv papers, etc in the README, documentation, and code docstrings. Please let me know if I missed anything.
All model architecture families include variants with pretrained weights. There are specific model variants without any weights, it is NOT a bug. Help training new or better weights is always appreciated.
Included optimizers available via create_optimizer
/ create_optimizer_v2
factory methods:
adabelief
an implementation of AdaBelief adapted from https://github.com/juntang-zhuang/Adabelief-Optimizer - https://arxiv.org/abs/2010.07468
adafactor
adapted from FAIRSeq impl - https://arxiv.org/abs/1804.04235
adahessian
by David Samuel - https://arxiv.org/abs/2006.00719
adamp
and sgdp
by Naver ClovAI - https://arxiv.org/abs/2006.08217
adan
an implementation of Adan adapted from https://github.com/sail-sg/Adan - https://arxiv.org/abs/2208.06677
lamb
an implementation of Lamb and LambC (w/ trust-clipping) cleaned up and modified to support use with XLA - https://arxiv.org/abs/1904.00962
lars
an implementation of LARS and LARC (w/ trust-clipping) - https://arxiv.org/abs/1708.03888
lion
and implementation of Lion adapted from https://github.com/google/automl/tree/master/lion - https://arxiv.org/abs/2302.06675
lookahead
adapted from impl by Liam - https://arxiv.org/abs/1907.08610
madgrad
- and implementation of MADGRAD adapted from https://github.com/facebookresearch/madgrad - https://arxiv.org/abs/2101.11075
nadam
an implementation of Adam w/ Nesterov momentumnadamw
an impementation of AdamW (Adam w/ decoupled weight-decay) w/ Nesterov momentum. A simplified impl based on https://github.com/mlcommons/algorithmic-efficiency
novograd
by Masashi Kimura - https://arxiv.org/abs/1905.11286
radam
by Liyuan Liu - https://arxiv.org/abs/1908.03265
rmsprop_tf
adapted from PyTorch RMSProp by myself. Reproduces much improved Tensorflow RMSProp behavioursgdw
and implementation of SGD w/ decoupled weight-decayfused<name>
optimizers by name with NVIDIA Apex installedbits<name>
optimizers by name with BitsAndBytes installedSeveral (less common) features that I often utilize in my projects are included. Many of their additions are the reason why I maintain my own set of models, instead of using others' via PIP:
get_classifier
and reset_classifier
forward_features
(see documentation)create_model(name, features_only=True, out_indices=..., output_stride=...)
out_indices
creation arg specifies which feature maps to return, these indices are 0 based and generally correspond to the C(i + 1)
feature level.output_stride
creation arg controls output stride of the network by using dilated convolutions. Most networks are stride 32 by default. Not all networks support this..feature_info
memberstep
, cosine
w/ restarts, tanh
w/ restarts, plateau
Model validation results can be found in the results tables
The official documentation can be found at https://huggingface.co/docs/hub/timm. Documentation contributions are welcome.
Getting Started with PyTorch Image Models (timm): A Practitioner’s Guide by Chris Hughes is an extensive blog post covering many aspects of timm
in detail.
timmdocs is an alternate set of documentation for timm
. A big thanks to Aman Arora for his efforts creating timmdocs.
paperswithcode is a good resource for browsing the models within timm
.
The root folder of the repository contains reference train, validation, and inference scripts that work with the included models and other features of this repository. They are adaptable for other datasets and use cases with a little hacking. See documentation.
One of the greatest assets of PyTorch is the community and their contributions. A few of my favourite resources that pair well with the models and components here are listed below.
The code here is licensed Apache 2.0. I've taken care to make sure any third party code included or adapted has compatible (permissive) licenses such as MIT, BSD, etc. I've made an effort to avoid any GPL / LGPL conflicts. That said, it is your responsibility to ensure you comply with licenses here and conditions of any dependent licenses. Where applicable, I've linked the sources/references for various components in docstrings. If you think I've missed anything please create an issue.
So far all of the pretrained weights available here are pretrained on ImageNet with a select few that have some additional pretraining (see extra note below). ImageNet was released for non-commercial research purposes only (https://image-net.org/download). It's not clear what the implications of that are for the use of pretrained weights from that dataset. Any models I have trained with ImageNet are done for research purposes and one should assume that the original dataset license applies to the weights. It's best to seek legal advice if you intend to use the pretrained weights in a commercial product.
Several weights included or references here were pretrained with proprietary datasets that I do not have access to. These include the Facebook WSL, SSL, SWSL ResNe(Xt) and the Google Noisy Student EfficientNet models. The Facebook models have an explicit non-commercial license (CC-BY-NC 4.0, https://github.com/facebookresearch/semi-supervised-ImageNet1K-models, https://github.com/facebookresearch/WSL-Images). The Google models do not appear to have any restriction beyond the Apache 2.0 license (and ImageNet concerns). In either case, you should contact Facebook or Google with any questions.
@misc{rw2019timm,
author = {Ross Wightman},
title = {PyTorch Image Models},
year = {2019},
publisher = {GitHub},
journal = {GitHub repository},
doi = {10.5281/zenodo.4414861},
howpublished = {\url{https://github.com/rwightman/pytorch-image-models}}
}