From 9cbf6a64ba3f670d7ba69e0fef0ed2ce9cea3960 Mon Sep 17 00:00:00 2001 From: Galina Date: Fri, 10 Nov 2023 15:19:54 +0200 Subject: [PATCH] Add ViT feature vector hook --- .../classification/adapters/mmcls/task.py | 22 ++++------- .../mmcv/hooks/recording_forward_hook.py | 12 +++++- .../cli/classification/test_classification.py | 14 ------- .../cli/classification/test_classification.py | 38 ------------------- 4 files changed, 19 insertions(+), 67 deletions(-) diff --git a/src/otx/algorithms/classification/adapters/mmcls/task.py b/src/otx/algorithms/classification/adapters/mmcls/task.py index 55b7ea0dd3a..e42e4d5f628 100644 --- a/src/otx/algorithms/classification/adapters/mmcls/task.py +++ b/src/otx/algorithms/classification/adapters/mmcls/task.py @@ -19,7 +19,6 @@ from mmcv.runner import wrap_fp16_model from mmcv.utils import Config, ConfigDict -from otx.algorithms import TRANSFORMER_BACKBONES from otx.algorithms.classification.adapters.mmcls.utils.exporter import ( ClassificationExporter, ) @@ -31,6 +30,7 @@ EigenCamHook, FeatureVectorHook, ReciproCAMHook, + ViTFeatureVectorHook, ViTReciproCAMHook, ) from otx.algorithms.common.adapters.mmcv.utils import ( @@ -225,7 +225,6 @@ def _infer_model( ) ) - dump_features = True dump_saliency_map = not inference_parameters.is_evaluation if inference_parameters else True self._init_task() @@ -274,16 +273,16 @@ def hook(module, inp, outp): forward_explainer_hook: Union[nullcontext, BaseRecordingForwardHook] if model_type == "VisionTransformer": forward_explainer_hook = ViTReciproCAMHook(feature_model) - elif ( - not dump_saliency_map or model_type in TRANSFORMER_BACKBONES - ): # TODO: remove latter "or" condition after resolving Issue#2098 + elif not dump_saliency_map: forward_explainer_hook = nullcontext() else: forward_explainer_hook = ReciproCAMHook(feature_model) - if ( - not dump_features or model_type in TRANSFORMER_BACKBONES - ): # TODO: remove latter "or" condition after resolving Issue#2098 - feature_vector_hook: Union[nullcontext, BaseRecordingForwardHook] = nullcontext() + + feature_vector_hook: Union[nullcontext, BaseRecordingForwardHook] + if model_type == "VisionTransformer": + feature_vector_hook = ViTFeatureVectorHook(feature_model) + elif not dump_saliency_map: + feature_vector_hook = nullcontext() else: feature_vector_hook = FeatureVectorHook(feature_model) @@ -533,11 +532,6 @@ def _export_model(self, precision: ModelPrecision, export_format: ExportType, du export_options["precision"] = str(precision) export_options["type"] = str(export_format) - # [TODO] Enable dump_features for ViT backbones - model_type = cfg.model.backbone.type.split(".")[-1] # mmcls.VisionTransformer => VisionTransformer - if model_type in TRANSFORMER_BACKBONES: - dump_features = False - export_options["deploy_cfg"]["dump_features"] = dump_features if dump_features: output_names = export_options["deploy_cfg"]["ir_config"]["output_names"] diff --git a/src/otx/algorithms/common/adapters/mmcv/hooks/recording_forward_hook.py b/src/otx/algorithms/common/adapters/mmcv/hooks/recording_forward_hook.py index 062cc230367..f71440f0fef 100644 --- a/src/otx/algorithms/common/adapters/mmcv/hooks/recording_forward_hook.py +++ b/src/otx/algorithms/common/adapters/mmcv/hooks/recording_forward_hook.py @@ -16,7 +16,7 @@ from __future__ import annotations from abc import ABC -from typing import List, Optional, Sequence, Union +from typing import List, Optional, Sequence, Tuple, Union import numpy as np import torch @@ -172,6 +172,16 @@ def func(feature_map: Union[torch.Tensor, Sequence[torch.Tensor]], fpn_idx: int return feature_vector +class ViTFeatureVectorHook(BaseRecordingForwardHook): + """FeatureVectorHook for transformer-based classifiers.""" + + @staticmethod + def func(features: Tuple[List[torch.Tensor]], fpn_idx: int = -1) -> torch.Tensor: + """Generate the feature vector for transformer-based classifiers by returning the cls token.""" + _, cls_token = features[0] + return cls_token + + class ReciproCAMHook(BaseRecordingForwardHook): """Implementation of recipro-cam for class-wise saliency map. diff --git a/tests/e2e/cli/classification/test_classification.py b/tests/e2e/cli/classification/test_classification.py index 3252c596bcd..9f075925119 100644 --- a/tests/e2e/cli/classification/test_classification.py +++ b/tests/e2e/cli/classification/test_classification.py @@ -137,8 +137,6 @@ def test_otx_resume(self, template, tmp_dir_path): @pytest.mark.parametrize("template", templates, ids=templates_ids) @pytest.mark.parametrize("dump_features", [True, False]) def test_otx_export(self, template, tmp_dir_path, dump_features): - if template.name == "DeiT-Tiny" and dump_features: - pytest.skip(reason="Issue#2098 ViT template does not support dump_features.") tmp_dir_path = tmp_dir_path / "multi_class_cls" otx_export_testing(template, tmp_dir_path, dump_features) @@ -160,8 +158,6 @@ def test_otx_eval(self, template, tmp_dir_path): @pytest.mark.skipif(TT_STABILITY_TESTS, reason="This is TT_STABILITY_TESTS") @pytest.mark.parametrize("template", templates, ids=templates_ids) def test_otx_explain(self, template, tmp_dir_path): - if template.name == "DeiT-Tiny": - pytest.skip(reason="Issue#2098 ViT inference does not work by FeatureVectorHook.") tmp_dir_path = tmp_dir_path / "multi_class_cls" otx_explain_testing(template, tmp_dir_path, otx_dir, args) @@ -169,8 +165,6 @@ def test_otx_explain(self, template, tmp_dir_path): @pytest.mark.skipif(TT_STABILITY_TESTS, reason="This is TT_STABILITY_TESTS") @pytest.mark.parametrize("template", templates, ids=templates_ids) def test_otx_explain_openvino(self, template, tmp_dir_path): - if template.name == "DeiT-Tiny": - pytest.skip(reason="Issue#2098 ViT inference does not work by FeatureVectorHook.") tmp_dir_path = tmp_dir_path / "multi_class_cls" otx_explain_openvino_testing(template, tmp_dir_path, otx_dir, args) @@ -383,8 +377,6 @@ def test_otx_resume(self, template, tmp_dir_path): @pytest.mark.parametrize("template", templates, ids=templates_ids) @pytest.mark.parametrize("dump_features", [True, False]) def test_otx_export(self, template, tmp_dir_path, dump_features): - if template.name == "DeiT-Tiny" and dump_features: - pytest.skip(reason="Issue#2098 ViT template does not support dump_features.") tmp_dir_path = tmp_dir_path / "multi_label_cls" otx_export_testing(template, tmp_dir_path, dump_features) @@ -399,8 +391,6 @@ def test_otx_eval(self, template, tmp_dir_path): @pytest.mark.skipif(TT_STABILITY_TESTS, reason="This is TT_STABILITY_TESTS") @pytest.mark.parametrize("template", templates, ids=templates_ids) def test_otx_explain(self, template, tmp_dir_path): - if template.name == "DeiT-Tiny": - pytest.skip(reason="Issue#2098 ViT inference does not work by FeatureVectorHook.") tmp_dir_path = tmp_dir_path / "multi_label_cls" otx_explain_testing(template, tmp_dir_path, otx_dir, args_m) @@ -546,8 +536,6 @@ def test_otx_resume(self, template, tmp_dir_path): @pytest.mark.parametrize("template", templates, ids=templates_ids) @pytest.mark.parametrize("dump_features", [True, False]) def test_otx_export(self, template, tmp_dir_path, dump_features): - if template.name == "DeiT-Tiny" and dump_features: - pytest.skip(reason="Issue#2098 ViT template does not support dump_features.") tmp_dir_path = tmp_dir_path / "h_label_cls" otx_export_testing(template, tmp_dir_path, dump_features) @@ -562,8 +550,6 @@ def test_otx_eval(self, template, tmp_dir_path): @pytest.mark.skipif(TT_STABILITY_TESTS, reason="This is TT_STABILITY_TESTS") @pytest.mark.parametrize("template", templates, ids=templates_ids) def test_otx_explain(self, template, tmp_dir_path): - if template.name == "DeiT-Tiny": - pytest.skip(reason="Issue#2098 ViT inference does not work by FeatureVectorHook.") tmp_dir_path = tmp_dir_path / "h_label_cls" otx_explain_testing(template, tmp_dir_path, otx_dir, args_h) diff --git a/tests/integration/cli/classification/test_classification.py b/tests/integration/cli/classification/test_classification.py index 9e927d8bfbb..186613de79f 100644 --- a/tests/integration/cli/classification/test_classification.py +++ b/tests/integration/cli/classification/test_classification.py @@ -124,8 +124,6 @@ def test_otx_resume(self, template, tmp_dir_path): @pytest.mark.parametrize("template", templates, ids=templates_ids) @pytest.mark.parametrize("dump_features", [True, False]) def test_otx_export(self, template, tmp_dir_path, dump_features): - if template.name == "DeiT-Tiny" and dump_features: - pytest.skip(reason="Issue#2098 ViT inference does not work by FeatureVectorHook.") tmp_dir_path = tmp_dir_path / "multi_class_cls" otx_export_testing(template, tmp_dir_path, dump_features, check_ir_meta=True) @@ -150,48 +148,36 @@ def test_otx_eval(self, template, tmp_dir_path): @e2e_pytest_component @pytest.mark.parametrize("template", templates, ids=templates_ids) def test_otx_explain(self, template, tmp_dir_path): - if template.name == "DeiT-Tiny": - pytest.skip(reason="Issue#2098 ViT inference does not work by FeatureVectorHook.") tmp_dir_path = tmp_dir_path / "multi_class_cls" otx_explain_testing(template, tmp_dir_path, otx_dir, args) @e2e_pytest_component @pytest.mark.parametrize("template", templates, ids=templates_ids) def test_otx_explain_all_classes(self, template, tmp_dir_path): - if template.name == "DeiT-Tiny": - pytest.skip(reason="Issue#2098 ViT inference does not work by FeatureVectorHook.") tmp_dir_path = tmp_dir_path / "multi_class_cls" otx_explain_testing_all_classes(template, tmp_dir_path, otx_dir, args) @e2e_pytest_component @pytest.mark.parametrize("template", templates, ids=templates_ids) def test_otx_explain_process_saliency_maps(self, template, tmp_dir_path): - if template.name == "DeiT-Tiny": - pytest.skip(reason="Issue#2098 ViT inference does not work by FeatureVectorHook.") tmp_dir_path = tmp_dir_path / "multi_class_cls" otx_explain_testing_process_saliency_maps(template, tmp_dir_path, otx_dir, args) @e2e_pytest_component @pytest.mark.parametrize("template", templates, ids=templates_ids) def test_otx_explain_openvino(self, template, tmp_dir_path): - if template.name == "DeiT-Tiny": - pytest.skip(reason="Issue#2098 ViT inference does not work by FeatureVectorHook.") tmp_dir_path = tmp_dir_path / "multi_class_cls" otx_explain_openvino_testing(template, tmp_dir_path, otx_dir, args) @e2e_pytest_component @pytest.mark.parametrize("template", templates, ids=templates_ids) def test_otx_explain_all_classes_openvino(self, template, tmp_dir_path): - if template.name == "DeiT-Tiny": - pytest.skip(reason="Issue#2098 ViT inference does not work by FeatureVectorHook.") tmp_dir_path = tmp_dir_path / "multi_class_cls" otx_explain_all_classes_openvino_testing(template, tmp_dir_path, otx_dir, args) @e2e_pytest_component @pytest.mark.parametrize("template", templates, ids=templates_ids) def test_otx_explain_process_saliency_maps_openvino(self, template, tmp_dir_path): - if template.name == "DeiT-Tiny": - pytest.skip(reason="Issue#2098 ViT inference does not work by FeatureVectorHook.") tmp_dir_path = tmp_dir_path / "multi_class_cls" otx_explain_process_saliency_maps_openvino_testing(template, tmp_dir_path, otx_dir, args) @@ -365,48 +351,36 @@ def test_otx_eval(self, template, tmp_dir_path): @e2e_pytest_component @pytest.mark.parametrize("template", default_templates, ids=default_templates_ids) def test_otx_explain(self, template, tmp_dir_path): - if template.name == "DeiT-Tiny": - pytest.skip(reason="Issue#2098 ViT inference does not work by FeatureVectorHook.") tmp_dir_path = tmp_dir_path / "multi_label_cls" otx_explain_testing(template, tmp_dir_path, otx_dir, args_m) @e2e_pytest_component @pytest.mark.parametrize("template", default_templates, ids=default_templates_ids) def test_otx_explain_all_classes(self, template, tmp_dir_path): - if template.name == "DeiT-Tiny": - pytest.skip(reason="Issue#2098 ViT inference does not work by FeatureVectorHook.") tmp_dir_path = tmp_dir_path / "multi_label_cls" otx_explain_testing_all_classes(template, tmp_dir_path, otx_dir, args_m) @e2e_pytest_component @pytest.mark.parametrize("template", default_templates, ids=default_templates_ids) def test_otx_explain_process_saliency_maps(self, template, tmp_dir_path): - if template.name == "DeiT-Tiny": - pytest.skip(reason="Issue#2098 ViT inference does not work by FeatureVectorHook.") tmp_dir_path = tmp_dir_path / "multi_label_cls" otx_explain_testing_process_saliency_maps(template, tmp_dir_path, otx_dir, args_m) @e2e_pytest_component @pytest.mark.parametrize("template", default_templates, ids=default_templates_ids) def test_otx_explain_openvino(self, template, tmp_dir_path): - if template.name == "DeiT-Tiny": - pytest.skip(reason="Issue#2098 ViT inference does not work by FeatureVectorHook.") tmp_dir_path = tmp_dir_path / "multi_label_cls" otx_explain_openvino_testing(template, tmp_dir_path, otx_dir, args_m) @e2e_pytest_component @pytest.mark.parametrize("template", default_templates, ids=default_templates_ids) def test_otx_explain_all_classes_openvino(self, template, tmp_dir_path): - if template.name == "DeiT-Tiny": - pytest.skip(reason="Issue#2098 ViT inference does not work by FeatureVectorHook.") tmp_dir_path = tmp_dir_path / "multi_label_cls" otx_explain_all_classes_openvino_testing(template, tmp_dir_path, otx_dir, args_m) @e2e_pytest_component @pytest.mark.parametrize("template", default_templates, ids=default_templates_ids) def test_otx_explain_process_saliency_maps_openvino(self, template, tmp_dir_path): - if template.name == "DeiT-Tiny": - pytest.skip(reason="Issue#2098 ViT inference does not work by FeatureVectorHook.") tmp_dir_path = tmp_dir_path / "multi_label_cls" otx_explain_process_saliency_maps_openvino_testing(template, tmp_dir_path, otx_dir, args_m) @@ -502,48 +476,36 @@ def test_otx_eval(self, template, tmp_dir_path): @e2e_pytest_component @pytest.mark.parametrize("template", default_templates, ids=default_templates_ids) def test_otx_explain(self, template, tmp_dir_path): - if template.name == "DeiT-Tiny": - pytest.skip(reason="Issue#2098 ViT inference does not work by FeatureVectorHook.") tmp_dir_path = tmp_dir_path / "h_label_cls" otx_explain_testing(template, tmp_dir_path, otx_dir, args_h) @e2e_pytest_component @pytest.mark.parametrize("template", default_templates, ids=default_templates_ids) def test_otx_explain_all_classes(self, template, tmp_dir_path): - if template.name == "DeiT-Tiny": - pytest.skip(reason="Issue#2098 ViT inference does not work by FeatureVectorHook.") tmp_dir_path = tmp_dir_path / "h_label_cls" otx_explain_testing_all_classes(template, tmp_dir_path, otx_dir, args_h) @e2e_pytest_component @pytest.mark.parametrize("template", default_templates, ids=default_templates_ids) def test_otx_explain_process_saliency_maps(self, template, tmp_dir_path): - if template.name == "DeiT-Tiny": - pytest.skip(reason="Issue#2098 ViT inference does not work by FeatureVectorHook.") tmp_dir_path = tmp_dir_path / "h_label_cls" otx_explain_testing_process_saliency_maps(template, tmp_dir_path, otx_dir, args_h) @e2e_pytest_component @pytest.mark.parametrize("template", default_templates, ids=default_templates_ids) def test_otx_explain_openvino(self, template, tmp_dir_path): - if template.name == "DeiT-Tiny": - pytest.skip(reason="Issue#2098 ViT inference does not work by FeatureVectorHook.") tmp_dir_path = tmp_dir_path / "h_label_cls" otx_explain_openvino_testing(template, tmp_dir_path, otx_dir, args_h) @e2e_pytest_component @pytest.mark.parametrize("template", default_templates, ids=default_templates_ids) def test_otx_explain_all_classes_openvino(self, template, tmp_dir_path): - if template.name == "DeiT-Tiny": - pytest.skip(reason="Issue#2098 ViT inference does not work by FeatureVectorHook.") tmp_dir_path = tmp_dir_path / "h_label_cls" otx_explain_all_classes_openvino_testing(template, tmp_dir_path, otx_dir, args_h) @e2e_pytest_component @pytest.mark.parametrize("template", default_templates, ids=default_templates_ids) def test_otx_explain_process_saliency_maps_openvino(self, template, tmp_dir_path): - if template.name == "DeiT-Tiny": - pytest.skip(reason="Issue#2098 ViT inference does not work by FeatureVectorHook.") tmp_dir_path = tmp_dir_path / "h_label_cls" otx_explain_process_saliency_maps_openvino_testing(template, tmp_dir_path, otx_dir, args_h)