diff --git a/configs/_base_/datasets/imagenet_bs64_pil_resize_autoaug.py b/configs/_base_/datasets/imagenet_bs64_pil_resize_autoaug.py
index f9c50267af7..2a9a4de805b 100644
--- a/configs/_base_/datasets/imagenet_bs64_pil_resize_autoaug.py
+++ b/configs/_base_/datasets/imagenet_bs64_pil_resize_autoaug.py
@@ -8,7 +8,11 @@
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
- dict(type='RandomResizedCrop', size=224, backend='pillow'),
+ dict(
+ type='RandomResizedCrop',
+ size=224,
+ backend='pillow',
+ interpolation='bicubic'),
dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
dict(type='AutoAugment', policies={{_base_.policy_imagenet}}),
dict(type='Normalize', **img_norm_cfg),
@@ -18,7 +22,11 @@
]
test_pipeline = [
dict(type='LoadImageFromFile'),
- dict(type='Resize', size=(256, -1), backend='pillow'),
+ dict(
+ type='Resize',
+ size=(256, -1),
+ backend='pillow',
+ interpolation='bicubic'),
dict(type='CenterCrop', crop_size=224),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
diff --git a/configs/_base_/schedules/imagenet_bs4096_AdamW.py b/configs/_base_/schedules/imagenet_bs4096_AdamW.py
index 859cf4b23aa..75b00d80430 100644
--- a/configs/_base_/schedules/imagenet_bs4096_AdamW.py
+++ b/configs/_base_/schedules/imagenet_bs4096_AdamW.py
@@ -1,18 +1,24 @@
+# specific to vit pretrain
+paramwise_cfg = dict(custom_keys={
+ '.cls_token': dict(decay_mult=0.0),
+ '.pos_embed': dict(decay_mult=0.0)
+})
+
# optimizer
-optimizer = dict(type='AdamW', lr=0.003, weight_decay=0.3)
+optimizer = dict(
+ type='AdamW',
+ lr=0.003,
+ weight_decay=0.3,
+ paramwise_cfg=paramwise_cfg,
+)
optimizer_config = dict(grad_clip=dict(max_norm=1.0))
-# specific to vit pretrain
-paramwise_cfg = dict(
- custom_keys={
- '.backbone.cls_token': dict(decay_mult=0.0),
- '.backbone.pos_embed': dict(decay_mult=0.0)
- })
# learning policy
lr_config = dict(
policy='CosineAnnealing',
min_lr=0,
warmup='linear',
warmup_iters=10000,
- warmup_ratio=1e-4)
+ warmup_ratio=1e-4,
+)
runner = dict(type='EpochBasedRunner', max_epochs=300)
diff --git a/configs/deit/README.md b/configs/deit/README.md
new file mode 100644
index 00000000000..ba496b5734f
--- /dev/null
+++ b/configs/deit/README.md
@@ -0,0 +1,61 @@
+# Training data-efficient image transformers & distillation through attention
+
+
+
+## Abstract
+
+
+Recently, neural networks purely based on attention were shown to address image understanding tasks such as image classification. However, these visual transformers are pre-trained with hundreds of millions of images using an expensive infrastructure, thereby limiting their adoption. In this work, we produce a competitive convolution-free transformer by training on Imagenet only. We train them on a single computer in less than 3 days. Our reference vision transformer (86M parameters) achieves top-1 accuracy of 83.1% (single-crop evaluation) on ImageNet with no external data. More importantly, we introduce a teacher-student strategy specific to transformers. It relies on a distillation token ensuring that the student learns from the teacher through attention. We show the interest of this token-based distillation, especially when using a convnet as a teacher. This leads us to report results competitive with convnets for both Imagenet (where we obtain up to 85.2% accuracy) and when transferring to other tasks. We share our code and models.
+
+
+
+
+
+
+## Citation
+```{latex}
+@InProceedings{pmlr-v139-touvron21a,
+ title = {Training data-efficient image transformers & distillation through attention},
+ author = {Touvron, Hugo and Cord, Matthieu and Douze, Matthijs and Massa, Francisco and Sablayrolles, Alexandre and Jegou, Herve},
+ booktitle = {International Conference on Machine Learning},
+ pages = {10347--10357},
+ year = {2021},
+ volume = {139},
+ month = {July}
+}
+```
+
+## Pretrained models
+
+The pre-trained models are converted from the [official repo](https://github.com/facebookresearch/deit). And the teacher of the distilled version DeiT is RegNetY-16GF.
+
+### ImageNet-1k
+
+| Model | Params(M) | Flops(G) | Top-1 (%) | Top-5 (%) | Config | Download |
+|:---------------------:|:---------:|:--------:|:---------:|:---------:|:------:|:--------:|
+| DeiT-tiny\* | 5.72 | 1.08 | 72.13 | 91.13 | [config](configs/deit/deit-tiny_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-tiny_3rdparty_pt-4xb256_in1k_20211124-e930093b.pth) |
+| DeiT-tiny distilled\* | 5.72 | 1.08 | 74.51 | 91.90 | [config](configs/deit/deit-tiny-distilled_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-tiny-distilled_3rdparty_pt-4xb256_in1k_20211124-e71bdd9a.pth) |
+| DeiT-small\* | 22.05 | 4.24 | 79.83 | 94.95 | [config](configs/deit/deit-small_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-small_3rdparty_pt-4xb256_in1k_20211124-ffe94edd.pth) |
+| DeiT-small distilled\* | 22.05 | 4.24 | 81.17 | 95.40 | [config](configs/deit/deit-small-distilled_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-small-distilled_3rdparty_pt-4xb256_in1k_20211124-15e341b0.pth) |
+| DeiT-base\* | 86.57 | 16.86 | 81.79 | 95.59 | [config](configs/deit/deit-base_pt-16xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-base_3rdparty_pt-16xb64_in1k_20211124-6f40c188.pth) |
+| DeiT-base distilled\* | 86.57 | 16.86 | 83.33 | 96.49 | [config](configs/deit/deit-base-distilled_pt-16xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-base-distilled_3rdparty_pt-16xb64_in1k_20211124-766d123d.pth) |
+
+*Models with \* are converted from other repos.*
+
+## Fine-tuned models
+
+The fine-tuned models are converted from the [official repo](https://github.com/facebookresearch/deit).
+
+### ImageNet-1k
+
+| Model | Params(M) | Flops(G) | Top-1 (%) | Top-5 (%) | Config | Download |
+|:---------------------:|:---------:|:--------:|:---------:|:---------:|:------:|:--------:|
+| DeiT-base 384px\* | 86.86 | 49.37 | 83.04 | 96.31 | [config](configs/deit/deit-base_ft-16xb32_in1k-384px.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-base_3rdparty_ft-16xb32_in1k-384px_20211124-822d02f2.pth) |
+| DeiT-base distilled 384px\* | 86.86 | 49.37 | 85.55 | 97.35 | [config](configs/deit/deit-base-distilled_ft-16xb32_in1k-384px.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-base-distilled_3rdparty_ft-16xb32_in1k-384px_20211124-91e88933.pth) |
+
+*Models with \* are converted from other repos.*
+
+```{warning}
+MMClassification doesn't support training the distilled version DeiT.
+And we provide distilled version checkpoints for inference only.
+```
diff --git a/configs/deit/deit-base-distilled_ft-16xb32_in1k-384px.py b/configs/deit/deit-base-distilled_ft-16xb32_in1k-384px.py
new file mode 100644
index 00000000000..c8bdfb537bd
--- /dev/null
+++ b/configs/deit/deit-base-distilled_ft-16xb32_in1k-384px.py
@@ -0,0 +1,9 @@
+_base_ = './deit-base_ft-16xb32_in1k-384px.py'
+
+# model settings
+model = dict(
+ backbone=dict(type='DistilledVisionTransformer'),
+ head=dict(type='DeiTClsHead'),
+ # Change to the path of the pretrained model
+ # init_cfg=dict(type='Pretrained', checkpoint=''),
+)
diff --git a/configs/deit/deit-base-distilled_pt-16xb64_in1k.py b/configs/deit/deit-base-distilled_pt-16xb64_in1k.py
new file mode 100644
index 00000000000..671658383ae
--- /dev/null
+++ b/configs/deit/deit-base-distilled_pt-16xb64_in1k.py
@@ -0,0 +1,10 @@
+_base_ = './deit-small_pt-4xb256_in1k.py'
+
+# model settings
+model = dict(
+ backbone=dict(type='DistilledVisionTransformer', arch='deit-base'),
+ head=dict(type='DeiTClsHead', in_channels=768),
+)
+
+# data settings
+data = dict(samples_per_gpu=64, workers_per_gpu=5)
diff --git a/configs/deit/deit-base_ft-16xb32_in1k-384px.py b/configs/deit/deit-base_ft-16xb32_in1k-384px.py
new file mode 100644
index 00000000000..db444168d43
--- /dev/null
+++ b/configs/deit/deit-base_ft-16xb32_in1k-384px.py
@@ -0,0 +1,29 @@
+_base_ = [
+ '../_base_/datasets/imagenet_bs64_swin_384.py',
+ '../_base_/schedules/imagenet_bs4096_AdamW.py',
+ '../_base_/default_runtime.py'
+]
+
+# model settings
+model = dict(
+ type='ImageClassifier',
+ backbone=dict(
+ type='VisionTransformer',
+ arch='deit-base',
+ img_size=384,
+ patch_size=16,
+ ),
+ neck=None,
+ head=dict(
+ type='VisionTransformerClsHead',
+ num_classes=1000,
+ in_channels=768,
+ loss=dict(
+ type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
+ ),
+ # Change to the path of the pretrained model
+ # init_cfg=dict(type='Pretrained', checkpoint=''),
+)
+
+# data settings
+data = dict(samples_per_gpu=32, workers_per_gpu=5)
diff --git a/configs/deit/deit-base_pt-16xb64_in1k.py b/configs/deit/deit-base_pt-16xb64_in1k.py
new file mode 100644
index 00000000000..818dc3f1cf2
--- /dev/null
+++ b/configs/deit/deit-base_pt-16xb64_in1k.py
@@ -0,0 +1,10 @@
+_base_ = './deit-small_pt-4xb256_in1k.py'
+
+# model settings
+model = dict(
+ backbone=dict(type='VisionTransformer', arch='deit-base'),
+ head=dict(type='VisionTransformerClsHead', in_channels=768),
+)
+
+# data settings
+data = dict(samples_per_gpu=64, workers_per_gpu=5)
diff --git a/configs/deit/deit-small-distilled_pt-4xb256_in1k.py b/configs/deit/deit-small-distilled_pt-4xb256_in1k.py
new file mode 100644
index 00000000000..3b1fac22490
--- /dev/null
+++ b/configs/deit/deit-small-distilled_pt-4xb256_in1k.py
@@ -0,0 +1,7 @@
+_base_ = './deit-small_pt-4xb256_in1k.py'
+
+# model settings
+model = dict(
+ backbone=dict(type='DistilledVisionTransformer', arch='deit-small'),
+ head=dict(type='DeiTClsHead', in_channels=384),
+)
diff --git a/configs/deit/deit-small_pt-4xb256_in1k.py b/configs/deit/deit-small_pt-4xb256_in1k.py
new file mode 100644
index 00000000000..fad7a365d90
--- /dev/null
+++ b/configs/deit/deit-small_pt-4xb256_in1k.py
@@ -0,0 +1,29 @@
+_base_ = [
+ '../_base_/datasets/imagenet_bs64_pil_resize_autoaug.py',
+ '../_base_/schedules/imagenet_bs4096_AdamW.py',
+ '../_base_/default_runtime.py'
+]
+
+# model settings
+model = dict(
+ type='ImageClassifier',
+ backbone=dict(
+ type='VisionTransformer',
+ arch='deit-small',
+ img_size=224,
+ patch_size=16),
+ neck=None,
+ head=dict(
+ type='VisionTransformerClsHead',
+ num_classes=1000,
+ in_channels=384,
+ loss=dict(
+ type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
+ ),
+ init_cfg=[
+ dict(type='TruncNormal', layer='Linear', std=.02),
+ dict(type='Constant', layer='LayerNorm', val=1., bias=0.),
+ ])
+
+# data settings
+data = dict(samples_per_gpu=256, workers_per_gpu=5)
diff --git a/configs/deit/deit-tiny-distilled_pt-4xb256_in1k.py b/configs/deit/deit-tiny-distilled_pt-4xb256_in1k.py
new file mode 100644
index 00000000000..175f980445d
--- /dev/null
+++ b/configs/deit/deit-tiny-distilled_pt-4xb256_in1k.py
@@ -0,0 +1,7 @@
+_base_ = './deit-small_pt-4xb256_in1k.py'
+
+# model settings
+model = dict(
+ backbone=dict(type='DistilledVisionTransformer', arch='deit-tiny'),
+ head=dict(type='DeiTClsHead', in_channels=192),
+)
diff --git a/configs/deit/deit-tiny_pt-4xb256_in1k.py b/configs/deit/deit-tiny_pt-4xb256_in1k.py
new file mode 100644
index 00000000000..43df6e13823
--- /dev/null
+++ b/configs/deit/deit-tiny_pt-4xb256_in1k.py
@@ -0,0 +1,7 @@
+_base_ = './deit-small_pt-4xb256_in1k.py'
+
+# model settings
+model = dict(
+ backbone=dict(type='VisionTransformer', arch='deit-tiny'),
+ head=dict(type='VisionTransformerClsHead', in_channels=192),
+)
diff --git a/configs/deit/metafile.yml b/configs/deit/metafile.yml
new file mode 100644
index 00000000000..9f475eaa067
--- /dev/null
+++ b/configs/deit/metafile.yml
@@ -0,0 +1,143 @@
+Collections:
+ - Name: DeiT
+ Metadata:
+ Training Data: ImageNet-1k
+ Architecture:
+ - Layer Normalization
+ - Scaled Dot-Product Attention
+ - Attention Dropout
+ - Multi-Head Attention
+ Paper:
+ URL: https://arxiv.org/abs/2012.12877
+ Title: "Training data-efficient image transformers & distillation through attention"
+ README: configs/deit/README.md
+
+Models:
+ - Name: deit-tiny_3rdparty_pt-4xb256_in1k
+ Metadata:
+ FLOPs: 1080000000
+ Parameters: 5720000
+ In Collection: DeiT
+ Results:
+ - Dataset: ImageNet-1k
+ Metrics:
+ Top 1 Accuracy: 72.13
+ Top 5 Accuracy: 91.13
+ Task: Image Classification
+ Weights: https://download.openmmlab.com/mmclassification/v0/deit/deit-tiny_3rdparty_pt-4xb256_in1k_20211124-e930093b.pth
+ Converted From:
+ Weights: https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth
+ Code: https://github.com/facebookresearch/deit/blob/f5123946205daf72a88783dae94cabff98c49c55/models.py#L63
+ Config: configs/deit/deit-tiny_pt-4xb256_in1k.py
+ - Name: deit-tiny-distilled_3rdparty_pt-4xb256_in1k
+ Metadata:
+ FLOPs: 1080000000
+ Parameters: 5720000
+ In Collection: DeiT
+ Results:
+ - Dataset: ImageNet-1k
+ Metrics:
+ Top 1 Accuracy: 74.51
+ Top 5 Accuracy: 91.90
+ Task: Image Classification
+ Weights: https://download.openmmlab.com/mmclassification/v0/deit/deit-tiny-distilled_3rdparty_pt-4xb256_in1k_20211124-e71bdd9a.pth
+ Converted From:
+ Weights: https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth
+ Code: https://github.com/facebookresearch/deit/blob/f5123946205daf72a88783dae94cabff98c49c55/models.py#L108
+ Config: configs/deit/deit-tiny-distilled_pt-4xb256_in1k.py
+ - Name: deit-small_3rdparty_pt-4xb256_in1k
+ Metadata:
+ FLOPs: 4240000000
+ Parameters: 22050000
+ In Collection: DeiT
+ Results:
+ - Dataset: ImageNet-1k
+ Metrics:
+ Top 1 Accuracy: 79.83
+ Top 5 Accuracy: 94.95
+ Task: Image Classification
+ Weights: https://download.openmmlab.com/mmclassification/v0/deit/deit-small_3rdparty_pt-4xb256_in1k_20211124-ffe94edd.pth
+ Converted From:
+ Weights: https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth
+ Code: https://github.com/facebookresearch/deit/blob/f5123946205daf72a88783dae94cabff98c49c55/models.py#L78
+ Config: configs/deit/deit-small_pt-4xb256_in1k.py
+ - Name: deit-small-distilled_3rdparty_pt-4xb256_in1k
+ Metadata:
+ FLOPs: 4240000000
+ Parameters: 22050000
+ In Collection: DeiT
+ Results:
+ - Dataset: ImageNet-1k
+ Metrics:
+ Top 1 Accuracy: 81.17
+ Top 5 Accuracy: 95.40
+ Task: Image Classification
+ Weights: https://download.openmmlab.com/mmclassification/v0/deit/deit-small-distilled_3rdparty_pt-4xb256_in1k_20211124-15e341b0.pth
+ Converted From:
+ Weights: https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth
+ Code: https://github.com/facebookresearch/deit/blob/f5123946205daf72a88783dae94cabff98c49c55/models.py#L123
+ Config: configs/deit/deit-small-distilled_pt-4xb256_in1k.py
+ - Name: deit-base_3rdparty_pt-16xb64_in1k
+ Metadata:
+ FLOPs: 16860000000
+ Parameters: 86570000
+ In Collection: DeiT
+ Results:
+ - Dataset: ImageNet-1k
+ Metrics:
+ Top 1 Accuracy: 81.79
+ Top 5 Accuracy: 95.59
+ Task: Image Classification
+ Weights: https://download.openmmlab.com/mmclassification/v0/deit/deit-base_3rdparty_pt-16xb64_in1k_20211124-6f40c188.pth
+ Converted From:
+ Weights: https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth
+ Code: https://github.com/facebookresearch/deit/blob/f5123946205daf72a88783dae94cabff98c49c55/models.py#L93
+ Config: configs/deit/deit-base_pt-16xb64_in1k.py
+ - Name: deit-base-distilled_3rdparty_pt-16xb64_in1k
+ Metadata:
+ FLOPs: 16860000000
+ Parameters: 86570000
+ In Collection: DeiT
+ Results:
+ - Dataset: ImageNet-1k
+ Metrics:
+ Top 1 Accuracy: 83.33
+ Top 5 Accuracy: 96.49
+ Task: Image Classification
+ Weights: https://download.openmmlab.com/mmclassification/v0/deit/deit-base-distilled_3rdparty_pt-16xb64_in1k_20211124-766d123d.pth
+ Converted From:
+ Weights: https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth
+ Code: https://github.com/facebookresearch/deit/blob/f5123946205daf72a88783dae94cabff98c49c55/models.py#L138
+ Config: configs/deit/deit-base-distilled_pt-16xb64_in1k.py
+ - Name: deit-base_3rdparty_ft-16xb32_in1k-384px
+ Metadata:
+ FLOPs: 49370000000
+ Parameters: 86860000
+ In Collection: DeiT
+ Results:
+ - Dataset: ImageNet-1k
+ Metrics:
+ Top 1 Accuracy: 83.04
+ Top 5 Accuracy: 96.31
+ Task: Image Classification
+ Weights: https://download.openmmlab.com/mmclassification/v0/deit/deit-base_3rdparty_ft-16xb32_in1k-384px_20211124-822d02f2.pth
+ Converted From:
+ Weights: https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth
+ Code: https://github.com/facebookresearch/deit/blob/f5123946205daf72a88783dae94cabff98c49c55/models.py#L153
+ Config: configs/deit/deit-base_ft-16xb32_in1k-384px.py
+ - Name: deit-base-distilled_3rdparty_ft-16xb32_in1k-384px
+ Metadata:
+ FLOPs: 49370000000
+ Parameters: 86860000
+ In Collection: DeiT
+ Results:
+ - Dataset: ImageNet-1k
+ Metrics:
+ Top 1 Accuracy: 85.55
+ Top 5 Accuracy: 97.35
+ Task: Image Classification
+ Weights: https://download.openmmlab.com/mmclassification/v0/deit/deit-base-distilled_3rdparty_ft-16xb32_in1k-384px_20211124-91e88933.pth
+ Converted From:
+ Weights: https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth
+ Code: https://github.com/facebookresearch/deit/blob/f5123946205daf72a88783dae94cabff98c49c55/models.py#L168
+ Config: configs/deit/deit-base-distilled_ft-16xb32_in1k-384px.py
diff --git a/docs/en/model_zoo.md b/docs/en/model_zoo.md
index 7f378e0d814..918f5e4b0f4 100644
--- a/docs/en/model_zoo.md
+++ b/docs/en/model_zoo.md
@@ -63,12 +63,17 @@ The ResNet family models below are trained by standard data augmentations, i.e.,
| T2T-ViT_t-24\* | 64.00 | 12.69 | 82.55 | 96.06 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/t2t_vit/t2t-vit-t-24_8xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/t2t-vit/t2t-vit-t-24_3rdparty_8xb64_in1k_20210928-fe95a61b.pth) | [log]()|
| Mixer-B/16\* | 59.88 | 12.61 | 76.68 | 92.25 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/mlp_mixer/mlp-mixer-base-p16_64xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/mlp-mixer/mixer-base-p16_3rdparty_64xb64_in1k_20211124-1377e3e0.pth) | [log]()|
| Mixer-L/16\* | 208.2 | 44.57 | 72.34 | 88.02 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/mlp_mixer/mlp-mixer-large-p16_64xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/mlp-mixer/mixer-large-p16_3rdparty_64xb64_in1k_20211124-5a2519d2.pth) | [log]()|
+| DeiT-tiny\* | 5.72 | 1.08 | 72.13 | 91.13 | [config](configs/deit/deit-tiny_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-tiny_3rdparty_pt-4xb256_in1k_20211124-e930093b.pth) | [log]()|
+| DeiT-tiny distilled\* | 5.72 | 1.08 | 74.51 | 91.90 | [config](configs/deit/deit-tiny-distilled_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-tiny-distilled_3rdparty_pt-4xb256_in1k_20211124-e71bdd9a.pth) | [log]()|
+| DeiT-small\* | 22.05 | 4.24 | 79.83 | 94.95 | [config](configs/deit/deit-small_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-small_3rdparty_pt-4xb256_in1k_20211124-ffe94edd.pth) | [log]()|
+| DeiT-small distilled\* | 22.05 | 4.24 | 81.17 | 95.40 | [config](configs/deit/deit-small-distilled_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-small-distilled_3rdparty_pt-4xb256_in1k_20211124-15e341b0.pth) | [log]()|
+| DeiT-base\* | 86.57 | 16.86 | 81.79 | 95.59 | [config](configs/deit/deit-base_pt-16xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-base_3rdparty_pt-16xb64_in1k_20211124-6f40c188.pth) | [log]()|
+| DeiT-base distilled\* | 86.57 | 16.86 | 83.33 | 96.49 | [config](configs/deit/deit-base-distilled_pt-16xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-base-distilled_3rdparty_pt-16xb64_in1k_20211124-766d123d.pth) | [log]()|
| Conformer-tiny-p16\* | 23.52 | 4.90 | 81.31 | 95.60 | [config](configs/conformer/conformer-tiny-p16_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/conformer/conformer-tiny-p16_3rdparty_8xb128_in1k_20211206-f6860372.pth) | [log]()|
| Conformer-small-p32 | 38.85 | 7.09 | 81.96 | 96.02 | [config](configs/conformer/conformer-small-p32_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/conformer/conformer-small-p32_8xb128_in1k_20211206-947a0816.pth) | [log]()|
| Conformer-small-p16\* | 37.67 | 10.31 | 83.32 | 96.46 | [config](configs/conformer/conformer-small-p16_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/conformer/conformer-small-p16_3rdparty_8xb128_in1k_20211206-3065dcf5.pth) | [log]()|
| Conformer-base-p16\* | 83.29 | 22.89 | 83.82 | 96.59 | [config](configs/conformer/conformer-base-p16_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/conformer/conformer-base-p16_3rdparty_8xb128_in1k_20211206-bfdf8637.pth) | [log]()|
-
Models with * are converted from other repos, others are trained by ourselves.
## CIFAR10
diff --git a/mmcls/models/backbones/__init__.py b/mmcls/models/backbones/__init__.py
index f9dbf705486..faa7927f377 100644
--- a/mmcls/models/backbones/__init__.py
+++ b/mmcls/models/backbones/__init__.py
@@ -1,6 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .alexnet import AlexNet
from .conformer import Conformer
+from .deit import DistilledVisionTransformer
from .lenet import LeNet5
from .mlp_mixer import MlpMixer
from .mobilenet_v2 import MobileNetV2
@@ -28,5 +29,5 @@
'ResNeSt', 'ResNet_CIFAR', 'SEResNet', 'SEResNeXt', 'ShuffleNetV1',
'ShuffleNetV2', 'MobileNetV2', 'MobileNetV3', 'VisionTransformer',
'SwinTransformer', 'TNT', 'TIMMBackbone', 'T2T_ViT', 'Res2Net', 'RepVGG',
- 'Conformer', 'MlpMixer'
+ 'Conformer', 'MlpMixer', 'DistilledVisionTransformer'
]
diff --git a/mmcls/models/backbones/deit.py b/mmcls/models/backbones/deit.py
new file mode 100644
index 00000000000..37851798dde
--- /dev/null
+++ b/mmcls/models/backbones/deit.py
@@ -0,0 +1,85 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+from mmcv.cnn.utils.weight_init import trunc_normal_
+
+from ..builder import BACKBONES
+from .vision_transformer import VisionTransformer
+
+
+@BACKBONES.register_module()
+class DistilledVisionTransformer(VisionTransformer):
+ """Distilled Vision Transformer.
+
+ A PyTorch implement of : `Training data-efficient image transformers &
+ distillation through attention `_
+
+ Args:
+ arch (str | dict): Vision Transformer architecture
+ Default: 'b'
+ img_size (int | tuple): Input image size
+ patch_size (int | tuple): The patch size
+ out_indices (Sequence | int): Output from which stages.
+ Defaults to -1, means the last stage.
+ drop_rate (float): Probability of an element to be zeroed.
+ Defaults to 0.
+ drop_path_rate (float): stochastic depth rate. Defaults to 0.
+ norm_cfg (dict): Config dict for normalization layer.
+ Defaults to ``dict(type='LN')``.
+ final_norm (bool): Whether to add a additional layer to normalize
+ final feature map. Defaults to True.
+ output_cls_token (bool): Whether output the cls_token. If set True,
+ `with_cls_token` must be True. Defaults to True.
+ interpolate_mode (str): Select the interpolate mode for position
+ embeding vector resize. Defaults to "bicubic".
+ patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict.
+ layer_cfgs (Sequence | dict): Configs of each transformer layer in
+ encoder. Defaults to an empty dict.
+ init_cfg (dict, optional): Initialization config dict.
+ Defaults to None.
+ """
+ num_extra_tokens = 2 # cls_token, dist_token
+
+ def __init__(self, *args, **kwargs):
+ super(DistilledVisionTransformer, self).__init__(*args, **kwargs)
+ self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims))
+
+ def forward(self, x):
+ B = x.shape[0]
+ x = self.patch_embed(x)
+ patch_resolution = self.patch_embed.patches_resolution
+
+ # stole cls_tokens impl from Phil Wang, thanks
+ cls_tokens = self.cls_token.expand(B, -1, -1)
+ dist_token = self.dist_token.expand(B, -1, -1)
+ x = torch.cat((cls_tokens, dist_token, x), dim=1)
+ x = x + self.pos_embed
+ x = self.drop_after_pos(x)
+
+ outs = []
+ for i, layer in enumerate(self.layers):
+ x = layer(x)
+
+ if i == len(self.layers) - 1 and self.final_norm:
+ x = self.norm1(x)
+
+ if i in self.out_indices:
+ B, _, C = x.shape
+ patch_token = x[:, 2:].reshape(B, *patch_resolution, C)
+ patch_token = patch_token.permute(0, 3, 1, 2)
+ cls_token = x[:, 0]
+ dist_token = x[:, 1]
+ if self.output_cls_token:
+ out = [patch_token, cls_token, dist_token]
+ else:
+ out = patch_token
+ outs.append(out)
+
+ return tuple(outs)
+
+ def init_weights(self):
+ super(DistilledVisionTransformer, self).init_weights()
+
+ if not (isinstance(self.init_cfg, dict)
+ and self.init_cfg['type'] == 'Pretrained'):
+ trunc_normal_(self.dist_token, std=0.02)
diff --git a/mmcls/models/backbones/vision_transformer.py b/mmcls/models/backbones/vision_transformer.py
index 4da7b185636..fc0f57b5802 100644
--- a/mmcls/models/backbones/vision_transformer.py
+++ b/mmcls/models/backbones/vision_transformer.py
@@ -1,5 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
-from copy import deepcopy
from typing import Sequence
import numpy as np
@@ -8,6 +7,7 @@
import torch.nn.functional as F
from mmcv.cnn import build_norm_layer
from mmcv.cnn.bricks.transformer import FFN
+from mmcv.cnn.utils.weight_init import trunc_normal_
from mmcv.runner.base_module import BaseModule, ModuleList
from mmcls.utils import get_root_logger
@@ -104,9 +104,8 @@ def forward(self, x):
class VisionTransformer(BaseBackbone):
"""Vision Transformer.
- A PyTorch implement of : `An Image is Worth 16x16 Words:
- Transformers for Image Recognition at Scale
- `_
+ A PyTorch implement of : `An Image is Worth 16x16 Words: Transformers
+ for Image Recognition at Scale `_
Args:
arch (str | dict): Vision Transformer architecture
@@ -155,7 +154,30 @@ class VisionTransformer(BaseBackbone):
'num_heads': 16,
'feedforward_channels': 4096
}),
+ **dict.fromkeys(
+ ['deit-t', 'deit-tiny'], {
+ 'embed_dims': 192,
+ 'num_layers': 12,
+ 'num_heads': 3,
+ 'feedforward_channels': 192 * 4
+ }),
+ **dict.fromkeys(
+ ['deit-s', 'deit-small'], {
+ 'embed_dims': 384,
+ 'num_layers': 12,
+ 'num_heads': 6,
+ 'feedforward_channels': 384 * 4
+ }),
+ **dict.fromkeys(
+ ['deit-b', 'deit-base'], {
+ 'embed_dims': 768,
+ 'num_layers': 12,
+ 'num_heads': 12,
+ 'feedforward_channels': 768 * 4
+ }),
}
+ # Some structures have multiple extra tokens, like DeiT.
+ num_extra_tokens = 1 # cls_token
def __init__(self,
arch='b',
@@ -182,7 +204,7 @@ def __init__(self,
essential_keys = {
'embed_dims', 'num_layers', 'num_heads', 'feedforward_channels'
}
- assert isinstance(arch, dict) and set(arch) == essential_keys, \
+ assert isinstance(arch, dict) and essential_keys <= set(arch), \
f'Custom arch needs a dict with keys {essential_keys}'
self.arch_settings = arch
@@ -208,7 +230,8 @@ def __init__(self,
# Set position embedding
self.interpolate_mode = interpolate_mode
self.pos_embed = nn.Parameter(
- torch.zeros(1, num_patches + 1, self.embed_dims))
+ torch.zeros(1, num_patches + self.num_extra_tokens,
+ self.embed_dims))
self.drop_after_pos = nn.Dropout(p=drop_rate)
if isinstance(out_indices, int):
@@ -247,65 +270,49 @@ def __init__(self,
norm_cfg, self.embed_dims, postfix=1)
self.add_module(self.norm1_name, norm1)
+ self._register_load_state_dict_pre_hook(self._prepare_checkpoint_hook)
+
@property
def norm1(self):
return getattr(self, self.norm1_name)
def init_weights(self):
- # Suppress default init if use pretrained model.
- # And use custom load_checkpoint function to load checkpoint.
- if (isinstance(self.init_cfg, dict)
+ super(VisionTransformer, self).init_weights()
+
+ if not (isinstance(self.init_cfg, dict)
and self.init_cfg['type'] == 'Pretrained'):
- init_cfg = deepcopy(self.init_cfg)
- init_cfg.pop('type')
- self._load_checkpoint(**init_cfg)
- else:
- super(VisionTransformer, self).init_weights()
- # Modified from ClassyVision
- nn.init.normal_(self.pos_embed, std=0.02)
-
- def _load_checkpoint(self, checkpoint, prefix=None, map_location=None):
- from mmcv.runner import (_load_checkpoint,
- _load_checkpoint_with_prefix, load_state_dict)
- from mmcv.utils import print_log
-
- logger = get_root_logger()
-
- if prefix is None:
- print_log(f'load model from: {checkpoint}', logger=logger)
- checkpoint = _load_checkpoint(checkpoint, map_location, logger)
- # get state_dict from checkpoint
- if 'state_dict' in checkpoint:
- state_dict = checkpoint['state_dict']
- else:
- state_dict = checkpoint
- else:
- print_log(
- f'load {prefix} in model from: {checkpoint}', logger=logger)
- state_dict = _load_checkpoint_with_prefix(prefix, checkpoint,
- map_location)
+ trunc_normal_(self.pos_embed, std=0.02)
- if 'pos_embed' in state_dict.keys():
- ckpt_pos_embed_shape = state_dict['pos_embed'].shape
- if self.pos_embed.shape != ckpt_pos_embed_shape:
- print_log(
- f'Resize the pos_embed shape from {ckpt_pos_embed_shape} '
- f'to {self.pos_embed.shape}.',
- logger=logger)
+ def _prepare_checkpoint_hook(self, state_dict, prefix, *args, **kwargs):
+ name = prefix + 'pos_embed'
+ if name not in state_dict.keys():
+ return
- ckpt_pos_embed_shape = to_2tuple(
- int(np.sqrt(ckpt_pos_embed_shape[1] - 1)))
- pos_embed_shape = self.patch_embed.patches_resolution
+ ckpt_pos_embed_shape = state_dict[name].shape
+ if self.pos_embed.shape != ckpt_pos_embed_shape:
+ from mmcv.utils import print_log
+ logger = get_root_logger()
+ print_log(
+ f'Resize the pos_embed shape from {ckpt_pos_embed_shape} '
+ f'to {self.pos_embed.shape}.',
+ logger=logger)
- state_dict['pos_embed'] = self.resize_pos_embed(
- state_dict['pos_embed'], ckpt_pos_embed_shape,
- pos_embed_shape, self.interpolate_mode)
+ ckpt_pos_embed_shape = to_2tuple(
+ int(np.sqrt(ckpt_pos_embed_shape[1] - self.num_extra_tokens)))
+ pos_embed_shape = self.patch_embed.patches_resolution
- # load state_dict
- load_state_dict(self, state_dict, strict=False, logger=logger)
+ state_dict[name] = self.resize_pos_embed(state_dict[name],
+ ckpt_pos_embed_shape,
+ pos_embed_shape,
+ self.interpolate_mode,
+ self.num_extra_tokens)
@staticmethod
- def resize_pos_embed(pos_embed, src_shape, dst_shape, mode='bicubic'):
+ def resize_pos_embed(pos_embed,
+ src_shape,
+ dst_shape,
+ mode='bicubic',
+ num_extra_tokens=1):
"""Resize pos_embed weights.
Args:
@@ -324,17 +331,17 @@ def resize_pos_embed(pos_embed, src_shape, dst_shape, mode='bicubic'):
assert pos_embed.ndim == 3, 'shape of pos_embed must be [1, L, C]'
_, L, C = pos_embed.shape
src_h, src_w = src_shape
- assert L == src_h * src_w + 1
- cls_token = pos_embed[:, :1]
+ assert L == src_h * src_w + num_extra_tokens
+ extra_tokens = pos_embed[:, :num_extra_tokens]
- src_weight = pos_embed[:, 1:]
+ src_weight = pos_embed[:, num_extra_tokens:]
src_weight = src_weight.reshape(1, src_h, src_w, C).permute(0, 3, 1, 2)
dst_weight = F.interpolate(
src_weight, size=dst_shape, align_corners=False, mode=mode)
dst_weight = torch.flatten(dst_weight, 2).transpose(1, 2)
- return torch.cat((cls_token, dst_weight), dim=1)
+ return torch.cat((extra_tokens, dst_weight), dim=1)
def forward(self, x):
B = x.shape[0]
diff --git a/mmcls/models/heads/__init__.py b/mmcls/models/heads/__init__.py
index 4be4daf8722..b81106fbe96 100644
--- a/mmcls/models/heads/__init__.py
+++ b/mmcls/models/heads/__init__.py
@@ -1,6 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .cls_head import ClsHead
from .conformer_head import ConformerHead
+from .deit_head import DeiTClsHead
from .linear_head import LinearClsHead
from .multi_label_head import MultiLabelClsHead
from .multi_label_linear_head import MultiLabelLinearClsHead
@@ -9,5 +10,6 @@
__all__ = [
'ClsHead', 'LinearClsHead', 'StackedLinearClsHead', 'MultiLabelClsHead',
- 'MultiLabelLinearClsHead', 'VisionTransformerClsHead', 'ConformerHead'
+ 'MultiLabelLinearClsHead', 'VisionTransformerClsHead', 'DeiTClsHead',
+ 'ConformerHead'
]
diff --git a/mmcls/models/heads/deit_head.py b/mmcls/models/heads/deit_head.py
new file mode 100644
index 00000000000..4a79e455587
--- /dev/null
+++ b/mmcls/models/heads/deit_head.py
@@ -0,0 +1,36 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch.nn as nn
+import torch.nn.functional as F
+
+from mmcls.utils import get_root_logger
+from ..builder import HEADS
+from .vision_transformer_head import VisionTransformerClsHead
+
+
+@HEADS.register_module()
+class DeiTClsHead(VisionTransformerClsHead):
+
+ def __init__(self, *args, **kwargs):
+ super(DeiTClsHead, self).__init__(*args, **kwargs)
+ self.head_dist = nn.Linear(self.in_channels, self.num_classes)
+
+ def simple_test(self, x):
+ """Test without augmentation."""
+ x = x[-1]
+ assert isinstance(x, list) and len(x) == 3
+ _, cls_token, dist_token = x
+ cls_score = (self.layers(cls_token) + self.head_dist(dist_token)) / 2
+ pred = F.softmax(cls_score, dim=1) if cls_score is not None else None
+
+ return self.post_process(pred)
+
+ def forward_train(self, x, gt_label):
+ logger = get_root_logger()
+ logger.warning("MMClassification doesn't support to train the "
+ 'distilled version DeiT.')
+ x = x[-1]
+ assert isinstance(x, list) and len(x) == 3
+ _, cls_token, dist_token = x
+ cls_score = (self.layers(cls_token) + self.head_dist(dist_token)) / 2
+ losses = self.loss(cls_score, gt_label)
+ return losses
diff --git a/model-index.yml b/model-index.yml
index f8b8b4d5567..a17a5e08cb4 100644
--- a/model-index.yml
+++ b/model-index.yml
@@ -14,3 +14,4 @@ Import:
- configs/t2t_vit/metafile.yml
- configs/mlp_mixer/metafile.yml
- configs/conformer/metafile.yml
+ - configs/deit/metafile.yml
diff --git a/tests/test_models/test_backbones/test_deit.py b/tests/test_models/test_backbones/test_deit.py
new file mode 100644
index 00000000000..af9efe9c0b8
--- /dev/null
+++ b/tests/test_models/test_backbones/test_deit.py
@@ -0,0 +1,43 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+
+import torch
+from torch.nn.modules.batchnorm import _BatchNorm
+
+from mmcls.models.backbones import DistilledVisionTransformer
+
+
+def check_norm_state(modules, train_state):
+ """Check if norm layer is in correct train state."""
+ for mod in modules:
+ if isinstance(mod, _BatchNorm):
+ if mod.training != train_state:
+ return False
+ return True
+
+
+def test_deit_backbone():
+ cfg_ori = dict(arch='deit-b', img_size=224, patch_size=16)
+
+ # Test structure
+ model = DistilledVisionTransformer(**cfg_ori)
+ model.init_weights()
+ model.train()
+
+ assert check_norm_state(model.modules(), True)
+ assert model.dist_token.shape == (1, 1, 768)
+ assert model.pos_embed.shape == (1, model.patch_embed.num_patches + 2, 768)
+
+ # Test forward
+ imgs = torch.rand(1, 3, 224, 224)
+ outs = model(imgs)
+ patch_token, cls_token, dist_token = outs[0]
+ assert patch_token.shape == (1, 768, 14, 14)
+ assert cls_token.shape == (1, 768)
+ assert dist_token.shape == (1, 768)
+
+ # Test multiple out_indices
+ model = DistilledVisionTransformer(
+ **cfg_ori, out_indices=(0, 1, 2, 3), output_cls_token=False)
+ outs = model(imgs)
+ for out in outs:
+ assert out.shape == (1, 768, 14, 14)
diff --git a/tests/test_models/test_heads.py b/tests/test_models/test_heads.py
index 4d2105698f5..c56d509ea06 100644
--- a/tests/test_models/test_heads.py
+++ b/tests/test_models/test_heads.py
@@ -4,9 +4,9 @@
import pytest
import torch
-from mmcls.models.heads import (ClsHead, LinearClsHead, MultiLabelClsHead,
- MultiLabelLinearClsHead, StackedLinearClsHead,
- VisionTransformerClsHead)
+from mmcls.models.heads import (ClsHead, DeiTClsHead, LinearClsHead,
+ MultiLabelClsHead, MultiLabelLinearClsHead,
+ StackedLinearClsHead, VisionTransformerClsHead)
@pytest.mark.parametrize('feat', [torch.rand(4, 3), (torch.rand(4, 3), )])
@@ -157,3 +157,44 @@ def test_vit_head():
# test assertion
with pytest.raises(ValueError):
VisionTransformerClsHead(-1, 100)
+
+
+def test_deit_head():
+ fake_features = ([
+ torch.rand(4, 7, 7, 16),
+ torch.rand(4, 100),
+ torch.rand(4, 100)
+ ], )
+ fake_gt_label = torch.randint(0, 10, (4, ))
+
+ # test deit head forward
+ head = DeiTClsHead(num_classes=10, in_channels=100)
+ losses = head.forward_train(fake_features, fake_gt_label)
+ assert not hasattr(head.layers, 'pre_logits')
+ assert not hasattr(head.layers, 'act')
+ assert losses['loss'].item() > 0
+
+ # test deit head forward with hidden layer
+ head = DeiTClsHead(num_classes=10, in_channels=100, hidden_dim=20)
+ losses = head.forward_train(fake_features, fake_gt_label)
+ assert hasattr(head.layers, 'pre_logits') and hasattr(head.layers, 'act')
+ assert losses['loss'].item() > 0
+
+ # test deit head init_weights
+ head = DeiTClsHead(10, 100, hidden_dim=20)
+ head.init_weights()
+ assert abs(head.layers.pre_logits.weight).sum() > 0
+
+ # test simple_test
+ head = DeiTClsHead(10, 100, hidden_dim=20)
+ pred = head.simple_test(fake_features)
+ assert isinstance(pred, list) and len(pred) == 4
+
+ with patch('torch.onnx.is_in_onnx_export', return_value=True):
+ head = DeiTClsHead(10, 100, hidden_dim=20)
+ pred = head.simple_test(fake_features)
+ assert pred.shape == (4, 10)
+
+ # test assertion
+ with pytest.raises(ValueError):
+ DeiTClsHead(-1, 100)