Skip to content

Commit

Permalink
Split transform (Project-MONAI#4153)
Browse files Browse the repository at this point in the history
* Redesign whole slide image reading (Project-MONAI#4107)

* Redesign BaseWSIReader,  WSIReader, CuCIMWSIReader

Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com>

* Add unittests for WSIReader

Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com>

* Add image mode for output validation

Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com>

* Update docs

Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com>

* Update references to new WSIReader

Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com>

* Remove legacy WSIReader

Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com>

* Update unittests

Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com>

* Update docs

Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com>

* sort imports

Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com>

* Clean up imports

Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com>

* Update docstrings

Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com>

* Update docs and docstrings

Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix a typo

Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com>

* Remove redundant checking

Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com>

* Update read and other methods

Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com>

* Update wsireader to support multi image and update docstrings

Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com>

* Make workaround for CuImage objects

Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com>

* Add unittests for multi image reading

Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com>

* Update a note about cucim

Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com>

* Update type hints and docstrings

Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com>

* Implement Split transform

Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com>

* Add unittests

Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com>

* Update formatting

Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com>

* Implement SplitDict

Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com>

* Add unittests for SplitDict

Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com>

* Add docs

Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com>

* Remove images from docs

Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com>

* Address all comments

Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com>

* Add example and size check

Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com>

* Update docs

Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com>

* Revert references to new wsireader

Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com>

* Add missing comma

Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com>
  • Loading branch information
drbeh authored and Can-Zhao committed May 10, 2022
1 parent dbe073a commit 3ea72e6
Show file tree
Hide file tree
Showing 6 changed files with 331 additions and 0 deletions.
14 changes: 14 additions & 0 deletions docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -737,6 +737,13 @@ Spatial
:members:
:special-members: __call__

`GridSplit`
"""""""""""
.. autoclass:: GridSplit
:members:
:special-members: __call__


Smooth Field
^^^^^^^^^^^^

Expand Down Expand Up @@ -1506,6 +1513,13 @@ Spatial (Dict)
:members:
:special-members: __call__

`GridSplitd`
""""""""""""
.. autoclass:: GridSplitd
:members:
:special-members: __call__


`RandRotate90d`
"""""""""""""""
.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandRotate90d.png
Expand Down
4 changes: 4 additions & 0 deletions monai/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,7 @@
AffineGrid,
Flip,
GridDistortion,
GridSplit,
Orientation,
Rand2DElastic,
Rand3DElastic,
Expand Down Expand Up @@ -342,6 +343,9 @@
GridDistortiond,
GridDistortionD,
GridDistortionDict,
GridSplitd,
GridSplitD,
GridSplitDict,
Orientationd,
OrientationD,
OrientationDict,
Expand Down
90 changes: 90 additions & 0 deletions monai/transforms/spatial/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import numpy as np
import torch
from numpy.lib.stride_tricks import as_strided

from monai.config import USE_COMPILED, DtypeLike
from monai.config.type_definitions import NdarrayOrTensor
Expand Down Expand Up @@ -66,6 +67,7 @@
"Orientation",
"Flip",
"GridDistortion",
"GridSplit",
"Resize",
"Rotate",
"Zoom",
Expand Down Expand Up @@ -2475,3 +2477,91 @@ def __call__(
if not self._do_transform:
return img
return self.grid_distortion(img, distort_steps=self.distort_steps, mode=mode, padding_mode=padding_mode)


class GridSplit(Transform):
"""
Split the image into patches based on the provided grid in 2D.
Args:
grid: a tuple define the shape of the grid upon which the image is split. Defaults to (2, 2)
size: a tuple or an integer that defines the output patch sizes.
If it's an integer, the value will be repeated for each dimension.
The default is None, where the patch size will be inferred from the grid shape.
Example:
Given an image (torch.Tensor or numpy.ndarray) with size of (3, 10, 10) and a grid of (2, 2),
it will return a Tensor or array with the size of (4, 3, 5, 5).
Here, if the `size` is provided, the returned shape will be (4, 3, size, size)
Note: This transform currently support only image with two spatial dimensions.
"""

backend = [TransformBackends.TORCH, TransformBackends.NUMPY]

def __init__(self, grid: Tuple[int, int] = (2, 2), size: Optional[Union[int, Tuple[int, int]]] = None):
# Grid size
self.grid = grid

# Patch size
self.size = None if size is None else ensure_tuple_rep(size, len(self.grid))

def __call__(self, image: NdarrayOrTensor) -> NdarrayOrTensor:
if self.grid == (1, 1) and self.size is None:
if isinstance(image, torch.Tensor):
return torch.stack([image])
elif isinstance(image, np.ndarray):
return np.stack([image]) # type: ignore
else:
raise ValueError(f"Input type [{type(image)}] is not supported.")

size, steps = self._get_params(image.shape[1:])
patches: NdarrayOrTensor
if isinstance(image, torch.Tensor):
patches = (
image.unfold(1, size[0], steps[0])
.unfold(2, size[1], steps[1])
.flatten(1, 2)
.transpose(0, 1)
.contiguous()
)
elif isinstance(image, np.ndarray):
x_step, y_step = steps
c_stride, x_stride, y_stride = image.strides
n_channels = image.shape[0]
patches = as_strided(
image,
shape=(*self.grid, n_channels, size[0], size[1]),
strides=(x_stride * x_step, y_stride * y_step, c_stride, x_stride, y_stride),
writeable=False,
)
# flatten the first two dimensions
patches = patches.reshape(np.prod(patches.shape[:2]), *patches.shape[2:])
# make it a contiguous array
patches = np.ascontiguousarray(patches)
else:
raise ValueError(f"Input type [{type(image)}] is not supported.")

return patches

def _get_params(self, image_size: Union[Sequence[int], np.ndarray]):
"""
Calculate the size and step required for splitting the image
Args:
The size of the input image
"""
if self.size is not None:
# Set the split size to the given default size
if any(self.size[i] > image_size[i] for i in range(len(self.grid))):
raise ValueError("The image size ({image_size})is smaller than the requested split size ({self.size})")
split_size = self.size
else:
# infer each sub-image size from the image size and the grid
split_size = tuple(image_size[i] // self.grid[i] for i in range(len(self.grid)))

steps = tuple(
(image_size[i] - split_size[i]) // (self.grid[i] - 1) if self.grid[i] > 1 else image_size[i]
for i in range(len(self.grid))
)

return split_size, steps
39 changes: 39 additions & 0 deletions monai/transforms/spatial/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
AffineGrid,
Flip,
GridDistortion,
GridSplit,
Orientation,
Rand2DElastic,
Rand3DElastic,
Expand Down Expand Up @@ -129,6 +130,9 @@
"ZoomDict",
"RandZoomD",
"RandZoomDict",
"GridSplitd",
"GridSplitD",
"GridSplitDict",
]

GridSampleModeSequence = Union[Sequence[Union[GridSampleMode, str]], GridSampleMode, str]
Expand Down Expand Up @@ -2149,6 +2153,40 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N
return d


class GridSplitd(MapTransform):
"""
Split the image into patches based on the provided grid in 2D.
Args:
keys: keys of the corresponding items to be transformed.
grid: a tuple define the shape of the grid upon which the image is split. Defaults to (2, 2)
size: a tuple or an integer that defines the output patch sizes.
If it's an integer, the value will be repeated for each dimension.
The default is None, where the patch size will be inferred from the grid shape.
allow_missing_keys: don't raise exception if key is missing.
Note: This transform currently support only image with two spatial dimensions.
"""

backend = GridSplit.backend

def __init__(
self,
keys: KeysCollection,
grid: Tuple[int, int] = (2, 2),
size: Optional[Union[int, Tuple[int, int]]] = None,
allow_missing_keys: bool = False,
):
super().__init__(keys, allow_missing_keys)
self.splitter = GridSplit(grid=grid, size=size)

def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = dict(data)
for key in self.key_iterator(d):
d[key] = self.splitter(d[key])
return d


SpatialResampleD = SpatialResampleDict = SpatialResampled
ResampleToMatchD = ResampleToMatchDict = ResampleToMatchd
SpacingD = SpacingDict = Spacingd
Expand All @@ -2169,3 +2207,4 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N
RandRotateD = RandRotateDict = RandRotated
ZoomD = ZoomDict = Zoomd
RandZoomD = RandZoomDict = RandZoomd
GridSplitD = GridSplitDict = GridSplitd
84 changes: 84 additions & 0 deletions tests/test_grid_split.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import torch
from parameterized import parameterized

from monai.transforms import GridSplit
from tests.utils import TEST_NDARRAYS, assert_allclose

A11 = torch.randn(3, 2, 2)
A12 = torch.randn(3, 2, 2)
A21 = torch.randn(3, 2, 2)
A22 = torch.randn(3, 2, 2)

A1 = torch.cat([A11, A12], 2)
A2 = torch.cat([A21, A22], 2)
A = torch.cat([A1, A2], 1)

TEST_CASE_0 = [{"grid": (2, 2)}, A, torch.stack([A11, A12, A21, A22])]
TEST_CASE_1 = [{"grid": (2, 1)}, A, torch.stack([A1, A2])]
TEST_CASE_2 = [{"grid": (1, 2)}, A1, torch.stack([A11, A12])]
TEST_CASE_3 = [{"grid": (1, 2)}, A2, torch.stack([A21, A22])]
TEST_CASE_4 = [{"grid": (1, 1), "size": (2, 2)}, A, torch.stack([A11])]
TEST_CASE_5 = [{"grid": (1, 1), "size": 4}, A, torch.stack([A])]
TEST_CASE_6 = [{"grid": (2, 2), "size": 2}, A, torch.stack([A11, A12, A21, A22])]
TEST_CASE_7 = [{"grid": (1, 1)}, A, torch.stack([A])]
TEST_CASE_8 = [
{"grid": (2, 2), "size": 2},
torch.arange(12).reshape(1, 3, 4).to(torch.float32),
torch.Tensor([[[[0, 1], [4, 5]]], [[[2, 3], [6, 7]]], [[[4, 5], [8, 9]]], [[[6, 7], [10, 11]]]]).to(torch.float32),
]

TEST_SINGLE = []
for p in TEST_NDARRAYS:
TEST_SINGLE.append([p, *TEST_CASE_0])
TEST_SINGLE.append([p, *TEST_CASE_1])
TEST_SINGLE.append([p, *TEST_CASE_2])
TEST_SINGLE.append([p, *TEST_CASE_3])
TEST_SINGLE.append([p, *TEST_CASE_4])
TEST_SINGLE.append([p, *TEST_CASE_5])
TEST_SINGLE.append([p, *TEST_CASE_6])
TEST_SINGLE.append([p, *TEST_CASE_7])
TEST_SINGLE.append([p, *TEST_CASE_8])

TEST_CASE_MC_0 = [{"grid": (2, 2)}, [A, A], [torch.stack([A11, A12, A21, A22]), torch.stack([A11, A12, A21, A22])]]
TEST_CASE_MC_1 = [{"grid": (2, 1)}, [A] * 5, [torch.stack([A1, A2])] * 5]
TEST_CASE_MC_2 = [{"grid": (1, 2)}, [A1, A2], [torch.stack([A11, A12]), torch.stack([A21, A22])]]

TEST_MULTIPLE = []
for p in TEST_NDARRAYS:
TEST_MULTIPLE.append([p, *TEST_CASE_MC_0])
TEST_MULTIPLE.append([p, *TEST_CASE_MC_1])
TEST_MULTIPLE.append([p, *TEST_CASE_MC_2])


class TestGridSplit(unittest.TestCase):
@parameterized.expand(TEST_SINGLE)
def test_split_patch_single_call(self, in_type, input_parameters, image, expected):
input_image = in_type(image)
splitter = GridSplit(**input_parameters)
output = splitter(input_image)
assert_allclose(output, expected, type_test=False)

@parameterized.expand(TEST_MULTIPLE)
def test_split_patch_multiple_call(self, in_type, input_parameters, img_list, expected_list):
splitter = GridSplit(**input_parameters)
for image, expected in zip(img_list, expected_list):
input_image = in_type(image)
output = splitter(input_image)
assert_allclose(output, expected, type_test=False)


if __name__ == "__main__":
unittest.main()
Loading

0 comments on commit 3ea72e6

Please sign in to comment.