Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add mmcls.VisionTransformer backbone support #1908

Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions otx/algorithms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,5 @@

# Copyright (C) 2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

TRANSFORMER_BACKBONES = ["VisionTransformer", "T2T_ViT", "Conformer"]
2 changes: 1 addition & 1 deletion otx/algorithms/classification/configs/configuration.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ learning_parameters:
stable. A larger batch size has higher memory requirements.
editable: true
header: Batch size
max_value: 512
max_value: 2048
harimkang marked this conversation as resolved.
Show resolved Hide resolved
min_value: 1
type: INTEGER
ui_rules:
Expand Down
2 changes: 1 addition & 1 deletion otx/algorithms/common/configs/training_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ class BaseLearningParameters(ParameterGroup):
batch_size = configurable_integer(
default_value=5,
min_value=1,
max_value=512,
max_value=2048,
header="Batch size",
description="The number of training samples seen in each iteration of training. Increasing thisvalue "
"improves training time and may make the training more stable. A larger batch size has higher "
Expand Down
11 changes: 9 additions & 2 deletions otx/cli/builder/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from mmcv.utils import Registry, build_from_cfg
from torch import nn

from otx.algorithms import TRANSFORMER_BACKBONES
from otx.api.entities.model_template import TaskType
from otx.cli.utils.importing import (
get_backbone_list,
Expand Down Expand Up @@ -101,8 +102,8 @@ def update_backbone_args(backbone_config: dict, registry: Registry, backend: str

def update_channels(model_config: MPAConfig, out_channels: Any):
"""Update in_channel of head or neck."""
if hasattr(model_config.model, "neck"):
if model_config.model.neck.type == "GlobalAveragePooling":
if hasattr(model_config.model, "neck") and model_config.model.neck:
if model_config.model.neck.get("type", None) == "GlobalAveragePooling":
model_config.model.neck.pop("in_channels", None)
else:
print(f"\tUpdate model.neck.in_channels: {out_channels}")
Expand Down Expand Up @@ -212,6 +213,12 @@ def merge_backbone(
out_channels = -1
if hasattr(model_config.model, "head"):
model_config.model.head.in_channels = -1
# TODO: This is a hard coded part of the Transformer backbone and needs to be refactored.
if backend == "mmcls" and backbone_class in TRANSFORMER_BACKBONES:
if hasattr(model_config.model, "neck"):
model_config.model.neck = None
if hasattr(model_config.model, "head"):
model_config.model.head["type"] = "VisionTransformerClsHead"
sungchul2 marked this conversation as resolved.
Show resolved Hide resolved
else:
# Need to update in/out channel configuration here
out_channels = get_backbone_out_channels(backbone)
Expand Down
6 changes: 3 additions & 3 deletions otx/cli/builder/supported_backbone/mmcls.json
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
"options": {
"arch": ["tiny", "small", "base"]
},
"available": []
"available": ["CLASSIFICATION"]
},
"mmcls.ConvMixer": {
"required": ["arch"],
Expand Down Expand Up @@ -287,7 +287,7 @@
"mmcls.T2T_ViT": {
"required": [],
"options": {},
"available": []
"available": ["CLASSIFICATION"]
},
"mmcls.TIMMBackbone": {
"required": ["model_name"],
Expand Down Expand Up @@ -341,7 +341,7 @@
"deit-base"
]
},
"available": []
"available": ["CLASSIFICATION"]
}
}
}
5 changes: 5 additions & 0 deletions otx/mpa/cls/inferrer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from mmcls.datasets import build_dataset as mmcls_build_dataset
from mmcv import Config, ConfigDict

from otx.algorithms import TRANSFORMER_BACKBONES
from otx.algorithms.common.adapters.mmcv.utils import (
build_data_parallel,
build_dataloader,
Expand Down Expand Up @@ -53,6 +54,10 @@ def run(self, model_cfg, model_ckpt, data_cfg, **kwargs):
model_builder = kwargs.get("model_builder", None)
dump_features = kwargs.get("dump_features", False)
dump_saliency_map = kwargs.get("dump_saliency_map", False)
# TODO: It looks like we need to modify that code in an appropriate way.
if model_cfg.model.head.get("type", None) == "VisionTransformerClsHead":
harimkang marked this conversation as resolved.
Show resolved Hide resolved
dump_features = False
dump_saliency_map = False
eval = kwargs.get("eval", False)
outputs = self.infer(
cfg,
Expand Down
8 changes: 8 additions & 0 deletions otx/mpa/cls/stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import torch
from mmcv import ConfigDict, build_from_cfg

from otx.algorithms import TRANSFORMER_BACKBONES
from otx.algorithms.classification.adapters.mmcls.utils.builder import build_classifier
from otx.mpa.stage import Stage
from otx.mpa.utils.config_utils import recursively_update_cfg, update_or_add_custom_hook
Expand Down Expand Up @@ -89,6 +90,13 @@ def configure_in_channel(cfg):
output = layer(torch.rand([1] + list(input_shape)))
if isinstance(output, (tuple, list)):
output = output[-1]

if layer.__class__.__name__ in TRANSFORMER_BACKBONES and isinstance(output, (tuple, list)):
# mmcls.VisionTransformer outputs Tuple[List[...]] and the last index of List is the final logit.
_, output = output
if cfg.model.head.type != "VisionTransformerClsHead":
raise ValueError(f"{layer.__class__.__name__ } needs VisionTransformerClsHead as head")

in_channels = output.shape[1]
if cfg.model.get("neck") is not None:
if cfg.model.neck.get("in_channels") is not None:
Expand Down
1 change: 1 addition & 0 deletions otx/mpa/modules/hooks/recording_forward_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import torch

from otx import MMCLS_AVAILABLE
from otx.algorithms import TRANSFORMER_BACKBONES

if MMCLS_AVAILABLE:
from mmcls.models.necks.gap import GlobalAveragePooling
Expand Down