Skip to content

Commit

Permalink
add and update test codes regarding video dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
Yi, Jihyeon committed May 8, 2024
1 parent 4f74c60 commit 29754ee
Show file tree
Hide file tree
Showing 8 changed files with 238 additions and 77 deletions.
65 changes: 56 additions & 9 deletions src/datumaro/components/media.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def media(self) -> Optional[Type[MediaElement]]:
class MediaElement(Generic[AnyData]):
_type = MediaType.MEDIA_ELEMENT

def __init__(self, crypter: Crypter = NULL_CRYPTER) -> None:
def __init__(self, crypter: Crypter = NULL_CRYPTER, *args, **kwargs) -> None:
self._crypter = crypter

def as_dict(self) -> Dict[str, Any]:
Expand Down Expand Up @@ -488,6 +488,26 @@ def video(self) -> Video:
def path(self) -> str:
return self._video.path

def from_self(self, **kwargs):
attrs = deepcopy(self.as_dict())
if "path" in kwargs:
attrs.update({"video": self.video.from_self(**kwargs)})
kwargs.pop("path")
attrs.update(kwargs)
return self.__class__(**attrs)

def __getstate__(self):
# Return only the picklable parts of the state.
state = self.__dict__.copy()
del state["_data"]
return state

def __setstate__(self, state):
# Restore the objects' state.
self.__dict__.update(state)
# Reinitialize unpichlable attributes
self._data = lambda: self._video.get_frame_data(self._index)


class _VideoFrameIterator(Iterator[VideoFrame]):
"""
Expand Down Expand Up @@ -527,6 +547,11 @@ def _decode(self, cap) -> Iterator[VideoFrame]:

if self._video._frame_count is None:
self._video._frame_count = self._pos + 1
if self._video._end_frame and self._video._end_frame >= self._video._frame_count:
raise ValueError(
f"The end_frame value({self._video._end_frame}) of the video "
f"must be less than the frame count({self._video._frame_count})."
)

def _make_frame(self, index) -> VideoFrame:
return VideoFrame(self._video, index=index)
Expand Down Expand Up @@ -575,14 +600,22 @@ class Video(MediaElement, Iterable[VideoFrame]):
"""

def __init__(
self, path: str, *, step: int = 1, start_frame: int = 0, end_frame: Optional[int] = None
self,
path: str,
step: int = 1,
start_frame: int = 0,
end_frame: Optional[int] = None,
*args,
**kwargs,
) -> None:
super().__init__()
super().__init__(*args, **kwargs)
self._path = path

assert 0 <= start_frame
if end_frame:
assert start_frame <= end_frame
# we can't know the video length here,
# so we cannot validate if the end_frame is valid.
assert 0 < step
self._step = step
self._start_frame = start_frame
Expand Down Expand Up @@ -727,12 +760,26 @@ def __eq__(self, other: object) -> bool:
if not isinstance(other, __class__):
return False

return (
self.path == other.path
and self._start_frame == other._start_frame
and self._step == other._step
and self._end_frame == other._end_frame
)
if (
self._start_frame != other._start_frame
or self._step != other._step
or self._end_frame != other._end_frame
):
return False

# The video path can vary if a dataset is copied.
# So, we need to check if the video data is the same instead of checking paths.
if self._end_frame is None:
# Decoding is not necessary to get frame pointers
# However, it can be inacurrate
end_frame = self._get_end_frame()
for index in range(self._start_frame, end_frame + 1, self._step):
yield VideoFrame(video=self, index=index)
for frame_self, frame_other in zip(self, other):
if frame_self != frame_other:
return False

return True

def __hash__(self):
# Required for caching
Expand Down
3 changes: 2 additions & 1 deletion src/datumaro/plugins/data_formats/datumaro_binary/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,8 @@ def _read_items(self) -> None:
media_path_prefix = {
MediaType.IMAGE: osp.join(self._images_dir, self._subset),
MediaType.POINT_CLOUD: osp.join(self._pcd_dir, self._subset),
MediaType.VIDEO_FRAME: self._video_dir,
MediaType.VIDEO: osp.join(self._video_dir, self._subset),
MediaType.VIDEO_FRAME: osp.join(self._video_dir, self._subset),
}

if self._num_workers > 0:
Expand Down
52 changes: 51 additions & 1 deletion tests/unit/data_formats/datumaro/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,13 @@
RleMask,
)
from datumaro.components.dataset_base import DatasetItem
from datumaro.components.media import Image, PointCloud
from datumaro.components.media import Image, MediaElement, PointCloud, Video, VideoFrame
from datumaro.components.project import Dataset
from datumaro.plugins.data_formats.datumaro.format import DatumaroPath
from datumaro.util.mask_tools import generate_colormap

from tests.utils.video import make_sample_video


@pytest.fixture
def fxt_test_datumaro_format_dataset():
Expand Down Expand Up @@ -199,6 +201,54 @@ def fxt_test_datumaro_format_dataset():
)


@pytest.fixture
def fxt_test_datumaro_format_video_dataset(test_dir) -> Dataset:
video_path = osp.join(test_dir, "video.avi")
make_sample_video(video_path, frame_size=(4, 6), frames=4)
video = Video(video_path)

return Dataset.from_iterable(
iterable=[
DatasetItem(
"f0",
subset="train",
media=VideoFrame(video, 0),
annotations=[
Bbox(1, 1, 1, 1, label=0, object_id=0),
Bbox(2, 2, 2, 2, label=1, object_id=1),
],
),
DatasetItem(
"f1",
subset="test",
media=VideoFrame(video, 0),
annotations=[
Bbox(0, 0, 2, 2, label=1, object_id=1),
Bbox(3, 3, 1, 1, label=0, object_id=0),
],
),
DatasetItem(
"v0",
subset="train",
media=Video(video_path, step=1, start_frame=0, end_frame=1),
annotations=[
Label(0),
],
),
DatasetItem(
"v1",
subset="test",
media=Video(video_path, step=1, start_frame=2, end_frame=2),
annotations=[
Bbox(1, 1, 3, 3, label=1, object_id=1),
],
),
],
media_type=MediaElement,
categories=["a", "b"],
)


@pytest.fixture
def fxt_wrong_version_dir(fxt_test_datumaro_format_dataset, test_dir):
dest_dir = osp.join(test_dir, "wrong_version")
Expand Down
22 changes: 13 additions & 9 deletions tests/unit/data_formats/datumaro/test_datumaro_binary_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,42 +39,41 @@ class DatumaroBinaryFormatTest(TestBase):
ann_ext = DatumaroBinaryPath.ANNOTATION_EXT

@pytest.mark.parametrize(
["fxt_dataset", "compare", "require_media", "fxt_import_kwargs", "fxt_export_kwargs"],
"fxt_dataset",
("fxt_test_datumaro_format_dataset", "fxt_test_datumaro_format_video_dataset"),
)
@pytest.mark.parametrize(
["compare", "require_media", "fxt_import_kwargs", "fxt_export_kwargs"],
[
pytest.param(
"fxt_test_datumaro_format_dataset",
compare_datasets_strict,
True,
{},
{},
id="test_no_encryption",
),
pytest.param(
"fxt_test_datumaro_format_dataset",
compare_datasets_strict,
True,
{"encryption_key": ENCRYPTION_KEY},
{"encryption_key": ENCRYPTION_KEY},
id="test_with_encryption",
),
pytest.param(
"fxt_test_datumaro_format_dataset",
compare_datasets_strict,
True,
{"encryption_key": ENCRYPTION_KEY},
{"encryption_key": ENCRYPTION_KEY, "no_media_encryption": True},
id="test_no_media_encryption",
),
pytest.param(
"fxt_test_datumaro_format_dataset",
compare_datasets_strict,
True,
{"encryption_key": ENCRYPTION_KEY},
{"encryption_key": ENCRYPTION_KEY, "max_blob_size": 1}, # 1 byte
id="test_multi_blobs",
),
pytest.param(
"fxt_test_datumaro_format_dataset",
compare_datasets_strict,
True,
{"encryption_key": ENCRYPTION_KEY, "num_workers": 2},
Expand Down Expand Up @@ -167,10 +166,15 @@ def _get_ann_mapper(ann: Annotation) -> AnnotationMapper:
def test_common_mapper(self, mapper: Mapper, expected: Any):
self._test(mapper, expected)

def test_annotations_mapper(self, fxt_test_datumaro_format_dataset):
"""Test all annotations in fxt_test_datumaro_format_dataset"""
@pytest.mark.parametrize(
"fxt_dataset",
("fxt_test_datumaro_format_dataset", "fxt_test_datumaro_format_video_dataset"),
)
def test_annotations_mapper(self, fxt_dataset, request):
"""Test all annotations in fxt_dataset"""
mapper = DatasetItemMapper
for item in fxt_test_datumaro_format_dataset:
fxt_dataset = request.getfixturevalue(fxt_dataset)
for item in fxt_dataset:
for ann in item.annotations:
mapper = self._get_ann_mapper(ann)
self._test(mapper, ann)
Expand Down
21 changes: 19 additions & 2 deletions tests/unit/data_formats/datumaro/test_datumaro_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,18 @@ def _test_save_and_load(
False,
id="test_can_save_and_load_with_no_save_media",
),
pytest.param(
"fxt_test_datumaro_format_video_dataset",
compare_datasets,
True,
id="test_can_save_and_load_video_dataset",
),
pytest.param(
"fxt_test_datumaro_format_video_dataset",
None,
False,
id="test_can_save_and_load_video_dataset_with_no_save_media",
),
pytest.param(
"fxt_relative_paths",
compare_datasets,
Expand Down Expand Up @@ -176,8 +188,13 @@ def test_source_target_pair(
)

@mark_requirement(Requirements.DATUM_GENERAL_REQ)
def test_can_detect(self, fxt_test_datumaro_format_dataset, test_dir):
self.exporter.convert(fxt_test_datumaro_format_dataset, save_dir=test_dir)
@pytest.mark.parametrize(
"fxt_dataset",
("fxt_test_datumaro_format_dataset", "fxt_test_datumaro_format_video_dataset"),
)
def test_can_detect(self, fxt_dataset, test_dir, request):
fxt_dataset = request.getfixturevalue(fxt_dataset)
self.exporter.convert(fxt_dataset, save_dir=test_dir)

detected_formats = Environment().detect_dataset(test_dir)
assert [self.importer.NAME] == detected_formats
Expand Down
Loading

0 comments on commit 29754ee

Please sign in to comment.