diff --git a/configs/_base_/datasets/imagenet_bs64_pil_resize_autoaug.py b/configs/_base_/datasets/imagenet_bs64_pil_resize_autoaug.py index f9c50267af7..2a9a4de805b 100644 --- a/configs/_base_/datasets/imagenet_bs64_pil_resize_autoaug.py +++ b/configs/_base_/datasets/imagenet_bs64_pil_resize_autoaug.py @@ -8,7 +8,11 @@ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) train_pipeline = [ dict(type='LoadImageFromFile'), - dict(type='RandomResizedCrop', size=224, backend='pillow'), + dict( + type='RandomResizedCrop', + size=224, + backend='pillow', + interpolation='bicubic'), dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'), dict(type='AutoAugment', policies={{_base_.policy_imagenet}}), dict(type='Normalize', **img_norm_cfg), @@ -18,7 +22,11 @@ ] test_pipeline = [ dict(type='LoadImageFromFile'), - dict(type='Resize', size=(256, -1), backend='pillow'), + dict( + type='Resize', + size=(256, -1), + backend='pillow', + interpolation='bicubic'), dict(type='CenterCrop', crop_size=224), dict(type='Normalize', **img_norm_cfg), dict(type='ImageToTensor', keys=['img']), diff --git a/configs/_base_/schedules/imagenet_bs4096_AdamW.py b/configs/_base_/schedules/imagenet_bs4096_AdamW.py index 859cf4b23aa..75b00d80430 100644 --- a/configs/_base_/schedules/imagenet_bs4096_AdamW.py +++ b/configs/_base_/schedules/imagenet_bs4096_AdamW.py @@ -1,18 +1,24 @@ +# specific to vit pretrain +paramwise_cfg = dict(custom_keys={ + '.cls_token': dict(decay_mult=0.0), + '.pos_embed': dict(decay_mult=0.0) +}) + # optimizer -optimizer = dict(type='AdamW', lr=0.003, weight_decay=0.3) +optimizer = dict( + type='AdamW', + lr=0.003, + weight_decay=0.3, + paramwise_cfg=paramwise_cfg, +) optimizer_config = dict(grad_clip=dict(max_norm=1.0)) -# specific to vit pretrain -paramwise_cfg = dict( - custom_keys={ - '.backbone.cls_token': dict(decay_mult=0.0), - '.backbone.pos_embed': dict(decay_mult=0.0) - }) # learning policy lr_config = dict( policy='CosineAnnealing', min_lr=0, warmup='linear', warmup_iters=10000, - warmup_ratio=1e-4) + warmup_ratio=1e-4, +) runner = dict(type='EpochBasedRunner', max_epochs=300) diff --git a/configs/deit/README.md b/configs/deit/README.md new file mode 100644 index 00000000000..ba496b5734f --- /dev/null +++ b/configs/deit/README.md @@ -0,0 +1,61 @@ +# Training data-efficient image transformers & distillation through attention + + + +## Abstract + + +Recently, neural networks purely based on attention were shown to address image understanding tasks such as image classification. However, these visual transformers are pre-trained with hundreds of millions of images using an expensive infrastructure, thereby limiting their adoption. In this work, we produce a competitive convolution-free transformer by training on Imagenet only. We train them on a single computer in less than 3 days. Our reference vision transformer (86M parameters) achieves top-1 accuracy of 83.1% (single-crop evaluation) on ImageNet with no external data. More importantly, we introduce a teacher-student strategy specific to transformers. It relies on a distillation token ensuring that the student learns from the teacher through attention. We show the interest of this token-based distillation, especially when using a convnet as a teacher. This leads us to report results competitive with convnets for both Imagenet (where we obtain up to 85.2% accuracy) and when transferring to other tasks. We share our code and models. + + +
+ +
+ +## Citation +```{latex} +@InProceedings{pmlr-v139-touvron21a, + title = {Training data-efficient image transformers & distillation through attention}, + author = {Touvron, Hugo and Cord, Matthieu and Douze, Matthijs and Massa, Francisco and Sablayrolles, Alexandre and Jegou, Herve}, + booktitle = {International Conference on Machine Learning}, + pages = {10347--10357}, + year = {2021}, + volume = {139}, + month = {July} +} +``` + +## Pretrained models + +The pre-trained models are converted from the [official repo](https://github.com/facebookresearch/deit). And the teacher of the distilled version DeiT is RegNetY-16GF. + +### ImageNet-1k + +| Model | Params(M) | Flops(G) | Top-1 (%) | Top-5 (%) | Config | Download | +|:---------------------:|:---------:|:--------:|:---------:|:---------:|:------:|:--------:| +| DeiT-tiny\* | 5.72 | 1.08 | 72.13 | 91.13 | [config](configs/deit/deit-tiny_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-tiny_3rdparty_pt-4xb256_in1k_20211124-e930093b.pth) | +| DeiT-tiny distilled\* | 5.72 | 1.08 | 74.51 | 91.90 | [config](configs/deit/deit-tiny-distilled_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-tiny-distilled_3rdparty_pt-4xb256_in1k_20211124-e71bdd9a.pth) | +| DeiT-small\* | 22.05 | 4.24 | 79.83 | 94.95 | [config](configs/deit/deit-small_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-small_3rdparty_pt-4xb256_in1k_20211124-ffe94edd.pth) | +| DeiT-small distilled\* | 22.05 | 4.24 | 81.17 | 95.40 | [config](configs/deit/deit-small-distilled_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-small-distilled_3rdparty_pt-4xb256_in1k_20211124-15e341b0.pth) | +| DeiT-base\* | 86.57 | 16.86 | 81.79 | 95.59 | [config](configs/deit/deit-base_pt-16xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-base_3rdparty_pt-16xb64_in1k_20211124-6f40c188.pth) | +| DeiT-base distilled\* | 86.57 | 16.86 | 83.33 | 96.49 | [config](configs/deit/deit-base-distilled_pt-16xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-base-distilled_3rdparty_pt-16xb64_in1k_20211124-766d123d.pth) | + +*Models with \* are converted from other repos.* + +## Fine-tuned models + +The fine-tuned models are converted from the [official repo](https://github.com/facebookresearch/deit). + +### ImageNet-1k + +| Model | Params(M) | Flops(G) | Top-1 (%) | Top-5 (%) | Config | Download | +|:---------------------:|:---------:|:--------:|:---------:|:---------:|:------:|:--------:| +| DeiT-base 384px\* | 86.86 | 49.37 | 83.04 | 96.31 | [config](configs/deit/deit-base_ft-16xb32_in1k-384px.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-base_3rdparty_ft-16xb32_in1k-384px_20211124-822d02f2.pth) | +| DeiT-base distilled 384px\* | 86.86 | 49.37 | 85.55 | 97.35 | [config](configs/deit/deit-base-distilled_ft-16xb32_in1k-384px.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-base-distilled_3rdparty_ft-16xb32_in1k-384px_20211124-91e88933.pth) | + +*Models with \* are converted from other repos.* + +```{warning} +MMClassification doesn't support training the distilled version DeiT. +And we provide distilled version checkpoints for inference only. +``` diff --git a/configs/deit/deit-base-distilled_ft-16xb32_in1k-384px.py b/configs/deit/deit-base-distilled_ft-16xb32_in1k-384px.py new file mode 100644 index 00000000000..c8bdfb537bd --- /dev/null +++ b/configs/deit/deit-base-distilled_ft-16xb32_in1k-384px.py @@ -0,0 +1,9 @@ +_base_ = './deit-base_ft-16xb32_in1k-384px.py' + +# model settings +model = dict( + backbone=dict(type='DistilledVisionTransformer'), + head=dict(type='DeiTClsHead'), + # Change to the path of the pretrained model + # init_cfg=dict(type='Pretrained', checkpoint=''), +) diff --git a/configs/deit/deit-base-distilled_pt-16xb64_in1k.py b/configs/deit/deit-base-distilled_pt-16xb64_in1k.py new file mode 100644 index 00000000000..671658383ae --- /dev/null +++ b/configs/deit/deit-base-distilled_pt-16xb64_in1k.py @@ -0,0 +1,10 @@ +_base_ = './deit-small_pt-4xb256_in1k.py' + +# model settings +model = dict( + backbone=dict(type='DistilledVisionTransformer', arch='deit-base'), + head=dict(type='DeiTClsHead', in_channels=768), +) + +# data settings +data = dict(samples_per_gpu=64, workers_per_gpu=5) diff --git a/configs/deit/deit-base_ft-16xb32_in1k-384px.py b/configs/deit/deit-base_ft-16xb32_in1k-384px.py new file mode 100644 index 00000000000..db444168d43 --- /dev/null +++ b/configs/deit/deit-base_ft-16xb32_in1k-384px.py @@ -0,0 +1,29 @@ +_base_ = [ + '../_base_/datasets/imagenet_bs64_swin_384.py', + '../_base_/schedules/imagenet_bs4096_AdamW.py', + '../_base_/default_runtime.py' +] + +# model settings +model = dict( + type='ImageClassifier', + backbone=dict( + type='VisionTransformer', + arch='deit-base', + img_size=384, + patch_size=16, + ), + neck=None, + head=dict( + type='VisionTransformerClsHead', + num_classes=1000, + in_channels=768, + loss=dict( + type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'), + ), + # Change to the path of the pretrained model + # init_cfg=dict(type='Pretrained', checkpoint=''), +) + +# data settings +data = dict(samples_per_gpu=32, workers_per_gpu=5) diff --git a/configs/deit/deit-base_pt-16xb64_in1k.py b/configs/deit/deit-base_pt-16xb64_in1k.py new file mode 100644 index 00000000000..818dc3f1cf2 --- /dev/null +++ b/configs/deit/deit-base_pt-16xb64_in1k.py @@ -0,0 +1,10 @@ +_base_ = './deit-small_pt-4xb256_in1k.py' + +# model settings +model = dict( + backbone=dict(type='VisionTransformer', arch='deit-base'), + head=dict(type='VisionTransformerClsHead', in_channels=768), +) + +# data settings +data = dict(samples_per_gpu=64, workers_per_gpu=5) diff --git a/configs/deit/deit-small-distilled_pt-4xb256_in1k.py b/configs/deit/deit-small-distilled_pt-4xb256_in1k.py new file mode 100644 index 00000000000..3b1fac22490 --- /dev/null +++ b/configs/deit/deit-small-distilled_pt-4xb256_in1k.py @@ -0,0 +1,7 @@ +_base_ = './deit-small_pt-4xb256_in1k.py' + +# model settings +model = dict( + backbone=dict(type='DistilledVisionTransformer', arch='deit-small'), + head=dict(type='DeiTClsHead', in_channels=384), +) diff --git a/configs/deit/deit-small_pt-4xb256_in1k.py b/configs/deit/deit-small_pt-4xb256_in1k.py new file mode 100644 index 00000000000..fad7a365d90 --- /dev/null +++ b/configs/deit/deit-small_pt-4xb256_in1k.py @@ -0,0 +1,29 @@ +_base_ = [ + '../_base_/datasets/imagenet_bs64_pil_resize_autoaug.py', + '../_base_/schedules/imagenet_bs4096_AdamW.py', + '../_base_/default_runtime.py' +] + +# model settings +model = dict( + type='ImageClassifier', + backbone=dict( + type='VisionTransformer', + arch='deit-small', + img_size=224, + patch_size=16), + neck=None, + head=dict( + type='VisionTransformerClsHead', + num_classes=1000, + in_channels=384, + loss=dict( + type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'), + ), + init_cfg=[ + dict(type='TruncNormal', layer='Linear', std=.02), + dict(type='Constant', layer='LayerNorm', val=1., bias=0.), + ]) + +# data settings +data = dict(samples_per_gpu=256, workers_per_gpu=5) diff --git a/configs/deit/deit-tiny-distilled_pt-4xb256_in1k.py b/configs/deit/deit-tiny-distilled_pt-4xb256_in1k.py new file mode 100644 index 00000000000..175f980445d --- /dev/null +++ b/configs/deit/deit-tiny-distilled_pt-4xb256_in1k.py @@ -0,0 +1,7 @@ +_base_ = './deit-small_pt-4xb256_in1k.py' + +# model settings +model = dict( + backbone=dict(type='DistilledVisionTransformer', arch='deit-tiny'), + head=dict(type='DeiTClsHead', in_channels=192), +) diff --git a/configs/deit/deit-tiny_pt-4xb256_in1k.py b/configs/deit/deit-tiny_pt-4xb256_in1k.py new file mode 100644 index 00000000000..43df6e13823 --- /dev/null +++ b/configs/deit/deit-tiny_pt-4xb256_in1k.py @@ -0,0 +1,7 @@ +_base_ = './deit-small_pt-4xb256_in1k.py' + +# model settings +model = dict( + backbone=dict(type='VisionTransformer', arch='deit-tiny'), + head=dict(type='VisionTransformerClsHead', in_channels=192), +) diff --git a/configs/deit/metafile.yml b/configs/deit/metafile.yml new file mode 100644 index 00000000000..9f475eaa067 --- /dev/null +++ b/configs/deit/metafile.yml @@ -0,0 +1,143 @@ +Collections: + - Name: DeiT + Metadata: + Training Data: ImageNet-1k + Architecture: + - Layer Normalization + - Scaled Dot-Product Attention + - Attention Dropout + - Multi-Head Attention + Paper: + URL: https://arxiv.org/abs/2012.12877 + Title: "Training data-efficient image transformers & distillation through attention" + README: configs/deit/README.md + +Models: + - Name: deit-tiny_3rdparty_pt-4xb256_in1k + Metadata: + FLOPs: 1080000000 + Parameters: 5720000 + In Collection: DeiT + Results: + - Dataset: ImageNet-1k + Metrics: + Top 1 Accuracy: 72.13 + Top 5 Accuracy: 91.13 + Task: Image Classification + Weights: https://download.openmmlab.com/mmclassification/v0/deit/deit-tiny_3rdparty_pt-4xb256_in1k_20211124-e930093b.pth + Converted From: + Weights: https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth + Code: https://github.com/facebookresearch/deit/blob/f5123946205daf72a88783dae94cabff98c49c55/models.py#L63 + Config: configs/deit/deit-tiny_pt-4xb256_in1k.py + - Name: deit-tiny-distilled_3rdparty_pt-4xb256_in1k + Metadata: + FLOPs: 1080000000 + Parameters: 5720000 + In Collection: DeiT + Results: + - Dataset: ImageNet-1k + Metrics: + Top 1 Accuracy: 74.51 + Top 5 Accuracy: 91.90 + Task: Image Classification + Weights: https://download.openmmlab.com/mmclassification/v0/deit/deit-tiny-distilled_3rdparty_pt-4xb256_in1k_20211124-e71bdd9a.pth + Converted From: + Weights: https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth + Code: https://github.com/facebookresearch/deit/blob/f5123946205daf72a88783dae94cabff98c49c55/models.py#L108 + Config: configs/deit/deit-tiny-distilled_pt-4xb256_in1k.py + - Name: deit-small_3rdparty_pt-4xb256_in1k + Metadata: + FLOPs: 4240000000 + Parameters: 22050000 + In Collection: DeiT + Results: + - Dataset: ImageNet-1k + Metrics: + Top 1 Accuracy: 79.83 + Top 5 Accuracy: 94.95 + Task: Image Classification + Weights: https://download.openmmlab.com/mmclassification/v0/deit/deit-small_3rdparty_pt-4xb256_in1k_20211124-ffe94edd.pth + Converted From: + Weights: https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth + Code: https://github.com/facebookresearch/deit/blob/f5123946205daf72a88783dae94cabff98c49c55/models.py#L78 + Config: configs/deit/deit-small_pt-4xb256_in1k.py + - Name: deit-small-distilled_3rdparty_pt-4xb256_in1k + Metadata: + FLOPs: 4240000000 + Parameters: 22050000 + In Collection: DeiT + Results: + - Dataset: ImageNet-1k + Metrics: + Top 1 Accuracy: 81.17 + Top 5 Accuracy: 95.40 + Task: Image Classification + Weights: https://download.openmmlab.com/mmclassification/v0/deit/deit-small-distilled_3rdparty_pt-4xb256_in1k_20211124-15e341b0.pth + Converted From: + Weights: https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth + Code: https://github.com/facebookresearch/deit/blob/f5123946205daf72a88783dae94cabff98c49c55/models.py#L123 + Config: configs/deit/deit-small-distilled_pt-4xb256_in1k.py + - Name: deit-base_3rdparty_pt-16xb64_in1k + Metadata: + FLOPs: 16860000000 + Parameters: 86570000 + In Collection: DeiT + Results: + - Dataset: ImageNet-1k + Metrics: + Top 1 Accuracy: 81.79 + Top 5 Accuracy: 95.59 + Task: Image Classification + Weights: https://download.openmmlab.com/mmclassification/v0/deit/deit-base_3rdparty_pt-16xb64_in1k_20211124-6f40c188.pth + Converted From: + Weights: https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth + Code: https://github.com/facebookresearch/deit/blob/f5123946205daf72a88783dae94cabff98c49c55/models.py#L93 + Config: configs/deit/deit-base_pt-16xb64_in1k.py + - Name: deit-base-distilled_3rdparty_pt-16xb64_in1k + Metadata: + FLOPs: 16860000000 + Parameters: 86570000 + In Collection: DeiT + Results: + - Dataset: ImageNet-1k + Metrics: + Top 1 Accuracy: 83.33 + Top 5 Accuracy: 96.49 + Task: Image Classification + Weights: https://download.openmmlab.com/mmclassification/v0/deit/deit-base-distilled_3rdparty_pt-16xb64_in1k_20211124-766d123d.pth + Converted From: + Weights: https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth + Code: https://github.com/facebookresearch/deit/blob/f5123946205daf72a88783dae94cabff98c49c55/models.py#L138 + Config: configs/deit/deit-base-distilled_pt-16xb64_in1k.py + - Name: deit-base_3rdparty_ft-16xb32_in1k-384px + Metadata: + FLOPs: 49370000000 + Parameters: 86860000 + In Collection: DeiT + Results: + - Dataset: ImageNet-1k + Metrics: + Top 1 Accuracy: 83.04 + Top 5 Accuracy: 96.31 + Task: Image Classification + Weights: https://download.openmmlab.com/mmclassification/v0/deit/deit-base_3rdparty_ft-16xb32_in1k-384px_20211124-822d02f2.pth + Converted From: + Weights: https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth + Code: https://github.com/facebookresearch/deit/blob/f5123946205daf72a88783dae94cabff98c49c55/models.py#L153 + Config: configs/deit/deit-base_ft-16xb32_in1k-384px.py + - Name: deit-base-distilled_3rdparty_ft-16xb32_in1k-384px + Metadata: + FLOPs: 49370000000 + Parameters: 86860000 + In Collection: DeiT + Results: + - Dataset: ImageNet-1k + Metrics: + Top 1 Accuracy: 85.55 + Top 5 Accuracy: 97.35 + Task: Image Classification + Weights: https://download.openmmlab.com/mmclassification/v0/deit/deit-base-distilled_3rdparty_ft-16xb32_in1k-384px_20211124-91e88933.pth + Converted From: + Weights: https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth + Code: https://github.com/facebookresearch/deit/blob/f5123946205daf72a88783dae94cabff98c49c55/models.py#L168 + Config: configs/deit/deit-base-distilled_ft-16xb32_in1k-384px.py diff --git a/docs/en/model_zoo.md b/docs/en/model_zoo.md index 7f378e0d814..918f5e4b0f4 100644 --- a/docs/en/model_zoo.md +++ b/docs/en/model_zoo.md @@ -63,12 +63,17 @@ The ResNet family models below are trained by standard data augmentations, i.e., | T2T-ViT_t-24\* | 64.00 | 12.69 | 82.55 | 96.06 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/t2t_vit/t2t-vit-t-24_8xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/t2t-vit/t2t-vit-t-24_3rdparty_8xb64_in1k_20210928-fe95a61b.pth) | [log]()| | Mixer-B/16\* | 59.88 | 12.61 | 76.68 | 92.25 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/mlp_mixer/mlp-mixer-base-p16_64xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/mlp-mixer/mixer-base-p16_3rdparty_64xb64_in1k_20211124-1377e3e0.pth) | [log]()| | Mixer-L/16\* | 208.2 | 44.57 | 72.34 | 88.02 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/mlp_mixer/mlp-mixer-large-p16_64xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/mlp-mixer/mixer-large-p16_3rdparty_64xb64_in1k_20211124-5a2519d2.pth) | [log]()| +| DeiT-tiny\* | 5.72 | 1.08 | 72.13 | 91.13 | [config](configs/deit/deit-tiny_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-tiny_3rdparty_pt-4xb256_in1k_20211124-e930093b.pth) | [log]()| +| DeiT-tiny distilled\* | 5.72 | 1.08 | 74.51 | 91.90 | [config](configs/deit/deit-tiny-distilled_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-tiny-distilled_3rdparty_pt-4xb256_in1k_20211124-e71bdd9a.pth) | [log]()| +| DeiT-small\* | 22.05 | 4.24 | 79.83 | 94.95 | [config](configs/deit/deit-small_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-small_3rdparty_pt-4xb256_in1k_20211124-ffe94edd.pth) | [log]()| +| DeiT-small distilled\* | 22.05 | 4.24 | 81.17 | 95.40 | [config](configs/deit/deit-small-distilled_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-small-distilled_3rdparty_pt-4xb256_in1k_20211124-15e341b0.pth) | [log]()| +| DeiT-base\* | 86.57 | 16.86 | 81.79 | 95.59 | [config](configs/deit/deit-base_pt-16xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-base_3rdparty_pt-16xb64_in1k_20211124-6f40c188.pth) | [log]()| +| DeiT-base distilled\* | 86.57 | 16.86 | 83.33 | 96.49 | [config](configs/deit/deit-base-distilled_pt-16xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-base-distilled_3rdparty_pt-16xb64_in1k_20211124-766d123d.pth) | [log]()| | Conformer-tiny-p16\* | 23.52 | 4.90 | 81.31 | 95.60 | [config](configs/conformer/conformer-tiny-p16_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/conformer/conformer-tiny-p16_3rdparty_8xb128_in1k_20211206-f6860372.pth) | [log]()| | Conformer-small-p32 | 38.85 | 7.09 | 81.96 | 96.02 | [config](configs/conformer/conformer-small-p32_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/conformer/conformer-small-p32_8xb128_in1k_20211206-947a0816.pth) | [log]()| | Conformer-small-p16\* | 37.67 | 10.31 | 83.32 | 96.46 | [config](configs/conformer/conformer-small-p16_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/conformer/conformer-small-p16_3rdparty_8xb128_in1k_20211206-3065dcf5.pth) | [log]()| | Conformer-base-p16\* | 83.29 | 22.89 | 83.82 | 96.59 | [config](configs/conformer/conformer-base-p16_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/conformer/conformer-base-p16_3rdparty_8xb128_in1k_20211206-bfdf8637.pth) | [log]()| - Models with * are converted from other repos, others are trained by ourselves. ## CIFAR10 diff --git a/mmcls/models/backbones/__init__.py b/mmcls/models/backbones/__init__.py index f9dbf705486..faa7927f377 100644 --- a/mmcls/models/backbones/__init__.py +++ b/mmcls/models/backbones/__init__.py @@ -1,6 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from .alexnet import AlexNet from .conformer import Conformer +from .deit import DistilledVisionTransformer from .lenet import LeNet5 from .mlp_mixer import MlpMixer from .mobilenet_v2 import MobileNetV2 @@ -28,5 +29,5 @@ 'ResNeSt', 'ResNet_CIFAR', 'SEResNet', 'SEResNeXt', 'ShuffleNetV1', 'ShuffleNetV2', 'MobileNetV2', 'MobileNetV3', 'VisionTransformer', 'SwinTransformer', 'TNT', 'TIMMBackbone', 'T2T_ViT', 'Res2Net', 'RepVGG', - 'Conformer', 'MlpMixer' + 'Conformer', 'MlpMixer', 'DistilledVisionTransformer' ] diff --git a/mmcls/models/backbones/deit.py b/mmcls/models/backbones/deit.py new file mode 100644 index 00000000000..37851798dde --- /dev/null +++ b/mmcls/models/backbones/deit.py @@ -0,0 +1,85 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +from mmcv.cnn.utils.weight_init import trunc_normal_ + +from ..builder import BACKBONES +from .vision_transformer import VisionTransformer + + +@BACKBONES.register_module() +class DistilledVisionTransformer(VisionTransformer): + """Distilled Vision Transformer. + + A PyTorch implement of : `Training data-efficient image transformers & + distillation through attention `_ + + Args: + arch (str | dict): Vision Transformer architecture + Default: 'b' + img_size (int | tuple): Input image size + patch_size (int | tuple): The patch size + out_indices (Sequence | int): Output from which stages. + Defaults to -1, means the last stage. + drop_rate (float): Probability of an element to be zeroed. + Defaults to 0. + drop_path_rate (float): stochastic depth rate. Defaults to 0. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + final_norm (bool): Whether to add a additional layer to normalize + final feature map. Defaults to True. + output_cls_token (bool): Whether output the cls_token. If set True, + `with_cls_token` must be True. Defaults to True. + interpolate_mode (str): Select the interpolate mode for position + embeding vector resize. Defaults to "bicubic". + patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict. + layer_cfgs (Sequence | dict): Configs of each transformer layer in + encoder. Defaults to an empty dict. + init_cfg (dict, optional): Initialization config dict. + Defaults to None. + """ + num_extra_tokens = 2 # cls_token, dist_token + + def __init__(self, *args, **kwargs): + super(DistilledVisionTransformer, self).__init__(*args, **kwargs) + self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims)) + + def forward(self, x): + B = x.shape[0] + x = self.patch_embed(x) + patch_resolution = self.patch_embed.patches_resolution + + # stole cls_tokens impl from Phil Wang, thanks + cls_tokens = self.cls_token.expand(B, -1, -1) + dist_token = self.dist_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, dist_token, x), dim=1) + x = x + self.pos_embed + x = self.drop_after_pos(x) + + outs = [] + for i, layer in enumerate(self.layers): + x = layer(x) + + if i == len(self.layers) - 1 and self.final_norm: + x = self.norm1(x) + + if i in self.out_indices: + B, _, C = x.shape + patch_token = x[:, 2:].reshape(B, *patch_resolution, C) + patch_token = patch_token.permute(0, 3, 1, 2) + cls_token = x[:, 0] + dist_token = x[:, 1] + if self.output_cls_token: + out = [patch_token, cls_token, dist_token] + else: + out = patch_token + outs.append(out) + + return tuple(outs) + + def init_weights(self): + super(DistilledVisionTransformer, self).init_weights() + + if not (isinstance(self.init_cfg, dict) + and self.init_cfg['type'] == 'Pretrained'): + trunc_normal_(self.dist_token, std=0.02) diff --git a/mmcls/models/backbones/vision_transformer.py b/mmcls/models/backbones/vision_transformer.py index 4da7b185636..fc0f57b5802 100644 --- a/mmcls/models/backbones/vision_transformer.py +++ b/mmcls/models/backbones/vision_transformer.py @@ -1,5 +1,4 @@ # Copyright (c) OpenMMLab. All rights reserved. -from copy import deepcopy from typing import Sequence import numpy as np @@ -8,6 +7,7 @@ import torch.nn.functional as F from mmcv.cnn import build_norm_layer from mmcv.cnn.bricks.transformer import FFN +from mmcv.cnn.utils.weight_init import trunc_normal_ from mmcv.runner.base_module import BaseModule, ModuleList from mmcls.utils import get_root_logger @@ -104,9 +104,8 @@ def forward(self, x): class VisionTransformer(BaseBackbone): """Vision Transformer. - A PyTorch implement of : `An Image is Worth 16x16 Words: - Transformers for Image Recognition at Scale - `_ + A PyTorch implement of : `An Image is Worth 16x16 Words: Transformers + for Image Recognition at Scale `_ Args: arch (str | dict): Vision Transformer architecture @@ -155,7 +154,30 @@ class VisionTransformer(BaseBackbone): 'num_heads': 16, 'feedforward_channels': 4096 }), + **dict.fromkeys( + ['deit-t', 'deit-tiny'], { + 'embed_dims': 192, + 'num_layers': 12, + 'num_heads': 3, + 'feedforward_channels': 192 * 4 + }), + **dict.fromkeys( + ['deit-s', 'deit-small'], { + 'embed_dims': 384, + 'num_layers': 12, + 'num_heads': 6, + 'feedforward_channels': 384 * 4 + }), + **dict.fromkeys( + ['deit-b', 'deit-base'], { + 'embed_dims': 768, + 'num_layers': 12, + 'num_heads': 12, + 'feedforward_channels': 768 * 4 + }), } + # Some structures have multiple extra tokens, like DeiT. + num_extra_tokens = 1 # cls_token def __init__(self, arch='b', @@ -182,7 +204,7 @@ def __init__(self, essential_keys = { 'embed_dims', 'num_layers', 'num_heads', 'feedforward_channels' } - assert isinstance(arch, dict) and set(arch) == essential_keys, \ + assert isinstance(arch, dict) and essential_keys <= set(arch), \ f'Custom arch needs a dict with keys {essential_keys}' self.arch_settings = arch @@ -208,7 +230,8 @@ def __init__(self, # Set position embedding self.interpolate_mode = interpolate_mode self.pos_embed = nn.Parameter( - torch.zeros(1, num_patches + 1, self.embed_dims)) + torch.zeros(1, num_patches + self.num_extra_tokens, + self.embed_dims)) self.drop_after_pos = nn.Dropout(p=drop_rate) if isinstance(out_indices, int): @@ -247,65 +270,49 @@ def __init__(self, norm_cfg, self.embed_dims, postfix=1) self.add_module(self.norm1_name, norm1) + self._register_load_state_dict_pre_hook(self._prepare_checkpoint_hook) + @property def norm1(self): return getattr(self, self.norm1_name) def init_weights(self): - # Suppress default init if use pretrained model. - # And use custom load_checkpoint function to load checkpoint. - if (isinstance(self.init_cfg, dict) + super(VisionTransformer, self).init_weights() + + if not (isinstance(self.init_cfg, dict) and self.init_cfg['type'] == 'Pretrained'): - init_cfg = deepcopy(self.init_cfg) - init_cfg.pop('type') - self._load_checkpoint(**init_cfg) - else: - super(VisionTransformer, self).init_weights() - # Modified from ClassyVision - nn.init.normal_(self.pos_embed, std=0.02) - - def _load_checkpoint(self, checkpoint, prefix=None, map_location=None): - from mmcv.runner import (_load_checkpoint, - _load_checkpoint_with_prefix, load_state_dict) - from mmcv.utils import print_log - - logger = get_root_logger() - - if prefix is None: - print_log(f'load model from: {checkpoint}', logger=logger) - checkpoint = _load_checkpoint(checkpoint, map_location, logger) - # get state_dict from checkpoint - if 'state_dict' in checkpoint: - state_dict = checkpoint['state_dict'] - else: - state_dict = checkpoint - else: - print_log( - f'load {prefix} in model from: {checkpoint}', logger=logger) - state_dict = _load_checkpoint_with_prefix(prefix, checkpoint, - map_location) + trunc_normal_(self.pos_embed, std=0.02) - if 'pos_embed' in state_dict.keys(): - ckpt_pos_embed_shape = state_dict['pos_embed'].shape - if self.pos_embed.shape != ckpt_pos_embed_shape: - print_log( - f'Resize the pos_embed shape from {ckpt_pos_embed_shape} ' - f'to {self.pos_embed.shape}.', - logger=logger) + def _prepare_checkpoint_hook(self, state_dict, prefix, *args, **kwargs): + name = prefix + 'pos_embed' + if name not in state_dict.keys(): + return - ckpt_pos_embed_shape = to_2tuple( - int(np.sqrt(ckpt_pos_embed_shape[1] - 1))) - pos_embed_shape = self.patch_embed.patches_resolution + ckpt_pos_embed_shape = state_dict[name].shape + if self.pos_embed.shape != ckpt_pos_embed_shape: + from mmcv.utils import print_log + logger = get_root_logger() + print_log( + f'Resize the pos_embed shape from {ckpt_pos_embed_shape} ' + f'to {self.pos_embed.shape}.', + logger=logger) - state_dict['pos_embed'] = self.resize_pos_embed( - state_dict['pos_embed'], ckpt_pos_embed_shape, - pos_embed_shape, self.interpolate_mode) + ckpt_pos_embed_shape = to_2tuple( + int(np.sqrt(ckpt_pos_embed_shape[1] - self.num_extra_tokens))) + pos_embed_shape = self.patch_embed.patches_resolution - # load state_dict - load_state_dict(self, state_dict, strict=False, logger=logger) + state_dict[name] = self.resize_pos_embed(state_dict[name], + ckpt_pos_embed_shape, + pos_embed_shape, + self.interpolate_mode, + self.num_extra_tokens) @staticmethod - def resize_pos_embed(pos_embed, src_shape, dst_shape, mode='bicubic'): + def resize_pos_embed(pos_embed, + src_shape, + dst_shape, + mode='bicubic', + num_extra_tokens=1): """Resize pos_embed weights. Args: @@ -324,17 +331,17 @@ def resize_pos_embed(pos_embed, src_shape, dst_shape, mode='bicubic'): assert pos_embed.ndim == 3, 'shape of pos_embed must be [1, L, C]' _, L, C = pos_embed.shape src_h, src_w = src_shape - assert L == src_h * src_w + 1 - cls_token = pos_embed[:, :1] + assert L == src_h * src_w + num_extra_tokens + extra_tokens = pos_embed[:, :num_extra_tokens] - src_weight = pos_embed[:, 1:] + src_weight = pos_embed[:, num_extra_tokens:] src_weight = src_weight.reshape(1, src_h, src_w, C).permute(0, 3, 1, 2) dst_weight = F.interpolate( src_weight, size=dst_shape, align_corners=False, mode=mode) dst_weight = torch.flatten(dst_weight, 2).transpose(1, 2) - return torch.cat((cls_token, dst_weight), dim=1) + return torch.cat((extra_tokens, dst_weight), dim=1) def forward(self, x): B = x.shape[0] diff --git a/mmcls/models/heads/__init__.py b/mmcls/models/heads/__init__.py index 4be4daf8722..b81106fbe96 100644 --- a/mmcls/models/heads/__init__.py +++ b/mmcls/models/heads/__init__.py @@ -1,6 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from .cls_head import ClsHead from .conformer_head import ConformerHead +from .deit_head import DeiTClsHead from .linear_head import LinearClsHead from .multi_label_head import MultiLabelClsHead from .multi_label_linear_head import MultiLabelLinearClsHead @@ -9,5 +10,6 @@ __all__ = [ 'ClsHead', 'LinearClsHead', 'StackedLinearClsHead', 'MultiLabelClsHead', - 'MultiLabelLinearClsHead', 'VisionTransformerClsHead', 'ConformerHead' + 'MultiLabelLinearClsHead', 'VisionTransformerClsHead', 'DeiTClsHead', + 'ConformerHead' ] diff --git a/mmcls/models/heads/deit_head.py b/mmcls/models/heads/deit_head.py new file mode 100644 index 00000000000..4a79e455587 --- /dev/null +++ b/mmcls/models/heads/deit_head.py @@ -0,0 +1,36 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn as nn +import torch.nn.functional as F + +from mmcls.utils import get_root_logger +from ..builder import HEADS +from .vision_transformer_head import VisionTransformerClsHead + + +@HEADS.register_module() +class DeiTClsHead(VisionTransformerClsHead): + + def __init__(self, *args, **kwargs): + super(DeiTClsHead, self).__init__(*args, **kwargs) + self.head_dist = nn.Linear(self.in_channels, self.num_classes) + + def simple_test(self, x): + """Test without augmentation.""" + x = x[-1] + assert isinstance(x, list) and len(x) == 3 + _, cls_token, dist_token = x + cls_score = (self.layers(cls_token) + self.head_dist(dist_token)) / 2 + pred = F.softmax(cls_score, dim=1) if cls_score is not None else None + + return self.post_process(pred) + + def forward_train(self, x, gt_label): + logger = get_root_logger() + logger.warning("MMClassification doesn't support to train the " + 'distilled version DeiT.') + x = x[-1] + assert isinstance(x, list) and len(x) == 3 + _, cls_token, dist_token = x + cls_score = (self.layers(cls_token) + self.head_dist(dist_token)) / 2 + losses = self.loss(cls_score, gt_label) + return losses diff --git a/model-index.yml b/model-index.yml index f8b8b4d5567..a17a5e08cb4 100644 --- a/model-index.yml +++ b/model-index.yml @@ -14,3 +14,4 @@ Import: - configs/t2t_vit/metafile.yml - configs/mlp_mixer/metafile.yml - configs/conformer/metafile.yml + - configs/deit/metafile.yml diff --git a/tests/test_models/test_backbones/test_deit.py b/tests/test_models/test_backbones/test_deit.py new file mode 100644 index 00000000000..af9efe9c0b8 --- /dev/null +++ b/tests/test_models/test_backbones/test_deit.py @@ -0,0 +1,43 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +import torch +from torch.nn.modules.batchnorm import _BatchNorm + +from mmcls.models.backbones import DistilledVisionTransformer + + +def check_norm_state(modules, train_state): + """Check if norm layer is in correct train state.""" + for mod in modules: + if isinstance(mod, _BatchNorm): + if mod.training != train_state: + return False + return True + + +def test_deit_backbone(): + cfg_ori = dict(arch='deit-b', img_size=224, patch_size=16) + + # Test structure + model = DistilledVisionTransformer(**cfg_ori) + model.init_weights() + model.train() + + assert check_norm_state(model.modules(), True) + assert model.dist_token.shape == (1, 1, 768) + assert model.pos_embed.shape == (1, model.patch_embed.num_patches + 2, 768) + + # Test forward + imgs = torch.rand(1, 3, 224, 224) + outs = model(imgs) + patch_token, cls_token, dist_token = outs[0] + assert patch_token.shape == (1, 768, 14, 14) + assert cls_token.shape == (1, 768) + assert dist_token.shape == (1, 768) + + # Test multiple out_indices + model = DistilledVisionTransformer( + **cfg_ori, out_indices=(0, 1, 2, 3), output_cls_token=False) + outs = model(imgs) + for out in outs: + assert out.shape == (1, 768, 14, 14) diff --git a/tests/test_models/test_heads.py b/tests/test_models/test_heads.py index 4d2105698f5..c56d509ea06 100644 --- a/tests/test_models/test_heads.py +++ b/tests/test_models/test_heads.py @@ -4,9 +4,9 @@ import pytest import torch -from mmcls.models.heads import (ClsHead, LinearClsHead, MultiLabelClsHead, - MultiLabelLinearClsHead, StackedLinearClsHead, - VisionTransformerClsHead) +from mmcls.models.heads import (ClsHead, DeiTClsHead, LinearClsHead, + MultiLabelClsHead, MultiLabelLinearClsHead, + StackedLinearClsHead, VisionTransformerClsHead) @pytest.mark.parametrize('feat', [torch.rand(4, 3), (torch.rand(4, 3), )]) @@ -157,3 +157,44 @@ def test_vit_head(): # test assertion with pytest.raises(ValueError): VisionTransformerClsHead(-1, 100) + + +def test_deit_head(): + fake_features = ([ + torch.rand(4, 7, 7, 16), + torch.rand(4, 100), + torch.rand(4, 100) + ], ) + fake_gt_label = torch.randint(0, 10, (4, )) + + # test deit head forward + head = DeiTClsHead(num_classes=10, in_channels=100) + losses = head.forward_train(fake_features, fake_gt_label) + assert not hasattr(head.layers, 'pre_logits') + assert not hasattr(head.layers, 'act') + assert losses['loss'].item() > 0 + + # test deit head forward with hidden layer + head = DeiTClsHead(num_classes=10, in_channels=100, hidden_dim=20) + losses = head.forward_train(fake_features, fake_gt_label) + assert hasattr(head.layers, 'pre_logits') and hasattr(head.layers, 'act') + assert losses['loss'].item() > 0 + + # test deit head init_weights + head = DeiTClsHead(10, 100, hidden_dim=20) + head.init_weights() + assert abs(head.layers.pre_logits.weight).sum() > 0 + + # test simple_test + head = DeiTClsHead(10, 100, hidden_dim=20) + pred = head.simple_test(fake_features) + assert isinstance(pred, list) and len(pred) == 4 + + with patch('torch.onnx.is_in_onnx_export', return_value=True): + head = DeiTClsHead(10, 100, hidden_dim=20) + pred = head.simple_test(fake_features) + assert pred.shape == (4, 10) + + # test assertion + with pytest.raises(ValueError): + DeiTClsHead(-1, 100)