From 6880bb29278b6b5ac738c8359264521e78f54697 Mon Sep 17 00:00:00 2001 From: kprokofi Date: Thu, 7 Nov 2024 22:41:17 +0900 Subject: [PATCH 1/8] fix linter --- src/otx/core/data/dataset/base.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/otx/core/data/dataset/base.py b/src/otx/core/data/dataset/base.py index c6ef710771..fa35e3a431 100644 --- a/src/otx/core/data/dataset/base.py +++ b/src/otx/core/data/dataset/base.py @@ -142,7 +142,11 @@ def __getitem__(self, index: int) -> T_OTXDataEntity: msg = f"Reach the maximum refetch number ({self.max_refetch})" raise RuntimeError(msg) - def _get_img_data_and_shape(self, img: Image, roi: dict[str, Any] | None) -> tuple[np.ndarray, tuple[int, int]]: + def _get_img_data_and_shape( + self, + img: Image, + roi: dict[str, Any] | None = None, + ) -> tuple[np.ndarray, tuple[int, int]]: key = img.path if isinstance(img, ImageFromFile) else id(img) if (img_data := self.mem_cache_handler.get(key=key)[0]) is not None: From a2afe948bc0e5105fc5c39a72fb13e7d9d800f19 Mon Sep 17 00:00:00 2001 From: kprokofi Date: Thu, 7 Nov 2024 22:45:34 +0900 Subject: [PATCH 2/8] return recipe back --- src/otx/recipe/_base_/data/classification.yaml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/otx/recipe/_base_/data/classification.yaml b/src/otx/recipe/_base_/data/classification.yaml index 944de074fe..e8ee41bf15 100644 --- a/src/otx/recipe/_base_/data/classification.yaml +++ b/src/otx/recipe/_base_/data/classification.yaml @@ -12,7 +12,7 @@ train_subset: subset_name: train transform_lib_type: TORCHVISION batch_size: 64 - num_workers: 0 + num_workers: 2 to_tv_image: false transforms: - class_path: otx.core.data.transform_libs.torchvision.RandomResizedCrop @@ -37,7 +37,7 @@ val_subset: subset_name: val transform_lib_type: TORCHVISION batch_size: 64 - num_workers: 0 + num_workers: 2 to_tv_image: false transforms: - class_path: otx.core.data.transform_libs.torchvision.Resize @@ -59,7 +59,7 @@ test_subset: subset_name: test transform_lib_type: TORCHVISION batch_size: 64 - num_workers: 0 + num_workers: 2 to_tv_image: false transforms: - class_path: otx.core.data.transform_libs.torchvision.Resize From 1acc2f0755b655b3e52f629b26333dc4da762771 Mon Sep 17 00:00:00 2001 From: kprokofi Date: Thu, 7 Nov 2024 23:01:27 +0900 Subject: [PATCH 3/8] added roi extraction for multi cllass classification datasett --- src/otx/core/data/dataset/base.py | 17 +++++++++++-- src/otx/core/data/dataset/classification.py | 24 +++++++++---------- .../recipe/_base_/data/classification.yaml | 6 ++--- 3 files changed, 30 insertions(+), 17 deletions(-) diff --git a/src/otx/core/data/dataset/base.py b/src/otx/core/data/dataset/base.py index a98f7c6083..c6ef710771 100644 --- a/src/otx/core/data/dataset/base.py +++ b/src/otx/core/data/dataset/base.py @@ -8,7 +8,7 @@ from abc import abstractmethod from collections.abc import Iterable from contextlib import contextmanager -from typing import TYPE_CHECKING, Callable, Generic, Iterator, List, Union +from typing import TYPE_CHECKING, Any, Callable, Generic, Iterator, List, Union import cv2 import numpy as np @@ -92,6 +92,7 @@ def __init__( self.image_color_channel = image_color_channel self.stack_images = stack_images self.to_tv_image = to_tv_image + if self.dm_subset.categories(): self.label_info = LabelInfo.from_dm_label_groups(self.dm_subset.categories()[AnnotationType.label]) else: @@ -141,7 +142,7 @@ def __getitem__(self, index: int) -> T_OTXDataEntity: msg = f"Reach the maximum refetch number ({self.max_refetch})" raise RuntimeError(msg) - def _get_img_data_and_shape(self, img: Image) -> tuple[np.ndarray, tuple[int, int]]: + def _get_img_data_and_shape(self, img: Image, roi: dict[str, Any] | None) -> tuple[np.ndarray, tuple[int, int]]: key = img.path if isinstance(img, ImageFromFile) else id(img) if (img_data := self.mem_cache_handler.get(key=key)[0]) is not None: @@ -158,6 +159,18 @@ def _get_img_data_and_shape(self, img: Image) -> tuple[np.ndarray, tuple[int, in msg = "Cannot get image data" raise RuntimeError(msg) + if roi: + # extract ROI from image + shape = roi["shape"] + h, w = img_data.shape[:2] + x1, y1, x2, y2 = ( + np.trunc(shape["x1"] * w), + np.trunc(shape["y1"] * h), + np.ceil(shape["x2"] * w), + np.ceil(shape["y2"] * h), + ) + img_data = img_data[int(y1) : int(y2), int(x1) : int(x2)] + img_data = self._cache_img(key=key, img_data=img_data.astype(np.uint8)) return img_data, img_data.shape[:2] diff --git a/src/otx/core/data/dataset/classification.py b/src/otx/core/data/dataset/classification.py index c5048dd798..8ba08c9d86 100644 --- a/src/otx/core/data/dataset/classification.py +++ b/src/otx/core/data/dataset/classification.py @@ -32,18 +32,18 @@ class OTXMulticlassClsDataset(OTXDataset[MulticlassClsDataEntity]): def _get_item_impl(self, index: int) -> MulticlassClsDataEntity | None: item = self.dm_subset[index] img = item.media_as(Image) - img_data, img_shape = self._get_img_data_and_shape(img) + roi = item.attributes.get("roi", None) + img_data, img_shape = self._get_img_data_and_shape(img, roi) + if roi: + # extract labels from ROI + labels_ids = [ + label["label"]["_id"] for label in roi["labels"] if label["label"]["domain"] == "CLASSIFICATION" + ] + label_anns = [self.label_info.label_names.index(label_id) for label_id in labels_ids] + else: + # extract labels from annotations + label_anns = [ann.label for ann in item.annotations if isinstance(ann, Label)] - label_anns = [] - for ann in item.annotations: - if isinstance(ann, Label): - label_anns.append(ann) - else: - # If the annotation is not Label, it should be converted to Label. - # For Chained Task: Detection (Bbox) -> Classification (Label) - label = Label(label=ann.label) - if label not in label_anns: - label_anns.append(label) if len(label_anns) > 1: msg = f"Multi-class Classification can't use the multi-label, currently len(labels) = {len(label_anns)}" raise ValueError(msg) @@ -56,7 +56,7 @@ def _get_item_impl(self, index: int) -> MulticlassClsDataEntity | None: ori_shape=img_shape, image_color_channel=self.image_color_channel, ), - labels=torch.as_tensor([ann.label for ann in label_anns]), + labels=torch.as_tensor(label_anns), ) return self._apply_transforms(entity) diff --git a/src/otx/recipe/_base_/data/classification.yaml b/src/otx/recipe/_base_/data/classification.yaml index e8ee41bf15..944de074fe 100644 --- a/src/otx/recipe/_base_/data/classification.yaml +++ b/src/otx/recipe/_base_/data/classification.yaml @@ -12,7 +12,7 @@ train_subset: subset_name: train transform_lib_type: TORCHVISION batch_size: 64 - num_workers: 2 + num_workers: 0 to_tv_image: false transforms: - class_path: otx.core.data.transform_libs.torchvision.RandomResizedCrop @@ -37,7 +37,7 @@ val_subset: subset_name: val transform_lib_type: TORCHVISION batch_size: 64 - num_workers: 2 + num_workers: 0 to_tv_image: false transforms: - class_path: otx.core.data.transform_libs.torchvision.Resize @@ -59,7 +59,7 @@ test_subset: subset_name: test transform_lib_type: TORCHVISION batch_size: 64 - num_workers: 2 + num_workers: 0 to_tv_image: false transforms: - class_path: otx.core.data.transform_libs.torchvision.Resize From accf03935bf2a033f8bac8265e51cf10a57f9509 Mon Sep 17 00:00:00 2001 From: kprokofi Date: Fri, 8 Nov 2024 01:13:18 +0900 Subject: [PATCH 4/8] fix linter --- src/otx/algo/utils/xai_utils.py | 2 +- src/otx/core/data/dataset/base.py | 4 ++-- src/otx/core/data/dataset/segmentation.py | 4 ++-- src/otx/core/data/dataset/tile.py | 2 +- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/otx/algo/utils/xai_utils.py b/src/otx/algo/utils/xai_utils.py index 434d2612cf..210d6aad0d 100644 --- a/src/otx/algo/utils/xai_utils.py +++ b/src/otx/algo/utils/xai_utils.py @@ -225,7 +225,7 @@ def _get_image_data_name( subset = datamodule.subsets[subset_name] item = subset.dm_subset[img_id] img = item.media_as(Image) - img_data, _ = subset._get_img_data_and_shape(img) # noqa: SLF001 + img_data, _, _ = subset._get_img_data_and_shape(img) # noqa: SLF001 image_save_name = "".join([char if char.isalnum() else "_" for char in item.id]) return img_data, image_save_name diff --git a/src/otx/core/data/dataset/base.py b/src/otx/core/data/dataset/base.py index 101e8a8e45..93a84a087e 100644 --- a/src/otx/core/data/dataset/base.py +++ b/src/otx/core/data/dataset/base.py @@ -146,7 +146,7 @@ def _get_img_data_and_shape( self, img: Image, roi: dict[str, Any] | None = None, - ) -> tuple[np.ndarray, tuple[int, int]]: + ) -> tuple[np.ndarray, tuple[int, int], dict[str, Any] | None]: key = img.path if isinstance(img, ImageFromFile) else id(img) roi_meta = None @@ -176,7 +176,7 @@ def _get_img_data_and_shape( int(np.ceil(shape["x2"] * w)), int(np.ceil(shape["y2"] * h)), ) - img_data = img_data[y1 : y2, x1 : x2] + img_data = img_data[y1:y2, x1:x2] roi_meta = {"x1": x1, "y1": y1, "x2": x2, "y2": y2, "orig_image_shape": (h, w)} img_data = self._cache_img(key=key, img_data=img_data.astype(np.uint8), meta=roi_meta) diff --git a/src/otx/core/data/dataset/segmentation.py b/src/otx/core/data/dataset/segmentation.py index be9f4dd419..a690dde42a 100644 --- a/src/otx/core/data/dataset/segmentation.py +++ b/src/otx/core/data/dataset/segmentation.py @@ -205,9 +205,9 @@ def _get_item_impl(self, index: int) -> SegDataEntity | None: roi = item.attributes.get("roi", None) img_data, img_shape, roi_meta = self._get_img_data_and_shape(img, roi) if item.annotations: - ori_shape = roi_meta["orig_image_shape"] if roi else img_shape + ori_shape = roi_meta["orig_image_shape"] if roi_meta else img_shape extracted_mask = _extract_class_mask(item=item, img_shape=ori_shape, ignore_index=self.ignore_index) - if roi: + if roi_meta: extracted_mask = extracted_mask[roi_meta["y1"] : roi_meta["y2"], roi_meta["x1"] : roi_meta["x2"]] masks = tv_tensors.Mask(extracted_mask[None]) diff --git a/src/otx/core/data/dataset/tile.py b/src/otx/core/data/dataset/tile.py index 63a990380d..a729ddc186 100644 --- a/src/otx/core/data/dataset/tile.py +++ b/src/otx/core/data/dataset/tile.py @@ -461,7 +461,7 @@ def _get_item_impl(self, index: int) -> TileInstSegDataEntity: # type: ignore[o """ item = self.dm_subset[index] img = item.media_as(Image) - img_data, img_shape = self._get_img_data_and_shape(img) + img_data, img_shape, _ = self._get_img_data_and_shape(img) gt_bboxes, gt_labels, gt_masks, gt_polygons = [], [], [], [] From 9fb6c73eaf4215466bbe888be7214e074dfa9300 Mon Sep 17 00:00:00 2001 From: kprokofi Date: Fri, 8 Nov 2024 01:33:06 +0900 Subject: [PATCH 5/8] add same logic to semantic seg --- src/otx/core/data/dataset/anomaly.py | 2 +- src/otx/core/data/dataset/base.py | 32 +++++++++++-------- src/otx/core/data/dataset/classification.py | 6 ++-- src/otx/core/data/dataset/detection.py | 2 +- .../data/dataset/instance_segmentation.py | 2 +- .../core/data/dataset/keypoint_detection.py | 2 +- src/otx/core/data/dataset/segmentation.py | 9 ++++-- src/otx/core/data/dataset/tile.py | 2 +- src/otx/core/data/dataset/visual_prompting.py | 4 +-- .../recipe/_base_/data/classification.yaml | 6 ++-- 10 files changed, 38 insertions(+), 29 deletions(-) diff --git a/src/otx/core/data/dataset/anomaly.py b/src/otx/core/data/dataset/anomaly.py index ec9b59ce49..0f855f5b3d 100644 --- a/src/otx/core/data/dataset/anomaly.py +++ b/src/otx/core/data/dataset/anomaly.py @@ -79,7 +79,7 @@ def _get_item_impl( datumaro_item = self.dm_subset[index] img = datumaro_item.media_as(Image) # returns image in RGB format if self.image_color_channel is RGB - img_data, img_shape = self._get_img_data_and_shape(img) + img_data, img_shape, _ = self._get_img_data_and_shape(img) label = self._get_label(datumaro_item) diff --git a/src/otx/core/data/dataset/base.py b/src/otx/core/data/dataset/base.py index c6ef710771..189af3eddc 100644 --- a/src/otx/core/data/dataset/base.py +++ b/src/otx/core/data/dataset/base.py @@ -142,11 +142,14 @@ def __getitem__(self, index: int) -> T_OTXDataEntity: msg = f"Reach the maximum refetch number ({self.max_refetch})" raise RuntimeError(msg) - def _get_img_data_and_shape(self, img: Image, roi: dict[str, Any] | None) -> tuple[np.ndarray, tuple[int, int]]: + def _get_img_data_and_shape(self, img: Image, roi: dict[str, Any] | None = None) -> tuple[np.ndarray, tuple[int, int]]: key = img.path if isinstance(img, ImageFromFile) else id(img) + roi_meta = None - if (img_data := self.mem_cache_handler.get(key=key)[0]) is not None: - return img_data, img_data.shape[:2] + # check if the image is already in the cache + img_data, roi_meta = self.mem_cache_handler.get(key=key) + if img_data is not None: + return img_data, img_data.shape[:2], roi_meta with image_decode_context(): img_data = ( @@ -164,18 +167,19 @@ def _get_img_data_and_shape(self, img: Image, roi: dict[str, Any] | None) -> tup shape = roi["shape"] h, w = img_data.shape[:2] x1, y1, x2, y2 = ( - np.trunc(shape["x1"] * w), - np.trunc(shape["y1"] * h), - np.ceil(shape["x2"] * w), - np.ceil(shape["y2"] * h), + int(np.trunc(shape["x1"] * w)), + int(np.trunc(shape["y1"] * h)), + int(np.ceil(shape["x2"] * w)), + int(np.ceil(shape["y2"] * h)), ) - img_data = img_data[int(y1) : int(y2), int(x1) : int(x2)] + img_data = img_data[y1 : y2, x1 : x2] + roi_meta = {"x1": x1, "y1": y1, "x2": x2, "y2": y2, "orig_image_shape": (h, w)} - img_data = self._cache_img(key=key, img_data=img_data.astype(np.uint8)) + img_data = self._cache_img(key=key, img_data=img_data.astype(np.uint8), meta=roi_meta) - return img_data, img_data.shape[:2] + return img_data, img_data.shape[:2], roi_meta - def _cache_img(self, key: str | int, img_data: np.ndarray) -> np.ndarray: + def _cache_img(self, key: str | int, img_data: np.ndarray, meta: dict[str, Any] | None = None) -> np.ndarray: """Cache an image after resizing. If there is available space in the memory pool, the input image is cached. @@ -195,14 +199,14 @@ def _cache_img(self, key: str | int, img_data: np.ndarray) -> np.ndarray: return img_data if self.mem_cache_img_max_size is None: - self.mem_cache_handler.put(key=key, data=img_data, meta=None) + self.mem_cache_handler.put(key=key, data=img_data, meta=meta) return img_data height, width = img_data.shape[:2] max_height, max_width = self.mem_cache_img_max_size if height <= max_height and width <= max_width: - self.mem_cache_handler.put(key=key, data=img_data, meta=None) + self.mem_cache_handler.put(key=key, data=img_data, meta=meta) return img_data # Preserve the image size ratio and fit to max_height or max_width @@ -219,7 +223,7 @@ def _cache_img(self, key: str | int, img_data: np.ndarray) -> np.ndarray: self.mem_cache_handler.put( key=key, data=resized_img, - meta=None, + meta=meta, ) return resized_img diff --git a/src/otx/core/data/dataset/classification.py b/src/otx/core/data/dataset/classification.py index 8ba08c9d86..8f4f5ffc24 100644 --- a/src/otx/core/data/dataset/classification.py +++ b/src/otx/core/data/dataset/classification.py @@ -33,7 +33,7 @@ def _get_item_impl(self, index: int) -> MulticlassClsDataEntity | None: item = self.dm_subset[index] img = item.media_as(Image) roi = item.attributes.get("roi", None) - img_data, img_shape = self._get_img_data_and_shape(img, roi) + img_data, img_shape, _ = self._get_img_data_and_shape(img, roi) if roi: # extract labels from ROI labels_ids = [ @@ -78,7 +78,7 @@ def _get_item_impl(self, index: int) -> MultilabelClsDataEntity | None: item = self.dm_subset[index] img = item.media_as(Image) ignored_labels: list[int] = [] # This should be assigned form item - img_data, img_shape = self._get_img_data_and_shape(img) + img_data, img_shape, _ = self._get_img_data_and_shape(img) label_anns = [] for ann in item.annotations: @@ -195,7 +195,7 @@ def _get_item_impl(self, index: int) -> HlabelClsDataEntity | None: item = self.dm_subset[index] img = item.media_as(Image) ignored_labels: list[int] = [] # This should be assigned form item - img_data, img_shape = self._get_img_data_and_shape(img) + img_data, img_shape, _ = self._get_img_data_and_shape(img) label_anns = [] for ann in item.annotations: diff --git a/src/otx/core/data/dataset/detection.py b/src/otx/core/data/dataset/detection.py index 8094638b45..6783fce720 100644 --- a/src/otx/core/data/dataset/detection.py +++ b/src/otx/core/data/dataset/detection.py @@ -26,7 +26,7 @@ def _get_item_impl(self, index: int) -> DetDataEntity | None: item = self.dm_subset[index] img = item.media_as(Image) ignored_labels: list[int] = [] # This should be assigned form item - img_data, img_shape = self._get_img_data_and_shape(img) + img_data, img_shape, _ = self._get_img_data_and_shape(img) bbox_anns = [ann for ann in item.annotations if isinstance(ann, Bbox)] diff --git a/src/otx/core/data/dataset/instance_segmentation.py b/src/otx/core/data/dataset/instance_segmentation.py index 0a3abaeb87..2457e12934 100644 --- a/src/otx/core/data/dataset/instance_segmentation.py +++ b/src/otx/core/data/dataset/instance_segmentation.py @@ -40,7 +40,7 @@ def _get_item_impl(self, index: int) -> InstanceSegDataEntity | None: item = self.dm_subset[index] img = item.media_as(Image) ignored_labels: list[int] = [] - img_data, img_shape = self._get_img_data_and_shape(img) + img_data, img_shape, _ = self._get_img_data_and_shape(img) gt_bboxes, gt_labels, gt_masks, gt_polygons = [], [], [], [] diff --git a/src/otx/core/data/dataset/keypoint_detection.py b/src/otx/core/data/dataset/keypoint_detection.py index f0e0d30c37..c74b77c931 100644 --- a/src/otx/core/data/dataset/keypoint_detection.py +++ b/src/otx/core/data/dataset/keypoint_detection.py @@ -86,7 +86,7 @@ def _get_item_impl(self, index: int) -> KeypointDetDataEntity | None: item = self.dm_subset[index] img = item.media_as(Image) ignored_labels: list[int] = [] # This should be assigned form item - img_data, img_shape = self._get_img_data_and_shape(img) + img_data, img_shape, _ = self._get_img_data_and_shape(img) bbox_anns = [ann for ann in item.annotations if isinstance(ann, Bbox)] bboxes = ( diff --git a/src/otx/core/data/dataset/segmentation.py b/src/otx/core/data/dataset/segmentation.py index 53975456b6..be9f4dd419 100644 --- a/src/otx/core/data/dataset/segmentation.py +++ b/src/otx/core/data/dataset/segmentation.py @@ -202,9 +202,14 @@ def _get_item_impl(self, index: int) -> SegDataEntity | None: item = self.dm_subset[index] img = item.media_as(Image) ignored_labels: list[int] = [] - img_data, img_shape = self._get_img_data_and_shape(img) + roi = item.attributes.get("roi", None) + img_data, img_shape, roi_meta = self._get_img_data_and_shape(img, roi) if item.annotations: - extracted_mask = _extract_class_mask(item=item, img_shape=img_shape, ignore_index=self.ignore_index) + ori_shape = roi_meta["orig_image_shape"] if roi else img_shape + extracted_mask = _extract_class_mask(item=item, img_shape=ori_shape, ignore_index=self.ignore_index) + if roi: + extracted_mask = extracted_mask[roi_meta["y1"] : roi_meta["y2"], roi_meta["x1"] : roi_meta["x2"]] + masks = tv_tensors.Mask(extracted_mask[None]) else: # semi-supervised learning, unlabeled dataset diff --git a/src/otx/core/data/dataset/tile.py b/src/otx/core/data/dataset/tile.py index a39ea5aa90..63a990380d 100644 --- a/src/otx/core/data/dataset/tile.py +++ b/src/otx/core/data/dataset/tile.py @@ -370,7 +370,7 @@ def _get_item_impl(self, index: int) -> TileDetDataEntity: # type: ignore[overr """ item = self.dm_subset[index] img = item.media_as(Image) - img_data, img_shape = self._get_img_data_and_shape(img) + img_data, img_shape, _ = self._get_img_data_and_shape(img) bbox_anns = [ann for ann in item.annotations if isinstance(ann, Bbox)] diff --git a/src/otx/core/data/dataset/visual_prompting.py b/src/otx/core/data/dataset/visual_prompting.py index 0047e9350f..8f2ccb620d 100644 --- a/src/otx/core/data/dataset/visual_prompting.py +++ b/src/otx/core/data/dataset/visual_prompting.py @@ -79,7 +79,7 @@ def __init__( def _get_item_impl(self, index: int) -> VisualPromptingDataEntity | None: item = self.dm_subset[index] img = item.media_as(dmImage) - img_data, img_shape = self._get_img_data_and_shape(img) + img_data, img_shape, _ = self._get_img_data_and_shape(img) gt_bboxes, gt_points = [], [] gt_masks = defaultdict(list) @@ -214,7 +214,7 @@ def __init__( def _get_item_impl(self, index: int) -> ZeroShotVisualPromptingDataEntity | None: item = self.dm_subset[index] img = item.media_as(dmImage) - img_data, img_shape = self._get_img_data_and_shape(img) + img_data, img_shape, _ = self._get_img_data_and_shape(img) gt_prompts: list[tvBoundingBoxes | Points] = [] gt_masks: list[tvMask] = [] diff --git a/src/otx/recipe/_base_/data/classification.yaml b/src/otx/recipe/_base_/data/classification.yaml index 944de074fe..e8ee41bf15 100644 --- a/src/otx/recipe/_base_/data/classification.yaml +++ b/src/otx/recipe/_base_/data/classification.yaml @@ -12,7 +12,7 @@ train_subset: subset_name: train transform_lib_type: TORCHVISION batch_size: 64 - num_workers: 0 + num_workers: 2 to_tv_image: false transforms: - class_path: otx.core.data.transform_libs.torchvision.RandomResizedCrop @@ -37,7 +37,7 @@ val_subset: subset_name: val transform_lib_type: TORCHVISION batch_size: 64 - num_workers: 0 + num_workers: 2 to_tv_image: false transforms: - class_path: otx.core.data.transform_libs.torchvision.Resize @@ -59,7 +59,7 @@ test_subset: subset_name: test transform_lib_type: TORCHVISION batch_size: 64 - num_workers: 0 + num_workers: 2 to_tv_image: false transforms: - class_path: otx.core.data.transform_libs.torchvision.Resize From 97255e3b5ccb5581b15bac882eba749d94da2c6f Mon Sep 17 00:00:00 2001 From: kprokofi Date: Fri, 8 Nov 2024 18:35:55 +0900 Subject: [PATCH 6/8] added test for OTXDataset --- src/otx/core/data/dataset/base.py | 2 +- tests/unit/core/data/dataset/test_base.py | 104 ++++++++++++++++++++++ 2 files changed, 105 insertions(+), 1 deletion(-) create mode 100644 tests/unit/core/data/dataset/test_base.py diff --git a/src/otx/core/data/dataset/base.py b/src/otx/core/data/dataset/base.py index 101e8a8e45..e26c1c25fc 100644 --- a/src/otx/core/data/dataset/base.py +++ b/src/otx/core/data/dataset/base.py @@ -176,7 +176,7 @@ def _get_img_data_and_shape( int(np.ceil(shape["x2"] * w)), int(np.ceil(shape["y2"] * h)), ) - img_data = img_data[y1 : y2, x1 : x2] + img_data = img_data[y1:y2, x1:x2] roi_meta = {"x1": x1, "y1": y1, "x2": x2, "y2": y2, "orig_image_shape": (h, w)} img_data = self._cache_img(key=key, img_data=img_data.astype(np.uint8), meta=roi_meta) diff --git a/tests/unit/core/data/dataset/test_base.py b/tests/unit/core/data/dataset/test_base.py new file mode 100644 index 0000000000..47afdf9cf5 --- /dev/null +++ b/tests/unit/core/data/dataset/test_base.py @@ -0,0 +1,104 @@ +from unittest import mock + +import numpy as np +import pytest +from datumaro.components.media import Image +from otx.core.data.dataset.base import OTXDataset + + +class TestOTXDataset: + @pytest.fixture() + def mock_image(self) -> Image: + img = mock.Mock(spec=Image) + img.data = np.random.randint(0, 256, (10, 10, 3), dtype=np.uint8) + img.path = "test_path" + return img + + @pytest.fixture() + def mock_mem_cache_handler(self): + mem_cache_handler = mock.MagicMock() + mem_cache_handler.frozen = False + return mem_cache_handler + + @pytest.fixture() + def otx_dataset(self, mock_mem_cache_handler): + class MockOTXDataset(OTXDataset): + def _get_item_impl(self, idx: int) -> None: + return None + + @property + def collate_fn(self) -> None: + return None + + dm_subset = mock.Mock() + dm_subset.categories = mock.MagicMock() + dm_subset.categories.return_value = None + + return MockOTXDataset( + dm_subset=dm_subset, + transforms=None, + mem_cache_handler=mock_mem_cache_handler, + mem_cache_img_max_size=None, + ) + + def test_get_img_data_and_shape_no_cache(self, otx_dataset, mock_image, mock_mem_cache_handler): + mock_mem_cache_handler.get.return_value = (None, None) + img_data, img_shape, roi_meta = otx_dataset._get_img_data_and_shape(mock_image) + assert img_data.shape == (10, 10, 3) + assert img_shape == (10, 10) + assert roi_meta is None + + def test_get_img_data_and_shape_with_cache(self, otx_dataset, mock_image, mock_mem_cache_handler): + mock_mem_cache_handler.get.return_value = (np.random.randint(0, 256, (10, 10, 3), dtype=np.uint8), None) + img_data, img_shape, roi_meta = otx_dataset._get_img_data_and_shape(mock_image) + assert img_data.shape == (10, 10, 3) + assert img_shape == (10, 10) + assert roi_meta is None + + def test_get_img_data_and_shape_with_roi(self, otx_dataset, mock_image, mock_mem_cache_handler): + roi = {"shape": {"x1": 0.1, "y1": 0.1, "x2": 0.9, "y2": 0.9}} + mock_mem_cache_handler.get.return_value = (None, None) + img_data, img_shape, roi_meta = otx_dataset._get_img_data_and_shape(mock_image, roi) + assert img_data.shape == (8, 8, 3) + assert img_shape == (8, 8) + assert roi_meta == {"x1": 1, "y1": 1, "x2": 9, "y2": 9, "orig_image_shape": (10, 10)} + + def test_cache_img_no_resize(self, otx_dataset): + img_data = np.random.randint(0, 256, (50, 50, 3), dtype=np.uint8) + key = "test_key" + + cached_img = otx_dataset._cache_img(key, img_data) + + assert np.array_equal(cached_img, img_data) + otx_dataset.mem_cache_handler.put.assert_called_once_with(key=key, data=img_data, meta=None) + + def test_cache_img_with_resize(self, otx_dataset, mock_mem_cache_handler): + otx_dataset.mem_cache_img_max_size = (100, 100) + img_data = np.random.randint(0, 256, (200, 200, 3), dtype=np.uint8) + key = "test_key" + + cached_img = otx_dataset._cache_img(key, img_data) + + assert cached_img.shape == (100, 100, 3) + mock_mem_cache_handler.put.assert_called_once() + assert mock_mem_cache_handler.put.call_args[1]["data"].shape == (100, 100, 3) + + def test_cache_img_no_max_size(self, otx_dataset, mock_mem_cache_handler): + otx_dataset.mem_cache_img_max_size = None + img_data = np.random.randint(0, 256, (200, 200, 3), dtype=np.uint8) + key = "test_key" + + cached_img = otx_dataset._cache_img(key, img_data) + + assert np.array_equal(cached_img, img_data) + mock_mem_cache_handler.put.assert_called_once_with(key=key, data=img_data, meta=None) + + def test_cache_img_frozen_handler(self, otx_dataset, mock_mem_cache_handler): + mock_mem_cache_handler.frozen = True + img_data = np.random.randint(0, 256, (200, 200, 3), dtype=np.uint8) + key = "test_key" + + cached_img = otx_dataset._cache_img(key, img_data) + + assert np.array_equal(cached_img, img_data) + mock_mem_cache_handler.put.assert_not_called() From f703e596013858fbfc9cc587f25a9b4579f50f9b Mon Sep 17 00:00:00 2001 From: kprokofi Date: Fri, 8 Nov 2024 20:28:54 +0900 Subject: [PATCH 7/8] add clip and raise an error when coordinates are invalid. --- src/otx/core/data/dataset/base.py | 24 ++++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/src/otx/core/data/dataset/base.py b/src/otx/core/data/dataset/base.py index 93a84a087e..b57bb72581 100644 --- a/src/otx/core/data/dataset/base.py +++ b/src/otx/core/data/dataset/base.py @@ -147,6 +147,19 @@ def _get_img_data_and_shape( img: Image, roi: dict[str, Any] | None = None, ) -> tuple[np.ndarray, tuple[int, int], dict[str, Any] | None]: + """Get image data and shape. + + This method is used to get image data and shape from Datumaro image object. + If ROI is provided, the image data is extracted from the ROI. + + Args: + img (Image): Image object from Datumaro. + roi (dict[str, Any] | None, Optional): Region of interest. + Represented by dict with coordinates and some meta information. + + Returns: + The image data, shape, and ROI meta information + """ key = img.path if isinstance(img, ImageFromFile) else id(img) roi_meta = None @@ -171,11 +184,14 @@ def _get_img_data_and_shape( shape = roi["shape"] h, w = img_data.shape[:2] x1, y1, x2, y2 = ( - int(np.trunc(shape["x1"] * w)), - int(np.trunc(shape["y1"] * h)), - int(np.ceil(shape["x2"] * w)), - int(np.ceil(shape["y2"] * h)), + int(np.clip(np.trunc(shape["x1"] * w), 0, w)), + int(np.clip(np.trunc(shape["y1"] * h), 0, h)), + int(np.clip(np.ceil(shape["x2"] * w), 0, w)), + int(np.clip(np.ceil(shape["y2"] * h), 0, h)), ) + if x1 >= x2 or y1 >= y2: + msg = f"Invalid ROI coordinates: {x1}, {y1}, {x2}, {y2}" + raise ValueError(msg) img_data = img_data[y1:y2, x1:x2] roi_meta = {"x1": x1, "y1": y1, "x2": x2, "y2": y2, "orig_image_shape": (h, w)} From 5563ba9038fe56e5e5633857410fbce4d51de800 Mon Sep 17 00:00:00 2001 From: kprokofi Date: Fri, 8 Nov 2024 21:27:40 +0900 Subject: [PATCH 8/8] rewrite value error --- src/otx/core/data/dataset/base.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/otx/core/data/dataset/base.py b/src/otx/core/data/dataset/base.py index b57bb72581..21e8d349a9 100644 --- a/src/otx/core/data/dataset/base.py +++ b/src/otx/core/data/dataset/base.py @@ -189,9 +189,10 @@ def _get_img_data_and_shape( int(np.clip(np.ceil(shape["x2"] * w), 0, w)), int(np.clip(np.ceil(shape["y2"] * h), 0, h)), ) - if x1 >= x2 or y1 >= y2: - msg = f"Invalid ROI coordinates: {x1}, {y1}, {x2}, {y2}" + if (x2 - x1) * (y2 - y1) <= 0: + msg = f"ROI has zero or negative area. ROI coordinates: {x1}, {y1}, {x2}, {y2}" raise ValueError(msg) + img_data = img_data[y1:y2, x1:x2] roi_meta = {"x1": x1, "y1": y1, "x2": x2, "y2": y2, "orig_image_shape": (h, w)}