From d41b2cb47d59037c5600896a280156616096cd50 Mon Sep 17 00:00:00 2001 From: "Kim, Vinnam" Date: Mon, 15 Jan 2024 14:39:28 +0900 Subject: [PATCH 1/3] Fix a bug in the previous behavior when importing nested datasets in the project Signed-off-by: Kim, Vinnam --- src/datumaro/cli/commands/convert.py | 2 +- .../components/merge/extractor_merger.py | 48 +++++++++------ .../data_formats/coco/extractor_merger.py | 2 +- .../plugins/data_formats/yolo/importer.py | 58 +++++++++---------- tests/integration/cli/test_yolo_format.py | 51 ++++++++++++++++ tests/requirements.py | 3 + 6 files changed, 115 insertions(+), 49 deletions(-) diff --git a/src/datumaro/cli/commands/convert.py b/src/datumaro/cli/commands/convert.py index 828e77144e..b23a12d0db 100644 --- a/src/datumaro/cli/commands/convert.py +++ b/src/datumaro/cli/commands/convert.py @@ -135,7 +135,7 @@ def convert_command(args): log.info(f"Source dataset format detected as {fmt}") if fmt == args.output_format: - log.error("The source data format and the output data format is same as {fmt}.") + log.error(f"The source data format and the output data format is same as {fmt}.") return 3 source = osp.abspath(args.source) diff --git a/src/datumaro/components/merge/extractor_merger.py b/src/datumaro/components/merge/extractor_merger.py index 725b39a2f0..f87dd18f0d 100644 --- a/src/datumaro/components/merge/extractor_merger.py +++ b/src/datumaro/components/merge/extractor_merger.py @@ -2,10 +2,17 @@ # # SPDX-License-Identifier: MIT -from typing import Optional, Sequence, TypeVar +from collections import defaultdict +from typing import Dict, Iterator, List, Optional, Sequence, TypeVar from datumaro.components.contexts.importer import _ImportFail -from datumaro.components.dataset_base import DatasetBase, SubsetBase +from datumaro.components.dataset_base import ( + CategoriesInfo, + DatasetBase, + DatasetInfo, + DatasetItem, + SubsetBase, +) T = TypeVar("T") @@ -36,30 +43,37 @@ def __init__( self._categories = check_identicalness([s.categories() for s in sources]) self._media_type = check_identicalness([s.media_type() for s in sources]) self._is_stream = check_identicalness([s.is_stream for s in sources]) - self._subsets = {s.subset: s for s in sources} - def infos(self): + self._subsets: Dict[str, List[SubsetBase]] = defaultdict(list) + for source in sources: + self._subsets[source.subset] += [source] + + def infos(self) -> DatasetInfo: return self._infos - def categories(self): + def categories(self) -> CategoriesInfo: return self._categories - def __iter__(self): - for subset in self._subsets.values(): - yield from subset + def __iter__(self) -> Iterator[DatasetItem]: + for sources in self._subsets.values(): + for source in sources: + yield from source + + def __len__(self) -> int: + return sum(len(source) for sources in self._subsets.values() for source in sources) - def __len__(self): - return sum(len(subset) for subset in self._subsets.values()) + def get(self, id: str, subset: Optional[str] = None) -> Optional[DatasetItem]: + if subset is not None and (sources := self._subsets.get(subset, [])): + for source in sources: + if item := source.get(id, subset): + return item - def get(self, id: str, subset: Optional[str] = None): - if subset is None: - for s in self._subsets.values(): - item = s.get(id) - if item is not None: + for sources in self._subsets.values(): + for source in sources: + if item := source.get(id=id, subset=source.subset): return item - s = self._subsets[subset] - return s.get(id, subset) + return None @property def is_stream(self) -> bool: diff --git a/src/datumaro/plugins/data_formats/coco/extractor_merger.py b/src/datumaro/plugins/data_formats/coco/extractor_merger.py index bdbec9e95e..e2422d724e 100644 --- a/src/datumaro/plugins/data_formats/coco/extractor_merger.py +++ b/src/datumaro/plugins/data_formats/coco/extractor_merger.py @@ -85,6 +85,6 @@ def __init__(self, sources: Sequence[_CocoBase]): grouped_by_subset[s.subset] += [s] self._subsets = { - subset: COCOTaskMergedBase(sources, subset) + subset: [COCOTaskMergedBase(sources, subset)] for subset, sources in grouped_by_subset.items() } diff --git a/src/datumaro/plugins/data_formats/yolo/importer.py b/src/datumaro/plugins/data_formats/yolo/importer.py index 4f605412cc..83f510672f 100644 --- a/src/datumaro/plugins/data_formats/yolo/importer.py +++ b/src/datumaro/plugins/data_formats/yolo/importer.py @@ -22,22 +22,23 @@ def detect(cls, context: FormatDetectionContext) -> None: context.require_file("obj.data") @classmethod - def find_sources(cls, path: str) -> List[Dict]: - found = cls._find_sources_recursive(path, ".data", YoloFormatType.yolo_strict.name) - if len(found) == 0: - return [] + def find_sources(cls, path: str) -> List[Dict[str, Any]]: + sources = cls._find_sources_recursive(path, ".data", YoloFormatType.yolo_strict.name) - config_path = found[0]["url"] - config = YoloPath._parse_config(config_path) - subsets = [k for k in config if k not in YoloPath.RESERVED_CONFIG_KEYS] - return [ - { - "url": config_path, - "format": YoloFormatType.yolo_strict.name, - "options": {"subset": subset}, - } - for subset in subsets - ] + def _extract_subset_wise_sources(source) -> List[Dict[str, Any]]: + config_path = source["url"] + config = YoloPath._parse_config(config_path) + subsets = [k for k in config if k not in YoloPath.RESERVED_CONFIG_KEYS] + return [ + { + "url": config_path, + "format": YoloFormatType.yolo_strict.name, + "options": {"subset": subset}, + } + for subset in subsets + ] + + return sum([_extract_subset_wise_sources(source) for source in sources], []) class _YoloLooseImporter(Importer): @@ -126,7 +127,7 @@ def _filter_ann_file(fpath: str): def find_sources(cls, path: str) -> List[Dict[str, Any]]: # Check obj.names first filename, ext = osp.splitext(cls.META_FILE) - sources = cls._find_sources_recursive( + obj_names_files = cls._find_sources_recursive( path, ext=ext, extractor_name="", @@ -135,20 +136,20 @@ def find_sources(cls, path: str) -> List[Dict[str, Any]]: max_depth=1, recursive=False, ) - if len(sources) == 0: + if len(obj_names_files) == 0: return [] - # TODO: From Python >= 3.8, we can use - # "if (sources := cls._find_strict(path)): return sources" - sources = cls._find_loose(path, "[Aa]nnotations") - if sources: - return sources + sources = [] - sources = cls._find_loose(path, "[Ll]abels") - if sources: - return sources + for obj_names_file in obj_names_files: + base_path = osp.dirname(obj_names_file["url"]) + if found := cls._find_loose(base_path, "[Aa]nnotations"): + sources += found - return [] + if found := cls._find_loose(path, "[Ll]abels"): + sources += found + + return sources @property def can_stream(self) -> bool: @@ -179,10 +180,7 @@ def detect(cls, context: FormatDetectionContext) -> FormatDetectionConfidence: @classmethod def find_sources(cls, path: str) -> List[Dict[str, Any]]: for importer_cls in cls.SUB_IMPORTERS.values(): - # TODO: From Python >= 3.8, we can use - # "if (sources := importer_cls.find_sources(path)): return sources" - sources = importer_cls.find_sources(path) - if sources: + if sources := importer_cls.find_sources(path): return sources return [] diff --git a/tests/integration/cli/test_yolo_format.py b/tests/integration/cli/test_yolo_format.py index f815ab67ed..eaed9c07a6 100644 --- a/tests/integration/cli/test_yolo_format.py +++ b/tests/integration/cli/test_yolo_format.py @@ -2,6 +2,7 @@ from unittest import TestCase import numpy as np +import pytest import datumaro.plugins.data_formats.voc.format as VOC from datumaro.components.annotation import AnnotationType, Bbox @@ -202,3 +203,53 @@ def test_can_delete_labels_from_yolo_dataset(self): parsed_dataset = Dataset.import_from(export_dir, format="yolo") compare_datasets(self, target_dataset, parsed_dataset) + + +# TODO(vinnamki): Migrate above test cases to the pytest framework below +class YoloIntegrationScenariosTest: + @pytest.fixture(params=["annotations", "labels", "strict"]) + def fxt_yolo_dir(self, request) -> str: + return get_test_asset_path("yolo_dataset", request.param) + + @mark_requirement(Requirements.DATUM_BUG_1204) + def test_can_import_nested_datasets_in_project(self, fxt_yolo_dir, test_dir, helper_tc): + run(helper_tc, "project", "create", "-o", test_dir) + + num_total_items = 0 + # Import twice + for i in range(1, 3): + run( + helper_tc, + "project", + "import", + "-p", + test_dir, + "-n", + f"dataset_{i}", + "-f", + "yolo", + fxt_yolo_dir, + ) + + # Reindex to prevent overlapping IDs of dataset items. + run( + helper_tc, + "transform", + "-p", + test_dir, + "-t", + "reindex", + f"dataset_{i}", + "--", + "--start", + f"{i * 10}", # Reindex starting from 0 or 10 + ) + + # Add the number of dataset items for each dataset in the project + num_total_items += len( + Dataset.import_from(osp.join(test_dir, f"dataset_{i}"), format="yolo") + ) + + # Import a dataset from the project + imported = Dataset.import_from(test_dir, format="yolo") + assert len(imported) == num_total_items diff --git a/tests/requirements.py b/tests/requirements.py index bf2a160c27..d12813a212 100644 --- a/tests/requirements.py +++ b/tests/requirements.py @@ -64,6 +64,9 @@ class Requirements: DATUM_BUG_1204 = ( "Statistics raise an error when there is a label annotation not in the category" ) + DATUM_BUG_1214 = ( + "Dataset.import_from() can import nested datasets that exist in the given path." + ) class SkipMessages: From bf9e6b7786fa58d51c3fa86c501f7f549008b43c Mon Sep 17 00:00:00 2001 From: "Kim, Vinnam" Date: Mon, 15 Jan 2024 14:48:27 +0900 Subject: [PATCH 2/3] Update CHANGELOG.md Signed-off-by: Kim, Vinnam --- CHANGELOG.md | 2 ++ tests/integration/cli/test_yolo_format.py | 4 ++++ 2 files changed, 6 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4578e4eeb0..89d923d4fb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -47,6 +47,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 () - Add unit test for item rename () +- Fix a bug in the previous behavior when importing nested datasets in the project + () ## 16/11/2023 - Release 1.5.1 ### Enhancements diff --git a/tests/integration/cli/test_yolo_format.py b/tests/integration/cli/test_yolo_format.py index eaed9c07a6..247f8d3d06 100644 --- a/tests/integration/cli/test_yolo_format.py +++ b/tests/integration/cli/test_yolo_format.py @@ -1,3 +1,7 @@ +# Copyright (C) 2023-2024 Intel Corporation +# +# SPDX-License-Identifier: MIT + import os.path as osp from unittest import TestCase From e0d266b75c384afef62dd08a61d8cbac7a0af8d6 Mon Sep 17 00:00:00 2001 From: "Kim, Vinnam" Date: Mon, 15 Jan 2024 17:24:00 +0900 Subject: [PATCH 3/3] Fix incorrect req num Signed-off-by: Kim, Vinnam --- tests/integration/cli/test_yolo_format.py | 2 +- tests/unit/operations/test_statistics.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/integration/cli/test_yolo_format.py b/tests/integration/cli/test_yolo_format.py index 247f8d3d06..8bd631b92c 100644 --- a/tests/integration/cli/test_yolo_format.py +++ b/tests/integration/cli/test_yolo_format.py @@ -215,7 +215,7 @@ class YoloIntegrationScenariosTest: def fxt_yolo_dir(self, request) -> str: return get_test_asset_path("yolo_dataset", request.param) - @mark_requirement(Requirements.DATUM_BUG_1204) + @mark_requirement(Requirements.DATUM_BUG_1214) def test_can_import_nested_datasets_in_project(self, fxt_yolo_dir, test_dir, helper_tc): run(helper_tc, "project", "create", "-o", test_dir) diff --git a/tests/unit/operations/test_statistics.py b/tests/unit/operations/test_statistics.py index a3a488615b..557a124632 100644 --- a/tests/unit/operations/test_statistics.py +++ b/tests/unit/operations/test_statistics.py @@ -334,7 +334,7 @@ def test_stats_with_empty_dataset(self): actual = compute_ann_statistics(dataset) assert actual == expected - @mark_requirement(Requirements.DATUM_BUG_1204) + @mark_requirement(Requirements.DATUM_BUG_1214) def test_stats_with_invalid_label(self): label_names = ["label_%s" % i for i in range(3)] dataset = Dataset.from_iterable(