Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix a bug in the previous behavior when importing nested datasets in the project #1243

Merged
merged 3 commits into from
Jan 16, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
(<https://github.com/openvinotoolkit/datumaro/pull/1232>)
- Add unit test for item rename
(<https://github.com/openvinotoolkit/datumaro/pull/1237>)
- Fix a bug in the previous behavior when importing nested datasets in the project
(<https://github.com/openvinotoolkit/datumaro/pull/1243>)

## 16/11/2023 - Release 1.5.1
### Enhancements
Expand Down
2 changes: 1 addition & 1 deletion src/datumaro/cli/commands/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@
log.info(f"Source dataset format detected as {fmt}")

if fmt == args.output_format:
log.error("The source data format and the output data format is same as {fmt}.")
log.error(f"The source data format and the output data format is same as {fmt}.")

Check warning on line 138 in src/datumaro/cli/commands/convert.py

View check run for this annotation

Codecov / codecov/patch

src/datumaro/cli/commands/convert.py#L138

Added line #L138 was not covered by tests
return 3

source = osp.abspath(args.source)
Expand Down
48 changes: 31 additions & 17 deletions src/datumaro/components/merge/extractor_merger.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,17 @@
#
# SPDX-License-Identifier: MIT

from typing import Optional, Sequence, TypeVar
from collections import defaultdict
from typing import Dict, Iterator, List, Optional, Sequence, TypeVar

from datumaro.components.contexts.importer import _ImportFail
from datumaro.components.dataset_base import DatasetBase, SubsetBase
from datumaro.components.dataset_base import (
CategoriesInfo,
DatasetBase,
DatasetInfo,
DatasetItem,
SubsetBase,
)

T = TypeVar("T")

Expand Down Expand Up @@ -36,30 +43,37 @@
self._categories = check_identicalness([s.categories() for s in sources])
self._media_type = check_identicalness([s.media_type() for s in sources])
self._is_stream = check_identicalness([s.is_stream for s in sources])
self._subsets = {s.subset: s for s in sources}

def infos(self):
self._subsets: Dict[str, List[SubsetBase]] = defaultdict(list)
for source in sources:
self._subsets[source.subset] += [source]

def infos(self) -> DatasetInfo:
return self._infos

def categories(self):
def categories(self) -> CategoriesInfo:
return self._categories

def __iter__(self):
for subset in self._subsets.values():
yield from subset
def __iter__(self) -> Iterator[DatasetItem]:
for sources in self._subsets.values():
for source in sources:
yield from source

def __len__(self) -> int:
return sum(len(source) for sources in self._subsets.values() for source in sources)

def __len__(self):
return sum(len(subset) for subset in self._subsets.values())
def get(self, id: str, subset: Optional[str] = None) -> Optional[DatasetItem]:
if subset is not None and (sources := self._subsets.get(subset, [])):
for source in sources:
if item := source.get(id, subset):
return item

def get(self, id: str, subset: Optional[str] = None):
if subset is None:
for s in self._subsets.values():
item = s.get(id)
if item is not None:
for sources in self._subsets.values():
for source in sources:
if item := source.get(id=id, subset=source.subset):
return item

s = self._subsets[subset]
return s.get(id, subset)
return None

Check warning on line 76 in src/datumaro/components/merge/extractor_merger.py

View check run for this annotation

Codecov / codecov/patch

src/datumaro/components/merge/extractor_merger.py#L76

Added line #L76 was not covered by tests

@property
def is_stream(self) -> bool:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,6 @@ def __init__(self, sources: Sequence[_CocoBase]):
grouped_by_subset[s.subset] += [s]

self._subsets = {
subset: COCOTaskMergedBase(sources, subset)
subset: [COCOTaskMergedBase(sources, subset)]
for subset, sources in grouped_by_subset.items()
}
58 changes: 28 additions & 30 deletions src/datumaro/plugins/data_formats/yolo/importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,23 @@ def detect(cls, context: FormatDetectionContext) -> None:
context.require_file("obj.data")

@classmethod
def find_sources(cls, path: str) -> List[Dict]:
found = cls._find_sources_recursive(path, ".data", YoloFormatType.yolo_strict.name)
if len(found) == 0:
return []
def find_sources(cls, path: str) -> List[Dict[str, Any]]:
sources = cls._find_sources_recursive(path, ".data", YoloFormatType.yolo_strict.name)

config_path = found[0]["url"]
config = YoloPath._parse_config(config_path)
subsets = [k for k in config if k not in YoloPath.RESERVED_CONFIG_KEYS]
return [
{
"url": config_path,
"format": YoloFormatType.yolo_strict.name,
"options": {"subset": subset},
}
for subset in subsets
]
def _extract_subset_wise_sources(source) -> List[Dict[str, Any]]:
config_path = source["url"]
config = YoloPath._parse_config(config_path)
subsets = [k for k in config if k not in YoloPath.RESERVED_CONFIG_KEYS]
return [
{
"url": config_path,
"format": YoloFormatType.yolo_strict.name,
"options": {"subset": subset},
}
for subset in subsets
]

return sum([_extract_subset_wise_sources(source) for source in sources], [])


class _YoloLooseImporter(Importer):
Expand Down Expand Up @@ -126,7 +127,7 @@ def _filter_ann_file(fpath: str):
def find_sources(cls, path: str) -> List[Dict[str, Any]]:
# Check obj.names first
filename, ext = osp.splitext(cls.META_FILE)
sources = cls._find_sources_recursive(
obj_names_files = cls._find_sources_recursive(
path,
ext=ext,
extractor_name="",
Expand All @@ -135,20 +136,20 @@ def find_sources(cls, path: str) -> List[Dict[str, Any]]:
max_depth=1,
recursive=False,
)
if len(sources) == 0:
if len(obj_names_files) == 0:
return []

# TODO: From Python >= 3.8, we can use
# "if (sources := cls._find_strict(path)): return sources"
sources = cls._find_loose(path, "[Aa]nnotations")
if sources:
return sources
sources = []

sources = cls._find_loose(path, "[Ll]abels")
if sources:
return sources
for obj_names_file in obj_names_files:
base_path = osp.dirname(obj_names_file["url"])
if found := cls._find_loose(base_path, "[Aa]nnotations"):
sources += found

return []
if found := cls._find_loose(path, "[Ll]abels"):
sources += found

return sources

@property
def can_stream(self) -> bool:
Expand Down Expand Up @@ -179,10 +180,7 @@ def detect(cls, context: FormatDetectionContext) -> FormatDetectionConfidence:
@classmethod
def find_sources(cls, path: str) -> List[Dict[str, Any]]:
for importer_cls in cls.SUB_IMPORTERS.values():
# TODO: From Python >= 3.8, we can use
# "if (sources := importer_cls.find_sources(path)): return sources"
sources = importer_cls.find_sources(path)
if sources:
if sources := importer_cls.find_sources(path):
return sources

return []
Expand Down
55 changes: 55 additions & 0 deletions tests/integration/cli/test_yolo_format.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
# Copyright (C) 2023-2024 Intel Corporation
#
# SPDX-License-Identifier: MIT

import os.path as osp
from unittest import TestCase

import numpy as np
import pytest

import datumaro.plugins.data_formats.voc.format as VOC
from datumaro.components.annotation import AnnotationType, Bbox
Expand Down Expand Up @@ -202,3 +207,53 @@ def test_can_delete_labels_from_yolo_dataset(self):

parsed_dataset = Dataset.import_from(export_dir, format="yolo")
compare_datasets(self, target_dataset, parsed_dataset)


# TODO(vinnamki): Migrate above test cases to the pytest framework below
class YoloIntegrationScenariosTest:
@pytest.fixture(params=["annotations", "labels", "strict"])
def fxt_yolo_dir(self, request) -> str:
return get_test_asset_path("yolo_dataset", request.param)

@mark_requirement(Requirements.DATUM_BUG_1204)
wonjuleee marked this conversation as resolved.
Show resolved Hide resolved
def test_can_import_nested_datasets_in_project(self, fxt_yolo_dir, test_dir, helper_tc):
run(helper_tc, "project", "create", "-o", test_dir)

num_total_items = 0
# Import twice
for i in range(1, 3):
run(
helper_tc,
"project",
"import",
"-p",
test_dir,
"-n",
f"dataset_{i}",
"-f",
"yolo",
fxt_yolo_dir,
)

# Reindex to prevent overlapping IDs of dataset items.
run(
helper_tc,
"transform",
"-p",
test_dir,
"-t",
"reindex",
f"dataset_{i}",
"--",
"--start",
f"{i * 10}", # Reindex starting from 0 or 10
)

# Add the number of dataset items for each dataset in the project
num_total_items += len(
Dataset.import_from(osp.join(test_dir, f"dataset_{i}"), format="yolo")
)

# Import a dataset from the project
imported = Dataset.import_from(test_dir, format="yolo")
assert len(imported) == num_total_items
3 changes: 3 additions & 0 deletions tests/requirements.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ class Requirements:
DATUM_BUG_1204 = (
"Statistics raise an error when there is a label annotation not in the category"
)
DATUM_BUG_1214 = (
"Dataset.import_from() can import nested datasets that exist in the given path."
)


class SkipMessages:
Expand Down
Loading