Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GridPatch with both count and threshold filtering #6055

Merged
merged 20 commits into from
Mar 1, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 21 additions & 20 deletions monai/transforms/spatial/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -3139,8 +3139,10 @@ class GridPatch(Transform, MultiSampleTrait):
Args:
patch_size: size of patches to generate slices for, 0 or None selects whole dimension
offset: offset of starting position in the array, default is 0 for each dimension.
num_patches: number of patches to return. Defaults to None, which returns all the available patches.
If the required patches are more than the available patches, padding will be applied.
num_patches: number of patches (or maximum number of patches) to return.
If the requested number of patches is greater than the number of available patches,
padding will be applied to provide exactly `num_patches` patches unless `threshold` is set.
Defaults to None, which returns all the available patches.
overlap: the amount of overlap of neighboring patches in each dimension (a value between 0.0 and 1.0).
If only one float number is given, it will be applied to all dimensions. Defaults to 0.0.
sort_fn: when `num_patches` is provided, it determines if keep patches with highest values (`"max"`),
Expand Down Expand Up @@ -3184,11 +3186,10 @@ def filter_threshold(self, image_np: np.ndarray, locations: np.ndarray):
image_np: a numpy.ndarray representing a stack of patches
locations: a numpy.ndarray representing the stack of location of each patch
"""
if self.threshold is not None:
n_dims = len(image_np.shape)
idx = np.argwhere(image_np.sum(axis=tuple(range(1, n_dims))) < self.threshold).reshape(-1)
image_np = image_np[idx]
locations = locations[idx]
n_dims = len(image_np.shape)
idx = np.argwhere(image_np.sum(axis=tuple(range(1, n_dims))) < self.threshold).reshape(-1)
image_np = image_np[idx]
locations = locations[idx]
return image_np, locations

def filter_count(self, image_np: np.ndarray, locations: np.ndarray):
Expand Down Expand Up @@ -3230,22 +3231,22 @@ def __call__(self, array: NdarrayOrTensor):
patched_image = np.array(patches[0])
locations = np.array(patches[1])[:, 1:, 0] # only keep the starting location

# Filter patches
if self.num_patches:
patched_image, locations = self.filter_count(patched_image, locations)
drbeh marked this conversation as resolved.
Show resolved Hide resolved
elif self.threshold:
if self.threshold is not None:
patched_image, locations = self.filter_threshold(patched_image, locations)

# Pad the patch list to have the requested number of patches
if self.num_patches:
padding = self.num_patches - len(patched_image)
if padding > 0:
patched_image = np.pad(
patched_image,
[[0, padding], [0, 0]] + [[0, 0]] * len(self.patch_size),
constant_values=self.pad_kwargs.get("constant_values", 0),
)
locations = np.pad(locations, [[0, padding], [0, 0]], constant_values=0)
# Limit number of patches
patched_image, locations = self.filter_count(patched_image, locations)
# Pad the patch list to have the requested number of patches
if self.threshold is None:
padding = self.num_patches - len(patched_image)
if padding > 0:
patched_image = np.pad(
drbeh marked this conversation as resolved.
Show resolved Hide resolved
patched_image,
[[0, padding], [0, 0]] + [[0, 0]] * len(self.patch_size),
constant_values=self.pad_kwargs.get("constant_values", 0),
)
locations = np.pad(locations, [[0, padding], [0, 0]], constant_values=0)

# Convert to MetaTensor
metadata = array.meta if isinstance(array, MetaTensor) else MetaTensor.get_default_meta()
Expand Down
15 changes: 12 additions & 3 deletions tests/test_grid_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,16 +47,22 @@
A,
[np.zeros((3, 3, 3)), np.pad(A[:, :1, 1:4], ((0, 0), (2, 0), (0, 0)), mode="constant")],
]
# Only threshold filtering
TEST_CASE_13 = [{"patch_size": (2, 2), "threshold": 50.0}, A, [A11]]
TEST_CASE_14 = [{"patch_size": (2, 2), "threshold": 150.0}, A, [A11, A12, A21]]
# threshold filtering with num_patches more than available patches (no effect)
TEST_CASE_15 = [{"patch_size": (2, 2), "threshold": 50.0, "num_patches": 3}, A, [A11]]
# threshold filtering with num_patches less than available patches (count filtering)
TEST_CASE_16 = [{"patch_size": (2, 2), "threshold": 150.0, "num_patches": 2}, A, [A11, A12]]

TEST_CASE_MEAT_0 = [
TEST_CASE_META_0 = [
{"patch_size": (2, 2)},
A,
[A11, A12, A21, A22],
[{"location": [0, 0]}, {"location": [0, 2]}, {"location": [2, 0]}, {"location": [2, 2]}],
]

TEST_CASE_MEAT_1 = [
TEST_CASE_META_1 = [
{"patch_size": (2, 2)},
MetaTensor(x=A, meta={"path": "path/to/file"}),
[A11, A12, A21, A22],
Expand Down Expand Up @@ -84,6 +90,9 @@
TEST_CASES.append([p, *TEST_CASE_11])
TEST_CASES.append([p, *TEST_CASE_12])
TEST_CASES.append([p, *TEST_CASE_13])
TEST_CASES.append([p, *TEST_CASE_14])
TEST_CASES.append([p, *TEST_CASE_15])
TEST_CASES.append([p, *TEST_CASE_16])


class TestGridPatch(unittest.TestCase):
Expand All @@ -96,7 +105,7 @@ def test_grid_patch(self, in_type, input_parameters, image, expected):
for output_patch, expected_patch in zip(output, expected):
assert_allclose(output_patch, expected_patch, type_test=False)

@parameterized.expand([TEST_CASE_MEAT_0, TEST_CASE_MEAT_1])
@parameterized.expand([TEST_CASE_META_0, TEST_CASE_META_1])
@SkipIfBeforePyTorchVersion((1, 9, 1))
def test_grid_patch_meta(self, input_parameters, image, expected, expected_meta):
set_track_meta(True)
Expand Down
9 changes: 9 additions & 0 deletions tests/test_grid_patchd.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,13 @@
{"image": A},
[np.zeros((3, 3, 3)), np.pad(A[:, :1, 1:4], ((0, 0), (2, 0), (0, 0)), mode="constant")],
]
# Only threshold filtering
TEST_CASE_13 = [{"patch_size": (2, 2), "threshold": 50.0}, {"image": A}, [A11]]
TEST_CASE_14 = [{"patch_size": (2, 2), "threshold": 150.0}, {"image": A}, [A11, A12, A21]]
# threshold filtering with num_patches more than available patches (no effect)
TEST_CASE_15 = [{"patch_size": (2, 2), "threshold": 50.0, "num_patches": 3}, {"image": A}, [A11]]
# threshold filtering with num_patches less than available patches (count filtering)
TEST_CASE_16 = [{"patch_size": (2, 2), "threshold": 150.0, "num_patches": 2}, {"image": A}, [A11, A12]]

TEST_SINGLE = []
for p in TEST_NDARRAYS:
Expand All @@ -64,6 +70,9 @@
TEST_SINGLE.append([p, *TEST_CASE_11])
TEST_SINGLE.append([p, *TEST_CASE_12])
TEST_SINGLE.append([p, *TEST_CASE_13])
TEST_SINGLE.append([p, *TEST_CASE_14])
TEST_SINGLE.append([p, *TEST_CASE_15])
TEST_SINGLE.append([p, *TEST_CASE_16])


class TestGridPatchd(unittest.TestCase):
Expand Down