From 579fe65c96dce810ab06fcdbf37c142de8ef6815 Mon Sep 17 00:00:00 2001 From: "Dr. Behrooz Hashemian" <3968947+drbeh@users.noreply.github.com> Date: Wed, 1 Mar 2023 11:44:38 -0500 Subject: [PATCH] GridPatch with both count and threshold filtering (#6055) Fixes #6049 ### Description This PR add support to `GridPatch` for both filtering by count and threshold. When filtering by threshold, the `num_patches` will be treated as maximum number of patches. UPDATE: It also removes the deprecated argument of `skimage.measure.regionprops`, 'coordinates', which was causing some test to fail. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [x] New tests added to cover the changes. - [x] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. --------- Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> --- monai/transforms/spatial/array.py | 115 +++++++++++++++++-------- monai/transforms/spatial/dictionary.py | 48 +++++++---- tests/test_grid_patch.py | 20 +++-- tests/test_grid_patchd.py | 14 ++- 4 files changed, 138 insertions(+), 59 deletions(-) diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index bef4bb2409..8e3e0ee83d 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -68,7 +68,7 @@ optional_import, ) from monai.utils.deprecate_utils import deprecated_arg -from monai.utils.enums import GridPatchSort, PytorchPadMode, TraceKeys, TransformBackends, WSIPatchKeys +from monai.utils.enums import GridPatchSort, PatchKeys, PytorchPadMode, TraceKeys, TransformBackends from monai.utils.misc import ImageMetaKey as Key from monai.utils.module import look_up_option from monai.utils.type_conversion import convert_data_type, get_equivalent_dtype, get_torch_dtype_from_string @@ -3025,19 +3025,32 @@ class GridPatch(Transform, MultiSampleTrait): Args: patch_size: size of patches to generate slices for, 0 or None selects whole dimension offset: offset of starting position in the array, default is 0 for each dimension. - num_patches: number of patches to return. Defaults to None, which returns all the available patches. - If the required patches are more than the available patches, padding will be applied. + num_patches: number of patches (or maximum number of patches) to return. + If the requested number of patches is greater than the number of available patches, + padding will be applied to provide exactly `num_patches` patches unless `threshold` is set. + When `threshold` is set, this value is treated as the maximum number of patches. + Defaults to None, which does not limit number of the patches. overlap: the amount of overlap of neighboring patches in each dimension (a value between 0.0 and 1.0). If only one float number is given, it will be applied to all dimensions. Defaults to 0.0. sort_fn: when `num_patches` is provided, it determines if keep patches with highest values (`"max"`), lowest values (`"min"`), or in their default order (`None`). Default to None. threshold: a value to keep only the patches whose sum of intensities are less than the threshold. Defaults to no filtering. - pad_mode: refer to NumpyPadMode and PytorchPadMode. If None, no padding will be applied. Defaults to ``"constant"``. + pad_mode: the mode for padding the input image by `patch_size` to include patches that cross boundaries. + Defaults to None, which means no padding will be applied. + Available modes:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, ``"mean"``, + ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``}. + See also: https://numpy.org/doc/stable/reference/generated/numpy.pad.html pad_kwargs: other arguments for the `np.pad` or `torch.pad` function. Returns: - MetaTensor: A MetaTensor consisting of a batch of all the patches with associated metadata + MetaTensor: the extracted patches as a single tensor (with patch dimension as the first dimension), + with following metadata: + + - `PatchKeys.LOCATION`: the starting location of the patch in the image, + - `PatchKeys.COUNT`: total number of patches in the image, + - "spatial_shape": spatial size of the extracted patch, and + - "offset": the amount of offset for the patches in the image (starting position of the first patch) """ @@ -3051,7 +3064,7 @@ def __init__( overlap: Sequence[float] | float = 0.0, sort_fn: str | None = None, threshold: float | None = None, - pad_mode: str = PytorchPadMode.CONSTANT, + pad_mode: str | None = None, **pad_kwargs, ): self.patch_size = ensure_tuple(patch_size) @@ -3065,24 +3078,26 @@ def __init__( def filter_threshold(self, image_np: np.ndarray, locations: np.ndarray): """ - Filter the patches and their locations according to a threshold + Filter the patches and their locations according to a threshold. + Args: - image_np: a numpy.ndarray representing a stack of patches - locations: a numpy.ndarray representing the stack of location of each patch + image_np: a numpy.ndarray representing a stack of patches. + locations: a numpy.ndarray representing the stack of location of each patch. + + Returns: + tuple[numpy.ndarray, numpy.ndarray]: tuple of filtered patches and locations. """ - if self.threshold is not None: - n_dims = len(image_np.shape) - idx = np.argwhere(image_np.sum(axis=tuple(range(1, n_dims))) < self.threshold).reshape(-1) - image_np = image_np[idx] - locations = locations[idx] - return image_np, locations + n_dims = len(image_np.shape) + idx = np.argwhere(image_np.sum(axis=tuple(range(1, n_dims))) < self.threshold).reshape(-1) + return image_np[idx], locations[idx] def filter_count(self, image_np: np.ndarray, locations: np.ndarray): """ Sort the patches based on the sum of their intensity, and just keep `self.num_patches` of them. + Args: - image_np: a numpy.ndarray representing a stack of patches - locations: a numpy.ndarray representing the stack of location of each patch + image_np: a numpy.ndarray representing a stack of patches. + locations: a numpy.ndarray representing the stack of location of each patch. """ if self.sort_fn is None: image_np = image_np[: self.num_patches] @@ -3100,7 +3115,17 @@ def filter_count(self, image_np: np.ndarray, locations: np.ndarray): locations = locations[idx] return image_np, locations - def __call__(self, array: NdarrayOrTensor): + def __call__(self, array: NdarrayOrTensor) -> MetaTensor: + """ + Extract the patches (sweeping the entire image in a row-major sliding-window manner with possible overlaps). + + Args: + array: a input image as `numpy.ndarray` or `torch.Tensor` + + Return: + MetaTensor: the extracted patches as a single tensor (with patch dimension as the first dimension), + with defined `PatchKeys.LOCATION` and `PatchKeys.COUNT` metadata. + """ # create the patch iterator which sweeps the image row-by-row array_np, *_ = convert_data_type(array, np.ndarray) patch_iterator = iter_patch( @@ -3115,29 +3140,33 @@ def __call__(self, array: NdarrayOrTensor): patches = list(zip(*patch_iterator)) patched_image = np.array(patches[0]) locations = np.array(patches[1])[:, 1:, 0] # only keep the starting location + del patches # it will free up some memory if padding is used. - # Filter patches - if self.num_patches: - patched_image, locations = self.filter_count(patched_image, locations) - elif self.threshold: + # Apply threshold filtering + if self.threshold is not None: patched_image, locations = self.filter_threshold(patched_image, locations) - # Pad the patch list to have the requested number of patches + # Apply count filtering if self.num_patches: - padding = self.num_patches - len(patched_image) - if padding > 0: - patched_image = np.pad( - patched_image, - [[0, padding], [0, 0]] + [[0, 0]] * len(self.patch_size), - constant_values=self.pad_kwargs.get("constant_values", 0), - ) - locations = np.pad(locations, [[0, padding], [0, 0]], constant_values=0) + # Limit number of patches + patched_image, locations = self.filter_count(patched_image, locations) + # Pad the patch list to have the requested number of patches + if self.threshold is None: + padding = self.num_patches - len(patched_image) + if padding > 0: + patched_image = np.pad( + patched_image, + [[0, padding], [0, 0]] + [[0, 0]] * len(self.patch_size), + constant_values=self.pad_kwargs.get("constant_values", 0), + ) + locations = np.pad(locations, [[0, padding], [0, 0]], constant_values=0) # Convert to MetaTensor metadata = array.meta if isinstance(array, MetaTensor) else MetaTensor.get_default_meta() - metadata[WSIPatchKeys.LOCATION] = locations.T - metadata[WSIPatchKeys.COUNT] = len(locations) + metadata[PatchKeys.LOCATION] = locations.T + metadata[PatchKeys.COUNT] = len(locations) metadata["spatial_shape"] = np.tile(np.array(self.patch_size), (len(locations), 1)).T + metadata["offset"] = self.offset output = MetaTensor(x=patched_image, meta=metadata) output.is_batch = True @@ -3155,18 +3184,32 @@ class RandGridPatch(GridPatch, RandomizableTransform, MultiSampleTrait): min_offset: the minimum range of offset to be selected randomly. Defaults to 0. max_offset: the maximum range of offset to be selected randomly. Defaults to image size modulo patch size. - num_patches: number of patches to return. Defaults to None, which returns all the available patches. + num_patches: number of patches (or maximum number of patches) to return. + If the requested number of patches is greater than the number of available patches, + padding will be applied to provide exactly `num_patches` patches unless `threshold` is set. + When `threshold` is set, this value is treated as the maximum number of patches. + Defaults to None, which does not limit number of the patches. overlap: the amount of overlap of neighboring patches in each dimension (a value between 0.0 and 1.0). If only one float number is given, it will be applied to all dimensions. Defaults to 0.0. sort_fn: when `num_patches` is provided, it determines if keep patches with highest values (`"max"`), lowest values (`"min"`), or in their default order (`None`). Default to None. threshold: a value to keep only the patches whose sum of intensities are less than the threshold. Defaults to no filtering. - pad_mode: refer to NumpyPadMode and PytorchPadMode. If None, no padding will be applied. Defaults to ``"constant"``. + pad_mode: the mode for padding the input image by `patch_size` to include patches that cross boundaries. + Defaults to None, which means no padding will be applied. + Available modes:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, ``"mean"``, + ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``}. + See also: https://numpy.org/doc/stable/reference/generated/numpy.pad.html pad_kwargs: other arguments for the `np.pad` or `torch.pad` function. Returns: - MetaTensor: A MetaTensor consisting of a batch of all the patches with associated metadata + MetaTensor: the extracted patches as a single tensor (with patch dimension as the first dimension), + with following metadata: + + - `PatchKeys.LOCATION`: the starting location of the patch in the image, + - `PatchKeys.COUNT`: total number of patches in the image, + - "spatial_shape": spatial size of the extracted patch, and + - "offset": the amount of offset for the patches in the image (starting position of the first patch) """ diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index cea89dc76d..4c1fe4f268 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -1836,24 +1836,32 @@ class GridPatchd(MapTransform, MultiSampleTrait): patch_size: size of patches to generate slices for, 0 or None selects whole dimension offset: starting position in the array, default is 0 for each dimension. np.random.randint(0, patch_size, 2) creates random start between 0 and `patch_size` for a 2D image. - num_patches: number of patches to return. Defaults to None, which returns all the available patches. + num_patches: number of patches (or maximum number of patches) to return. + If the requested number of patches is greater than the number of available patches, + padding will be applied to provide exactly `num_patches` patches unless `threshold` is set. + When `threshold` is set, this value is treated as the maximum number of patches. + Defaults to None, which does not limit number of the patches. overlap: amount of overlap between patches in each dimension. Default to 0.0. sort_fn: when `num_patches` is provided, it determines if keep patches with highest values (`"max"`), lowest values (`"min"`), or in their default order (`None`). Default to None. threshold: a value to keep only the patches whose sum of intensities are less than the threshold. Defaults to no filtering. - pad_mode: refer to NumpyPadMode and PytorchPadMode. If None, no padding will be applied. Defaults to ``"constant"``. + pad_mode: the mode for padding the input image by `patch_size` to include patches that cross boundaries. + Defaults to None, which means no padding will be applied. + Available modes:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, ``"mean"``, + ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``}. + See also: https://numpy.org/doc/stable/reference/generated/numpy.pad.html allow_missing_keys: don't raise exception if key is missing. pad_kwargs: other arguments for the `np.pad` or `torch.pad` function. Returns: - a list of dictionaries, each of which contains the all the original key/value with the values for `keys` - replaced by the patches. It also add the following new keys: + dictionary, contains the all the original key/value with the values for `keys` + replaced by the patches, a MetaTensor with following metadata: - "patch_location": the starting location of the patch in the image, - "patch_size": size of the extracted patch - "num_patches": total number of patches in the image - "offset": the amount of offset for the patches in the image (starting position of upper left patch) + - `PatchKeys.LOCATION`: the starting location of the patch in the image, + - `PatchKeys.COUNT`: total number of patches in the image, + - "spatial_shape": spatial size of the extracted patch, and + - "offset": the amount of offset for the patches in the image (starting position of the first patch) """ backend = GridPatch.backend @@ -1902,25 +1910,33 @@ class RandGridPatchd(RandomizableTransform, MapTransform, MultiSampleTrait): min_offset: the minimum range of starting position to be selected randomly. Defaults to 0. max_offset: the maximum range of starting position to be selected randomly. Defaults to image size modulo patch size. - num_patches: number of patches to return. Defaults to None, which returns all the available patches. + num_patches: number of patches (or maximum number of patches) to return. + If the requested number of patches is greater than the number of available patches, + padding will be applied to provide exactly `num_patches` patches unless `threshold` is set. + When `threshold` is set, this value is treated as the maximum number of patches. + Defaults to None, which does not limit number of the patches. overlap: the amount of overlap of neighboring patches in each dimension (a value between 0.0 and 1.0). If only one float number is given, it will be applied to all dimensions. Defaults to 0.0. sort_fn: when `num_patches` is provided, it determines if keep patches with highest values (`"max"`), lowest values (`"min"`), or in their default order (`None`). Default to None. threshold: a value to keep only the patches whose sum of intensities are less than the threshold. Defaults to no filtering. - pad_mode: refer to NumpyPadMode and PytorchPadMode. If None, no padding will be applied. Defaults to ``"constant"``. + pad_mode: the mode for padding the input image by `patch_size` to include patches that cross boundaries. + Defaults to None, which means no padding will be applied. + Available modes:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, ``"mean"``, + ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``}. + See also: https://numpy.org/doc/stable/reference/generated/numpy.pad.html allow_missing_keys: don't raise exception if key is missing. pad_kwargs: other arguments for the `np.pad` or `torch.pad` function. Returns: - a list of dictionaries, each of which contains the all the original key/value with the values for `keys` - replaced by the patches. It also add the following new keys: + dictionary, contains the all the original key/value with the values for `keys` + replaced by the patches, a MetaTensor with following metadata: - "patch_location": the starting location of the patch in the image, - "patch_size": size of the extracted patch - "num_patches": total number of patches in the image - "offset": the amount of offset for the patches in the image (starting position of the first patch) + - `PatchKeys.LOCATION`: the starting location of the patch in the image, + - `PatchKeys.COUNT`: total number of patches in the image, + - "spatial_shape": spatial size of the extracted patch, and + - "offset": the amount of offset for the patches in the image (starting position of the first patch) """ diff --git a/tests/test_grid_patch.py b/tests/test_grid_patch.py index 05e773929e..22c2218afd 100644 --- a/tests/test_grid_patch.py +++ b/tests/test_grid_patch.py @@ -38,25 +38,32 @@ TEST_CASE_9 = [{"patch_size": (2, 2), "num_patches": 4, "sort_fn": "min"}, A, [A11, A12, A21, A22]] TEST_CASE_10 = [{"patch_size": (2, 2), "overlap": 0.5, "num_patches": 3}, A, [A11, A[:, :2, 1:3], A12]] TEST_CASE_11 = [ - {"patch_size": (3, 3), "num_patches": 2, "constant_values": 255}, + {"patch_size": (3, 3), "num_patches": 2, "constant_values": 255, "pad_mode": "constant"}, A, [A[:, :3, :3], np.pad(A[:, :3, 3:], ((0, 0), (0, 0), (0, 2)), mode="constant", constant_values=255)], ] TEST_CASE_12 = [ - {"patch_size": (3, 3), "offset": (-2, -2), "num_patches": 2}, + {"patch_size": (3, 3), "offset": (-2, -2), "num_patches": 2, "pad_mode": "constant"}, A, [np.zeros((3, 3, 3)), np.pad(A[:, :1, 1:4], ((0, 0), (2, 0), (0, 0)), mode="constant")], ] +# Only threshold filtering TEST_CASE_13 = [{"patch_size": (2, 2), "threshold": 50.0}, A, [A11]] +TEST_CASE_14 = [{"patch_size": (2, 2), "threshold": 150.0}, A, [A11, A12, A21]] +# threshold filtering with num_patches more than available patches (no effect) +TEST_CASE_15 = [{"patch_size": (2, 2), "num_patches": 3, "threshold": 50.0}, A, [A11]] +# threshold filtering with num_patches less than available patches (count filtering) +TEST_CASE_16 = [{"patch_size": (2, 2), "num_patches": 2, "threshold": 150.0}, A, [A11, A12]] -TEST_CASE_MEAT_0 = [ + +TEST_CASE_META_0 = [ {"patch_size": (2, 2)}, A, [A11, A12, A21, A22], [{"location": [0, 0]}, {"location": [0, 2]}, {"location": [2, 0]}, {"location": [2, 2]}], ] -TEST_CASE_MEAT_1 = [ +TEST_CASE_META_1 = [ {"patch_size": (2, 2)}, MetaTensor(x=A, meta={"path": "path/to/file"}), [A11, A12, A21, A22], @@ -84,6 +91,9 @@ TEST_CASES.append([p, *TEST_CASE_11]) TEST_CASES.append([p, *TEST_CASE_12]) TEST_CASES.append([p, *TEST_CASE_13]) + TEST_CASES.append([p, *TEST_CASE_14]) + TEST_CASES.append([p, *TEST_CASE_15]) + TEST_CASES.append([p, *TEST_CASE_16]) class TestGridPatch(unittest.TestCase): @@ -96,7 +106,7 @@ def test_grid_patch(self, in_type, input_parameters, image, expected): for output_patch, expected_patch in zip(output, expected): assert_allclose(output_patch, expected_patch, type_test=False) - @parameterized.expand([TEST_CASE_MEAT_0, TEST_CASE_MEAT_1]) + @parameterized.expand([TEST_CASE_META_0, TEST_CASE_META_1]) @SkipIfBeforePyTorchVersion((1, 9, 1)) def test_grid_patch_meta(self, input_parameters, image, expected, expected_meta): set_track_meta(True) diff --git a/tests/test_grid_patchd.py b/tests/test_grid_patchd.py index a19e26a16d..5629c0e871 100644 --- a/tests/test_grid_patchd.py +++ b/tests/test_grid_patchd.py @@ -37,16 +37,23 @@ TEST_CASE_9 = [{"patch_size": (2, 2), "num_patches": 4, "sort_fn": "min"}, {"image": A}, [A11, A12, A21, A22]] TEST_CASE_10 = [{"patch_size": (2, 2), "overlap": 0.5, "num_patches": 3}, {"image": A}, [A11, A[:, :2, 1:3], A12]] TEST_CASE_11 = [ - {"patch_size": (3, 3), "num_patches": 2, "constant_values": 255}, + {"patch_size": (3, 3), "num_patches": 2, "constant_values": 255, "pad_mode": "constant"}, {"image": A}, [A[:, :3, :3], np.pad(A[:, :3, 3:], ((0, 0), (0, 0), (0, 2)), mode="constant", constant_values=255)], ] TEST_CASE_12 = [ - {"patch_size": (3, 3), "offset": (-2, -2), "num_patches": 2}, + {"patch_size": (3, 3), "offset": (-2, -2), "num_patches": 2, "pad_mode": "constant"}, {"image": A}, [np.zeros((3, 3, 3)), np.pad(A[:, :1, 1:4], ((0, 0), (2, 0), (0, 0)), mode="constant")], ] +# Only threshold filtering TEST_CASE_13 = [{"patch_size": (2, 2), "threshold": 50.0}, {"image": A}, [A11]] +TEST_CASE_14 = [{"patch_size": (2, 2), "threshold": 150.0}, {"image": A}, [A11, A12, A21]] +# threshold filtering with num_patches more than available patches (no effect) +TEST_CASE_15 = [{"patch_size": (2, 2), "threshold": 50.0, "num_patches": 3}, {"image": A}, [A11]] +# threshold filtering with num_patches less than available patches (count filtering) +TEST_CASE_16 = [{"patch_size": (2, 2), "threshold": 150.0, "num_patches": 2}, {"image": A}, [A11, A12]] + TEST_SINGLE = [] for p in TEST_NDARRAYS: @@ -64,6 +71,9 @@ TEST_SINGLE.append([p, *TEST_CASE_11]) TEST_SINGLE.append([p, *TEST_CASE_12]) TEST_SINGLE.append([p, *TEST_CASE_13]) + TEST_SINGLE.append([p, *TEST_CASE_14]) + TEST_SINGLE.append([p, *TEST_CASE_15]) + TEST_SINGLE.append([p, *TEST_CASE_16]) class TestGridPatchd(unittest.TestCase):