A collection of SOTA Image Classification Models in PyTorch
Intended for easy to use and integrate SOTA image classification models into down-stream tasks and finetuning with custom datasets
2021
2022
Losses
Model | ImageNet-1k Top-1 Acc (%) |
Params (M) |
GFLOPs | Variants & Weights |
---|---|---|---|---|
MicroNet | 51.4| 59.4| 62.5 |
2| 2| 3 |
7M| 14M| 23M |
M1|M2|M3 |
ResNet* | 71.5| 80.4| 81.5 |
12| 26| 45 |
2| 4| 8 |
18|50|101 |
PoolFormer | 80.3| 81.4| 82.1 |
21| 31| 56 |
4| 5| 9 |
S24|S36|M36 |
WaveMLP | 80.9| 82.9| 83.3 |
17| 30| 44 |
2| 5| 8 |
T|S|M |
PVTv2 | 78.7| 82.0| 83.6 |
14| 25| 63 |
2| 4| 10 |
B1|B2|B4 |
ResT | 79.6| 81.6| 83.6 |
14| 30| 52 |
2| 4| 8 |
S|B|L |
UniFormer | -NA| 82.9| 83.8 |
--| 22| 50 |
-| 4| 8 |
-|S|B |
VAN | 75.4| 81.1| 82.8| 83.9 |
4| 14| 27| 45 |
1| 3| 5| 9 |
T|S|B|L |
ResTv2 | 82.3| 83.2| 83.7| 84.2 |
30| 41| 56| 87 |
4| 6| 8| 14 |
T|S|B|L |
FAN | 80.1| 83.5| 83.9| 84.3 |
7| 26| 50| 77 |
4| 7| 11| 17 |
T|S|B|L |
PatchConvnet | 82.1| 83.2| 83.5 |
25| 48| 99 |
4| 8| 16 |
S60|S120|B60 |
ConvNeXt | 82.1| 83.1| 83.8 |
28| 50| 89 |
5| 9| 15 |
T|S|B |
FocalNet | 82.3| 83.5| 83.9 |
29| 50| 89 |
5| 9| 15 |
T|S|B |
CSWin | 82.7| 83.6| 84.2 |
23| 35| 78 |
4| 7| 15 |
T|S|B |
NAT | 81.8| 83.2| 83.7| 84.3 |
20| 28| 51| 90 |
3`|4 |8 |`14 |
M|T|S|B |
DaViT | 82.8| 84.2| 84.6 |
28| 50| 88 |
5| 9| 16 |
T|S|B |
Notes: ResNet* is from "ResNet strikes back" paper.
Other requirements can be installed with pip install -r requirements.txt
.
$ python list_models.py
A table with model names and variants will be shown:
Supported Models
Model Names │ Model Variants
╶──────────────┼──────────────────────────────────╴
ResNet │ ['18', '34', '50', '101', '152']
MicroNet │ ['M1', 'M2', 'M3']
ConvNeXt │ ['T', 'S', 'B']
VAN │ ['S', 'B', 'L']
PVTv2 │ ['B1', 'B2', 'B3', 'B4', 'B5']
ResT │ ['S', 'B', 'L']
CSWin │ ['T', 'S', 'B', 'L']
WaveMLP │ ['T', 'S', 'M']
PoolFormer │ ['S24', 'S36', 'M36']
PatchConvnet │ ['S60', 'S120', 'B60']
UniFormer │ ['S', 'B']
FocalNet │ ['T', 'S', 'B']
# Example with VAN-S
$ python infer.py --source assests/dog.jpg --model VAN --variant S --checkpoint /path/to/van_s
You will see an output similar to this:
assests\dog.jpg >>>>> Golden retriever
Note: The above code is only for ImageNet pre-trained models. Modify the model's checkpoint loading and class names in
infer.py
for your custom needs.
You can use any dataset from torchvision.datasets. For custom datasets, ImageFolder can be used to create a dataset class.
In this repo, finetuning on CIFAR-10 is provided in finetune.py
.
!! What is not available yet:
$ python finetune.py --cfg configs/finetune.yaml
Install respective libraries for your desire framework:
# ONNX
$ pip install onnx onnx-simplifier onnxruntime
# CoreML
$ pip install coremltools
# OpenVINO
$ pip install onnx onnx-simplifier openvino-dev
# TFLite (Coming Soon)
$ pip install onnx onnx-simplifier openvino-dev openvino2tensorflow tflite-runtime
Convert:
# ONNX
$ python convert/to_onnx.py --model MODEL_NAME --variant MODEL_VARIANT --num_classes NUM_CLASSES --checkpoint /path/to/weights --size IMAGE_SIZE
# CoreML
$ python convert/to_coreml.py --model MODEL_NAME --variant MODEL_VARIANT --num_classes NUM_CLASSES --checkpoint /path/to/weights --size IMAGE_SIZE
# OpenVINO
$ python convert/to_openvino.py --model MODEL_NAME --variant MODEL_VARIANT --num_classes NUM_CLASSES --checkpoint /path/to/weights --size IMAGE_SIZE --precision FP32 or FP16
Inference:
# PyTorch
$ python convert/infer_pt.py --source IMG_FILE_PATH --model MODEL_NAME --variant MODEL_VARIANT --num_classes NUM_CLASSES --checkpoint /path/to/weights --size IMAGE_SIZE --device cuda or cpu
# ONNX
$ python convert/infer_onnx.py --source IMG_FILE_PATH --model MODEL_PATH
# OpenVINO
$ python convert/infer_openvino.py --source IMG_FILE_PATH --model MODEL_PATH --device CPU or GPU
CPU:
Model | PyTorch | ONNX | OpenVINO | TFLite |
---|---|---|---|---|
VAN-S | 46 | 28 | - | - |
GPU:
Model | PyTorch (FP32) | TensorRT (FP32) |
---|---|---|
VAN-S | 6 | - |
Latency in milliseconds. Tested with Ryzen 7 4800HS and RTX 1650ti.
Most of the codes are borrowed from timm and DeiT. I would like to thank the papers' authors for open-sourcing their codes and providing pre-trained models.