Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GridPatch with both count and threshold filtering #6055

Merged
merged 20 commits into from
Mar 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 79 additions & 36 deletions monai/transforms/spatial/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@
optional_import,
)
from monai.utils.deprecate_utils import deprecated_arg
from monai.utils.enums import GridPatchSort, PytorchPadMode, TraceKeys, TransformBackends, WSIPatchKeys
from monai.utils.enums import GridPatchSort, PatchKeys, PytorchPadMode, TraceKeys, TransformBackends
from monai.utils.misc import ImageMetaKey as Key
from monai.utils.module import look_up_option
from monai.utils.type_conversion import convert_data_type, get_equivalent_dtype, get_torch_dtype_from_string
Expand Down Expand Up @@ -3025,19 +3025,32 @@ class GridPatch(Transform, MultiSampleTrait):
Args:
patch_size: size of patches to generate slices for, 0 or None selects whole dimension
offset: offset of starting position in the array, default is 0 for each dimension.
num_patches: number of patches to return. Defaults to None, which returns all the available patches.
If the required patches are more than the available patches, padding will be applied.
num_patches: number of patches (or maximum number of patches) to return.
If the requested number of patches is greater than the number of available patches,
padding will be applied to provide exactly `num_patches` patches unless `threshold` is set.
When `threshold` is set, this value is treated as the maximum number of patches.
Defaults to None, which does not limit number of the patches.
overlap: the amount of overlap of neighboring patches in each dimension (a value between 0.0 and 1.0).
If only one float number is given, it will be applied to all dimensions. Defaults to 0.0.
sort_fn: when `num_patches` is provided, it determines if keep patches with highest values (`"max"`),
lowest values (`"min"`), or in their default order (`None`). Default to None.
threshold: a value to keep only the patches whose sum of intensities are less than the threshold.
Defaults to no filtering.
pad_mode: refer to NumpyPadMode and PytorchPadMode. If None, no padding will be applied. Defaults to ``"constant"``.
pad_mode: the mode for padding the input image by `patch_size` to include patches that cross boundaries.
Defaults to None, which means no padding will be applied.
Available modes:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, ``"mean"``,
``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``}.
See also: https://numpy.org/doc/stable/reference/generated/numpy.pad.html
pad_kwargs: other arguments for the `np.pad` or `torch.pad` function.

Returns:
MetaTensor: A MetaTensor consisting of a batch of all the patches with associated metadata
MetaTensor: the extracted patches as a single tensor (with patch dimension as the first dimension),
with following metadata:

- `PatchKeys.LOCATION`: the starting location of the patch in the image,
- `PatchKeys.COUNT`: total number of patches in the image,
- "spatial_shape": spatial size of the extracted patch, and
- "offset": the amount of offset for the patches in the image (starting position of the first patch)

"""

Expand All @@ -3051,7 +3064,7 @@ def __init__(
overlap: Sequence[float] | float = 0.0,
sort_fn: str | None = None,
threshold: float | None = None,
pad_mode: str = PytorchPadMode.CONSTANT,
pad_mode: str | None = None,
**pad_kwargs,
):
self.patch_size = ensure_tuple(patch_size)
Expand All @@ -3065,24 +3078,26 @@ def __init__(

def filter_threshold(self, image_np: np.ndarray, locations: np.ndarray):
"""
Filter the patches and their locations according to a threshold
Filter the patches and their locations according to a threshold.

Args:
image_np: a numpy.ndarray representing a stack of patches
locations: a numpy.ndarray representing the stack of location of each patch
image_np: a numpy.ndarray representing a stack of patches.
locations: a numpy.ndarray representing the stack of location of each patch.

Returns:
tuple[numpy.ndarray, numpy.ndarray]: tuple of filtered patches and locations.
"""
if self.threshold is not None:
n_dims = len(image_np.shape)
idx = np.argwhere(image_np.sum(axis=tuple(range(1, n_dims))) < self.threshold).reshape(-1)
image_np = image_np[idx]
locations = locations[idx]
return image_np, locations
n_dims = len(image_np.shape)
idx = np.argwhere(image_np.sum(axis=tuple(range(1, n_dims))) < self.threshold).reshape(-1)
return image_np[idx], locations[idx]

def filter_count(self, image_np: np.ndarray, locations: np.ndarray):
"""
Sort the patches based on the sum of their intensity, and just keep `self.num_patches` of them.

Args:
image_np: a numpy.ndarray representing a stack of patches
locations: a numpy.ndarray representing the stack of location of each patch
image_np: a numpy.ndarray representing a stack of patches.
locations: a numpy.ndarray representing the stack of location of each patch.
"""
if self.sort_fn is None:
image_np = image_np[: self.num_patches]
Expand All @@ -3100,7 +3115,17 @@ def filter_count(self, image_np: np.ndarray, locations: np.ndarray):
locations = locations[idx]
return image_np, locations

def __call__(self, array: NdarrayOrTensor):
def __call__(self, array: NdarrayOrTensor) -> MetaTensor:
"""
Extract the patches (sweeping the entire image in a row-major sliding-window manner with possible overlaps).

Args:
array: a input image as `numpy.ndarray` or `torch.Tensor`

Return:
MetaTensor: the extracted patches as a single tensor (with patch dimension as the first dimension),
with defined `PatchKeys.LOCATION` and `PatchKeys.COUNT` metadata.
"""
# create the patch iterator which sweeps the image row-by-row
array_np, *_ = convert_data_type(array, np.ndarray)
patch_iterator = iter_patch(
Expand All @@ -3115,29 +3140,33 @@ def __call__(self, array: NdarrayOrTensor):
patches = list(zip(*patch_iterator))
patched_image = np.array(patches[0])
locations = np.array(patches[1])[:, 1:, 0] # only keep the starting location
del patches # it will free up some memory if padding is used.
wyli marked this conversation as resolved.
Show resolved Hide resolved

# Filter patches
if self.num_patches:
patched_image, locations = self.filter_count(patched_image, locations)
drbeh marked this conversation as resolved.
Show resolved Hide resolved
elif self.threshold:
# Apply threshold filtering
if self.threshold is not None:
patched_image, locations = self.filter_threshold(patched_image, locations)

# Pad the patch list to have the requested number of patches
# Apply count filtering
if self.num_patches:
padding = self.num_patches - len(patched_image)
if padding > 0:
patched_image = np.pad(
patched_image,
[[0, padding], [0, 0]] + [[0, 0]] * len(self.patch_size),
constant_values=self.pad_kwargs.get("constant_values", 0),
)
locations = np.pad(locations, [[0, padding], [0, 0]], constant_values=0)
# Limit number of patches
patched_image, locations = self.filter_count(patched_image, locations)
# Pad the patch list to have the requested number of patches
if self.threshold is None:
padding = self.num_patches - len(patched_image)
if padding > 0:
patched_image = np.pad(
drbeh marked this conversation as resolved.
Show resolved Hide resolved
patched_image,
[[0, padding], [0, 0]] + [[0, 0]] * len(self.patch_size),
constant_values=self.pad_kwargs.get("constant_values", 0),
)
locations = np.pad(locations, [[0, padding], [0, 0]], constant_values=0)

# Convert to MetaTensor
metadata = array.meta if isinstance(array, MetaTensor) else MetaTensor.get_default_meta()
metadata[WSIPatchKeys.LOCATION] = locations.T
metadata[WSIPatchKeys.COUNT] = len(locations)
metadata[PatchKeys.LOCATION] = locations.T
metadata[PatchKeys.COUNT] = len(locations)
metadata["spatial_shape"] = np.tile(np.array(self.patch_size), (len(locations), 1)).T
metadata["offset"] = self.offset
output = MetaTensor(x=patched_image, meta=metadata)
output.is_batch = True

Expand All @@ -3155,18 +3184,32 @@ class RandGridPatch(GridPatch, RandomizableTransform, MultiSampleTrait):
min_offset: the minimum range of offset to be selected randomly. Defaults to 0.
max_offset: the maximum range of offset to be selected randomly.
Defaults to image size modulo patch size.
num_patches: number of patches to return. Defaults to None, which returns all the available patches.
num_patches: number of patches (or maximum number of patches) to return.
If the requested number of patches is greater than the number of available patches,
padding will be applied to provide exactly `num_patches` patches unless `threshold` is set.
When `threshold` is set, this value is treated as the maximum number of patches.
Defaults to None, which does not limit number of the patches.
overlap: the amount of overlap of neighboring patches in each dimension (a value between 0.0 and 1.0).
If only one float number is given, it will be applied to all dimensions. Defaults to 0.0.
sort_fn: when `num_patches` is provided, it determines if keep patches with highest values (`"max"`),
lowest values (`"min"`), or in their default order (`None`). Default to None.
threshold: a value to keep only the patches whose sum of intensities are less than the threshold.
Defaults to no filtering.
pad_mode: refer to NumpyPadMode and PytorchPadMode. If None, no padding will be applied. Defaults to ``"constant"``.
pad_mode: the mode for padding the input image by `patch_size` to include patches that cross boundaries.
Defaults to None, which means no padding will be applied.
Available modes:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, ``"mean"``,
``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``}.
See also: https://numpy.org/doc/stable/reference/generated/numpy.pad.html
pad_kwargs: other arguments for the `np.pad` or `torch.pad` function.

Returns:
MetaTensor: A MetaTensor consisting of a batch of all the patches with associated metadata
MetaTensor: the extracted patches as a single tensor (with patch dimension as the first dimension),
with following metadata:

- `PatchKeys.LOCATION`: the starting location of the patch in the image,
- `PatchKeys.COUNT`: total number of patches in the image,
- "spatial_shape": spatial size of the extracted patch, and
- "offset": the amount of offset for the patches in the image (starting position of the first patch)

"""

Expand Down
48 changes: 32 additions & 16 deletions monai/transforms/spatial/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -1836,24 +1836,32 @@ class GridPatchd(MapTransform, MultiSampleTrait):
patch_size: size of patches to generate slices for, 0 or None selects whole dimension
offset: starting position in the array, default is 0 for each dimension.
np.random.randint(0, patch_size, 2) creates random start between 0 and `patch_size` for a 2D image.
num_patches: number of patches to return. Defaults to None, which returns all the available patches.
num_patches: number of patches (or maximum number of patches) to return.
If the requested number of patches is greater than the number of available patches,
padding will be applied to provide exactly `num_patches` patches unless `threshold` is set.
When `threshold` is set, this value is treated as the maximum number of patches.
Defaults to None, which does not limit number of the patches.
overlap: amount of overlap between patches in each dimension. Default to 0.0.
sort_fn: when `num_patches` is provided, it determines if keep patches with highest values (`"max"`),
lowest values (`"min"`), or in their default order (`None`). Default to None.
threshold: a value to keep only the patches whose sum of intensities are less than the threshold.
Defaults to no filtering.
pad_mode: refer to NumpyPadMode and PytorchPadMode. If None, no padding will be applied. Defaults to ``"constant"``.
pad_mode: the mode for padding the input image by `patch_size` to include patches that cross boundaries.
Defaults to None, which means no padding will be applied.
Available modes:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, ``"mean"``,
``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``}.
See also: https://numpy.org/doc/stable/reference/generated/numpy.pad.html
allow_missing_keys: don't raise exception if key is missing.
pad_kwargs: other arguments for the `np.pad` or `torch.pad` function.

Returns:
a list of dictionaries, each of which contains the all the original key/value with the values for `keys`
replaced by the patches. It also add the following new keys:
dictionary, contains the all the original key/value with the values for `keys`
replaced by the patches, a MetaTensor with following metadata:

"patch_location": the starting location of the patch in the image,
"patch_size": size of the extracted patch
"num_patches": total number of patches in the image
"offset": the amount of offset for the patches in the image (starting position of upper left patch)
- `PatchKeys.LOCATION`: the starting location of the patch in the image,
- `PatchKeys.COUNT`: total number of patches in the image,
- "spatial_shape": spatial size of the extracted patch, and
- "offset": the amount of offset for the patches in the image (starting position of the first patch)
"""

backend = GridPatch.backend
Expand Down Expand Up @@ -1902,25 +1910,33 @@ class RandGridPatchd(RandomizableTransform, MapTransform, MultiSampleTrait):
min_offset: the minimum range of starting position to be selected randomly. Defaults to 0.
max_offset: the maximum range of starting position to be selected randomly.
Defaults to image size modulo patch size.
num_patches: number of patches to return. Defaults to None, which returns all the available patches.
num_patches: number of patches (or maximum number of patches) to return.
If the requested number of patches is greater than the number of available patches,
padding will be applied to provide exactly `num_patches` patches unless `threshold` is set.
When `threshold` is set, this value is treated as the maximum number of patches.
Defaults to None, which does not limit number of the patches.
overlap: the amount of overlap of neighboring patches in each dimension (a value between 0.0 and 1.0).
If only one float number is given, it will be applied to all dimensions. Defaults to 0.0.
sort_fn: when `num_patches` is provided, it determines if keep patches with highest values (`"max"`),
lowest values (`"min"`), or in their default order (`None`). Default to None.
threshold: a value to keep only the patches whose sum of intensities are less than the threshold.
Defaults to no filtering.
pad_mode: refer to NumpyPadMode and PytorchPadMode. If None, no padding will be applied. Defaults to ``"constant"``.
pad_mode: the mode for padding the input image by `patch_size` to include patches that cross boundaries.
Defaults to None, which means no padding will be applied.
Available modes:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, ``"mean"``,
``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``}.
See also: https://numpy.org/doc/stable/reference/generated/numpy.pad.html
allow_missing_keys: don't raise exception if key is missing.
pad_kwargs: other arguments for the `np.pad` or `torch.pad` function.

Returns:
a list of dictionaries, each of which contains the all the original key/value with the values for `keys`
replaced by the patches. It also add the following new keys:
dictionary, contains the all the original key/value with the values for `keys`
replaced by the patches, a MetaTensor with following metadata:

"patch_location": the starting location of the patch in the image,
"patch_size": size of the extracted patch
"num_patches": total number of patches in the image
"offset": the amount of offset for the patches in the image (starting position of the first patch)
- `PatchKeys.LOCATION`: the starting location of the patch in the image,
- `PatchKeys.COUNT`: total number of patches in the image,
- "spatial_shape": spatial size of the extracted patch, and
- "offset": the amount of offset for the patches in the image (starting position of the first patch)

"""

Expand Down
Loading