From 36beba13ae85f8b016f42e7adac605cc029a2785 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Fri, 17 Nov 2023 22:14:17 +0800 Subject: [PATCH] Add cache option in `GridPatchDataset` (#7180) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Part of #6904 ### Description - Fix inefficient patching in `PatchDataset` - Add cache option in `GridPatchDataset` ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: KumoLiu Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Juan Pablo de la Cruz GutiƩrrez --- monai/data/grid_dataset.py | 218 ++++++++++++++++++++++++++++++------ tests/test_grid_dataset.py | 55 +++++++-- tests/test_patch_dataset.py | 15 ++- 3 files changed, 242 insertions(+), 46 deletions(-) diff --git a/monai/data/grid_dataset.py b/monai/data/grid_dataset.py index 06954e9f11..9079032e6f 100644 --- a/monai/data/grid_dataset.py +++ b/monai/data/grid_dataset.py @@ -11,18 +11,30 @@ from __future__ import annotations -from collections.abc import Callable, Generator, Hashable, Iterable, Mapping, Sequence +import sys +import warnings +from collections.abc import Callable, Generator, Hashable, Iterable, Iterator, Mapping, Sequence from copy import deepcopy +from multiprocessing.managers import ListProxy +from multiprocessing.pool import ThreadPool +from typing import TYPE_CHECKING import numpy as np +import torch from monai.config import KeysCollection from monai.config.type_definitions import NdarrayTensor -from monai.data.dataset import Dataset from monai.data.iterable_dataset import IterableDataset -from monai.data.utils import iter_patch -from monai.transforms import apply_transform -from monai.utils import NumpyPadMode, ensure_tuple, first +from monai.data.utils import iter_patch, pickle_hashing +from monai.transforms import Compose, RandomizableTrait, Transform, apply_transform, convert_to_contiguous +from monai.utils import NumpyPadMode, ensure_tuple, first, min_version, optional_import + +if TYPE_CHECKING: + from tqdm import tqdm + + has_tqdm = True +else: + tqdm, has_tqdm = optional_import("tqdm", "4.47.0", min_version, "tqdm") __all__ = ["PatchDataset", "GridPatchDataset", "PatchIter", "PatchIterd"] @@ -184,6 +196,25 @@ class GridPatchDataset(IterableDataset): see also: :py:class:`monai.data.PatchIter` or :py:class:`monai.data.PatchIterd`. transform: a callable data transform operates on the patches. with_coordinates: whether to yield the coordinates of each patch, default to `True`. + cache: whether to use cache mache mechanism, default to `False`. + see also: :py:class:`monai.data.CacheDataset`. + cache_num: number of items to be cached. Default is `sys.maxsize`. + will take the minimum of (cache_num, data_length x cache_rate, data_length). + cache_rate: percentage of cached data in total, default is 1.0 (cache all). + will take the minimum of (cache_num, data_length x cache_rate, data_length). + num_workers: the number of worker threads if computing cache in the initialization. + If num_workers is None then the number returned by os.cpu_count() is used. + If a value less than 1 is specified, 1 will be used instead. + progress: whether to display a progress bar. + copy_cache: whether to `deepcopy` the cache content before applying the random transforms, + default to `True`. if the random transforms don't modify the cached content + (for example, randomly crop from the cached image and deepcopy the crop region) + or if every cache item is only used once in a `multi-processing` environment, + may set `copy=False` for better performance. + as_contiguous: whether to convert the cached NumPy array or PyTorch tensor to be contiguous. + it may help improve the performance of following logic. + hash_func: a callable to compute hash from data items to be cached. + defaults to `monai.data.utils.pickle_hashing`. """ @@ -193,27 +224,148 @@ def __init__( patch_iter: Callable, transform: Callable | None = None, with_coordinates: bool = True, + cache: bool = False, + cache_num: int = sys.maxsize, + cache_rate: float = 1.0, + num_workers: int | None = 1, + progress: bool = True, + copy_cache: bool = True, + as_contiguous: bool = True, + hash_func: Callable[..., bytes] = pickle_hashing, ) -> None: super().__init__(data=data, transform=None) + if transform is not None and not isinstance(transform, Compose): + transform = Compose(transform) self.patch_iter = patch_iter self.patch_transform = transform self.with_coordinates = with_coordinates + self.set_num = cache_num + self.set_rate = cache_rate + self.progress = progress + self.copy_cache = copy_cache + self.as_contiguous = as_contiguous + self.hash_func = hash_func + self.num_workers = num_workers + if self.num_workers is not None: + self.num_workers = max(int(self.num_workers), 1) + self._cache: list | ListProxy = [] + self._cache_other: list | ListProxy = [] + self.cache = cache + self.first_random: int | None = None + if self.patch_transform is not None: + self.first_random = self.patch_transform.get_index_of_first( + lambda t: isinstance(t, RandomizableTrait) or not isinstance(t, Transform) + ) - def __iter__(self): - for image in super().__iter__(): - for patch, *others in self.patch_iter(image): - out_patch = patch - if self.patch_transform is not None: - out_patch = apply_transform(self.patch_transform, patch, map_items=False) - if self.with_coordinates and len(others) > 0: # patch_iter to yield at least 2 items: patch, coords - yield out_patch, others[0] - else: - yield out_patch + if self.cache: + if isinstance(data, Iterator): + raise TypeError("Data can not be iterator when cache is True") + self.set_data(data) # type: ignore + + def set_data(self, data: Sequence) -> None: + """ + Set the input data and run deterministic transforms to generate cache content. + + Note: should call this func after an entire epoch and must set `persistent_workers=False` + in PyTorch DataLoader, because it needs to create new worker processes based on new + generated cache content. + + """ + self.data = data + + # only compute cache for the unique items of dataset, and record the last index for duplicated items + mapping = {self.hash_func(v): i for i, v in enumerate(self.data)} + self.cache_num = min(int(self.set_num), int(len(mapping) * self.set_rate), len(mapping)) + self._hash_keys = list(mapping)[: self.cache_num] + indices = list(mapping.values())[: self.cache_num] + self._cache, self._cache_other = zip(*self._fill_cache(indices)) # type: ignore + + def _fill_cache(self, indices=None) -> list: + """ + Compute and fill the cache content from data source. + + Args: + indices: target indices in the `self.data` source to compute cache. + if None, use the first `cache_num` items. + + """ + if self.cache_num <= 0: + return [] + if indices is None: + indices = list(range(self.cache_num)) + if self.progress and not has_tqdm: + warnings.warn("tqdm is not installed, will not show the caching progress bar.") + + pfunc = tqdm if self.progress and has_tqdm else (lambda v, **_: v) + with ThreadPool(self.num_workers) as p: + return list(pfunc(p.imap(self._load_cache_item, indices), total=len(indices), desc="Loading dataset")) + + def _load_cache_item(self, idx: int): + """ + Args: + idx: the index of the input data sequence. + """ + item = self.data[idx] # type: ignore + patch_cache, other_cache = [], [] + for patch, *others in self.patch_iter(item): + if self.first_random is not None: + patch = self.patch_transform(patch, end=self.first_random, threading=True) # type: ignore + + if self.as_contiguous: + patch = convert_to_contiguous(patch, memory_format=torch.contiguous_format) + if self.with_coordinates and len(others) > 0: # patch_iter to yield at least 2 items: patch, coords + other_cache.append(others[0]) + patch_cache.append(patch) + return patch_cache, other_cache + + def _generate_patches(self, src, **apply_args): + """ + yield patches optionally post-processed by transform. + Args: + src: a iterable of image patches. + apply_args: other args for `self.patch_transform`. + + """ + for patch, *others in src: + out_patch = patch + if self.patch_transform is not None: + out_patch = self.patch_transform(patch, **apply_args) + if self.with_coordinates and len(others) > 0: # patch_iter to yield at least 2 items: patch, coords + yield out_patch, others[0] + else: + yield out_patch -class PatchDataset(Dataset): + def __iter__(self): + if self.cache: + cache_index = None + for image in super().__iter__(): + key = self.hash_func(image) + if key in self._hash_keys: + # if existing in cache, try to get the index in cache + cache_index = self._hash_keys.index(key) + if cache_index is None: + # no cache for this index, execute all the transforms directly + yield from self._generate_patches(self.patch_iter(image)) + else: + if self._cache is None: + raise RuntimeError( + "Cache buffer is not initialized, please call `set_data()` before epoch begins." + ) + data = self._cache[cache_index] # type: ignore + other = self._cache_other[cache_index] # type: ignore + + # load data from cache and execute from the first random transform + data = deepcopy(data) if self.copy_cache else data + yield from self._generate_patches(zip(data, other), start=self.first_random) + else: + for image in super().__iter__(): + yield from self._generate_patches(self.patch_iter(image)) + + +class PatchDataset(IterableDataset): """ - returns a patch from an image dataset. + Yields patches from data read from an image dataset. The patches are generated by a user-specified callable `patch_func`, and are optionally post-processed by `transform`. For example, to generate random patch samples from an image dataset: @@ -263,26 +415,26 @@ def __init__( samples_per_image: `patch_func` should return a sequence of `samples_per_image` elements. transform: transform applied to each patch. """ - super().__init__(data=data, transform=transform) + super().__init__(data=data, transform=None) self.patch_func = patch_func if samples_per_image <= 0: raise ValueError("sampler_per_image must be a positive integer.") self.samples_per_image = int(samples_per_image) + self.patch_transform = transform def __len__(self) -> int: - return len(self.data) * self.samples_per_image - - def _transform(self, index: int): - image_id = int(index / self.samples_per_image) - image = self.data[image_id] - patches = self.patch_func(image) - if len(patches) != self.samples_per_image: - raise RuntimeWarning( - f"`patch_func` must return a sequence of length: samples_per_image={self.samples_per_image}." - ) - patch_id = (index - image_id * self.samples_per_image) * (-1 if index < 0 else 1) - patch = patches[patch_id] - if self.transform is not None: - patch = apply_transform(self.transform, patch, map_items=False) - return patch + return len(self.data) * self.samples_per_image # type: ignore + + def __iter__(self): + for image in super().__iter__(): + patches = self.patch_func(image) + if len(patches) != self.samples_per_image: + raise RuntimeWarning( + f"`patch_func` must return a sequence of length: samples_per_image={self.samples_per_image}." + ) + for patch in patches: + out_patch = patch + if self.patch_transform is not None: + out_patch = apply_transform(self.patch_transform, patch, map_items=False) + yield out_patch diff --git a/tests/test_grid_dataset.py b/tests/test_grid_dataset.py index ba33547260..d937a5e266 100644 --- a/tests/test_grid_dataset.py +++ b/tests/test_grid_dataset.py @@ -108,11 +108,10 @@ def test_shape(self): self.assertEqual(sorted(output), sorted(expected)) def test_loading_array(self): - set_determinism(seed=1234) # test sequence input data with images images = [np.arange(16, dtype=float).reshape(1, 4, 4), np.arange(16, dtype=float).reshape(1, 4, 4)] # image level - patch_intensity = RandShiftIntensity(offsets=1.0, prob=1.0) + patch_intensity = RandShiftIntensity(offsets=1.0, prob=1.0).set_random_state(seed=1234) patch_iter = PatchIter(patch_size=(2, 2), start_pos=(0, 0)) ds = GridPatchDataset(data=images, patch_iter=patch_iter, transform=patch_intensity) # use the grid patch dataset @@ -120,7 +119,7 @@ def test_loading_array(self): np.testing.assert_equal(tuple(item[0].shape), (2, 1, 2, 2)) np.testing.assert_allclose( item[0], - np.array([[[[8.240326, 9.240326], [12.240326, 13.240326]]], [[[10.1624, 11.1624], [14.1624, 15.1624]]]]), + np.array([[[[8.708934, 9.708934], [12.708934, 13.708934]]], [[[10.8683, 11.8683], [14.8683, 15.8683]]]]), rtol=1e-4, ) np.testing.assert_allclose(item[1], np.array([[[0, 1], [2, 4], [0, 2]], [[0, 1], [2, 4], [2, 4]]]), rtol=1e-5) @@ -129,9 +128,7 @@ def test_loading_array(self): np.testing.assert_equal(tuple(item[0].shape), (2, 1, 2, 2)) np.testing.assert_allclose( item[0], - np.array( - [[[[7.723618, 8.723618], [11.723618, 12.723618]]], [[[10.7175, 11.7175], [14.7175, 15.7175]]]] - ), + np.array([[[[7.27427, 8.27427], [11.27427, 12.27427]]], [[[9.4353, 10.4353], [13.4353, 14.4353]]]]), rtol=1e-3, ) np.testing.assert_allclose( @@ -164,7 +161,7 @@ def test_loading_dict(self): self.assertListEqual(item[0]["metadata"], ["test string", "test string"]) np.testing.assert_allclose( item[0]["image"], - np.array([[[[8.240326, 9.240326], [12.240326, 13.240326]]], [[[10.1624, 11.1624], [14.1624, 15.1624]]]]), + np.array([[[[8.708934, 9.708934], [12.708934, 13.708934]]], [[[10.8683, 11.8683], [14.8683, 15.8683]]]]), rtol=1e-4, ) np.testing.assert_allclose(item[1], np.array([[[0, 1], [2, 4], [0, 2]], [[0, 1], [2, 4], [2, 4]]]), rtol=1e-5) @@ -173,15 +170,53 @@ def test_loading_dict(self): np.testing.assert_equal(item[0]["image"].shape, (2, 1, 2, 2)) np.testing.assert_allclose( item[0]["image"], - np.array( - [[[[7.723618, 8.723618], [11.723618, 12.723618]]], [[[10.7175, 11.7175], [14.7175, 15.7175]]]] - ), + np.array([[[[7.27427, 8.27427], [11.27427, 12.27427]]], [[[9.4353, 10.4353], [13.4353, 14.4353]]]]), rtol=1e-3, ) np.testing.assert_allclose( item[1], np.array([[[0, 1], [2, 4], [0, 2]], [[0, 1], [2, 4], [2, 4]]]), rtol=1e-5 ) + def test_set_data(self): + from monai.transforms import Compose, Lambda, RandLambda + + images = [np.arange(2, 18, dtype=float).reshape(1, 4, 4), np.arange(16, dtype=float).reshape(1, 4, 4)] + + transform = Compose( + [Lambda(func=lambda x: np.array(x * 10)), RandLambda(func=lambda x: x + 1)], map_items=False + ) + patch_iter = PatchIter(patch_size=(2, 2), start_pos=(0, 0)) + dataset = GridPatchDataset( + data=images, + patch_iter=patch_iter, + transform=transform, + cache=True, + cache_rate=1.0, + copy_cache=not sys.platform == "linux", + ) + + num_workers = 2 if sys.platform == "linux" else 0 + for item in DataLoader(dataset, batch_size=2, shuffle=False, num_workers=num_workers): + np.testing.assert_equal(tuple(item[0].shape), (2, 1, 2, 2)) + np.testing.assert_allclose(item[0], np.array([[[[81, 91], [121, 131]]], [[[101, 111], [141, 151]]]]), rtol=1e-4) + np.testing.assert_allclose(item[1], np.array([[[0, 1], [2, 4], [0, 2]], [[0, 1], [2, 4], [2, 4]]]), rtol=1e-5) + # simulate another epoch, the cache content should not be modified + for item in DataLoader(dataset, batch_size=2, shuffle=False, num_workers=num_workers): + np.testing.assert_equal(tuple(item[0].shape), (2, 1, 2, 2)) + np.testing.assert_allclose(item[0], np.array([[[[81, 91], [121, 131]]], [[[101, 111], [141, 151]]]]), rtol=1e-4) + np.testing.assert_allclose(item[1], np.array([[[0, 1], [2, 4], [0, 2]], [[0, 1], [2, 4], [2, 4]]]), rtol=1e-5) + + # update the datalist and fill the cache content + data_list2 = [np.arange(1, 17, dtype=float).reshape(1, 4, 4)] + dataset.set_data(data=data_list2) + # rerun with updated cache content + for item in DataLoader(dataset, batch_size=2, shuffle=False, num_workers=num_workers): + np.testing.assert_equal(tuple(item[0].shape), (2, 1, 2, 2)) + np.testing.assert_allclose( + item[0], np.array([[[[91, 101], [131, 141]]], [[[111, 121], [151, 161]]]]), rtol=1e-4 + ) + np.testing.assert_allclose(item[1], np.array([[[0, 1], [2, 4], [0, 2]], [[0, 1], [2, 4], [2, 4]]]), rtol=1e-5) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_patch_dataset.py b/tests/test_patch_dataset.py index 7d66bdccbb..eb705f0c61 100644 --- a/tests/test_patch_dataset.py +++ b/tests/test_patch_dataset.py @@ -37,7 +37,10 @@ def test_shape(self): n_workers = 0 if sys.platform == "win32" else 2 for item in DataLoader(result, batch_size=3, num_workers=n_workers): output.append("".join(item)) - expected = ["vwx", "yzh", "ell", "owo", "rld"] + if n_workers == 0: + expected = ["vwx", "yzh", "ell", "owo", "rld"] + else: + expected = ["vwx", "hel", "yzw", "lo", "orl", "d"] self.assertEqual(output, expected) def test_loading_array(self): @@ -61,7 +64,7 @@ def test_loading_array(self): np.testing.assert_allclose( item[0], np.array( - [[[-0.593095, 0.406905, 1.406905], [3.406905, 4.406905, 5.406905], [7.406905, 8.406905, 9.406905]]] + [[[4.970372, 5.970372, 6.970372], [8.970372, 9.970372, 10.970372], [12.970372, 13.970372, 14.970372]]] ), rtol=1e-5, ) @@ -71,7 +74,13 @@ def test_loading_array(self): np.testing.assert_allclose( item[0], np.array( - [[[0.234308, 1.234308, 2.234308], [4.234308, 5.234308, 6.234308], [8.234308, 9.234308, 10.234308]]] + [ + [ + [5.028125, 6.028125, 7.028125], + [9.028125, 10.028125, 11.028125], + [13.028125, 14.028125, 15.028125], + ] + ] ), rtol=1e-5, )