Skip to content

Commit

Permalink
GridPatch with both count and threshold filtering (#6055)
Browse files Browse the repository at this point in the history
Fixes #6049 

### Description
This PR add support to `GridPatch` for both filtering by count and
threshold. When filtering by threshold, the `num_patches` will be
treated as maximum number of patches.

UPDATE: It also removes the deprecated argument of
`skimage.measure.regionprops`, 'coordinates', which was causing some
test to fail.

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [x] New tests added to cover the changes.
- [x] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [x] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [x] In-line docstrings updated.

---------

Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com>
  • Loading branch information
drbeh authored Mar 1, 2023
1 parent 68074f0 commit 579fe65
Show file tree
Hide file tree
Showing 4 changed files with 138 additions and 59 deletions.
115 changes: 79 additions & 36 deletions monai/transforms/spatial/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@
optional_import,
)
from monai.utils.deprecate_utils import deprecated_arg
from monai.utils.enums import GridPatchSort, PytorchPadMode, TraceKeys, TransformBackends, WSIPatchKeys
from monai.utils.enums import GridPatchSort, PatchKeys, PytorchPadMode, TraceKeys, TransformBackends
from monai.utils.misc import ImageMetaKey as Key
from monai.utils.module import look_up_option
from monai.utils.type_conversion import convert_data_type, get_equivalent_dtype, get_torch_dtype_from_string
Expand Down Expand Up @@ -3025,19 +3025,32 @@ class GridPatch(Transform, MultiSampleTrait):
Args:
patch_size: size of patches to generate slices for, 0 or None selects whole dimension
offset: offset of starting position in the array, default is 0 for each dimension.
num_patches: number of patches to return. Defaults to None, which returns all the available patches.
If the required patches are more than the available patches, padding will be applied.
num_patches: number of patches (or maximum number of patches) to return.
If the requested number of patches is greater than the number of available patches,
padding will be applied to provide exactly `num_patches` patches unless `threshold` is set.
When `threshold` is set, this value is treated as the maximum number of patches.
Defaults to None, which does not limit number of the patches.
overlap: the amount of overlap of neighboring patches in each dimension (a value between 0.0 and 1.0).
If only one float number is given, it will be applied to all dimensions. Defaults to 0.0.
sort_fn: when `num_patches` is provided, it determines if keep patches with highest values (`"max"`),
lowest values (`"min"`), or in their default order (`None`). Default to None.
threshold: a value to keep only the patches whose sum of intensities are less than the threshold.
Defaults to no filtering.
pad_mode: refer to NumpyPadMode and PytorchPadMode. If None, no padding will be applied. Defaults to ``"constant"``.
pad_mode: the mode for padding the input image by `patch_size` to include patches that cross boundaries.
Defaults to None, which means no padding will be applied.
Available modes:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, ``"mean"``,
``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``}.
See also: https://numpy.org/doc/stable/reference/generated/numpy.pad.html
pad_kwargs: other arguments for the `np.pad` or `torch.pad` function.
Returns:
MetaTensor: A MetaTensor consisting of a batch of all the patches with associated metadata
MetaTensor: the extracted patches as a single tensor (with patch dimension as the first dimension),
with following metadata:
- `PatchKeys.LOCATION`: the starting location of the patch in the image,
- `PatchKeys.COUNT`: total number of patches in the image,
- "spatial_shape": spatial size of the extracted patch, and
- "offset": the amount of offset for the patches in the image (starting position of the first patch)
"""

Expand All @@ -3051,7 +3064,7 @@ def __init__(
overlap: Sequence[float] | float = 0.0,
sort_fn: str | None = None,
threshold: float | None = None,
pad_mode: str = PytorchPadMode.CONSTANT,
pad_mode: str | None = None,
**pad_kwargs,
):
self.patch_size = ensure_tuple(patch_size)
Expand All @@ -3065,24 +3078,26 @@ def __init__(

def filter_threshold(self, image_np: np.ndarray, locations: np.ndarray):
"""
Filter the patches and their locations according to a threshold
Filter the patches and their locations according to a threshold.
Args:
image_np: a numpy.ndarray representing a stack of patches
locations: a numpy.ndarray representing the stack of location of each patch
image_np: a numpy.ndarray representing a stack of patches.
locations: a numpy.ndarray representing the stack of location of each patch.
Returns:
tuple[numpy.ndarray, numpy.ndarray]: tuple of filtered patches and locations.
"""
if self.threshold is not None:
n_dims = len(image_np.shape)
idx = np.argwhere(image_np.sum(axis=tuple(range(1, n_dims))) < self.threshold).reshape(-1)
image_np = image_np[idx]
locations = locations[idx]
return image_np, locations
n_dims = len(image_np.shape)
idx = np.argwhere(image_np.sum(axis=tuple(range(1, n_dims))) < self.threshold).reshape(-1)
return image_np[idx], locations[idx]

def filter_count(self, image_np: np.ndarray, locations: np.ndarray):
"""
Sort the patches based on the sum of their intensity, and just keep `self.num_patches` of them.
Args:
image_np: a numpy.ndarray representing a stack of patches
locations: a numpy.ndarray representing the stack of location of each patch
image_np: a numpy.ndarray representing a stack of patches.
locations: a numpy.ndarray representing the stack of location of each patch.
"""
if self.sort_fn is None:
image_np = image_np[: self.num_patches]
Expand All @@ -3100,7 +3115,17 @@ def filter_count(self, image_np: np.ndarray, locations: np.ndarray):
locations = locations[idx]
return image_np, locations

def __call__(self, array: NdarrayOrTensor):
def __call__(self, array: NdarrayOrTensor) -> MetaTensor:
"""
Extract the patches (sweeping the entire image in a row-major sliding-window manner with possible overlaps).
Args:
array: a input image as `numpy.ndarray` or `torch.Tensor`
Return:
MetaTensor: the extracted patches as a single tensor (with patch dimension as the first dimension),
with defined `PatchKeys.LOCATION` and `PatchKeys.COUNT` metadata.
"""
# create the patch iterator which sweeps the image row-by-row
array_np, *_ = convert_data_type(array, np.ndarray)
patch_iterator = iter_patch(
Expand All @@ -3115,29 +3140,33 @@ def __call__(self, array: NdarrayOrTensor):
patches = list(zip(*patch_iterator))
patched_image = np.array(patches[0])
locations = np.array(patches[1])[:, 1:, 0] # only keep the starting location
del patches # it will free up some memory if padding is used.

# Filter patches
if self.num_patches:
patched_image, locations = self.filter_count(patched_image, locations)
elif self.threshold:
# Apply threshold filtering
if self.threshold is not None:
patched_image, locations = self.filter_threshold(patched_image, locations)

# Pad the patch list to have the requested number of patches
# Apply count filtering
if self.num_patches:
padding = self.num_patches - len(patched_image)
if padding > 0:
patched_image = np.pad(
patched_image,
[[0, padding], [0, 0]] + [[0, 0]] * len(self.patch_size),
constant_values=self.pad_kwargs.get("constant_values", 0),
)
locations = np.pad(locations, [[0, padding], [0, 0]], constant_values=0)
# Limit number of patches
patched_image, locations = self.filter_count(patched_image, locations)
# Pad the patch list to have the requested number of patches
if self.threshold is None:
padding = self.num_patches - len(patched_image)
if padding > 0:
patched_image = np.pad(
patched_image,
[[0, padding], [0, 0]] + [[0, 0]] * len(self.patch_size),
constant_values=self.pad_kwargs.get("constant_values", 0),
)
locations = np.pad(locations, [[0, padding], [0, 0]], constant_values=0)

# Convert to MetaTensor
metadata = array.meta if isinstance(array, MetaTensor) else MetaTensor.get_default_meta()
metadata[WSIPatchKeys.LOCATION] = locations.T
metadata[WSIPatchKeys.COUNT] = len(locations)
metadata[PatchKeys.LOCATION] = locations.T
metadata[PatchKeys.COUNT] = len(locations)
metadata["spatial_shape"] = np.tile(np.array(self.patch_size), (len(locations), 1)).T
metadata["offset"] = self.offset
output = MetaTensor(x=patched_image, meta=metadata)
output.is_batch = True

Expand All @@ -3155,18 +3184,32 @@ class RandGridPatch(GridPatch, RandomizableTransform, MultiSampleTrait):
min_offset: the minimum range of offset to be selected randomly. Defaults to 0.
max_offset: the maximum range of offset to be selected randomly.
Defaults to image size modulo patch size.
num_patches: number of patches to return. Defaults to None, which returns all the available patches.
num_patches: number of patches (or maximum number of patches) to return.
If the requested number of patches is greater than the number of available patches,
padding will be applied to provide exactly `num_patches` patches unless `threshold` is set.
When `threshold` is set, this value is treated as the maximum number of patches.
Defaults to None, which does not limit number of the patches.
overlap: the amount of overlap of neighboring patches in each dimension (a value between 0.0 and 1.0).
If only one float number is given, it will be applied to all dimensions. Defaults to 0.0.
sort_fn: when `num_patches` is provided, it determines if keep patches with highest values (`"max"`),
lowest values (`"min"`), or in their default order (`None`). Default to None.
threshold: a value to keep only the patches whose sum of intensities are less than the threshold.
Defaults to no filtering.
pad_mode: refer to NumpyPadMode and PytorchPadMode. If None, no padding will be applied. Defaults to ``"constant"``.
pad_mode: the mode for padding the input image by `patch_size` to include patches that cross boundaries.
Defaults to None, which means no padding will be applied.
Available modes:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, ``"mean"``,
``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``}.
See also: https://numpy.org/doc/stable/reference/generated/numpy.pad.html
pad_kwargs: other arguments for the `np.pad` or `torch.pad` function.
Returns:
MetaTensor: A MetaTensor consisting of a batch of all the patches with associated metadata
MetaTensor: the extracted patches as a single tensor (with patch dimension as the first dimension),
with following metadata:
- `PatchKeys.LOCATION`: the starting location of the patch in the image,
- `PatchKeys.COUNT`: total number of patches in the image,
- "spatial_shape": spatial size of the extracted patch, and
- "offset": the amount of offset for the patches in the image (starting position of the first patch)
"""

Expand Down
48 changes: 32 additions & 16 deletions monai/transforms/spatial/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -1836,24 +1836,32 @@ class GridPatchd(MapTransform, MultiSampleTrait):
patch_size: size of patches to generate slices for, 0 or None selects whole dimension
offset: starting position in the array, default is 0 for each dimension.
np.random.randint(0, patch_size, 2) creates random start between 0 and `patch_size` for a 2D image.
num_patches: number of patches to return. Defaults to None, which returns all the available patches.
num_patches: number of patches (or maximum number of patches) to return.
If the requested number of patches is greater than the number of available patches,
padding will be applied to provide exactly `num_patches` patches unless `threshold` is set.
When `threshold` is set, this value is treated as the maximum number of patches.
Defaults to None, which does not limit number of the patches.
overlap: amount of overlap between patches in each dimension. Default to 0.0.
sort_fn: when `num_patches` is provided, it determines if keep patches with highest values (`"max"`),
lowest values (`"min"`), or in their default order (`None`). Default to None.
threshold: a value to keep only the patches whose sum of intensities are less than the threshold.
Defaults to no filtering.
pad_mode: refer to NumpyPadMode and PytorchPadMode. If None, no padding will be applied. Defaults to ``"constant"``.
pad_mode: the mode for padding the input image by `patch_size` to include patches that cross boundaries.
Defaults to None, which means no padding will be applied.
Available modes:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, ``"mean"``,
``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``}.
See also: https://numpy.org/doc/stable/reference/generated/numpy.pad.html
allow_missing_keys: don't raise exception if key is missing.
pad_kwargs: other arguments for the `np.pad` or `torch.pad` function.
Returns:
a list of dictionaries, each of which contains the all the original key/value with the values for `keys`
replaced by the patches. It also add the following new keys:
dictionary, contains the all the original key/value with the values for `keys`
replaced by the patches, a MetaTensor with following metadata:
"patch_location": the starting location of the patch in the image,
"patch_size": size of the extracted patch
"num_patches": total number of patches in the image
"offset": the amount of offset for the patches in the image (starting position of upper left patch)
- `PatchKeys.LOCATION`: the starting location of the patch in the image,
- `PatchKeys.COUNT`: total number of patches in the image,
- "spatial_shape": spatial size of the extracted patch, and
- "offset": the amount of offset for the patches in the image (starting position of the first patch)
"""

backend = GridPatch.backend
Expand Down Expand Up @@ -1902,25 +1910,33 @@ class RandGridPatchd(RandomizableTransform, MapTransform, MultiSampleTrait):
min_offset: the minimum range of starting position to be selected randomly. Defaults to 0.
max_offset: the maximum range of starting position to be selected randomly.
Defaults to image size modulo patch size.
num_patches: number of patches to return. Defaults to None, which returns all the available patches.
num_patches: number of patches (or maximum number of patches) to return.
If the requested number of patches is greater than the number of available patches,
padding will be applied to provide exactly `num_patches` patches unless `threshold` is set.
When `threshold` is set, this value is treated as the maximum number of patches.
Defaults to None, which does not limit number of the patches.
overlap: the amount of overlap of neighboring patches in each dimension (a value between 0.0 and 1.0).
If only one float number is given, it will be applied to all dimensions. Defaults to 0.0.
sort_fn: when `num_patches` is provided, it determines if keep patches with highest values (`"max"`),
lowest values (`"min"`), or in their default order (`None`). Default to None.
threshold: a value to keep only the patches whose sum of intensities are less than the threshold.
Defaults to no filtering.
pad_mode: refer to NumpyPadMode and PytorchPadMode. If None, no padding will be applied. Defaults to ``"constant"``.
pad_mode: the mode for padding the input image by `patch_size` to include patches that cross boundaries.
Defaults to None, which means no padding will be applied.
Available modes:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, ``"mean"``,
``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``}.
See also: https://numpy.org/doc/stable/reference/generated/numpy.pad.html
allow_missing_keys: don't raise exception if key is missing.
pad_kwargs: other arguments for the `np.pad` or `torch.pad` function.
Returns:
a list of dictionaries, each of which contains the all the original key/value with the values for `keys`
replaced by the patches. It also add the following new keys:
dictionary, contains the all the original key/value with the values for `keys`
replaced by the patches, a MetaTensor with following metadata:
"patch_location": the starting location of the patch in the image,
"patch_size": size of the extracted patch
"num_patches": total number of patches in the image
"offset": the amount of offset for the patches in the image (starting position of the first patch)
- `PatchKeys.LOCATION`: the starting location of the patch in the image,
- `PatchKeys.COUNT`: total number of patches in the image,
- "spatial_shape": spatial size of the extracted patch, and
- "offset": the amount of offset for the patches in the image (starting position of the first patch)
"""

Expand Down
Loading

0 comments on commit 579fe65

Please sign in to comment.