Skip to content

A collection of SOTA Image Classification Models in PyTorch

License

Notifications You must be signed in to change notification settings

sithu31296/sota-backbones

Repository files navigation

SOTA Image Classification Models in PyTorch

Intended for easy to use and integrate SOTA image classification models into down-stream tasks and finetuning with custom datasets

Open In Colab

visiontransformer

Features

  • Applicable for the following tasks:
    • Fine-tuning with custom classification datasets.
    • Used as a backbone in downstream tasks like object detection, semantic segmentation, pose estimation, etc.
  • Almost no dependency in model usage.
  • 10+ High-precision and High-efficient SOTA models.
  • Regularly updated with new models.
  • PyTorch, ONNX, CoreML, TFLite, OpenVINO Inference and Export.

Supported Models

2021

2022

  • CSWin (CVPR 2022) (Microsoft)
  • PVTv2 (CVMJ 2022) (whai362)
  • UniFormer (ICLR 2022) (SenseTime X-Lab)
  • PoolFormer (CVPR 2022) (Sea AI Lab)
  • ConvNeXt (CVPR 2022) (Meta Research)
  • VAN (ArXiv 2022) (BNRist)
  • FocalNet (ArXiv 2022) (Microsoft)
  • WaveMLP (CVPR 2022) (HUAWEI Noah's Ark Lab)
  • DaViT (ArXiv 2022) (dingmyu)
  • NAT (ArXiv 2022) (SHI Lab)
  • FAN (ArXiv 2022) (NVlabs)
  • ResTv2 (ArXiv 2022) (wofmanaf)

Losses

Benchmarks

Model ImageNet-1k Top-1 Acc
(%)
Params
(M)
GFLOPs Variants & Weights
MicroNet 51.4|59.4|62.5 2|2|3 7M|14M|23M M1|M2|M3
ResNet* 71.5|80.4|81.5 12|26|45 2|4|8 18|50|101
PoolFormer 80.3|81.4|82.1 21|31|56 4|5|9 S24|S36|M36
WaveMLP 80.9|82.9|83.3 17|30|44 2|5|8 T|S|M
PVTv2 78.7|82.0|83.6 14|25|63 2|4|10 B1|B2|B4
ResT 79.6|81.6|83.6 14|30|52 2|4|8 S|B|L
UniFormer -NA|82.9|83.8 --|22|50 -|4|8 -|S|B
VAN 75.4|81.1|82.8|83.9 4|14|27|45 1|3|5|9 T|S|B|L
ResTv2 82.3|83.2|83.7|84.2 30|41|56|87 4|6|8|14 T|S|B|L
FAN 80.1|83.5|83.9|84.3 7|26|50|77 4|7|11|17 T|S|B|L
PatchConvnet 82.1|83.2|83.5 25|48|99 4|8|16 S60|S120|B60
ConvNeXt 82.1|83.1|83.8 28|50|89 5|9|15 T|S|B
FocalNet 82.3|83.5|83.9 29|50|89 5|9|15 T|S|B
CSWin 82.7|83.6|84.2 23|35|78 4|7|15 T|S|B
NAT 81.8|83.2|83.7|84.3 20|28|51|90 3`|4|8|`14 M|T|S|B
DaViT 82.8|84.2|84.6 28|50|88 5|9|16 T|S|B

Notes: ResNet* is from "ResNet strikes back" paper.

Table Notes
  • Only include models trained on ImageNet1k with image size of 224x224 without additional tricks like token labeling, self-supervised learning, etc.
  • Models' weights are from respective official repositories.
  • Large mdoels (Parameters > 100M) are not included.

Usage

Requirements

  • torch >= 1.11
  • torchvision >= 0.12

Other requirements can be installed with pip install -r requirements.txt.

Show Supported Models

$ python list_models.py

A table with model names and variants will be shown:

                Supported Models
               
  Model Names  │ Model Variants
╶──────────────┼──────────────────────────────────╴
  ResNet       │ ['18', '34', '50', '101', '152']
  MicroNet     │ ['M1', 'M2', 'M3']
  ConvNeXt     │ ['T', 'S', 'B']
  VAN          │ ['S', 'B', 'L']
  PVTv2        │ ['B1', 'B2', 'B3', 'B4', 'B5']
  ResT         │ ['S', 'B', 'L']
  CSWin        │ ['T', 'S', 'B', 'L']
  WaveMLP      │ ['T', 'S', 'M']
  PoolFormer   │ ['S24', 'S36', 'M36']
  PatchConvnet │ ['S60', 'S120', 'B60']
  UniFormer    │ ['S', 'B']
  FocalNet     │ ['T', 'S', 'B']

Inference

# Example with VAN-S
$ python infer.py --source assests/dog.jpg --model VAN --variant S --checkpoint /path/to/van_s

You will see an output similar to this:

assests\dog.jpg >>>>> Golden retriever

Note: The above code is only for ImageNet pre-trained models. Modify the model's checkpoint loading and class names in infer.py for your custom needs.

Finetune

You can use any dataset from torchvision.datasets. For custom datasets, ImageFolder can be used to create a dataset class.

In this repo, finetuning on CIFAR-10 is provided in finetune.py.

!! What is not available yet:

  • Distributed training
  • Mixup and Cutmix augmentation
$ python finetune.py --cfg configs/finetune.yaml

training

Convert to other Frameworks

Install respective libraries for your desire framework:

# ONNX
$ pip install onnx onnx-simplifier onnxruntime
# CoreML
$ pip install coremltools
# OpenVINO
$ pip install onnx onnx-simplifier openvino-dev 
# TFLite (Coming Soon)
$ pip install onnx onnx-simplifier openvino-dev openvino2tensorflow tflite-runtime

Convert:

# ONNX
$ python convert/to_onnx.py --model MODEL_NAME --variant MODEL_VARIANT --num_classes NUM_CLASSES --checkpoint /path/to/weights --size IMAGE_SIZE

# CoreML
$ python convert/to_coreml.py --model MODEL_NAME --variant MODEL_VARIANT --num_classes NUM_CLASSES --checkpoint /path/to/weights --size IMAGE_SIZE

# OpenVINO
$ python convert/to_openvino.py --model MODEL_NAME --variant MODEL_VARIANT --num_classes NUM_CLASSES --checkpoint /path/to/weights --size IMAGE_SIZE --precision FP32 or FP16

Inference:

# PyTorch
$ python convert/infer_pt.py --source IMG_FILE_PATH --model MODEL_NAME --variant MODEL_VARIANT --num_classes NUM_CLASSES --checkpoint /path/to/weights --size IMAGE_SIZE --device cuda or cpu

# ONNX
$ python convert/infer_onnx.py --source IMG_FILE_PATH --model MODEL_PATH

# OpenVINO
$ python convert/infer_openvino.py --source IMG_FILE_PATH --model MODEL_PATH --device CPU or GPU

Framework Comparison

CPU:

Model PyTorch ONNX OpenVINO TFLite
VAN-S 46 28 - -

GPU:

Model PyTorch (FP32) TensorRT (FP32)
VAN-S 6 -

Latency in milliseconds. Tested with Ryzen 7 4800HS and RTX 1650ti.

Acknowledgements

Most of the codes are borrowed from timm and DeiT. I would like to thank the papers' authors for open-sourcing their codes and providing pre-trained models.