From 948881e51cc12ac2103c34f0ffbde0055ac35a12 Mon Sep 17 00:00:00 2001 From: Ashwin Vaidya Date: Fri, 25 Aug 2023 12:24:29 +0200 Subject: [PATCH] Fix OpenVINO inference for legacy models (#2450) * bug fix for legacy openvino models * Add tests * Specific exceptions --------- --- src/otx/algorithms/anomaly/tasks/openvino.py | 42 ++++++++++-- src/otx/cli/utils/io.py | 4 ++ .../algorithms/anomaly/tasks/test_openvino.py | 66 ++++++++++++++++++- 3 files changed, 106 insertions(+), 6 deletions(-) diff --git a/src/otx/algorithms/anomaly/tasks/openvino.py b/src/otx/algorithms/anomaly/tasks/openvino.py index cc65ef74294..7859cfbfb36 100644 --- a/src/otx/algorithms/anomaly/tasks/openvino.py +++ b/src/otx/algorithms/anomaly/tasks/openvino.py @@ -26,6 +26,7 @@ import numpy as np import openvino.runtime as ov from addict import Dict as ADDict +from anomalib.data.utils.transform import get_transforms from anomalib.deploy import OpenVINOInferencer from nncf.common.quantization.structs import QuantizationPreset from omegaconf import OmegaConf @@ -216,16 +217,47 @@ def get_metadata(self) -> Dict: """Get Meta Data.""" metadata = {} if self.task_environment.model is not None: - metadata = json.loads(self.task_environment.model.get_data("metadata").decode()) - metadata["image_threshold"] = np.array(metadata["image_threshold"], dtype=np.float32).item() - metadata["pixel_threshold"] = np.array(metadata["pixel_threshold"], dtype=np.float32).item() - metadata["min"] = np.array(metadata["min"], dtype=np.float32).item() - metadata["max"] = np.array(metadata["max"], dtype=np.float32).item() + try: + metadata = json.loads(self.task_environment.model.get_data("metadata").decode()) + self._populate_metadata(metadata) + logger.info("Metadata loaded from model v1.4.") + except (KeyError, json.decoder.JSONDecodeError): + # model is from version 1.2.x + metadata = self._populate_metadata_legacy(self.task_environment.model) + logger.info("Metadata loaded from model v1.2.x.") else: raise ValueError("Cannot access meta-data. self.task_environment.model is empty.") return metadata + def _populate_metadata_legacy(self, model: ModelEntity) -> Dict[str, Any]: + """Populates metadata for models for version 1.2.x.""" + image_threshold = np.frombuffer(model.get_data("image_threshold"), dtype=np.float32) + pixel_threshold = np.frombuffer(model.get_data("pixel_threshold"), dtype=np.float32) + min_value = np.frombuffer(model.get_data("min"), dtype=np.float32) + max_value = np.frombuffer(model.get_data("max"), dtype=np.float32) + transform = get_transforms( + config=self.config.dataset.transform_config.train, + image_size=tuple(self.config.dataset.image_size), + to_tensor=True, + ) + metadata = { + "transform": transform.to_dict(), + "image_threshold": image_threshold, + "pixel_threshold": pixel_threshold, + "min": min_value, + "max": max_value, + "task": str(self.task_type).lower().split("_")[-1], + } + return metadata + + def _populate_metadata(self, metadata: Dict[str, Any]): + """Populates metadata for models from version 1.4 onwards.""" + metadata["image_threshold"] = np.array(metadata["image_threshold"], dtype=np.float32).item() + metadata["pixel_threshold"] = np.array(metadata["pixel_threshold"], dtype=np.float32).item() + metadata["min"] = np.array(metadata["min"], dtype=np.float32).item() + metadata["max"] = np.array(metadata["max"], dtype=np.float32).item() + def evaluate(self, output_resultset: ResultSetEntity, evaluation_metric: Optional[str] = None): """Evaluate the performance of the model. diff --git a/src/otx/cli/utils/io.py b/src/otx/cli/utils/io.py index 73941eb1a6c..3770fb279bf 100644 --- a/src/otx/cli/utils/io.py +++ b/src/otx/cli/utils/io.py @@ -51,6 +51,10 @@ "visual_prompting_image_encoder.bin", "visual_prompting_decoder.xml", "visual_prompting_decoder.bin", + "image_threshold", # NOTE: used for compatibility with with OTX 1.2.x. Remove when all Geti projects are upgraded. + "pixel_threshold", # NOTE: used for compatibility with with OTX 1.2.x. Remove when all Geti projects are upgraded. + "min", # NOTE: used for compatibility with with OTX 1.2.x. Remove when all Geti projects are upgraded. + "max", # NOTE: used for compatibility with with OTX 1.2.x. Remove when all Geti projects are upgraded. ) diff --git a/tests/unit/algorithms/anomaly/tasks/test_openvino.py b/tests/unit/algorithms/anomaly/tasks/test_openvino.py index 58b6b2a8450..82d1174bb97 100644 --- a/tests/unit/algorithms/anomaly/tasks/test_openvino.py +++ b/tests/unit/algorithms/anomaly/tasks/test_openvino.py @@ -3,27 +3,40 @@ # Copyright (C) 2021-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -import pytest +import json from copy import deepcopy +from pathlib import Path +from tempfile import TemporaryDirectory +from unittest.mock import MagicMock, patch import numpy as np +import pytest from otx.algorithms.anomaly.tasks.openvino import OpenVINOTask from otx.algorithms.anomaly.tasks.train import TrainingTask from otx.api.entities.datasets import DatasetEntity from otx.api.entities.inference_parameters import InferenceParameters +from otx.api.entities.label import Domain, LabelEntity +from otx.api.entities.label_schema import LabelSchemaEntity from otx.api.entities.model import ModelEntity, ModelOptimizationType from otx.api.entities.model_template import TaskType from otx.api.entities.optimization_parameters import OptimizationParameters from otx.api.entities.resultset import ResultSetEntity from otx.api.entities.subset import Subset +from otx.api.entities.task_environment import TaskEnvironment from otx.api.usecases.tasks.interfaces.export_interface import ExportType from otx.api.usecases.tasks.interfaces.optimization_interface import OptimizationType +from otx.cli.utils.io import read_model class TestOpenVINOTask: """Tests methods in the OpenVINO task.""" + @pytest.fixture + def tmp_dir(self): + with TemporaryDirectory() as tmp_dir: + yield tmp_dir + def set_normalization_params(self, output_model: ModelEntity): """Sets normalization parameters for an untrained output model. @@ -77,3 +90,54 @@ def test_openvino(self, tmpdir, setup_task_environment): # deploy openvino_task.deploy(output_model) assert output_model.exportable_code is not None + + @patch.multiple(OpenVINOTask, get_config=MagicMock(), load_inferencer=MagicMock()) + @patch("otx.algorithms.anomaly.tasks.openvino.get_transforms", MagicMock()) + def test_anomaly_legacy_keys(self, mocker, tmp_dir): + """Checks whether the model is loaded correctly with legacy and current keys.""" + + tmp_dir = Path(tmp_dir) + xml_model_path = tmp_dir / "model.xml" + xml_model_path.write_text("xml_model") + bin_model_path = tmp_dir / "model.bin" + bin_model_path.write_text("bin_model") + + # Test loading legacy keys + legacy_keys = ("image_threshold", "pixel_threshold", "min", "max") + for key in legacy_keys: + (tmp_dir / key).write_bytes(np.zeros(1, dtype=np.float32).tobytes()) + + model = read_model(mocker.MagicMock(), str(xml_model_path), mocker.MagicMock()) + task_environment = TaskEnvironment( + model_template=mocker.MagicMock(), + model=model, + hyper_parameters=mocker.MagicMock(), + label_schema=LabelSchemaEntity.from_labels( + [ + LabelEntity("Anomalous", is_anomalous=True, domain=Domain.ANOMALY_SEGMENTATION), + LabelEntity("Normal", domain=Domain.ANOMALY_SEGMENTATION), + ] + ), + ) + openvino_task = OpenVINOTask(task_environment) + metadata = openvino_task.get_metadata() + for key in legacy_keys: + assert metadata[key] == np.zeros(1, dtype=np.float32) + + # cleanup legacy keys + for key in legacy_keys: + (tmp_dir / key).unlink() + + # Test loading new keys + new_metadata = { + "image_threshold": np.zeros(1, dtype=np.float32).tolist(), + "pixel_threshold": np.zeros(1, dtype=np.float32).tolist(), + "min": np.zeros(1, dtype=np.float32).tolist(), + "max": np.zeros(1, dtype=np.float32).tolist(), + } + (tmp_dir / "metadata").write_bytes(json.dumps(new_metadata).encode()) + task_environment.model = read_model(mocker.MagicMock(), str(xml_model_path), mocker.MagicMock()) + openvino_task = OpenVINOTask(task_environment) + metadata = openvino_task.get_metadata() + for key in new_metadata.keys(): + assert metadata[key] == np.zeros(1, dtype=np.float32)