From f0c0af89367bbafbefd4954724ef9d45901455de Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Fri, 10 Sep 2021 12:23:17 +0100 Subject: [PATCH 01/12] AffineGrid, RandAffineGrid, Resample Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/networks/layers/spatial_transforms.py | 9 +- monai/transforms/spatial/array.py | 119 +++----- monai/transforms/spatial/dictionary.py | 9 +- tests/test_affine_grid.py | 160 +++++----- tests/test_rand_affine_grid.py | 322 ++++++++++---------- tests/test_resampler.py | 181 +++++++---- 6 files changed, 444 insertions(+), 356 deletions(-) diff --git a/monai/networks/layers/spatial_transforms.py b/monai/networks/layers/spatial_transforms.py index 511c24fcb0..6b5acb166a 100644 --- a/monai/networks/layers/spatial_transforms.py +++ b/monai/networks/layers/spatial_transforms.py @@ -46,7 +46,9 @@ def backward(ctx, grad): return None, grads[0], None, None, None -def grid_pull(input: torch.Tensor, grid: torch.Tensor, interpolation="linear", bound="zero", extrapolate: bool = True): +def grid_pull( + input: torch.Tensor, grid: torch.Tensor, interpolation="linear", bound="zero", extrapolate: bool = True +) -> torch.Tensor: """ Sample an image with respect to a deformation field. @@ -112,8 +114,9 @@ def grid_pull(input: torch.Tensor, grid: torch.Tensor, interpolation="linear", b _C.InterpolationType.__members__[i] if isinstance(i, str) else _C.InterpolationType(i) for i in ensure_tuple(interpolation) ] - - return _GridPull.apply(input, grid, interpolation, bound, extrapolate) + out: torch.Tensor + out = _GridPull.apply(input, grid, interpolation, bound, extrapolate) # type: ignore + return out class _GridPush(torch.autograd.Function): diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 816e9d58f2..5e218b9ce7 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -47,7 +47,7 @@ ) from monai.utils.enums import TransformBackends from monai.utils.module import look_up_option -from monai.utils.type_conversion import convert_data_type +from monai.utils.type_conversion import convert_data_type, convert_to_dst_type nib, _ = optional_import("nibabel") @@ -995,39 +995,35 @@ class AffineGrid(Transform): pixel/voxel relative to the center of the input image. Defaults to no translation. scale_params: scale factor for every spatial dims. a tuple of 2 floats for 2D, a tuple of 3 floats for 3D. Defaults to `1.0`. - as_tensor_output: whether to output tensor instead of numpy array, defaults to True. - device: device to store the output grid data. affine: If applied, ignore the params (`rotate_params`, etc.) and use the supplied matrix. Should be square with each side = num of image spatial dimensions + 1. """ + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + def __init__( self, rotate_params: Optional[Union[Sequence[float], float]] = None, shear_params: Optional[Union[Sequence[float], float]] = None, translate_params: Optional[Union[Sequence[float], float]] = None, scale_params: Optional[Union[Sequence[float], float]] = None, - as_tensor_output: bool = True, device: Optional[torch.device] = None, - affine: Optional[Union[np.ndarray, torch.Tensor]] = None, + affine: Optional[NdarrayOrTensor] = None, ) -> None: self.rotate_params = rotate_params self.shear_params = shear_params self.translate_params = translate_params self.scale_params = scale_params - - self.as_tensor_output = as_tensor_output self.device = device - self.affine = affine def __call__( self, spatial_size: Optional[Sequence[int]] = None, - grid: Optional[Union[np.ndarray, torch.Tensor]] = None, - ) -> Tuple[Union[np.ndarray, torch.Tensor], Union[np.ndarray, torch.Tensor]]: + grid: Optional[NdarrayOrTensor] = None, + ) -> Tuple[NdarrayOrTensor, NdarrayOrTensor]: """ Args: spatial_size: output grid size. @@ -1043,7 +1039,7 @@ def __call__( else: raise ValueError("Incompatible values: grid=None and spatial_size=None.") - affine: Union[torch.Tensor, np.ndarray] + affine: NdarrayOrTensor if self.affine is None: spatial_dims = len(grid.shape) - 1 affine = np.eye(spatial_dims + 1) @@ -1058,17 +1054,13 @@ def __call__( else: affine = self.affine - if isinstance(affine, np.ndarray): - affine = torch.as_tensor(np.ascontiguousarray(affine)) + if self.device is not None: + grid, *_ = convert_data_type(grid, torch.Tensor, device=self.device) + grid, *_ = convert_data_type(grid, dtype=torch.float32) + affine, *_ = convert_to_dst_type(affine, grid) - grid = torch.tensor(grid) if not isinstance(grid, torch.Tensor) else grid.detach().clone() - if self.device: - affine = affine.to(self.device) - grid = grid.to(self.device) - grid = (affine.float() @ grid.reshape((grid.shape[0], -1)).float()).reshape([-1] + list(grid.shape[1:])) - if grid is None or not isinstance(grid, torch.Tensor): - raise ValueError("Unknown grid.") - return grid if self.as_tensor_output else np.asarray(grid.cpu().numpy()), affine + grid = (affine @ grid.reshape((grid.shape[0], -1))).reshape([-1] + list(grid.shape[1:])) + return grid, affine class RandAffineGrid(Randomizable, Transform): @@ -1083,7 +1075,6 @@ def __init__( shear_range: RandRange = None, translate_range: RandRange = None, scale_range: RandRange = None, - as_tensor_output: bool = True, device: Optional[torch.device] = None, ) -> None: """ @@ -1110,8 +1101,6 @@ def __init__( scale_range: scaling range with format matching `rotate_range`. it defines the range to randomly select the scale factor to translate for every spatial dims. A value of 1.0 is added to the result. This allows 0 to correspond to no change (i.e., a scaling of 1.0). - as_tensor_output: whether to output tensor instead of numpy array. - defaults to True. device: device to store the output grid data. See also: @@ -1130,9 +1119,8 @@ def __init__( self.translate_params: Optional[List[float]] = None self.scale_params: Optional[List[float]] = None - self.as_tensor_output = as_tensor_output self.device = device - self.affine: Optional[Union[np.ndarray, torch.Tensor]] = None + self.affine: Optional[NdarrayOrTensor] = None def _get_rand_param(self, param_range, add_scalar: float = 0.0): out_param = [] @@ -1154,8 +1142,8 @@ def randomize(self, data: Optional[Any] = None) -> None: def __call__( self, spatial_size: Optional[Sequence[int]] = None, - grid: Optional[Union[np.ndarray, torch.Tensor]] = None, - ) -> Union[np.ndarray, torch.Tensor]: + grid: Optional[NdarrayOrTensor] = None, + ) -> NdarrayOrTensor: """ Args: spatial_size: output grid size. @@ -1170,13 +1158,12 @@ def __call__( shear_params=self.shear_params, translate_params=self.translate_params, scale_params=self.scale_params, - as_tensor_output=self.as_tensor_output, device=self.device, ) grid, self.affine = affine_grid(spatial_size, grid) return grid - def get_transformation_matrix(self) -> Optional[Union[np.ndarray, torch.Tensor]]: + def get_transformation_matrix(self) -> Optional[NdarrayOrTensor]: """Get the most recently applied transformation matrix""" return self.affine @@ -1232,11 +1219,13 @@ def __call__(self, spatial_size: Sequence[int]): class Resample(Transform): + + backend = [TransformBackends.TORCH] + def __init__( self, mode: Union[GridSampleMode, str] = GridSampleMode.BILINEAR, padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER, - as_tensor_output: bool = False, device: Optional[torch.device] = None, ) -> None: """ @@ -1250,21 +1239,19 @@ def __init__( padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} Padding mode for outside grid values. Defaults to ``"border"``. See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample - as_tensor_output: whether to return a torch tensor. Defaults to False. device: device on which the tensor will be allocated. """ self.mode: GridSampleMode = look_up_option(mode, GridSampleMode) self.padding_mode: GridSamplePadMode = look_up_option(padding_mode, GridSamplePadMode) - self.as_tensor_output = as_tensor_output self.device = device def __call__( self, - img: Union[np.ndarray, torch.Tensor], - grid: Optional[Union[np.ndarray, torch.Tensor]] = None, + img: NdarrayOrTensor, + grid: Optional[NdarrayOrTensor] = None, mode: Optional[Union[GridSampleMode, str]] = None, padding_mode: Optional[Union[GridSamplePadMode, str]] = None, - ) -> Union[np.ndarray, torch.Tensor]: + ) -> torch.Tensor: """ Args: img: shape must be (num_channels, H, W[, D]). @@ -1276,18 +1263,14 @@ def __call__( Padding mode for outside grid values. Defaults to ``self.padding_mode``. See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample """ - - if not isinstance(img, torch.Tensor): - img = torch.as_tensor(np.ascontiguousarray(img)) if grid is None: - raise AssertionError("Error, grid argument must be supplied as an ndarray or tensor ") - grid = torch.tensor(grid) if not isinstance(grid, torch.Tensor) else grid.detach().clone() - if self.device: - img = img.to(self.device) - grid = grid.to(self.device) + raise ValueError("Unknown grid.") + img_t: torch.Tensor + img_t, *_ = convert_data_type(img, torch.Tensor, device=self.device, dtype=torch.float32) # type: ignore + grid, *_ = convert_to_dst_type(grid, img_t) if USE_COMPILED: - for i, dim in enumerate(img.shape[1:]): + for i, dim in enumerate(img_t.shape[1:]): grid[i] += (dim - 1.0) / 2.0 grid = grid[:-1] / grid[-1:] grid = grid.permute(list(range(grid.ndimension()))[1:] + [0]) @@ -1302,29 +1285,28 @@ def __call__( bound = 1 _interp_mode = look_up_option(self.mode if mode is None else mode, GridSampleMode).value out = grid_pull( - img.unsqueeze(0).float(), - grid.unsqueeze(0).float(), + img_t.unsqueeze(0), + grid.unsqueeze(0), bound=bound, extrapolate=True, interpolation=1 if _interp_mode == "bilinear" else _interp_mode, )[0] else: - for i, dim in enumerate(img.shape[1:]): + for i, dim in enumerate(img_t.shape[1:]): grid[i] = 2.0 * grid[i] / (dim - 1.0) grid = grid[:-1] / grid[-1:] - index_ordering: List[int] = list(range(img.ndimension() - 2, -1, -1)) + index_ordering: List[int] = list(range(img_t.ndimension() - 2, -1, -1)) grid = grid[index_ordering] grid = grid.permute(list(range(grid.ndimension()))[1:] + [0]) out = torch.nn.functional.grid_sample( - img.unsqueeze(0).float(), - grid.unsqueeze(0).float(), + img_t.unsqueeze(0), + grid.unsqueeze(0), mode=self.mode.value if mode is None else GridSampleMode(mode).value, padding_mode=self.padding_mode.value if padding_mode is None else GridSamplePadMode(padding_mode).value, align_corners=True, )[0] - if self.as_tensor_output: - return torch.as_tensor(out) - return np.asarray(out.cpu().numpy()) + + return out class Affine(Transform): @@ -1343,7 +1325,6 @@ def __init__( spatial_size: Optional[Union[Sequence[int], int]] = None, mode: Union[GridSampleMode, str] = GridSampleMode.BILINEAR, padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.REFLECTION, - as_tensor_output: bool = False, device: Optional[torch.device] = None, image_only: bool = False, ) -> None: @@ -1379,8 +1360,6 @@ def __init__( padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} Padding mode for outside grid values. Defaults to ``"reflection"``. See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample - as_tensor_output: the computation is implemented using pytorch tensors, this option specifies - whether to convert it back to numpy arrays. device: device on which the tensor will be allocated. image_only: if True return only the image volume, otherwise return (image, affine). """ @@ -1389,18 +1368,17 @@ def __init__( shear_params=shear_params, translate_params=translate_params, scale_params=scale_params, - as_tensor_output=True, device=device, ) self.image_only = image_only - self.resampler = Resample(as_tensor_output=as_tensor_output, device=device) + self.resampler = Resample(device=device) self.spatial_size = spatial_size self.mode: GridSampleMode = look_up_option(mode, GridSampleMode) self.padding_mode: GridSamplePadMode = look_up_option(padding_mode, GridSamplePadMode) def __call__( self, - img: Union[np.ndarray, torch.Tensor], + img: NdarrayOrTensor, spatial_size: Optional[Union[Sequence[int], int]] = None, mode: Optional[Union[GridSampleMode, str]] = None, padding_mode: Optional[Union[GridSamplePadMode, str]] = None, @@ -1504,10 +1482,9 @@ def __init__( shear_range=shear_range, translate_range=translate_range, scale_range=scale_range, - as_tensor_output=True, device=device, ) - self.resampler = Resample(as_tensor_output=as_tensor_output, device=device) + self.resampler = Resample(device=device) self.spatial_size = spatial_size self.cache_grid = cache_grid @@ -1586,12 +1563,11 @@ def __call__( """ self.randomize() # if not doing transform and spatial size doesn't change, nothing to do - # except convert to float and convert numpy/torch + # except convert to float and device sp_size = fall_back_tuple(spatial_size or self.spatial_size, img.shape[1:]) do_resampling = self._do_transform or (sp_size != ensure_tuple(img.shape[1:])) if not do_resampling: - img = img.float() if isinstance(img, torch.Tensor) else img.astype("float32") - return torch.Tensor(img) if self.resampler.as_tensor_output else np.array(img) + img, *_ = convert_data_type(img, dtype=torch.float32, device=self.resampler.device) grid = self.get_identity_grid(sp_size) if self._do_transform: grid = self.rand_affine_grid(grid=grid) @@ -1678,10 +1654,9 @@ def __init__( shear_range=shear_range, translate_range=translate_range, scale_range=scale_range, - as_tensor_output=True, device=device, ) - self.resampler = Resample(as_tensor_output=as_tensor_output, device=device) + self.resampler = Resample(device=device) self.spatial_size = spatial_size self.mode: GridSampleMode = look_up_option(mode, GridSampleMode) @@ -1811,8 +1786,14 @@ def __init__( - :py:class:`Affine` for the affine transformation parameters configurations. """ RandomizableTransform.__init__(self, prob) - self.rand_affine_grid = RandAffineGrid(rotate_range, shear_range, translate_range, scale_range, True, device) - self.resampler = Resample(as_tensor_output=as_tensor_output, device=device) + self.rand_affine_grid = RandAffineGrid( + rotate_range=rotate_range, + shear_range=shear_range, + translate_range=translate_range, + scale_range=scale_range, + device=device, + ) + self.resampler = Resample(device=device) self.sigma_range = sigma_range self.magnitude_range = magnitude_range diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index c09d8e8011..d4bcacd71c 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -56,6 +56,7 @@ ) from monai.utils.enums import InverseKeys from monai.utils.module import optional_import +from monai.utils.type_conversion import convert_data_type nib, _ = optional_import("nibabel") @@ -635,7 +636,6 @@ def __init__( translate_params=translate_params, scale_params=scale_params, spatial_size=spatial_size, - as_tensor_output=as_tensor_output, device=device, ) self.mode = ensure_tuple_rep(mode, len(self.keys)) @@ -822,10 +822,9 @@ def __call__( d[key] = self.rand_affine.resampler(d[key], grid, mode=mode, padding_mode=padding_mode) # if not doing transform and and spatial size is unchanged, only need to do numpy/torch conversion else: - if self.rand_affine.resampler.as_tensor_output and not isinstance(d[key], torch.Tensor): - d[key] = torch.Tensor(d[key]) - elif not self.rand_affine.resampler.as_tensor_output and isinstance(d[key], torch.Tensor): - d[key] = d[key].detach().cpu().numpy() # type: ignore[union-attr] + d[key], *_ = convert_data_type( + d[key], torch.Tensor, dtype=torch.float32, device=self.rand_affine.resampler.device + ) return d diff --git a/tests/test_affine_grid.py b/tests/test_affine_grid.py index 24772b9a21..18a5f89075 100644 --- a/tests/test_affine_grid.py +++ b/tests/test_affine_grid.py @@ -16,88 +16,106 @@ from parameterized import parameterized from monai.transforms import AffineGrid +from tests.utils import TEST_NDARRAYS, assert_allclose -TEST_CASES = [ - [ - {"as_tensor_output": False, "device": torch.device("cpu:0")}, - {"spatial_size": (2, 2)}, - np.array([[[-0.5, -0.5], [0.5, 0.5]], [[-0.5, 0.5], [-0.5, 0.5]], [[1.0, 1.0], [1.0, 1.0]]]), - ], - [ - {"as_tensor_output": True, "device": None}, - {"spatial_size": (2, 2)}, - torch.tensor([[[-0.5, -0.5], [0.5, 0.5]], [[-0.5, 0.5], [-0.5, 0.5]], [[1.0, 1.0], [1.0, 1.0]]]), - ], - [{"as_tensor_output": False, "device": None}, {"grid": np.ones((3, 3, 3))}, np.ones((3, 3, 3))], - [{"as_tensor_output": True, "device": torch.device("cpu:0")}, {"grid": np.ones((3, 3, 3))}, torch.ones((3, 3, 3))], - [{"as_tensor_output": False, "device": None}, {"grid": torch.ones((3, 3, 3))}, np.ones((3, 3, 3))], - [ - {"as_tensor_output": True, "device": torch.device("cpu:0")}, - {"grid": torch.ones((3, 3, 3))}, - torch.ones((3, 3, 3)), - ], - [ - { - "rotate_params": (1.0, 1.0), - "scale_params": (-20, 10), - "as_tensor_output": True, - "device": torch.device("cpu:0"), - }, - {"grid": torch.ones((3, 3, 3))}, - torch.tensor( +TESTS = [] +for p in TEST_NDARRAYS: + for device in [None, "cpu", "cuda"] if torch.cuda.is_available() else [None, "cpu"]: + TESTS.append( [ - [[-19.2208, -19.2208, -19.2208], [-19.2208, -19.2208, -19.2208], [-19.2208, -19.2208, -19.2208]], - [[-11.4264, -11.4264, -11.4264], [-11.4264, -11.4264, -11.4264], [-11.4264, -11.4264, -11.4264]], - [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], + {"device": device}, + {"spatial_size": (2, 2)}, + np.array([[[-0.5, -0.5], [0.5, 0.5]], [[-0.5, 0.5], [-0.5, 0.5]], [[1.0, 1.0], [1.0, 1.0]]]), ] - ), - ], - [ - { - "rotate_params": (1.0, 1.0, 1.0), - "scale_params": (-20, 10), - "as_tensor_output": True, - "device": torch.device("cpu:0"), - }, - {"grid": torch.ones((4, 3, 3, 3))}, - torch.tensor( + ) + + TESTS.append([{"device": device}, {"grid": p(np.ones((3, 3, 3)))}, p(np.ones((3, 3, 3)))]) + TESTS.append([{"device": device}, {"grid": p(torch.ones((3, 3, 3)))}, p(np.ones((3, 3, 3)))]) + TESTS.append( + [ + { + "rotate_params": (1.0, 1.0), + "scale_params": (-20, 10), + "device": device, + }, + {"grid": p(torch.ones((3, 3, 3)))}, + p( + torch.tensor( + [ + [ + [-19.2208, -19.2208, -19.2208], + [-19.2208, -19.2208, -19.2208], + [-19.2208, -19.2208, -19.2208], + ], + [ + [-11.4264, -11.4264, -11.4264], + [-11.4264, -11.4264, -11.4264], + [-11.4264, -11.4264, -11.4264], + ], + [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], + ] + ) + ), + ] + ) + TESTS.append( [ - [ - [[-9.5435, -9.5435, -9.5435], [-9.5435, -9.5435, -9.5435], [-9.5435, -9.5435, -9.5435]], - [[-9.5435, -9.5435, -9.5435], [-9.5435, -9.5435, -9.5435], [-9.5435, -9.5435, -9.5435]], - [[-9.5435, -9.5435, -9.5435], [-9.5435, -9.5435, -9.5435], [-9.5435, -9.5435, -9.5435]], - ], - [ - [[-20.2381, -20.2381, -20.2381], [-20.2381, -20.2381, -20.2381], [-20.2381, -20.2381, -20.2381]], - [[-20.2381, -20.2381, -20.2381], [-20.2381, -20.2381, -20.2381], [-20.2381, -20.2381, -20.2381]], - [[-20.2381, -20.2381, -20.2381], [-20.2381, -20.2381, -20.2381], [-20.2381, -20.2381, -20.2381]], - ], - [ - [[-0.5844, -0.5844, -0.5844], [-0.5844, -0.5844, -0.5844], [-0.5844, -0.5844, -0.5844]], - [[-0.5844, -0.5844, -0.5844], [-0.5844, -0.5844, -0.5844], [-0.5844, -0.5844, -0.5844]], - [[-0.5844, -0.5844, -0.5844], [-0.5844, -0.5844, -0.5844], [-0.5844, -0.5844, -0.5844]], - ], - [ - [[1.0000, 1.0000, 1.0000], [1.0000, 1.0000, 1.0000], [1.0000, 1.0000, 1.0000]], - [[1.0000, 1.0000, 1.0000], [1.0000, 1.0000, 1.0000], [1.0000, 1.0000, 1.0000]], - [[1.0000, 1.0000, 1.0000], [1.0000, 1.0000, 1.0000], [1.0000, 1.0000, 1.0000]], - ], + { + "rotate_params": (1.0, 1.0, 1.0), + "scale_params": (-20, 10), + "device": device, + }, + {"grid": p(torch.ones((4, 3, 3, 3)))}, + p( + torch.tensor( + [ + [ + [[-9.5435, -9.5435, -9.5435], [-9.5435, -9.5435, -9.5435], [-9.5435, -9.5435, -9.5435]], + [[-9.5435, -9.5435, -9.5435], [-9.5435, -9.5435, -9.5435], [-9.5435, -9.5435, -9.5435]], + [[-9.5435, -9.5435, -9.5435], [-9.5435, -9.5435, -9.5435], [-9.5435, -9.5435, -9.5435]], + ], + [ + [ + [-20.2381, -20.2381, -20.2381], + [-20.2381, -20.2381, -20.2381], + [-20.2381, -20.2381, -20.2381], + ], + [ + [-20.2381, -20.2381, -20.2381], + [-20.2381, -20.2381, -20.2381], + [-20.2381, -20.2381, -20.2381], + ], + [ + [-20.2381, -20.2381, -20.2381], + [-20.2381, -20.2381, -20.2381], + [-20.2381, -20.2381, -20.2381], + ], + ], + [ + [[-0.5844, -0.5844, -0.5844], [-0.5844, -0.5844, -0.5844], [-0.5844, -0.5844, -0.5844]], + [[-0.5844, -0.5844, -0.5844], [-0.5844, -0.5844, -0.5844], [-0.5844, -0.5844, -0.5844]], + [[-0.5844, -0.5844, -0.5844], [-0.5844, -0.5844, -0.5844], [-0.5844, -0.5844, -0.5844]], + ], + [ + [[1.0000, 1.0000, 1.0000], [1.0000, 1.0000, 1.0000], [1.0000, 1.0000, 1.0000]], + [[1.0000, 1.0000, 1.0000], [1.0000, 1.0000, 1.0000], [1.0000, 1.0000, 1.0000]], + [[1.0000, 1.0000, 1.0000], [1.0000, 1.0000, 1.0000], [1.0000, 1.0000, 1.0000]], + ], + ] + ) + ), ] - ), - ], -] + ) class TestAffineGrid(unittest.TestCase): - @parameterized.expand(TEST_CASES) + @parameterized.expand(TESTS) def test_affine_grid(self, input_param, input_data, expected_val): g = AffineGrid(**input_param) result, _ = g(**input_data) - self.assertEqual(isinstance(result, torch.Tensor), isinstance(expected_val, torch.Tensor)) - if isinstance(result, torch.Tensor): - np.testing.assert_allclose(result.cpu().numpy(), expected_val.cpu().numpy(), rtol=1e-4, atol=1e-4) - else: - np.testing.assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4) + if "device" in input_data: + self.assertEqual(result.device, input_data[device]) + assert_allclose(result, expected_val, rtol=1e-3, atol=1e-3) if __name__ == "__main__": diff --git a/tests/test_rand_affine_grid.py b/tests/test_rand_affine_grid.py index 605d0a30ba..0670b4d042 100644 --- a/tests/test_rand_affine_grid.py +++ b/tests/test_rand_affine_grid.py @@ -16,182 +16,192 @@ from parameterized import parameterized from monai.transforms import RandAffineGrid +from tests.utils import TEST_NDARRAYS, assert_allclose -TEST_CASES = [ - [{"as_tensor_output": False, "device": None}, {"grid": torch.ones((3, 3, 3))}, np.ones((3, 3, 3))], - [ - {"rotate_range": (1, 2), "translate_range": (3, 3, 3)}, - {"grid": torch.arange(0, 27).reshape((3, 3, 3))}, - torch.tensor( - np.array( - [ - [ - [-32.81998, -33.910976, -35.001972], - [-36.092968, -37.183964, -38.27496], - [-39.36596, -40.456955, -41.54795], - ], - [[2.1380205, 3.1015975, 4.0651755], [5.028752, 5.9923296, 6.955907], [7.919484, 8.883063, 9.84664]], - [[18.0, 19.0, 20.0], [21.0, 22.0, 23.0], [24.0, 25.0, 26.0]], - ] - ) - ), - ], - [ - {"translate_range": (3, 3, 3), "as_tensor_output": False, "device": torch.device("cpu:0")}, - {"spatial_size": (3, 3, 3)}, - np.array( +TESTS = [] +for p in TEST_NDARRAYS: + for device in [None, "cpu", "cuda"] if torch.cuda.is_available() else [None, "cpu"]: + TESTS.append([{"device": device}, {"grid": p(torch.ones((3, 3, 3)))}, p(np.ones((3, 3, 3)))]) + TESTS.append( [ - [ - [ - [0.17881513, 0.17881513, 0.17881513], - [0.17881513, 0.17881513, 0.17881513], - [0.17881513, 0.17881513, 0.17881513], - ], - [ - [1.1788151, 1.1788151, 1.1788151], - [1.1788151, 1.1788151, 1.1788151], - [1.1788151, 1.1788151, 1.1788151], - ], - [ - [2.1788151, 2.1788151, 2.1788151], - [2.1788151, 2.1788151, 2.1788151], - [2.1788151, 2.1788151, 2.1788151], - ], - ], - [ - [ - [-2.283164, -2.283164, -2.283164], - [-1.283164, -1.283164, -1.283164], - [-0.28316402, -0.28316402, -0.28316402], - ], - [ - [-2.283164, -2.283164, -2.283164], - [-1.283164, -1.283164, -1.283164], - [-0.28316402, -0.28316402, -0.28316402], - ], - [ - [-2.283164, -2.283164, -2.283164], - [-1.283164, -1.283164, -1.283164], - [-0.28316402, -0.28316402, -0.28316402], - ], - ], - [ - [ - [-2.6388912, -1.6388912, -0.6388912], - [-2.6388912, -1.6388912, -0.6388912], - [-2.6388912, -1.6388912, -0.6388912], - ], - [ - [-2.6388912, -1.6388912, -0.6388912], - [-2.6388912, -1.6388912, -0.6388912], - [-2.6388912, -1.6388912, -0.6388912], - ], - [ - [-2.6388912, -1.6388912, -0.6388912], - [-2.6388912, -1.6388912, -0.6388912], - [-2.6388912, -1.6388912, -0.6388912], - ], - ], - [ - [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], - [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], - [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], - ], - ] - ), - ], - [ - {"rotate_range": (1.0, 1.0, 1.0), "shear_range": (0.1,), "scale_range": (1.2,)}, - {"grid": torch.arange(0, 108).reshape((4, 3, 3, 3))}, - torch.tensor( - np.array( - [ - [ - [ - [-9.4201e00, -8.1672e00, -6.9143e00], - [-5.6614e00, -4.4085e00, -3.1556e00], - [-1.9027e00, -6.4980e-01, 6.0310e-01], - ], - [ - [1.8560e00, 3.1089e00, 4.3618e00], - [5.6147e00, 6.8676e00, 8.1205e00], - [9.3734e00, 1.0626e01, 1.1879e01], - ], + {"rotate_range": (1, 2), "translate_range": (3, 3, 3)}, + {"grid": p(torch.arange(0, 27).reshape((3, 3, 3)))}, + p( + np.array( [ - [1.3132e01, 1.4385e01, 1.5638e01], - [1.6891e01, 1.8144e01, 1.9397e01], - [2.0650e01, 2.1902e01, 2.3155e01], - ], - ], - [ - [ - [9.9383e-02, -4.8845e-01, -1.0763e00], - [-1.6641e00, -2.2519e00, -2.8398e00], - [-3.4276e00, -4.0154e00, -4.6032e00], - ], - [ - [-5.1911e00, -5.7789e00, -6.3667e00], - [-6.9546e00, -7.5424e00, -8.1302e00], - [-8.7180e00, -9.3059e00, -9.8937e00], - ], - [ - [-1.0482e01, -1.1069e01, -1.1657e01], - [-1.2245e01, -1.2833e01, -1.3421e01], - [-1.4009e01, -1.4596e01, -1.5184e01], - ], - ], + [ + [-32.81998, -33.910976, -35.001972], + [-36.092968, -37.183964, -38.27496], + [-39.36596, -40.456955, -41.54795], + ], + [ + [2.1380205, 3.1015975, 4.0651755], + [5.028752, 5.9923296, 6.955907], + [7.919484, 8.883063, 9.84664], + ], + [[18.0, 19.0, 20.0], [21.0, 22.0, 23.0], [24.0, 25.0, 26.0]], + ] + ) + ), + ] + ) + TESTS.append( + [ + {"translate_range": (3, 3, 3), "device": device}, + {"spatial_size": (3, 3, 3)}, + np.array( [ [ - [5.9635e01, 6.1199e01, 6.2764e01], - [6.4328e01, 6.5892e01, 6.7456e01], - [6.9021e01, 7.0585e01, 7.2149e01], - ], - [ - [7.3714e01, 7.5278e01, 7.6842e01], - [7.8407e01, 7.9971e01, 8.1535e01], - [8.3099e01, 8.4664e01, 8.6228e01], + [ + [0.17881513, 0.17881513, 0.17881513], + [0.17881513, 0.17881513, 0.17881513], + [0.17881513, 0.17881513, 0.17881513], + ], + [ + [1.1788151, 1.1788151, 1.1788151], + [1.1788151, 1.1788151, 1.1788151], + [1.1788151, 1.1788151, 1.1788151], + ], + [ + [2.1788151, 2.1788151, 2.1788151], + [2.1788151, 2.1788151, 2.1788151], + [2.1788151, 2.1788151, 2.1788151], + ], ], [ - [8.7792e01, 8.9357e01, 9.0921e01], - [9.2485e01, 9.4049e01, 9.5614e01], - [9.7178e01, 9.8742e01, 1.0031e02], + [ + [-2.283164, -2.283164, -2.283164], + [-1.283164, -1.283164, -1.283164], + [-0.28316402, -0.28316402, -0.28316402], + ], + [ + [-2.283164, -2.283164, -2.283164], + [-1.283164, -1.283164, -1.283164], + [-0.28316402, -0.28316402, -0.28316402], + ], + [ + [-2.283164, -2.283164, -2.283164], + [-1.283164, -1.283164, -1.283164], + [-0.28316402, -0.28316402, -0.28316402], + ], ], - ], - [ [ - [8.1000e01, 8.2000e01, 8.3000e01], - [8.4000e01, 8.5000e01, 8.6000e01], - [8.7000e01, 8.8000e01, 8.9000e01], + [ + [-2.6388912, -1.6388912, -0.6388912], + [-2.6388912, -1.6388912, -0.6388912], + [-2.6388912, -1.6388912, -0.6388912], + ], + [ + [-2.6388912, -1.6388912, -0.6388912], + [-2.6388912, -1.6388912, -0.6388912], + [-2.6388912, -1.6388912, -0.6388912], + ], + [ + [-2.6388912, -1.6388912, -0.6388912], + [-2.6388912, -1.6388912, -0.6388912], + [-2.6388912, -1.6388912, -0.6388912], + ], ], [ - [9.0000e01, 9.1000e01, 9.2000e01], - [9.3000e01, 9.4000e01, 9.5000e01], - [9.6000e01, 9.7000e01, 9.8000e01], + [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], + [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], + [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], ], + ] + ), + ] + ) + TESTS.append( + [ + {"device": device, "rotate_range": (1.0, 1.0, 1.0), "shear_range": (0.1,), "scale_range": (1.2,)}, + {"grid": p(torch.arange(0, 108).reshape((4, 3, 3, 3)))}, + p( + np.array( [ - [9.9000e01, 1.0000e02, 1.0100e02], - [1.0200e02, 1.0300e02, 1.0400e02], - [1.0500e02, 1.0600e02, 1.0700e02], - ], - ], - ] - ) - ), - ], -] + [ + [ + [-9.4201e00, -8.1672e00, -6.9143e00], + [-5.6614e00, -4.4085e00, -3.1556e00], + [-1.9027e00, -6.4980e-01, 6.0310e-01], + ], + [ + [1.8560e00, 3.1089e00, 4.3618e00], + [5.6147e00, 6.8676e00, 8.1205e00], + [9.3734e00, 1.0626e01, 1.1879e01], + ], + [ + [1.3132e01, 1.4385e01, 1.5638e01], + [1.6891e01, 1.8144e01, 1.9397e01], + [2.0650e01, 2.1902e01, 2.3155e01], + ], + ], + [ + [ + [9.9383e-02, -4.8845e-01, -1.0763e00], + [-1.6641e00, -2.2519e00, -2.8398e00], + [-3.4276e00, -4.0154e00, -4.6032e00], + ], + [ + [-5.1911e00, -5.7789e00, -6.3667e00], + [-6.9546e00, -7.5424e00, -8.1302e00], + [-8.7180e00, -9.3059e00, -9.8937e00], + ], + [ + [-1.0482e01, -1.1069e01, -1.1657e01], + [-1.2245e01, -1.2833e01, -1.3421e01], + [-1.4009e01, -1.4596e01, -1.5184e01], + ], + ], + [ + [ + [5.9635e01, 6.1199e01, 6.2764e01], + [6.4328e01, 6.5892e01, 6.7456e01], + [6.9021e01, 7.0585e01, 7.2149e01], + ], + [ + [7.3714e01, 7.5278e01, 7.6842e01], + [7.8407e01, 7.9971e01, 8.1535e01], + [8.3099e01, 8.4664e01, 8.6228e01], + ], + [ + [8.7792e01, 8.9357e01, 9.0921e01], + [9.2485e01, 9.4049e01, 9.5614e01], + [9.7178e01, 9.8742e01, 1.0031e02], + ], + ], + [ + [ + [8.1000e01, 8.2000e01, 8.3000e01], + [8.4000e01, 8.5000e01, 8.6000e01], + [8.7000e01, 8.8000e01, 8.9000e01], + ], + [ + [9.0000e01, 9.1000e01, 9.2000e01], + [9.3000e01, 9.4000e01, 9.5000e01], + [9.6000e01, 9.7000e01, 9.8000e01], + ], + [ + [9.9000e01, 1.0000e02, 1.0100e02], + [1.0200e02, 1.0300e02, 1.0400e02], + [1.0500e02, 1.0600e02, 1.0700e02], + ], + ], + ] + ) + ), + ] + ) class TestRandAffineGrid(unittest.TestCase): - @parameterized.expand(TEST_CASES) + @parameterized.expand(TESTS) def test_rand_affine_grid(self, input_param, input_data, expected_val): g = RandAffineGrid(**input_param) g.set_random_state(123) result = g(**input_data) - self.assertEqual(isinstance(result, torch.Tensor), isinstance(expected_val, torch.Tensor)) - if isinstance(result, torch.Tensor): - np.testing.assert_allclose(result.cpu().numpy(), expected_val.cpu().numpy(), rtol=1e-4, atol=1e-4) - else: - np.testing.assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4) + if "device" in input_data: + self.assertEqual(result.device, input_data[device]) + assert_allclose(result, expected_val, rtol=1e-2, atol=1e-2) if __name__ == "__main__": diff --git a/tests/test_resampler.py b/tests/test_resampler.py index 2be94acebd..af23421ecc 100644 --- a/tests/test_resampler.py +++ b/tests/test_resampler.py @@ -17,69 +17,146 @@ from monai.transforms import Resample from monai.transforms.utils import create_grid +from tests.utils import TEST_NDARRAYS, assert_allclose -TEST_CASES = [ - [ - dict(padding_mode="zeros", as_tensor_output=False, device=None), - {"grid": create_grid((2, 2)), "img": np.arange(4).reshape((1, 2, 2))}, - np.array([[[0.0, 1.0], [2.0, 3.0]]]), - ], - [ - dict(padding_mode="zeros", as_tensor_output=False, device=None), - {"grid": create_grid((4, 4)), "img": np.arange(4).reshape((1, 2, 2))}, - np.array([[[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 2.0, 3.0, 0.0], [0.0, 0.0, 0.0, 0.0]]]), - ], - [ - dict(padding_mode="border", as_tensor_output=False, device=None), - {"grid": create_grid((4, 4)), "img": np.arange(4).reshape((1, 2, 2))}, - np.array([[[0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 1.0, 1.0], [2.0, 2.0, 3, 3.0], [2.0, 2.0, 3.0, 3.0]]]), - ], - [ - dict(padding_mode="reflection", as_tensor_output=False, device=None), - {"grid": create_grid((4, 4)), "img": np.arange(4).reshape((1, 2, 2)), "mode": "nearest"}, - np.array([[[3.0, 2.0, 3.0, 2.0], [1.0, 0.0, 1.0, 0.0], [3.0, 2.0, 3.0, 2.0], [1.0, 0.0, 1.0, 0.0]]]), - ], - [ - dict(padding_mode="zeros", as_tensor_output=False, device=None), - {"grid": create_grid((4, 4, 4)), "img": np.arange(8).reshape((1, 2, 2, 2)), "mode": "bilinear"}, - np.array( - [ +TESTS = [] +for p in TEST_NDARRAYS: + for q in TEST_NDARRAYS: + for device in [None, "cpu", "cuda"] if torch.cuda.is_available() else [None, "cpu"]: + TESTS.append( [ - [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 2.0, 3.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - [[0.0, 0.0, 0.0, 0.0], [0.0, 4.0, 5.0, 0.0], [0.0, 6.0, 7.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]], + dict(padding_mode="zeros", device=device), + {"grid": p(create_grid((2, 2))), "img": q(np.arange(4).reshape((1, 2, 2)))}, + q(np.array([[[0.0, 1.0], [2.0, 3.0]]])), ] - ] - ), - ], - [ - dict(padding_mode="border", as_tensor_output=False, device=None), - {"grid": create_grid((4, 4, 4)), "img": np.arange(8).reshape((1, 2, 2, 2)), "mode": "bilinear"}, - np.array( - [ + ) + TESTS.append( [ - [[0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 1.0, 1.0], [2.0, 2.0, 3.0, 3.0], [2.0, 2.0, 3.0, 3.0]], - [[0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 1.0, 1.0], [2.0, 2.0, 3.0, 3.0], [2.0, 2.0, 3.0, 3.0]], - [[4.0, 4.0, 5.0, 5.0], [4.0, 4.0, 5.0, 5.0], [6.0, 6.0, 7.0, 7.0], [6.0, 6.0, 7.0, 7.0]], - [[4.0, 4.0, 5.0, 5.0], [4.0, 4.0, 5.0, 5.0], [6.0, 6.0, 7.0, 7.0], [6.0, 6.0, 7.0, 7.0]], + dict(padding_mode="zeros", device=device), + {"grid": p(create_grid((4, 4))), "img": q(np.arange(4).reshape((1, 2, 2)))}, + q( + np.array( + [[[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 2.0, 3.0, 0.0], [0.0, 0.0, 0.0, 0.0]]] + ) + ), ] - ] - ), - ], -] + ) + TESTS.append( + [ + dict(padding_mode="border", device=device), + {"grid": p(create_grid((4, 4))), "img": q(np.arange(4).reshape((1, 2, 2)))}, + q( + np.array( + [[[0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 1.0, 1.0], [2.0, 2.0, 3, 3.0], [2.0, 2.0, 3.0, 3.0]]] + ) + ), + ] + ) + TESTS.append( + [ + dict(padding_mode="reflection", device=device), + {"grid": p(create_grid((4, 4))), "img": q(np.arange(4).reshape((1, 2, 2))), "mode": "nearest"}, + q( + np.array( + [[[3.0, 2.0, 3.0, 2.0], [1.0, 0.0, 1.0, 0.0], [3.0, 2.0, 3.0, 2.0], [1.0, 0.0, 1.0, 0.0]]] + ) + ), + ] + ) + TESTS.append( + [ + dict(padding_mode="zeros", device=device), + { + "grid": p(create_grid((4, 4, 4))), + "img": q(np.arange(8).reshape((1, 2, 2, 2))), + "mode": "bilinear", + }, + q( + np.array( + [ + [ + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0], + [0.0, 2.0, 3.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 4.0, 5.0, 0.0], + [0.0, 6.0, 7.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + ] + ] + ) + ), + ] + ) + TESTS.append( + [ + dict(padding_mode="border", device=device), + { + "grid": p(create_grid((4, 4, 4))), + "img": q(np.arange(8).reshape((1, 2, 2, 2))), + "mode": "bilinear", + }, + q( + np.array( + [ + [ + [ + [0.0, 0.0, 1.0, 1.0], + [0.0, 0.0, 1.0, 1.0], + [2.0, 2.0, 3.0, 3.0], + [2.0, 2.0, 3.0, 3.0], + ], + [ + [0.0, 0.0, 1.0, 1.0], + [0.0, 0.0, 1.0, 1.0], + [2.0, 2.0, 3.0, 3.0], + [2.0, 2.0, 3.0, 3.0], + ], + [ + [4.0, 4.0, 5.0, 5.0], + [4.0, 4.0, 5.0, 5.0], + [6.0, 6.0, 7.0, 7.0], + [6.0, 6.0, 7.0, 7.0], + ], + [ + [4.0, 4.0, 5.0, 5.0], + [4.0, 4.0, 5.0, 5.0], + [6.0, 6.0, 7.0, 7.0], + [6.0, 6.0, 7.0, 7.0], + ], + ] + ] + ) + ), + ] + ) class TestResample(unittest.TestCase): - @parameterized.expand(TEST_CASES) + @parameterized.expand(TESTS) def test_resample(self, input_param, input_data, expected_val): g = Resample(**input_param) result = g(**input_data) - self.assertEqual(isinstance(result, torch.Tensor), isinstance(expected_val, torch.Tensor)) - if isinstance(result, torch.Tensor): - np.testing.assert_allclose(result.cpu().numpy(), expected_val.cpu().numpy(), rtol=1e-4, atol=1e-4) - else: - np.testing.assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4) + if "device" in input_data: + self.assertEqual(result.device, input_data["device"]) + assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4) if __name__ == "__main__": From e13bf30409293cc4de76d4dfbd3955d680cdd994 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Mon, 13 Sep 2021 11:25:53 +0100 Subject: [PATCH 02/12] Affine, RandAffine, Affined, RandAffined Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/spatial/array.py | 9 +- monai/transforms/spatial/dictionary.py | 35 +-- tests/test_affine.py | 179 ++++++++---- tests/test_affine_grid.py | 2 +- tests/test_affined.py | 193 ++++++++----- tests/test_inverse_collation.py | 2 - tests/test_rand_affine.py | 222 ++++++++------- tests/test_rand_affine_grid.py | 2 +- tests/test_rand_affined.py | 360 +++++++++++++------------ tests/utils.py | 29 +- 10 files changed, 587 insertions(+), 446 deletions(-) diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 00a24e940a..162befe8d0 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -1069,7 +1069,7 @@ def __call__( if self.device is not None: grid, *_ = convert_data_type(grid, torch.Tensor, device=self.device) - grid, *_ = convert_data_type(grid, dtype=torch.float32) + grid, *_ = convert_data_type(grid, dtype=float) affine, *_ = convert_to_dst_type(affine, grid) grid = (affine @ grid.reshape((grid.shape[0], -1))).reshape([-1] + list(grid.shape[1:])) @@ -1329,6 +1329,8 @@ class Affine(Transform): """ + backend = list(set(AffineGrid.backend) & set(Resample.backend)) + def __init__( self, rotate_params: Optional[Union[Sequence[float], float]] = None, @@ -1425,6 +1427,8 @@ class RandAffine(RandomizableTransform): """ + backend = Affine.backend + def __init__( self, prob: float = 0.1, @@ -1436,7 +1440,6 @@ def __init__( mode: Union[GridSampleMode, str] = GridSampleMode.BILINEAR, padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.REFLECTION, cache_grid: bool = False, - as_tensor_output: bool = True, device: Optional[torch.device] = None, ) -> None: """ @@ -1480,8 +1483,6 @@ def __init__( cache_grid: whether to cache the identity sampling grid. If the spatial size is not dynamically defined by input image, enabling this option could accelerate the transform. - as_tensor_output: the computation is implemented using pytorch tensors, this option specifies - whether to convert it back to numpy arrays. device: device on which the tensor will be allocated. See also: diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index c0375f10b6..a8041697c7 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -574,6 +574,8 @@ class Affined(MapTransform, InvertibleTransform): Dictionary-based wrapper of :py:class:`monai.transforms.Affine`. """ + backend = Affine.backend + def __init__( self, keys: KeysCollection, @@ -584,7 +586,6 @@ def __init__( spatial_size: Optional[Union[Sequence[int], int]] = None, mode: GridSampleModeSequence = GridSampleMode.BILINEAR, padding_mode: GridSamplePadModeSequence = GridSamplePadMode.REFLECTION, - as_tensor_output: bool = False, device: Optional[torch.device] = None, allow_missing_keys: bool = False, ) -> None: @@ -621,8 +622,6 @@ def __init__( Padding mode for outside grid values. Defaults to ``"reflection"``. See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample It also can be a sequence of string, each element corresponds to a key in ``keys``. - as_tensor_output: the computation is implemented using pytorch tensors, this option specifies - whether to convert it back to numpy arrays. device: device on which the tensor will be allocated. allow_missing_keys: don't raise exception if key is missing. @@ -642,9 +641,7 @@ def __init__( self.mode = ensure_tuple_rep(mode, len(self.keys)) self.padding_mode = ensure_tuple_rep(padding_mode, len(self.keys)) - def __call__( - self, data: Mapping[Hashable, Union[np.ndarray, torch.Tensor]] - ) -> Dict[Hashable, Union[np.ndarray, torch.Tensor]]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode): orig_size = d[key].shape[1:] @@ -661,7 +658,7 @@ def __call__( ) return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = deepcopy(dict(data)) for key in self.key_iterator(d): @@ -677,10 +674,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar grid, _ = affine_grid(orig_size) # type: ignore # Apply inverse transform - out = self.affine.resampler(d[key], grid, mode, padding_mode) - - # Convert to numpy - d[key] = out if isinstance(out, np.ndarray) else out.cpu().numpy() + d[key] = self.affine.resampler(d[key], grid, mode, padding_mode) # Remove the applied transform self.pop_transform(d, key) @@ -693,6 +687,8 @@ class RandAffined(RandomizableTransform, MapTransform, InvertibleTransform): Dictionary-based wrapper of :py:class:`monai.transforms.RandAffine`. """ + backend = Affine.backend + def __init__( self, keys: KeysCollection, @@ -705,7 +701,6 @@ def __init__( mode: GridSampleModeSequence = GridSampleMode.BILINEAR, padding_mode: GridSamplePadModeSequence = GridSamplePadMode.REFLECTION, cache_grid: bool = False, - as_tensor_output: bool = True, device: Optional[torch.device] = None, allow_missing_keys: bool = False, ) -> None: @@ -753,8 +748,6 @@ def __init__( cache_grid: whether to cache the identity sampling grid. If the spatial size is not dynamically defined by input image, enabling this option could accelerate the transform. - as_tensor_output: the computation is implemented using pytorch tensors, this option specifies - whether to convert it back to numpy arrays. device: device on which the tensor will be allocated. allow_missing_keys: don't raise exception if key is missing. @@ -772,7 +765,6 @@ def __init__( scale_range=scale_range, spatial_size=spatial_size, cache_grid=cache_grid, - as_tensor_output=as_tensor_output, device=device, ) self.mode = ensure_tuple_rep(mode, len(self.keys)) @@ -789,9 +781,7 @@ def randomize(self, data: Optional[Any] = None) -> None: super().randomize(None) self.rand_affine.randomize() - def __call__( - self, data: Mapping[Hashable, Union[np.ndarray, torch.Tensor]] - ) -> Dict[Hashable, Union[np.ndarray, torch.Tensor]]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) self.randomize() @@ -829,14 +819,14 @@ def __call__( return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = deepcopy(dict(data)) for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # if transform was not performed and spatial size is None, nothing to do. if not transform[InverseKeys.DO_TRANSFORM] and self.rand_affine.spatial_size is None: - out: Union[np.ndarray, torch.Tensor] = d[key] + out: NdarrayOrTensor = d[key] else: orig_size = transform[InverseKeys.ORIG_SIZE] # Create inverse transform @@ -849,10 +839,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar grid, _ = affine_grid(orig_size) # type: ignore # Apply inverse transform - out = self.rand_affine.resampler(d[key], grid, mode, padding_mode) - - # Convert to numpy - d[key] = out if isinstance(out, np.ndarray) else out.cpu().numpy() + d[key] = self.rand_affine.resampler(d[key], grid, mode, padding_mode) # Remove the applied transform self.pop_transform(d, key) diff --git a/tests/test_affine.py b/tests/test_affine.py index dd82d72e23..bd89f1a436 100644 --- a/tests/test_affine.py +++ b/tests/test_affine.py @@ -16,78 +16,139 @@ from parameterized import parameterized from monai.transforms import Affine +from tests.utils import TEST_NDARRAYS, assert_allclose -TEST_CASES = [ - [ - dict(padding_mode="zeros", as_tensor_output=False, device=None), - {"img": np.arange(9).reshape((1, 3, 3)), "spatial_size": (-1, 0)}, - np.arange(9).reshape(1, 3, 3), - ], - [ - dict(padding_mode="zeros", as_tensor_output=False, device=None, image_only=True), - {"img": np.arange(9).reshape((1, 3, 3)), "spatial_size": (-1, 0)}, - np.arange(9).reshape(1, 3, 3), - ], - [ - dict(padding_mode="zeros", as_tensor_output=False, device=None), - {"img": np.arange(4).reshape((1, 2, 2))}, - np.arange(4).reshape(1, 2, 2), - ], - [ - dict(padding_mode="zeros", as_tensor_output=False, device=None), - {"img": np.arange(4).reshape((1, 2, 2)), "spatial_size": (4, 4)}, - np.array([[[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 2.0, 3.0, 0.0], [0.0, 0.0, 0.0, 0.0]]]), - ], - [ - dict(rotate_params=[np.pi / 2], padding_mode="zeros", as_tensor_output=False, device=None), - {"img": np.arange(4).reshape((1, 2, 2)), "spatial_size": (4, 4)}, - np.array([[[0.0, 0.0, 0.0, 0.0], [0.0, 2.0, 0.0, 0.0], [0.0, 3.0, 1.0, 0.0], [0.0, 0.0, 0.0, 0.0]]]), - ], - [ - dict(padding_mode="zeros", as_tensor_output=False, device=None), - {"img": np.arange(27).reshape((1, 3, 3, 3)), "spatial_size": (-1, 0, 0)}, - np.arange(27).reshape(1, 3, 3, 3), - ], - [ - dict(padding_mode="zeros", as_tensor_output=False, device=None), - {"img": np.arange(8).reshape((1, 2, 2, 2)), "spatial_size": (4, 4, 4)}, - np.array( +TESTS = [] +for p in TEST_NDARRAYS: + for device in [None, "cpu", "cuda"] if torch.cuda.is_available() else [None, "cpu"]: + TESTS.append( [ - [ - [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 2.0, 3.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - [[0.0, 0.0, 0.0, 0.0], [0.0, 4.0, 5.0, 0.0], [0.0, 6.0, 7.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - ] + dict(padding_mode="zeros", device=device), + {"img": p(np.arange(9).reshape((1, 3, 3))), "spatial_size": (-1, 0)}, + p(np.arange(9).reshape(1, 3, 3)), ] - ), - ], - [ - dict(rotate_params=[np.pi / 2], padding_mode="zeros", as_tensor_output=False, device=None), - {"img": np.arange(8).reshape((1, 2, 2, 2)), "spatial_size": (4, 4, 4)}, - np.array( + ) + TESTS.append( [ - [ - [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - [[0.0, 0.0, 0.0, 0.0], [0.0, 2.0, 0.0, 0.0], [0.0, 3.0, 1.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - [[0.0, 0.0, 0.0, 0.0], [0.0, 6.0, 4.0, 0.0], [0.0, 7.0, 5.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - ] + dict(padding_mode="zeros", device=device, image_only=True), + {"img": p(np.arange(9).reshape((1, 3, 3))), "spatial_size": (-1, 0)}, + p(np.arange(9).reshape(1, 3, 3)), ] - ), - ], -] + ) + TESTS.append( + [ + dict(padding_mode="zeros", device=device), + {"img": p(np.arange(4).reshape((1, 2, 2)))}, + p(np.arange(4).reshape(1, 2, 2)), + ] + ) + TESTS.append( + [ + dict(padding_mode="zeros", device=device), + {"img": p(np.arange(4).reshape((1, 2, 2))), "spatial_size": (4, 4)}, + p(np.array([[[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 2.0, 3.0, 0.0], [0.0, 0.0, 0.0, 0.0]]])), + ] + ) + TESTS.append( + [ + dict(rotate_params=[np.pi / 2], padding_mode="zeros", device=device), + {"img": p(np.arange(4).reshape((1, 2, 2))), "spatial_size": (4, 4)}, + p(np.array([[[0.0, 0.0, 0.0, 0.0], [0.0, 2.0, 0.0, 0.0], [0.0, 3.0, 1.0, 0.0], [0.0, 0.0, 0.0, 0.0]]])), + ] + ) + TESTS.append( + [ + dict(padding_mode="zeros", device=device), + {"img": p(np.arange(27).reshape((1, 3, 3, 3))), "spatial_size": (-1, 0, 0)}, + p(np.arange(27).reshape(1, 3, 3, 3)), + ] + ) + TESTS.append( + [ + dict(padding_mode="zeros", device=device), + {"img": p(np.arange(8).reshape((1, 2, 2, 2))), "spatial_size": (4, 4, 4)}, + p( + np.array( + [ + [ + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0], + [0.0, 2.0, 3.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 4.0, 5.0, 0.0], + [0.0, 6.0, 7.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + ] + ] + ) + ), + ] + ) + TESTS.append( + [ + dict(rotate_params=[np.pi / 2], padding_mode="zeros", device=device), + {"img": p(np.arange(8).reshape((1, 2, 2, 2))), "spatial_size": (4, 4, 4)}, + p( + np.array( + [ + [ + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 2.0, 0.0, 0.0], + [0.0, 3.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 6.0, 4.0, 0.0], + [0.0, 7.0, 5.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + ] + ] + ) + ), + ] + ) class TestAffine(unittest.TestCase): - @parameterized.expand(TEST_CASES) + @parameterized.expand(TESTS) def test_affine(self, input_param, input_data, expected_val): g = Affine(**input_param) result = g(**input_data) if isinstance(result, tuple): result = result[0] - self.assertEqual(isinstance(result, torch.Tensor), isinstance(expected_val, torch.Tensor)) - np.testing.assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4) + assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4) if __name__ == "__main__": diff --git a/tests/test_affine_grid.py b/tests/test_affine_grid.py index 18a5f89075..ac7c2741b0 100644 --- a/tests/test_affine_grid.py +++ b/tests/test_affine_grid.py @@ -115,7 +115,7 @@ def test_affine_grid(self, input_param, input_data, expected_val): result, _ = g(**input_data) if "device" in input_data: self.assertEqual(result.device, input_data[device]) - assert_allclose(result, expected_val, rtol=1e-3, atol=1e-3) + assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4) if __name__ == "__main__": diff --git a/tests/test_affined.py b/tests/test_affined.py index 850f12905d..142cedc8d9 100644 --- a/tests/test_affined.py +++ b/tests/test_affined.py @@ -16,85 +16,142 @@ from parameterized import parameterized from monai.transforms import Affined +from tests.utils import TEST_NDARRAYS, assert_allclose -TEST_CASES = [ - [ - dict(keys="img", padding_mode="zeros", as_tensor_output=False, spatial_size=(-1, 0), device=None), - {"img": np.arange(9).reshape((1, 3, 3))}, - np.arange(9).reshape(1, 3, 3), - ], - [ - dict(keys="img", padding_mode="zeros", as_tensor_output=False, device=None), - {"img": np.arange(4).reshape((1, 2, 2))}, - np.arange(4).reshape(1, 2, 2), - ], - [ - dict(keys="img", padding_mode="zeros", spatial_size=(4, 4), as_tensor_output=False, device=None), - {"img": np.arange(4).reshape((1, 2, 2))}, - np.array([[[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 2.0, 3.0, 0.0], [0.0, 0.0, 0.0, 0.0]]]), - ], - [ - dict( - keys="img", - rotate_params=[np.pi / 2], - padding_mode="zeros", - spatial_size=(4, 4), - as_tensor_output=False, - device=None, - ), - {"img": np.arange(4).reshape((1, 2, 2))}, - np.array([[[0.0, 0.0, 0.0, 0.0], [0.0, 2.0, 0.0, 0.0], [0.0, 3.0, 1.0, 0.0], [0.0, 0.0, 0.0, 0.0]]]), - ], - [ - dict(keys="img", padding_mode="zeros", spatial_size=(-1, 0, 0), as_tensor_output=False, device=None), - {"img": np.arange(27).reshape((1, 3, 3, 3))}, - np.arange(27).reshape(1, 3, 3, 3), - ], - [ - dict(keys="img", padding_mode="zeros", spatial_size=(4, 4, 4), as_tensor_output=False, device=None), - {"img": np.arange(8).reshape((1, 2, 2, 2))}, - np.array( +TESTS = [] +for p in TEST_NDARRAYS: + for device in [None, "cpu", "cuda"] if torch.cuda.is_available() else [None, "cpu"]: + TESTS.append( [ - [ - [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 2.0, 3.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - [[0.0, 0.0, 0.0, 0.0], [0.0, 4.0, 5.0, 0.0], [0.0, 6.0, 7.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - ] + dict(keys="img", padding_mode="zeros", spatial_size=(-1, 0), device=device), + {"img": p(np.arange(9).reshape((1, 3, 3)))}, + p(np.arange(9).reshape(1, 3, 3)), ] - ), - ], - [ - dict( - keys="img", - rotate_params=[np.pi / 2], - padding_mode="zeros", - spatial_size=(4, 4, 4), - as_tensor_output=False, - device=None, - ), - {"img": np.arange(8).reshape((1, 2, 2, 2))}, - np.array( + ) + TESTS.append( [ - [ - [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - [[0.0, 0.0, 0.0, 0.0], [0.0, 2.0, 0.0, 0.0], [0.0, 3.0, 1.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - [[0.0, 0.0, 0.0, 0.0], [0.0, 6.0, 4.0, 0.0], [0.0, 7.0, 5.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - ] + dict(keys="img", padding_mode="zeros", device=device), + {"img": p(np.arange(4).reshape((1, 2, 2)))}, + p(np.arange(4).reshape(1, 2, 2)), ] - ), - ], -] + ) + TESTS.append( + [ + dict(keys="img", padding_mode="zeros", spatial_size=(4, 4), device=device), + {"img": p(np.arange(4).reshape((1, 2, 2)))}, + p(np.array([[[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 2.0, 3.0, 0.0], [0.0, 0.0, 0.0, 0.0]]])), + ] + ) + TESTS.append( + [ + dict( + keys="img", + rotate_params=[np.pi / 2], + padding_mode="zeros", + spatial_size=(4, 4), + device=device, + ), + {"img": p(np.arange(4).reshape((1, 2, 2)))}, + p(np.array([[[0.0, 0.0, 0.0, 0.0], [0.0, 2.0, 0.0, 0.0], [0.0, 3.0, 1.0, 0.0], [0.0, 0.0, 0.0, 0.0]]])), + ] + ) + TESTS.append( + [ + dict(keys="img", padding_mode="zeros", spatial_size=(-1, 0, 0), device=device), + {"img": p(np.arange(27).reshape((1, 3, 3, 3)))}, + p(np.arange(27).reshape(1, 3, 3, 3)), + ] + ) + TESTS.append( + [ + dict(keys="img", padding_mode="zeros", spatial_size=(4, 4, 4), device=device), + {"img": p(np.arange(8).reshape((1, 2, 2, 2)))}, + p( + np.array( + [ + [ + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0], + [0.0, 2.0, 3.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 4.0, 5.0, 0.0], + [0.0, 6.0, 7.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + ] + ] + ) + ), + ] + ) + TESTS.append( + [ + dict( + keys="img", + rotate_params=[np.pi / 2], + padding_mode="zeros", + spatial_size=(4, 4, 4), + device=device, + ), + {"img": p(np.arange(8).reshape((1, 2, 2, 2)))}, + p( + np.array( + [ + [ + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 2.0, 0.0, 0.0], + [0.0, 3.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 6.0, 4.0, 0.0], + [0.0, 7.0, 5.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + ] + ] + ) + ), + ] + ) class TestAffined(unittest.TestCase): - @parameterized.expand(TEST_CASES) + @parameterized.expand(TESTS) def test_affine(self, input_param, input_data, expected_val): g = Affined(**input_param) result = g(input_data)["img"] - self.assertEqual(isinstance(result, torch.Tensor), isinstance(expected_val, torch.Tensor)) - np.testing.assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4) + assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4) if __name__ == "__main__": diff --git a/tests/test_inverse_collation.py b/tests/test_inverse_collation.py index c5dd9f1210..fb6a3a1e80 100644 --- a/tests/test_inverse_collation.py +++ b/tests/test_inverse_collation.py @@ -61,7 +61,6 @@ prob=0.5, rotate_range=np.pi, device=torch.device("cuda" if torch.cuda.is_available() else "cpu"), - as_tensor_output=False, ), ] ] @@ -85,7 +84,6 @@ prob=0.5, rotate_range=np.pi, device=torch.device("cuda" if torch.cuda.is_available() else "cpu"), - as_tensor_output=False, ), ] ] diff --git a/tests/test_rand_affine.py b/tests/test_rand_affine.py index 1e1a23bc09..c88aa538ed 100644 --- a/tests/test_rand_affine.py +++ b/tests/test_rand_affine.py @@ -16,114 +16,132 @@ from parameterized import parameterized from monai.transforms import RandAffine +from monai.utils.type_conversion import convert_data_type +from tests.utils import TEST_NDARRAYS, assert_allclose -TEST_CASES = [ - [ - dict(as_tensor_output=False, device=None), - {"img": torch.arange(27).reshape((3, 3, 3))}, - np.arange(27).reshape((3, 3, 3)), - ], - [ - dict(as_tensor_output=False, device=None, spatial_size=-1), - {"img": torch.arange(27).reshape((3, 3, 3))}, - np.arange(27).reshape((3, 3, 3)), - ], - [ - dict(as_tensor_output=False, device=None), - {"img": torch.arange(27).reshape((3, 3, 3)), "spatial_size": (2, 2)}, - np.array([[[2.0, 3.0], [5.0, 6.0]], [[11.0, 12.0], [14.0, 15.0]], [[20.0, 21.0], [23.0, 24.0]]]), - ], - [ - dict(as_tensor_output=True, device=None), - {"img": torch.ones((1, 3, 3, 3)), "spatial_size": (2, 2, 2)}, - torch.ones((1, 2, 2, 2)), - ], - [ - dict(as_tensor_output=True, device=None, spatial_size=(2, 2, 2), cache_grid=True), - {"img": torch.ones((1, 3, 3, 3))}, - torch.ones((1, 2, 2, 2)), - ], - [ - dict( - prob=0.9, - rotate_range=(np.pi / 2,), - shear_range=[1, 2], - translate_range=[2, 1], - as_tensor_output=True, - padding_mode="zeros", - spatial_size=(2, 2, 2), - device=None, - ), - {"img": torch.ones((1, 3, 3, 3)), "mode": "bilinear"}, - torch.tensor([[[[0.3658, 1.0000], [1.0000, 1.0000]], [[1.0000, 1.0000], [1.0000, 0.9333]]]]), - ], - [ - dict( - prob=0.9, - rotate_range=(np.pi / 2,), - shear_range=[1, 2], - translate_range=[2, 1], - as_tensor_output=True, - padding_mode="zeros", - spatial_size=(2, 2, 2), - cache_grid=True, - device=None, - ), - {"img": torch.ones((1, 3, 3, 3)), "mode": "bilinear"}, - torch.tensor([[[[0.3658, 1.0000], [1.0000, 1.0000]], [[1.0000, 1.0000], [1.0000, 0.9333]]]]), - ], - [ - dict( - prob=0.9, - rotate_range=(np.pi / 2,), - shear_range=[1, 2], - translate_range=[2, 1], - scale_range=[0.1, 0.2], - as_tensor_output=True, - device=None, - ), - {"img": torch.arange(64).reshape((1, 8, 8)), "spatial_size": (3, 3)}, - torch.tensor([[[18.7362, 15.5820, 12.4278], [27.3988, 24.2446, 21.0904], [36.0614, 32.9072, 29.7530]]]), - ], - [ - dict( - prob=0.9, - rotate_range=(np.pi / 2,), - shear_range=[1, 2], - translate_range=[2, 1], - scale_range=[0.1, 0.2], - spatial_size=(3, 3), - cache_grid=True, - as_tensor_output=True, - device=None, - ), - {"img": torch.arange(64).reshape((1, 8, 8))}, - torch.tensor([[[18.7362, 15.5820, 12.4278], [27.3988, 24.2446, 21.0904], [36.0614, 32.9072, 29.7530]]]), - ], -] +TESTS = [] +for p in TEST_NDARRAYS: + for device in [None, "cpu", "cuda"] if torch.cuda.is_available() else [None, "cpu"]: + TESTS.append( + [ + dict(device=device), + {"img": p(torch.arange(27).reshape((3, 3, 3)))}, + p(np.arange(27).reshape((3, 3, 3))), + ] + ) + TESTS.append( + [ + dict(device=device, spatial_size=-1), + {"img": p(torch.arange(27).reshape((3, 3, 3)))}, + p(np.arange(27).reshape((3, 3, 3))), + ] + ) + TESTS.append( + [ + dict(device=device), + {"img": p(torch.arange(27).reshape((3, 3, 3))), "spatial_size": (2, 2)}, + p(np.array([[[2.0, 3.0], [5.0, 6.0]], [[11.0, 12.0], [14.0, 15.0]], [[20.0, 21.0], [23.0, 24.0]]])), + ] + ) + TESTS.append( + [ + dict(device=device), + {"img": p(torch.ones((1, 3, 3, 3))), "spatial_size": (2, 2, 2)}, + p(torch.ones((1, 2, 2, 2))), + ] + ) + TESTS.append( + [ + dict(device=device, spatial_size=(2, 2, 2), cache_grid=True), + {"img": p(torch.ones((1, 3, 3, 3)))}, + p(torch.ones((1, 2, 2, 2))), + ] + ) + TESTS.append( + [ + dict( + prob=0.9, + rotate_range=(np.pi / 2,), + shear_range=[1, 2], + translate_range=[2, 1], + padding_mode="zeros", + spatial_size=(2, 2, 2), + device=device, + ), + {"img": p(torch.ones((1, 3, 3, 3))), "mode": "bilinear"}, + p(torch.tensor([[[[0.3658, 1.0000], [1.0000, 1.0000]], [[1.0000, 1.0000], [1.0000, 0.9333]]]])), + ] + ) + TESTS.append( + [ + dict( + prob=0.9, + rotate_range=(np.pi / 2,), + shear_range=[1, 2], + translate_range=[2, 1], + padding_mode="zeros", + spatial_size=(2, 2, 2), + cache_grid=True, + device=device, + ), + {"img": p(torch.ones((1, 3, 3, 3))), "mode": "bilinear"}, + p(torch.tensor([[[[0.3658, 1.0000], [1.0000, 1.0000]], [[1.0000, 1.0000], [1.0000, 0.9333]]]])), + ] + ) + TESTS.append( + [ + dict( + prob=0.9, + rotate_range=(np.pi / 2,), + shear_range=[1, 2], + translate_range=[2, 1], + scale_range=[0.1, 0.2], + device=device, + ), + {"img": p(torch.arange(64).reshape((1, 8, 8))), "spatial_size": (3, 3)}, + p( + torch.tensor( + [[[18.7362, 15.5820, 12.4278], [27.3988, 24.2446, 21.0904], [36.0614, 32.9072, 29.7530]]] + ) + ), + ] + ) + TESTS.append( + [ + dict( + prob=0.9, + rotate_range=(np.pi / 2,), + shear_range=[1, 2], + translate_range=[2, 1], + scale_range=[0.1, 0.2], + spatial_size=(3, 3), + cache_grid=True, + device=device, + ), + {"img": p(torch.arange(64).reshape((1, 8, 8)))}, + p( + torch.tensor( + [[[18.7362, 15.5820, 12.4278], [27.3988, 24.2446, 21.0904], [36.0614, 32.9072, 29.7530]]] + ) + ), + ] + ) -ARR_NUMPY = np.arange(9 * 10).reshape(1, 9, 10) -ARR_TORCH = torch.Tensor(ARR_NUMPY) TEST_CASES_SKIPPED_CONSISTENCY = [] -for im in (ARR_NUMPY, ARR_TORCH): - for as_tensor_output in (True, False): - for in_dtype_is_int in (True, False): - TEST_CASES_SKIPPED_CONSISTENCY.append((im, as_tensor_output, in_dtype_is_int)) +for p in TEST_NDARRAYS: + for in_dtype in (np.int32, np.float32): + TEST_CASES_SKIPPED_CONSISTENCY.append((p(np.arange(9 * 10).reshape(1, 9, 10)), in_dtype)) class TestRandAffine(unittest.TestCase): - @parameterized.expand(TEST_CASES) + @parameterized.expand(TESTS) def test_rand_affine(self, input_param, input_data, expected_val): g = RandAffine(**input_param) g.set_random_state(123) result = g(**input_data) if input_param.get("cache_grid", False): self.assertTrue(g._cached_grid is not None) - self.assertEqual(isinstance(result, torch.Tensor), isinstance(expected_val, torch.Tensor)) - if isinstance(result, torch.Tensor): - np.testing.assert_allclose(result.cpu().numpy(), expected_val.cpu().numpy(), rtol=1e-4, atol=1e-4) - else: - np.testing.assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4) + assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4) def test_ill_cache(self): with self.assertWarns(UserWarning): @@ -132,15 +150,11 @@ def test_ill_cache(self): RandAffine(cache_grid=True, spatial_size=(1, 1, -1)) @parameterized.expand(TEST_CASES_SKIPPED_CONSISTENCY) - def test_skipped_transform_consistency(self, im, as_tensor_output, in_dtype_is_int): - t1 = RandAffine(prob=0, as_tensor_output=as_tensor_output) - t2 = RandAffine(prob=1, spatial_size=(10, 11), as_tensor_output=as_tensor_output) + def test_skipped_transform_consistency(self, im, in_dtype): + t1 = RandAffine(prob=0) + t2 = RandAffine(prob=1, spatial_size=(10, 11)) - # change dtype to int32 or float32 - if in_dtype_is_int: - im = im.astype("int32") if isinstance(im, np.ndarray) else im.int() - else: - im = im.astype("float32") if isinstance(im, np.ndarray) else im.float() + im, *_ = convert_data_type(im, dtype=in_dtype) out1 = t1(im) out2 = t2(im) diff --git a/tests/test_rand_affine_grid.py b/tests/test_rand_affine_grid.py index 0670b4d042..64c32c8d75 100644 --- a/tests/test_rand_affine_grid.py +++ b/tests/test_rand_affine_grid.py @@ -201,7 +201,7 @@ def test_rand_affine_grid(self, input_param, input_data, expected_val): result = g(**input_data) if "device" in input_data: self.assertEqual(result.device, input_data[device]) - assert_allclose(result, expected_val, rtol=1e-2, atol=1e-2) + assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4) if __name__ == "__main__": diff --git a/tests/test_rand_affined.py b/tests/test_rand_affined.py index d2f8a60665..bec9602d62 100644 --- a/tests/test_rand_affined.py +++ b/tests/test_rand_affined.py @@ -17,179 +17,188 @@ from monai.transforms import RandAffined from monai.utils import GridSampleMode +from tests.utils import TEST_NDARRAYS, assert_allclose -TEST_CASES = [ - [ - dict(as_tensor_output=False, device=None, spatial_size=None, keys=("img", "seg")), - {"img": torch.arange(27).reshape((3, 3, 3)), "seg": torch.arange(27).reshape((3, 3, 3))}, - np.arange(27).reshape((3, 3, 3)), - ], - [ - dict(as_tensor_output=False, device=None, spatial_size=(2, 2), keys=("img", "seg")), - {"img": torch.ones((3, 3, 3)), "seg": torch.ones((3, 3, 3))}, - np.ones((3, 2, 2)), - ], - [ - dict(as_tensor_output=False, device=None, spatial_size=(2, 2), cache_grid=True, keys=("img", "seg")), - {"img": torch.ones((3, 3, 3)), "seg": torch.ones((3, 3, 3))}, - np.ones((3, 2, 2)), - ], - [ - dict(as_tensor_output=True, device=None, spatial_size=(2, 2, 2), keys=("img", "seg")), - {"img": torch.ones((1, 3, 3, 3)), "seg": torch.ones((1, 3, 3, 3))}, - torch.ones((1, 2, 2, 2)), - ], - [ - dict( - prob=0.9, - rotate_range=(np.pi / 2,), - shear_range=[1, 2], - translate_range=[2, 1], - as_tensor_output=True, - spatial_size=(2, 2, 2), - padding_mode="zeros", - device=None, - keys=("img", "seg"), - mode="bilinear", - ), - {"img": torch.ones((1, 3, 3, 3)), "seg": torch.ones((1, 3, 3, 3))}, - torch.tensor([[[[0.3658, 1.0000], [1.0000, 1.0000]], [[1.0000, 1.0000], [1.0000, 0.9333]]]]), - ], - [ - dict( - prob=0.9, - rotate_range=(np.pi / 2,), - shear_range=[1, 2], - translate_range=[2, 1], - as_tensor_output=False, - spatial_size=(2, 2, 2), - padding_mode="zeros", - device=None, - cache_grid=True, - keys=("img", "seg"), - mode="bilinear", - ), - {"img": torch.ones((1, 3, 3, 3)), "seg": torch.ones((1, 3, 3, 3))}, - np.array([[[[0.3658, 1.0000], [1.0000, 1.0000]], [[1.0000, 1.0000], [1.0000, 0.9333]]]]), - ], - [ - dict( - prob=0.9, - rotate_range=(np.pi / 2,), - shear_range=[1, 2], - translate_range=[2, 1], - scale_range=[0.1, 0.2], - as_tensor_output=True, - spatial_size=(3, 3), - keys=("img", "seg"), - device=None, - ), - {"img": torch.arange(64).reshape((1, 8, 8)), "seg": torch.arange(64).reshape((1, 8, 8))}, - torch.tensor([[[18.7362, 15.5820, 12.4278], [27.3988, 24.2446, 21.0904], [36.0614, 32.9072, 29.7530]]]), - ], - [ - dict( - prob=0.9, - mode=("bilinear", "nearest"), - rotate_range=(np.pi / 2,), - shear_range=[1, 2], - translate_range=[2, 1], - scale_range=[0.1, 0.2], - as_tensor_output=False, - spatial_size=(3, 3), - keys=("img", "seg"), - device=torch.device("cpu:0"), - ), - {"img": torch.arange(64).reshape((1, 8, 8)), "seg": torch.arange(64).reshape((1, 8, 8))}, - { - "img": np.array( - [ - [ - [18.736153, 15.581954, 12.4277525], - [27.398798, 24.244598, 21.090399], - [36.061443, 32.90724, 29.753046], - ] - ] - ), - "seg": np.array([[[19.0, 20.0, 12.0], [27.0, 28.0, 20.0], [35.0, 36.0, 29.0]]]), - }, - ], - [ - dict( - prob=0.9, - rotate_range=(np.pi / 2,), - shear_range=[1, 2], - translate_range=[2, 1], - as_tensor_output=True, - spatial_size=(2, 2, 2), - padding_mode="zeros", - device=None, - keys=("img", "seg"), - mode=GridSampleMode.BILINEAR, - ), - {"img": torch.ones((1, 3, 3, 3)), "seg": torch.ones((1, 3, 3, 3))}, - torch.tensor([[[[0.3658, 1.0000], [1.0000, 1.0000]], [[1.0000, 1.0000], [1.0000, 0.9333]]]]), - ], - [ - dict( - prob=0.9, - mode=(GridSampleMode.BILINEAR, GridSampleMode.NEAREST), - rotate_range=(np.pi / 2,), - shear_range=[1, 2], - translate_range=[2, 1], - scale_range=[0.1, 0.2], - as_tensor_output=False, - spatial_size=(3, 3), - keys=("img", "seg"), - device=torch.device("cpu:0"), - ), - {"img": torch.arange(64).reshape((1, 8, 8)), "seg": torch.arange(64).reshape((1, 8, 8))}, - { - "img": np.array( - [ - [ - [18.736153, 15.581954, 12.4277525], - [27.398798, 24.244598, 21.090399], - [36.061443, 32.90724, 29.753046], - ] - ] - ), - "seg": np.array([[[19.0, 20.0, 12.0], [27.0, 28.0, 20.0], [35.0, 36.0, 29.0]]]), - }, - ], - [ - dict( - prob=0.9, - mode=(GridSampleMode.BILINEAR, GridSampleMode.NEAREST), - rotate_range=(np.pi / 2,), - shear_range=[1, 2], - translate_range=[2, 1], - scale_range=[0.1, 0.2], - as_tensor_output=False, - spatial_size=(3, 3), - cache_grid=True, - keys=("img", "seg"), - device=torch.device("cpu:0"), - ), - {"img": torch.arange(64).reshape((1, 8, 8)), "seg": torch.arange(64).reshape((1, 8, 8))}, - { - "img": np.array( - [ - [ - [18.736153, 15.581954, 12.4277525], - [27.398798, 24.244598, 21.090399], - [36.061443, 32.90724, 29.753046], - ] - ] - ), - "seg": np.array([[[19.0, 20.0, 12.0], [27.0, 28.0, 20.0], [35.0, 36.0, 29.0]]]), - }, - ], -] +TESTS = [] +for p in TEST_NDARRAYS: + for device in [None, "cpu", "cuda"] if torch.cuda.is_available() else [None, "cpu"]: + TESTS.append( + [ + dict(device=device, spatial_size=None, keys=("img", "seg")), + {"img": p(torch.arange(27).reshape((3, 3, 3))), "seg": p(torch.arange(27).reshape((3, 3, 3)))}, + p(np.arange(27).reshape((3, 3, 3))), + ] + ) + TESTS.append( + [ + dict(device=device, spatial_size=(2, 2), keys=("img", "seg")), + {"img": p(torch.ones((3, 3, 3))), "seg": p(torch.ones((3, 3, 3)))}, + p(np.ones((3, 2, 2))), + ] + ) + TESTS.append( + [ + dict(device=device, spatial_size=(2, 2), cache_grid=True, keys=("img", "seg")), + {"img": p(torch.ones((3, 3, 3))), "seg": p(torch.ones((3, 3, 3)))}, + p(np.ones((3, 2, 2))), + ] + ) + TESTS.append( + [ + dict(device=device, spatial_size=(2, 2, 2), keys=("img", "seg")), + {"img": p(torch.ones((1, 3, 3, 3))), "seg": p(torch.ones((1, 3, 3, 3)))}, + p(torch.ones((1, 2, 2, 2))), + ] + ) + TESTS.append( + [ + dict( + prob=0.9, + rotate_range=(np.pi / 2,), + shear_range=[1, 2], + translate_range=[2, 1], + spatial_size=(2, 2, 2), + padding_mode="zeros", + device=device, + keys=("img", "seg"), + mode="bilinear", + ), + {"img": p(torch.ones((1, 3, 3, 3))), "seg": p(torch.ones((1, 3, 3, 3)))}, + p(torch.tensor([[[[0.3658, 1.0000], [1.0000, 1.0000]], [[1.0000, 1.0000], [1.0000, 0.9333]]]])), + ] + ) + TESTS.append( + [ + dict( + prob=0.9, + rotate_range=(np.pi / 2,), + shear_range=[1, 2], + translate_range=[2, 1], + scale_range=[0.1, 0.2], + spatial_size=(3, 3), + keys=("img", "seg"), + device=device, + ), + {"img": p(torch.arange(64).reshape((1, 8, 8))), "seg": p(torch.arange(64).reshape((1, 8, 8)))}, + p( + torch.tensor( + [[[18.7362, 15.5820, 12.4278], [27.3988, 24.2446, 21.0904], [36.0614, 32.9072, 29.7530]]] + ) + ), + ] + ) + TESTS.append( + [ + dict( + prob=0.9, + mode=("bilinear", "nearest"), + rotate_range=(np.pi / 2,), + shear_range=[1, 2], + translate_range=[2, 1], + scale_range=[0.1, 0.2], + spatial_size=(3, 3), + keys=("img", "seg"), + device=device, + ), + {"img": p(torch.arange(64).reshape((1, 8, 8))), "seg": p(torch.arange(64).reshape((1, 8, 8)))}, + { + "img": p( + np.array( + [ + [ + [18.736153, 15.581954, 12.4277525], + [27.398798, 24.244598, 21.090399], + [36.061443, 32.90724, 29.753046], + ] + ] + ) + ), + "seg": p(np.array([[[19.0, 20.0, 12.0], [27.0, 28.0, 20.0], [35.0, 36.0, 29.0]]])), + }, + ] + ) + TESTS.append( + [ + dict( + prob=0.9, + rotate_range=(np.pi / 2,), + shear_range=[1, 2], + translate_range=[2, 1], + spatial_size=(2, 2, 2), + padding_mode="zeros", + device=device, + keys=("img", "seg"), + mode=GridSampleMode.BILINEAR, + ), + {"img": p(torch.ones((1, 3, 3, 3))), "seg": p(torch.ones((1, 3, 3, 3)))}, + p(torch.tensor([[[[0.3658, 1.0000], [1.0000, 1.0000]], [[1.0000, 1.0000], [1.0000, 0.9333]]]])), + ] + ) + TESTS.append( + [ + dict( + prob=0.9, + mode=(GridSampleMode.BILINEAR, GridSampleMode.NEAREST), + rotate_range=(np.pi / 2,), + shear_range=[1, 2], + translate_range=[2, 1], + scale_range=[0.1, 0.2], + spatial_size=(3, 3), + keys=("img", "seg"), + device=device, + ), + {"img": p(torch.arange(64).reshape((1, 8, 8))), "seg": p(torch.arange(64).reshape((1, 8, 8)))}, + { + "img": p( + np.array( + [ + [ + [18.736153, 15.581954, 12.4277525], + [27.398798, 24.244598, 21.090399], + [36.061443, 32.90724, 29.753046], + ] + ] + ) + ), + "seg": p(np.array([[[19.0, 20.0, 12.0], [27.0, 28.0, 20.0], [35.0, 36.0, 29.0]]])), + }, + ] + ) + TESTS.append( + [ + dict( + prob=0.9, + mode=(GridSampleMode.BILINEAR, GridSampleMode.NEAREST), + rotate_range=(np.pi / 2,), + shear_range=[1, 2], + translate_range=[2, 1], + scale_range=[0.1, 0.2], + spatial_size=(3, 3), + cache_grid=True, + keys=("img", "seg"), + device=device, + ), + {"img": p(torch.arange(64).reshape((1, 8, 8))), "seg": p(torch.arange(64).reshape((1, 8, 8)))}, + { + "img": p( + np.array( + [ + [ + [18.736153, 15.581954, 12.4277525], + [27.398798, 24.244598, 21.090399], + [36.061443, 32.90724, 29.753046], + ] + ] + ) + ), + "seg": p(np.array([[[19.0, 20.0, 12.0], [27.0, 28.0, 20.0], [35.0, 36.0, 29.0]]])), + }, + ] + ) class TestRandAffined(unittest.TestCase): - @parameterized.expand(TEST_CASES) + @parameterized.expand(TESTS) def test_rand_affined(self, input_param, input_data, expected_val): g = RandAffined(**input_param).set_random_state(123) res = g(input_data) @@ -200,23 +209,16 @@ def test_rand_affined(self, input_param, input_data, expected_val): if "_transforms" in key: continue expected = expected_val[key] if isinstance(expected_val, dict) else expected_val - self.assertEqual(isinstance(result, torch.Tensor), isinstance(expected, torch.Tensor)) - if isinstance(result, torch.Tensor): - np.testing.assert_allclose(result.cpu().numpy(), expected.cpu().numpy(), rtol=1e-4, atol=1e-4) - else: - np.testing.assert_allclose(result, expected, rtol=1e-4, atol=1e-4) + assert_allclose(result, expected, rtol=1e-4, atol=1e-4) def test_ill_cache(self): with self.assertWarns(UserWarning): # spatial size is None - RandAffined( - as_tensor_output=False, device=None, spatial_size=None, prob=1.0, cache_grid=True, keys=("img", "seg") - ) + RandAffined(device=device, spatial_size=None, prob=1.0, cache_grid=True, keys=("img", "seg")) with self.assertWarns(UserWarning): # spatial size is dynamic RandAffined( - as_tensor_output=False, - device=None, + device=device, spatial_size=(2, -1), prob=1.0, cache_grid=True, diff --git a/tests/utils.py b/tests/utils.py index 1375cd2d72..e8bf727e80 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -40,6 +40,7 @@ from monai.utils.module import version_leq nib, _ = optional_import("nibabel") +plt, has_matplotlib = optional_import("matplotlib.pyplot") quick_test_var = "QUICKTEST" @@ -57,17 +58,37 @@ def clone(data: NdarrayTensor) -> NdarrayTensor: return copy.deepcopy(data) -def assert_allclose(a: NdarrayOrTensor, b: NdarrayOrTensor, *args, **kwargs): +def assert_allclose(a: NdarrayOrTensor, b: NdarrayOrTensor, visualize=True, *args, **kwargs): """ Assert that all values of two data objects are close. Args: a (NdarrayOrTensor): Pytorch Tensor or numpy array for comparison b (NdarrayOrTensor): Pytorch Tensor or numpy array to compare against + visualize: if the assert fails, display the two images (take central slices to reduce to 2D if necessary) + args: extra arguments to pass on to `np.testing.assert_allclose` + kwargs: extra arguments to pass on to `np.testing.assert_allclose` """ - a = a.cpu() if isinstance(a, torch.Tensor) else a - b = b.cpu() if isinstance(b, torch.Tensor) else b - np.testing.assert_allclose(a, b, *args, **kwargs) + a = a.cpu().numpy() if isinstance(a, torch.Tensor) else a + b = b.cpu().numpy() if isinstance(b, torch.Tensor) else b + try: + np.testing.assert_allclose(a, b, *args, **kwargs) + except AssertionError as e: + if visualize and has_matplotlib and a.shape == b.shape and a.ndim > 2 and b.ndim > 2: + while a.ndim > 2: + _slice = a.shape[-1] // 2 + a, b = a[..., _slice], b[..., _slice] + diff = a - b + ims_to_show = [a, b, diff] + titles = ["a", "b", "a - b"] + fig, axes = plt.subplots(1, 3) + for im, title, ax in zip(ims_to_show, titles, axes): + im_show = ax.imshow(im) + ax.set_title(title, fontsize=25) + ax.axis("off") + fig.colorbar(im_show, ax=ax) + plt.show() + raise e def test_pretrained_networks(network, input_param, device): From ee223baf1b4900676d89d79d3be9c7ec03097c98 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Mon, 13 Sep 2021 11:53:32 +0100 Subject: [PATCH 03/12] deprecate as_tensor_output Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/spatial/array.py | 22 +++++++++++++++++----- monai/transforms/spatial/dictionary.py | 5 +++++ 2 files changed, 22 insertions(+), 5 deletions(-) diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 162befe8d0..d2ac34e09f 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -46,6 +46,7 @@ issequenceiterable, optional_import, ) +from monai.utils.deprecated import deprecated_arg from monai.utils.enums import TransformBackends from monai.utils.module import look_up_option from monai.utils.type_conversion import convert_data_type, convert_to_dst_type @@ -1016,12 +1017,14 @@ class AffineGrid(Transform): backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + @deprecated_arg(name="as_tensor_output", since="0.6") def __init__( self, rotate_params: Optional[Union[Sequence[float], float]] = None, shear_params: Optional[Union[Sequence[float], float]] = None, translate_params: Optional[Union[Sequence[float], float]] = None, scale_params: Optional[Union[Sequence[float], float]] = None, + as_tensor_output: bool = True, device: Optional[torch.device] = None, affine: Optional[NdarrayOrTensor] = None, ) -> None: @@ -1082,12 +1085,14 @@ class RandAffineGrid(Randomizable, Transform): """ + @deprecated_arg(name="as_tensor_output", since="0.6") def __init__( self, rotate_range: RandRange = None, shear_range: RandRange = None, translate_range: RandRange = None, scale_range: RandRange = None, + as_tensor_output: bool = True, device: Optional[torch.device] = None, ) -> None: """ @@ -1173,8 +1178,9 @@ def __call__( scale_params=self.scale_params, device=self.device, ) - grid, self.affine = affine_grid(spatial_size, grid) - return grid + _grid: NdarrayOrTensor + _grid, self.affine = affine_grid(spatial_size, grid) + return _grid def get_transformation_matrix(self) -> Optional[NdarrayOrTensor]: """Get the most recently applied transformation matrix""" @@ -1235,10 +1241,12 @@ class Resample(Transform): backend = [TransformBackends.TORCH] + @deprecated_arg(name="as_tensor_output", since="0.6") def __init__( self, mode: Union[GridSampleMode, str] = GridSampleMode.BILINEAR, padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER, + as_tensor_output: bool = True, device: Optional[torch.device] = None, ) -> None: """ @@ -1331,6 +1339,7 @@ class Affine(Transform): backend = list(set(AffineGrid.backend) & set(Resample.backend)) + @deprecated_arg(name="as_tensor_output", since="0.6") def __init__( self, rotate_params: Optional[Union[Sequence[float], float]] = None, @@ -1340,6 +1349,7 @@ def __init__( spatial_size: Optional[Union[Sequence[int], int]] = None, mode: Union[GridSampleMode, str] = GridSampleMode.BILINEAR, padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.REFLECTION, + as_tensor_output: bool = True, device: Optional[torch.device] = None, image_only: bool = False, ) -> None: @@ -1397,7 +1407,7 @@ def __call__( spatial_size: Optional[Union[Sequence[int], int]] = None, mode: Optional[Union[GridSampleMode, str]] = None, padding_mode: Optional[Union[GridSamplePadMode, str]] = None, - ): + ) -> Union[NdarrayOrTensor, Tuple[NdarrayOrTensor, NdarrayOrTensor]]: """ Args: img: shape must be (num_channels, H, W[, D]), @@ -1429,6 +1439,7 @@ class RandAffine(RandomizableTransform): backend = Affine.backend + @deprecated_arg(name="as_tensor_output", since="0.6") def __init__( self, prob: float = 0.1, @@ -1440,6 +1451,7 @@ def __init__( mode: Union[GridSampleMode, str] = GridSampleMode.BILINEAR, padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.REFLECTION, cache_grid: bool = False, + as_tensor_output: bool = True, device: Optional[torch.device] = None, ) -> None: """ @@ -1555,11 +1567,11 @@ def randomize(self, data: Optional[Any] = None) -> None: def __call__( self, - img: Union[np.ndarray, torch.Tensor], + img: NdarrayOrTensor, spatial_size: Optional[Union[Sequence[int], int]] = None, mode: Optional[Union[GridSampleMode, str]] = None, padding_mode: Optional[Union[GridSamplePadMode, str]] = None, - ) -> Union[np.ndarray, torch.Tensor]: + ) -> NdarrayOrTensor: """ Args: img: shape must be (num_channels, H, W[, D]), diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index a8041697c7..20e0320b79 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -55,6 +55,7 @@ ensure_tuple_rep, fall_back_tuple, ) +from monai.utils.deprecated import deprecated_arg from monai.utils.enums import InverseKeys from monai.utils.module import optional_import from monai.utils.type_conversion import convert_data_type @@ -576,6 +577,7 @@ class Affined(MapTransform, InvertibleTransform): backend = Affine.backend + @deprecated_arg(name="as_tensor_output", since="0.6") def __init__( self, keys: KeysCollection, @@ -586,6 +588,7 @@ def __init__( spatial_size: Optional[Union[Sequence[int], int]] = None, mode: GridSampleModeSequence = GridSampleMode.BILINEAR, padding_mode: GridSamplePadModeSequence = GridSamplePadMode.REFLECTION, + as_tensor_output: bool = True, device: Optional[torch.device] = None, allow_missing_keys: bool = False, ) -> None: @@ -689,6 +692,7 @@ class RandAffined(RandomizableTransform, MapTransform, InvertibleTransform): backend = Affine.backend + @deprecated_arg(name="as_tensor_output", since="0.6") def __init__( self, keys: KeysCollection, @@ -701,6 +705,7 @@ def __init__( mode: GridSampleModeSequence = GridSampleMode.BILINEAR, padding_mode: GridSamplePadModeSequence = GridSamplePadMode.REFLECTION, cache_grid: bool = False, + as_tensor_output: bool = True, device: Optional[torch.device] = None, allow_missing_keys: bool = False, ) -> None: From 2117a1d7176b9fab079d5241bf7a099a72a4780a Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Mon, 13 Sep 2021 14:14:21 +0100 Subject: [PATCH 04/12] fixes Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/spatial/array.py | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index d2ac34e09f..0c177cadf3 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -1407,7 +1407,7 @@ def __call__( spatial_size: Optional[Union[Sequence[int], int]] = None, mode: Optional[Union[GridSampleMode, str]] = None, padding_mode: Optional[Union[GridSamplePadMode, str]] = None, - ) -> Union[NdarrayOrTensor, Tuple[NdarrayOrTensor, NdarrayOrTensor]]: + ) -> Union[torch.Tensor, Tuple[torch.Tensor, NdarrayOrTensor]]: """ Args: img: shape must be (num_channels, H, W[, D]), @@ -1571,7 +1571,7 @@ def __call__( spatial_size: Optional[Union[Sequence[int], int]] = None, mode: Optional[Union[GridSampleMode, str]] = None, padding_mode: Optional[Union[GridSamplePadMode, str]] = None, - ) -> NdarrayOrTensor: + ) -> torch.Tensor: """ Args: img: shape must be (num_channels, H, W[, D]), @@ -1597,9 +1597,10 @@ def __call__( grid = self.get_identity_grid(sp_size) if self._do_transform: grid = self.rand_affine_grid(grid=grid) - return self.resampler( + out: torch.Tensor = self.resampler( img=img, grid=grid, mode=mode or self.mode, padding_mode=padding_mode or self.padding_mode ) + return out class Rand2DElastic(RandomizableTransform): @@ -1703,11 +1704,11 @@ def randomize(self, spatial_size: Sequence[int]) -> None: def __call__( self, - img: Union[np.ndarray, torch.Tensor], + img: NdarrayOrTensor, spatial_size: Optional[Union[Tuple[int, int], int]] = None, mode: Optional[Union[GridSampleMode, str]] = None, padding_mode: Optional[Union[GridSamplePadMode, str]] = None, - ) -> Union[np.ndarray, torch.Tensor]: + ) -> torch.Tensor: """ Args: img: shape must be (num_channels, H, W), @@ -1736,7 +1737,10 @@ def __call__( grid = CenterSpatialCrop(roi_size=sp_size)(grid[0]) else: grid = create_grid(spatial_size=sp_size) - return self.resampler(img, grid, mode=mode or self.mode, padding_mode=padding_mode or self.padding_mode) + out: torch.Tensor = self.resampler( + img, grid, mode=mode or self.mode, padding_mode=padding_mode or self.padding_mode + ) + return out class Rand3DElastic(RandomizableTransform): @@ -1878,7 +1882,10 @@ def __call__( offset = torch.as_tensor(self.rand_offset, device=self.device).unsqueeze(0) grid[:3] += gaussian(offset)[0] * self.magnitude grid = self.rand_affine_grid(grid=grid) - return self.resampler(img, grid, mode=mode or self.mode, padding_mode=padding_mode or self.padding_mode) + out: torch.Tensor = self.resampler( + img, grid, mode=mode or self.mode, padding_mode=padding_mode or self.padding_mode + ) + return out class AddCoordinateChannels(Transform): From 1615eb64fc93eb802c48a9e49c0a90399cddd51b Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Mon, 13 Sep 2021 14:23:20 +0100 Subject: [PATCH 05/12] remove plotting Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/utils.py | 23 ++--------------------- 1 file changed, 2 insertions(+), 21 deletions(-) diff --git a/tests/utils.py b/tests/utils.py index e8bf727e80..929d762903 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -40,7 +40,6 @@ from monai.utils.module import version_leq nib, _ = optional_import("nibabel") -plt, has_matplotlib = optional_import("matplotlib.pyplot") quick_test_var = "QUICKTEST" @@ -58,37 +57,19 @@ def clone(data: NdarrayTensor) -> NdarrayTensor: return copy.deepcopy(data) -def assert_allclose(a: NdarrayOrTensor, b: NdarrayOrTensor, visualize=True, *args, **kwargs): +def assert_allclose(a: NdarrayOrTensor, b: NdarrayOrTensor, *args, **kwargs): """ Assert that all values of two data objects are close. Args: a (NdarrayOrTensor): Pytorch Tensor or numpy array for comparison b (NdarrayOrTensor): Pytorch Tensor or numpy array to compare against - visualize: if the assert fails, display the two images (take central slices to reduce to 2D if necessary) args: extra arguments to pass on to `np.testing.assert_allclose` kwargs: extra arguments to pass on to `np.testing.assert_allclose` """ a = a.cpu().numpy() if isinstance(a, torch.Tensor) else a b = b.cpu().numpy() if isinstance(b, torch.Tensor) else b - try: - np.testing.assert_allclose(a, b, *args, **kwargs) - except AssertionError as e: - if visualize and has_matplotlib and a.shape == b.shape and a.ndim > 2 and b.ndim > 2: - while a.ndim > 2: - _slice = a.shape[-1] // 2 - a, b = a[..., _slice], b[..., _slice] - diff = a - b - ims_to_show = [a, b, diff] - titles = ["a", "b", "a - b"] - fig, axes = plt.subplots(1, 3) - for im, title, ax in zip(ims_to_show, titles, axes): - im_show = ax.imshow(im) - ax.set_title(title, fontsize=25) - ax.axis("off") - fig.colorbar(im_show, ax=ax) - plt.show() - raise e + np.testing.assert_allclose(a, b, *args, **kwargs) def test_pretrained_networks(network, input_param, device): From be99a7e6332e356a37e05ae3266b6ad3c08e2e52 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Mon, 13 Sep 2021 16:15:51 +0100 Subject: [PATCH 06/12] Rand2dElastic, Rand3dElastic, Rand2dElasticd, Rand3dElasticd Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/spatial/array.py | 14 +- monai/transforms/spatial/dictionary.py | 20 +- tests/test_rand_elastic_2d.py | 151 ++++++++------- tests/test_rand_elastic_3d.py | 130 +++++++------ tests/test_rand_elasticd_2d.py | 248 +++++++++++++------------ tests/test_rand_elasticd_3d.py | 212 +++++++++++---------- 6 files changed, 423 insertions(+), 352 deletions(-) diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 0c177cadf3..f84f38e671 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -1610,6 +1610,9 @@ class Rand2DElastic(RandomizableTransform): """ + backend = Resample.backend + + @deprecated_arg(name="as_tensor_output", since="0.6") def __init__( self, spacing: Union[Tuple[float, float], float], @@ -1664,8 +1667,6 @@ def __init__( padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} Padding mode for outside grid values. Defaults to ``"reflection"``. See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample - as_tensor_output: the computation is implemented using pytorch tensors, this option specifies - whether to convert it back to numpy arrays. device: device on which the tensor will be allocated. See also: @@ -1750,6 +1751,9 @@ class Rand3DElastic(RandomizableTransform): """ + backend = Resample.backend + + @deprecated_arg(name="as_tensor_output", since="0.6") def __init__( self, sigma_range: Tuple[float, float], @@ -1807,8 +1811,6 @@ def __init__( padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} Padding mode for outside grid values. Defaults to ``"reflection"``. See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample - as_tensor_output: the computation is implemented using pytorch tensors, this option specifies - whether to convert it back to numpy arrays. device: device on which the tensor will be allocated. See also: @@ -1853,11 +1855,11 @@ def randomize(self, grid_size: Sequence[int]) -> None: def __call__( self, - img: Union[np.ndarray, torch.Tensor], + img: NdarrayOrTensor, spatial_size: Optional[Union[Tuple[int, int, int], int]] = None, mode: Optional[Union[GridSampleMode, str]] = None, padding_mode: Optional[Union[GridSamplePadMode, str]] = None, - ) -> Union[np.ndarray, torch.Tensor]: + ) -> torch.Tensor: """ Args: img: shape must be (num_channels, H, W, D), diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 20e0320b79..7648b17925 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -857,6 +857,9 @@ class Rand2DElasticd(RandomizableTransform, MapTransform): Dictionary-based wrapper of :py:class:`monai.transforms.Rand2DElastic`. """ + backend = Rand2DElastic.backend + + @deprecated_arg(name="as_tensor_output", since="0.6") def __init__( self, keys: KeysCollection, @@ -917,8 +920,6 @@ def __init__( Padding mode for outside grid values. Defaults to ``"reflection"``. See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample It also can be a sequence of string, each element corresponds to a key in ``keys``. - as_tensor_output: the computation is implemented using pytorch tensors, this option specifies - whether to convert it back to numpy arrays. device: device on which the tensor will be allocated. allow_missing_keys: don't raise exception if key is missing. @@ -937,7 +938,6 @@ def __init__( translate_range=translate_range, scale_range=scale_range, spatial_size=spatial_size, - as_tensor_output=as_tensor_output, device=device, ) self.mode = ensure_tuple_rep(mode, len(self.keys)) @@ -954,9 +954,7 @@ def randomize(self, spatial_size: Sequence[int]) -> None: super().randomize(None) self.rand_2d_elastic.randomize(spatial_size) - def __call__( - self, data: Mapping[Hashable, Union[np.ndarray, torch.Tensor]] - ) -> Dict[Hashable, Union[np.ndarray, torch.Tensor]]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) sp_size = fall_back_tuple(self.rand_2d_elastic.spatial_size, data[self.keys[0]].shape[1:]) @@ -986,6 +984,9 @@ class Rand3DElasticd(RandomizableTransform, MapTransform): Dictionary-based wrapper of :py:class:`monai.transforms.Rand3DElastic`. """ + backend = Rand3DElastic.backend + + @deprecated_arg(name="as_tensor_output", since="0.6") def __init__( self, keys: KeysCollection, @@ -1048,8 +1049,6 @@ def __init__( Padding mode for outside grid values. Defaults to ``"reflection"``. See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample It also can be a sequence of string, each element corresponds to a key in ``keys``. - as_tensor_output: the computation is implemented using pytorch tensors, this option specifies - whether to convert it back to numpy arrays. device: device on which the tensor will be allocated. allow_missing_keys: don't raise exception if key is missing. @@ -1068,7 +1067,6 @@ def __init__( translate_range=translate_range, scale_range=scale_range, spatial_size=spatial_size, - as_tensor_output=as_tensor_output, device=device, ) self.mode = ensure_tuple_rep(mode, len(self.keys)) @@ -1085,9 +1083,7 @@ def randomize(self, grid_size: Sequence[int]) -> None: super().randomize(None) self.rand_3d_elastic.randomize(grid_size) - def __call__( - self, data: Mapping[Hashable, Union[np.ndarray, torch.Tensor]] - ) -> Dict[Hashable, Union[np.ndarray, torch.Tensor]]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) sp_size = fall_back_tuple(self.rand_3d_elastic.spatial_size, data[self.keys[0]].shape[1:]) diff --git a/tests/test_rand_elastic_2d.py b/tests/test_rand_elastic_2d.py index fbfb7d5761..c414eb1ffd 100644 --- a/tests/test_rand_elastic_2d.py +++ b/tests/test_rand_elastic_2d.py @@ -16,90 +16,101 @@ from parameterized import parameterized from monai.transforms import Rand2DElastic +from tests.utils import TEST_NDARRAYS, assert_allclose -TEST_CASES = [ - [ - {"spacing": (0.3, 0.3), "magnitude_range": (1.0, 2.0), "prob": 0.0, "as_tensor_output": False, "device": None}, - {"img": torch.ones((3, 3, 3)), "spatial_size": (2, 2)}, - np.ones((3, 2, 2)), - ], - [ - {"spacing": (0.3, 0.3), "magnitude_range": (1.0, 2.0), "prob": 0.0, "as_tensor_output": False, "device": None}, - {"img": torch.arange(27).reshape((3, 3, 3))}, - np.arange(27).reshape((3, 3, 3)), - ], - [ - { - "spacing": (0.3, 0.3), - "magnitude_range": (1.0, 2.0), - "prob": 0.9, - "as_tensor_output": False, - "device": None, - "padding_mode": "zeros", - }, - {"img": torch.ones((3, 3, 3)), "spatial_size": (2, 2), "mode": "bilinear"}, - np.array( +TESTS = [] +for p in TEST_NDARRAYS: + for device in [None, "cpu", "cuda"] if torch.cuda.is_available() else [None, "cpu"]: + TESTS.append( [ - [[0.45531988, 0.0], [0.0, 0.71558857]], - [[0.45531988, 0.0], [0.0, 0.71558857]], - [[0.45531988, 0.0], [0.0, 0.71558857]], + {"spacing": (0.3, 0.3), "magnitude_range": (1.0, 2.0), "prob": 0.0, "device": device}, + {"img": p(torch.ones((3, 3, 3))), "spatial_size": (2, 2)}, + p(np.ones((3, 2, 2))), ] - ), - ], - [ - { - "spacing": (1.0, 1.0), - "magnitude_range": (1.0, 1.0), - "scale_range": [1.2, 2.2], - "prob": 0.9, - "padding_mode": "border", - "as_tensor_output": True, - "device": None, - "spatial_size": (2, 2), - }, - {"img": torch.arange(27).reshape((3, 3, 3))}, - torch.tensor( + ) + TESTS.append( [ - [[3.0793, 2.6141], [4.0568, 5.9978]], - [[12.0793, 11.6141], [13.0568, 14.9978]], - [[21.0793, 20.6141], [22.0568, 23.9978]], + {"spacing": (0.3, 0.3), "magnitude_range": (1.0, 2.0), "prob": 0.0, "device": device}, + {"img": p(torch.arange(27).reshape((3, 3, 3)))}, + p(np.arange(27).reshape((3, 3, 3))), ] - ), - ], - [ - { - "spacing": (0.3, 0.3), - "magnitude_range": (0.1, 0.2), - "translate_range": [-0.01, 0.01], - "scale_range": [0.01, 0.02], - "prob": 0.9, - "as_tensor_output": False, - "device": "cuda" if torch.cuda.is_available() else "cpu", - "spatial_size": (2, 2), - }, - {"img": torch.arange(27).reshape((3, 3, 3))}, - np.array( + ) + TESTS.append( [ - [[1.3584113, 1.9251312], [5.626623, 6.642721]], - [[10.358411, 10.925131], [14.626623, 15.642721]], - [[19.358412, 19.92513], [23.626623, 24.642721]], + { + "spacing": (0.3, 0.3), + "magnitude_range": (1.0, 2.0), + "prob": 0.9, + "device": device, + "padding_mode": "zeros", + }, + {"img": p(torch.ones((3, 3, 3))), "spatial_size": (2, 2), "mode": "bilinear"}, + p( + np.array( + [ + [[0.45531988, 0.0], [0.0, 0.71558857]], + [[0.45531988, 0.0], [0.0, 0.71558857]], + [[0.45531988, 0.0], [0.0, 0.71558857]], + ] + ) + ), ] - ), - ], -] + ) + TESTS.append( + [ + { + "spacing": (1.0, 1.0), + "magnitude_range": (1.0, 1.0), + "scale_range": [1.2, 2.2], + "prob": 0.9, + "padding_mode": "border", + "device": device, + "spatial_size": (2, 2), + }, + {"img": p(torch.arange(27).reshape((3, 3, 3)))}, + p( + torch.tensor( + [ + [[3.0793, 2.6141], [4.0568, 5.9978]], + [[12.0793, 11.6141], [13.0568, 14.9978]], + [[21.0793, 20.6141], [22.0568, 23.9978]], + ] + ) + ), + ] + ) + TESTS.append( + [ + { + "spacing": (0.3, 0.3), + "magnitude_range": (0.1, 0.2), + "translate_range": [-0.01, 0.01], + "scale_range": [0.01, 0.02], + "prob": 0.9, + "device": device, + "spatial_size": (2, 2), + }, + {"img": p(torch.arange(27).reshape((3, 3, 3)))}, + p( + np.array( + [ + [[1.3584113, 1.9251312], [5.626623, 6.642721]], + [[10.358411, 10.925131], [14.626623, 15.642721]], + [[19.358412, 19.92513], [23.626623, 24.642721]], + ] + ) + ), + ] + ) class TestRand2DElastic(unittest.TestCase): - @parameterized.expand(TEST_CASES) + @parameterized.expand(TESTS) def test_rand_2d_elastic(self, input_param, input_data, expected_val): g = Rand2DElastic(**input_param) g.set_random_state(123) result = g(**input_data) - self.assertEqual(isinstance(result, torch.Tensor), isinstance(expected_val, torch.Tensor)) - if isinstance(result, torch.Tensor): - np.testing.assert_allclose(result.cpu().numpy(), expected_val.cpu().numpy(), rtol=1e-4, atol=1e-4) - else: - np.testing.assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4) + assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4) if __name__ == "__main__": diff --git a/tests/test_rand_elastic_3d.py b/tests/test_rand_elastic_3d.py index c63282d571..d44324746f 100644 --- a/tests/test_rand_elastic_3d.py +++ b/tests/test_rand_elastic_3d.py @@ -16,69 +16,89 @@ from parameterized import parameterized from monai.transforms import Rand3DElastic +from tests.utils import TEST_NDARRAYS, assert_allclose -TEST_CASES = [ - [ - { - "magnitude_range": (0.3, 2.3), - "sigma_range": (1.0, 20.0), - "prob": 0.0, - "as_tensor_output": False, - "device": None, - "spatial_size": -1, - }, - {"img": torch.arange(72).reshape((2, 3, 3, 4))}, - np.arange(72).reshape((2, 3, 3, 4)), - ], - [ - { - "magnitude_range": (0.3, 2.3), - "sigma_range": (1.0, 20.0), - "prob": 0.0, - "as_tensor_output": False, - "device": None, - }, - {"img": torch.ones((2, 3, 3, 3)), "spatial_size": (2, 2, 2)}, - np.ones((2, 2, 2, 2)), - ], - [ - { - "magnitude_range": (0.3, 0.3), - "sigma_range": (1.0, 2.0), - "prob": 0.9, - "as_tensor_output": False, - "device": None, - }, - {"img": torch.arange(27).reshape((1, 3, 3, 3)), "spatial_size": (2, 2, 2)}, - np.array([[[[6.4939356, 7.50289], [9.518351, 10.522849]], [[15.512375, 16.523542], [18.531467, 19.53646]]]]), - ], - [ - { - "magnitude_range": (0.3, 0.3), - "sigma_range": (1.0, 2.0), - "prob": 0.9, - "rotate_range": [1, 1, 1], - "as_tensor_output": False, - "device": "cuda" if torch.cuda.is_available() else "cpu", - "spatial_size": (2, 2, 2), - }, - {"img": torch.arange(27).reshape((1, 3, 3, 3)), "mode": "bilinear"}, - np.array([[[[5.0069294, 9.463932], [9.287769, 13.739735]], [[12.319424, 16.777205], [16.594296, 21.045748]]]]), - ], -] +TESTS = [] +for p in TEST_NDARRAYS: + for device in [None, "cpu", "cuda"] if torch.cuda.is_available() else [None, "cpu"]: + TESTS.append( + [ + { + "magnitude_range": (0.3, 2.3), + "sigma_range": (1.0, 20.0), + "prob": 0.0, + "device": device, + "spatial_size": -1, + }, + {"img": p(torch.arange(72).reshape((2, 3, 3, 4)))}, + p(np.arange(72).reshape((2, 3, 3, 4))), + ] + ) + TESTS.append( + [ + { + "magnitude_range": (0.3, 2.3), + "sigma_range": (1.0, 20.0), + "prob": 0.0, + "device": device, + }, + {"img": p(torch.ones((2, 3, 3, 3))), "spatial_size": (2, 2, 2)}, + p(np.ones((2, 2, 2, 2))), + ] + ) + TESTS.append( + [ + { + "magnitude_range": (0.3, 0.3), + "sigma_range": (1.0, 2.0), + "prob": 0.9, + "device": device, + }, + {"img": p(torch.arange(27).reshape((1, 3, 3, 3))), "spatial_size": (2, 2, 2)}, + p( + np.array( + [ + [ + [[6.4939356, 7.50289], [9.518351, 10.522849]], + [[15.512375, 16.523542], [18.531467, 19.53646]], + ] + ] + ) + ), + ] + ) + TESTS.append( + [ + { + "magnitude_range": (0.3, 0.3), + "sigma_range": (1.0, 2.0), + "prob": 0.9, + "rotate_range": [1, 1, 1], + "device": device, + "spatial_size": (2, 2, 2), + }, + {"img": p(torch.arange(27).reshape((1, 3, 3, 3))), "mode": "bilinear"}, + p( + np.array( + [ + [ + [[5.0069294, 9.463932], [9.287769, 13.739735]], + [[12.319424, 16.777205], [16.594296, 21.045748]], + ] + ] + ) + ), + ] + ) class TestRand3DElastic(unittest.TestCase): - @parameterized.expand(TEST_CASES) + @parameterized.expand(TESTS) def test_rand_3d_elastic(self, input_param, input_data, expected_val): g = Rand3DElastic(**input_param) g.set_random_state(123) result = g(**input_data) - self.assertEqual(isinstance(result, torch.Tensor), isinstance(expected_val, torch.Tensor)) - if isinstance(result, torch.Tensor): - np.testing.assert_allclose(result.cpu().numpy(), expected_val.cpu().numpy(), rtol=1e-4, atol=1e-4) - else: - np.testing.assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4) + assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4) if __name__ == "__main__": diff --git a/tests/test_rand_elasticd_2d.py b/tests/test_rand_elasticd_2d.py index f8eb026088..84f18120e1 100644 --- a/tests/test_rand_elasticd_2d.py +++ b/tests/test_rand_elasticd_2d.py @@ -16,127 +16,147 @@ from parameterized import parameterized from monai.transforms import Rand2DElasticd +from tests.utils import TEST_NDARRAYS, assert_allclose -TEST_CASES = [ - [ - { - "keys": ("img", "seg"), - "spacing": (0.3, 0.3), - "magnitude_range": (1.0, 2.0), - "prob": 0.0, - "as_tensor_output": False, - "device": None, - "spatial_size": (2, 2), - }, - {"img": torch.ones((3, 3, 3)), "seg": torch.ones((3, 3, 3))}, - np.ones((3, 2, 2)), - ], - [ - { - "keys": ("img", "seg"), - "spacing": (0.3, 0.3), - "magnitude_range": (0.3, 0.3), - "prob": 0.0, - "as_tensor_output": False, - "device": None, - "spatial_size": -1, - }, - {"img": torch.arange(4).reshape((1, 2, 2)), "seg": torch.arange(4).reshape((1, 2, 2))}, - np.arange(4).reshape((1, 2, 2)), - ], - [ - { - "keys": ("img", "seg"), - "spacing": (0.3, 0.3), - "magnitude_range": (1.0, 2.0), - "prob": 0.9, - "as_tensor_output": False, - "padding_mode": "zeros", - "device": None, - "spatial_size": (2, 2), - "mode": "bilinear", - }, - {"img": torch.ones((3, 3, 3)), "seg": torch.ones((3, 3, 3))}, - np.array( +TESTS = [] +for p in TEST_NDARRAYS: + for device in [None, "cpu", "cuda"] if torch.cuda.is_available() else [None, "cpu"]: + TESTS.append( [ - [[0.45531988, 0.0], [0.0, 0.71558857]], - [[0.45531988, 0.0], [0.0, 0.71558857]], - [[0.45531988, 0.0], [0.0, 0.71558857]], + { + "keys": ("img", "seg"), + "spacing": (0.3, 0.3), + "magnitude_range": (1.0, 2.0), + "prob": 0.0, + "device": device, + "spatial_size": (2, 2), + }, + {"img": p(torch.ones((3, 3, 3))), "seg": p(torch.ones((3, 3, 3)))}, + p(np.ones((3, 2, 2))), ] - ), - ], - [ - { - "keys": ("img", "seg"), - "spacing": (1.0, 1.0), - "magnitude_range": (1.0, 1.0), - "scale_range": [1.2, 2.2], - "prob": 0.9, - "padding_mode": "border", - "as_tensor_output": True, - "device": None, - "spatial_size": (2, 2), - }, - {"img": torch.arange(27).reshape((3, 3, 3)), "seg": torch.arange(27).reshape((3, 3, 3))}, - torch.tensor( + ) + TESTS.append( [ - [[3.0793, 2.6141], [4.0568, 5.9978]], - [[12.0793, 11.6141], [13.0568, 14.9978]], - [[21.0793, 20.6141], [22.0568, 23.9978]], + { + "keys": ("img", "seg"), + "spacing": (0.3, 0.3), + "magnitude_range": (0.3, 0.3), + "prob": 0.0, + "device": device, + "spatial_size": -1, + }, + {"img": p(torch.arange(4).reshape((1, 2, 2))), "seg": p(torch.arange(4).reshape((1, 2, 2)))}, + p(np.arange(4).reshape((1, 2, 2))), ] - ), - ], - [ - { - "keys": ("img", "seg"), - "spacing": (0.3, 0.3), - "magnitude_range": (0.1, 0.2), - "translate_range": [-0.01, 0.01], - "scale_range": [0.01, 0.02], - "prob": 0.9, - "as_tensor_output": False, - "device": None, - "spatial_size": (2, 2), - }, - {"img": torch.arange(27).reshape((3, 3, 3)), "seg": torch.arange(27).reshape((3, 3, 3))}, - np.array( + ) + TESTS.append( [ - [[1.3584113, 1.9251312], [5.626623, 6.642721]], - [[10.358411, 10.925131], [14.626623, 15.642721]], - [[19.358412, 19.92513], [23.626623, 24.642721]], + { + "keys": ("img", "seg"), + "spacing": (0.3, 0.3), + "magnitude_range": (1.0, 2.0), + "prob": 0.9, + "padding_mode": "zeros", + "device": device, + "spatial_size": (2, 2), + "mode": "bilinear", + }, + {"img": p(torch.ones((3, 3, 3))), "seg": p(torch.ones((3, 3, 3)))}, + p( + np.array( + [ + [[0.45531988, 0.0], [0.0, 0.71558857]], + [[0.45531988, 0.0], [0.0, 0.71558857]], + [[0.45531988, 0.0], [0.0, 0.71558857]], + ] + ) + ), ] - ), - ], - [ - { - "keys": ("img", "seg"), - "mode": ("bilinear", "nearest"), - "spacing": (0.3, 0.3), - "magnitude_range": (0.1, 0.2), - "translate_range": [-0.01, 0.01], - "scale_range": [0.01, 0.02], - "prob": 0.9, - "as_tensor_output": True, - "device": None, - "spatial_size": (2, 2), - }, - {"img": torch.arange(27).reshape((3, 3, 3)), "seg": torch.arange(27).reshape((3, 3, 3))}, - { - "img": torch.tensor( - [ - [[1.3584, 1.9251], [5.6266, 6.6427]], - [[10.3584, 10.9251], [14.6266, 15.6427]], - [[19.3584, 19.9251], [23.6266, 24.6427]], - ] - ), - "seg": torch.tensor([[[0.0, 2.0], [6.0, 8.0]], [[9.0, 11.0], [15.0, 17.0]], [[18.0, 20.0], [24.0, 26.0]]]), - }, - ], -] + ) + TESTS.append( + [ + { + "keys": ("img", "seg"), + "spacing": (1.0, 1.0), + "magnitude_range": (1.0, 1.0), + "scale_range": [1.2, 2.2], + "prob": 0.9, + "padding_mode": "border", + "device": device, + "spatial_size": (2, 2), + }, + {"img": p(torch.arange(27).reshape((3, 3, 3))), "seg": p(torch.arange(27).reshape((3, 3, 3)))}, + p( + torch.tensor( + [ + [[3.0793, 2.6141], [4.0568, 5.9978]], + [[12.0793, 11.6141], [13.0568, 14.9978]], + [[21.0793, 20.6141], [22.0568, 23.9978]], + ] + ) + ), + ] + ) + TESTS.append( + [ + { + "keys": ("img", "seg"), + "spacing": (0.3, 0.3), + "magnitude_range": (0.1, 0.2), + "translate_range": [-0.01, 0.01], + "scale_range": [0.01, 0.02], + "prob": 0.9, + "device": device, + "spatial_size": (2, 2), + }, + {"img": p(torch.arange(27).reshape((3, 3, 3))), "seg": p(torch.arange(27).reshape((3, 3, 3)))}, + p( + np.array( + [ + [[1.3584113, 1.9251312], [5.626623, 6.642721]], + [[10.358411, 10.925131], [14.626623, 15.642721]], + [[19.358412, 19.92513], [23.626623, 24.642721]], + ] + ) + ), + ] + ) + TESTS.append( + [ + { + "keys": ("img", "seg"), + "mode": ("bilinear", "nearest"), + "spacing": (0.3, 0.3), + "magnitude_range": (0.1, 0.2), + "translate_range": [-0.01, 0.01], + "scale_range": [0.01, 0.02], + "prob": 0.9, + "device": device, + "spatial_size": (2, 2), + }, + {"img": p(torch.arange(27).reshape((3, 3, 3))), "seg": p(torch.arange(27).reshape((3, 3, 3)))}, + { + "img": p( + torch.tensor( + [ + [[1.3584, 1.9251], [5.6266, 6.6427]], + [[10.3584, 10.9251], [14.6266, 15.6427]], + [[19.3584, 19.9251], [23.6266, 24.6427]], + ] + ) + ), + "seg": p( + torch.tensor( + [[[0.0, 2.0], [6.0, 8.0]], [[9.0, 11.0], [15.0, 17.0]], [[18.0, 20.0], [24.0, 26.0]]] + ) + ), + }, + ] + ) class TestRand2DElasticd(unittest.TestCase): - @parameterized.expand(TEST_CASES) + @parameterized.expand(TESTS) def test_rand_2d_elasticd(self, input_param, input_data, expected_val): g = Rand2DElasticd(**input_param) g.set_random_state(123) @@ -144,11 +164,7 @@ def test_rand_2d_elasticd(self, input_param, input_data, expected_val): for key in res: result = res[key] expected = expected_val[key] if isinstance(expected_val, dict) else expected_val - self.assertEqual(isinstance(result, torch.Tensor), isinstance(expected, torch.Tensor)) - if isinstance(result, torch.Tensor): - np.testing.assert_allclose(result.cpu().numpy(), expected.cpu().numpy(), rtol=1e-4, atol=1e-4) - else: - np.testing.assert_allclose(result, expected, rtol=1e-4, atol=1e-4) + assert_allclose(result, expected, rtol=1e-4, atol=1e-4) if __name__ == "__main__": diff --git a/tests/test_rand_elasticd_3d.py b/tests/test_rand_elasticd_3d.py index 47ab814882..5f8a5f47ed 100644 --- a/tests/test_rand_elasticd_3d.py +++ b/tests/test_rand_elasticd_3d.py @@ -16,98 +16,128 @@ from parameterized import parameterized from monai.transforms import Rand3DElasticd +from tests.utils import TEST_NDARRAYS, assert_allclose -TEST_CASES = [ - [ - { - "keys": ("img", "seg"), - "magnitude_range": (0.3, 2.3), - "sigma_range": (1.0, 20.0), - "prob": 0.0, - "as_tensor_output": False, - "device": None, - "spatial_size": (2, 2, 2), - }, - {"img": torch.ones((2, 3, 3, 3)), "seg": torch.ones((2, 3, 3, 3))}, - np.ones((2, 2, 2, 2)), - ], - [ - { - "keys": ("img", "seg"), - "magnitude_range": (0.3, 2.3), - "sigma_range": (1.0, 20.0), - "prob": 0.0, - "as_tensor_output": False, - "device": None, - "spatial_size": (2, -1, -1), - }, - {"img": torch.ones((2, 3, 3, 3)), "seg": torch.ones((2, 3, 3, 3))}, - np.ones((2, 2, 3, 3)), - ], - [ - { - "keys": ("img", "seg"), - "magnitude_range": (0.3, 2.3), - "sigma_range": (1.0, 20.0), - "prob": 0.0, - "as_tensor_output": False, - "device": None, - "spatial_size": -1, - }, - {"img": torch.arange(8).reshape((1, 2, 2, 2)), "seg": torch.arange(8).reshape((1, 2, 2, 2))}, - np.arange(8).reshape((1, 2, 2, 2)), - ], - [ - { - "keys": ("img", "seg"), - "magnitude_range": (0.3, 0.3), - "sigma_range": (1.0, 2.0), - "prob": 0.9, - "as_tensor_output": False, - "device": None, - "spatial_size": (2, 2, 2), - }, - {"img": torch.arange(27).reshape((1, 3, 3, 3)), "seg": torch.arange(27).reshape((1, 3, 3, 3))}, - np.array([[[[6.4939356, 7.50289], [9.518351, 10.522849]], [[15.512375, 16.523542], [18.531467, 19.53646]]]]), - ], - [ - { - "keys": ("img", "seg"), - "magnitude_range": (0.3, 0.3), - "sigma_range": (1.0, 2.0), - "prob": 0.9, - "rotate_range": [1, 1, 1], - "as_tensor_output": False, - "device": None, - "spatial_size": (2, 2, 2), - "mode": "bilinear", - }, - {"img": torch.arange(27).reshape((1, 3, 3, 3)), "seg": torch.arange(27).reshape((1, 3, 3, 3))}, - np.array([[[[5.0069294, 9.463932], [9.287769, 13.739735]], [[12.319424, 16.777205], [16.594296, 21.045748]]]]), - ], - [ - { - "keys": ("img", "seg"), - "mode": ("bilinear", "nearest"), - "magnitude_range": (0.3, 0.3), - "sigma_range": (1.0, 2.0), - "prob": 0.9, - "rotate_range": [1, 1, 1], - "as_tensor_output": True, - "device": torch.device("cpu:0"), - "spatial_size": (2, 2, 2), - }, - {"img": torch.arange(27).reshape((1, 3, 3, 3)), "seg": torch.arange(27).reshape((1, 3, 3, 3))}, - { - "img": torch.tensor([[[[5.0069, 9.4639], [9.2878, 13.7397]], [[12.3194, 16.7772], [16.5943, 21.0457]]]]), - "seg": torch.tensor([[[[4.0, 14.0], [7.0, 14.0]], [[9.0, 19.0], [12.0, 22.0]]]]), - }, - ], -] +TESTS = [] +for p in TEST_NDARRAYS: + for device in [None, "cpu", "cuda"] if torch.cuda.is_available() else [None, "cpu"]: + TESTS.append( + [ + { + "keys": ("img", "seg"), + "magnitude_range": (0.3, 2.3), + "sigma_range": (1.0, 20.0), + "prob": 0.0, + "device": device, + "spatial_size": (2, 2, 2), + }, + {"img": p(torch.ones((2, 3, 3, 3))), "seg": p(torch.ones((2, 3, 3, 3)))}, + p(np.ones((2, 2, 2, 2))), + ] + ) + TESTS.append( + [ + { + "keys": ("img", "seg"), + "magnitude_range": (0.3, 2.3), + "sigma_range": (1.0, 20.0), + "prob": 0.0, + "device": device, + "spatial_size": (2, -1, -1), + }, + {"img": p(torch.ones((2, 3, 3, 3))), "seg": p(torch.ones((2, 3, 3, 3)))}, + p(np.ones((2, 2, 3, 3))), + ] + ) + TESTS.append( + [ + { + "keys": ("img", "seg"), + "magnitude_range": (0.3, 2.3), + "sigma_range": (1.0, 20.0), + "prob": 0.0, + "device": device, + "spatial_size": -1, + }, + {"img": p(torch.arange(8).reshape((1, 2, 2, 2))), "seg": p(torch.arange(8).reshape((1, 2, 2, 2)))}, + p(np.arange(8).reshape((1, 2, 2, 2))), + ] + ) + TESTS.append( + [ + { + "keys": ("img", "seg"), + "magnitude_range": (0.3, 0.3), + "sigma_range": (1.0, 2.0), + "prob": 0.9, + "device": device, + "spatial_size": (2, 2, 2), + }, + {"img": p(torch.arange(27).reshape((1, 3, 3, 3))), "seg": p(torch.arange(27).reshape((1, 3, 3, 3)))}, + p( + np.array( + [ + [ + [[6.4939356, 7.50289], [9.518351, 10.522849]], + [[15.512375, 16.523542], [18.531467, 19.53646]], + ] + ] + ) + ), + ] + ) + TESTS.append( + [ + { + "keys": ("img", "seg"), + "magnitude_range": (0.3, 0.3), + "sigma_range": (1.0, 2.0), + "prob": 0.9, + "rotate_range": [1, 1, 1], + "device": device, + "spatial_size": (2, 2, 2), + "mode": "bilinear", + }, + {"img": p(torch.arange(27).reshape((1, 3, 3, 3))), "seg": p(torch.arange(27).reshape((1, 3, 3, 3)))}, + p( + np.array( + [ + [ + [[5.0069294, 9.463932], [9.287769, 13.739735]], + [[12.319424, 16.777205], [16.594296, 21.045748]], + ] + ] + ) + ), + ] + ) + TESTS.append( + [ + { + "keys": ("img", "seg"), + "mode": ("bilinear", "nearest"), + "magnitude_range": (0.3, 0.3), + "sigma_range": (1.0, 2.0), + "prob": 0.9, + "rotate_range": [1, 1, 1], + "device": device, + "spatial_size": (2, 2, 2), + }, + {"img": p(torch.arange(27).reshape((1, 3, 3, 3))), "seg": p(torch.arange(27).reshape((1, 3, 3, 3)))}, + { + "img": p( + torch.tensor( + [[[[5.0069, 9.4639], [9.2878, 13.7397]], [[12.3194, 16.7772], [16.5943, 21.0457]]]] + ) + ), + "seg": p(torch.tensor([[[[4.0, 14.0], [7.0, 14.0]], [[9.0, 19.0], [12.0, 22.0]]]])), + }, + ] + ) class TestRand3DElasticd(unittest.TestCase): - @parameterized.expand(TEST_CASES) + @parameterized.expand(TESTS) def test_rand_3d_elasticd(self, input_param, input_data, expected_val): g = Rand3DElasticd(**input_param) g.set_random_state(123) @@ -115,11 +145,7 @@ def test_rand_3d_elasticd(self, input_param, input_data, expected_val): for key in res: result = res[key] expected = expected_val[key] if isinstance(expected_val, dict) else expected_val - self.assertEqual(isinstance(result, torch.Tensor), isinstance(expected, torch.Tensor)) - if isinstance(result, torch.Tensor): - np.testing.assert_allclose(result.cpu().numpy(), expected.cpu().numpy(), rtol=1e-4, atol=1e-4) - else: - np.testing.assert_allclose(result, expected, rtol=1e-4, atol=1e-4) + assert_allclose(result, expected, rtol=1e-4, atol=1e-4) if __name__ == "__main__": From 135c3aed56687868d7c27f47c0a9be16206e4aa9 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Mon, 13 Sep 2021 17:05:03 +0100 Subject: [PATCH 07/12] CropForeground Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/croppad/array.py | 2 ++ monai/transforms/croppad/dictionary.py | 5 ++++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py index 7e3bc835dd..4d2a62b390 100644 --- a/monai/transforms/croppad/array.py +++ b/monai/transforms/croppad/array.py @@ -754,6 +754,8 @@ def __call__(self, img: np.ndarray, mode: Optional[Union[NumpyPadMode, str]] = N Apply the transform to `img`, assuming `img` is channel-first and slicing doesn't change the channel dim. """ + img, *_ = convert_data_type(img, np.ndarray) # type: ignore + box_start, box_end = self.compute_bounding_box(img) cropped = self.crop_pad(img, box_start, box_end, mode) diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index 5c846b8d04..233f1b6edf 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -51,6 +51,7 @@ from monai.utils import ImageMetaKey as Key from monai.utils import Method, NumpyPadMode, PytorchPadMode, ensure_tuple, ensure_tuple_rep, fall_back_tuple from monai.utils.enums import InverseKeys +from monai.utils.type_conversion import convert_data_type __all__ = [ "PadModeSequence", @@ -848,7 +849,9 @@ def __init__( def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) - box_start, box_end = self.cropper.compute_bounding_box(img=d[self.source_key]) + img: np.ndarray + img, *_ = convert_data_type(d[self.source_key], np.ndarray) # type: ignore + box_start, box_end = self.cropper.compute_bounding_box(img=img) d[self.start_coord_key] = box_start d[self.end_coord_key] = box_end for key, m in self.key_iterator(d, self.mode): From 449c038c538c860aa5f2ec021be5e5b255d0b35b Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Mon, 13 Sep 2021 18:29:35 +0100 Subject: [PATCH 08/12] fix Invertd Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/spatial/dictionary.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 7648b17925..64e7b9b6c1 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -794,8 +794,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N # change image size or do random transform do_resampling = self._do_transform or (sp_size != ensure_tuple(data[self.keys[0]].shape[1:])) - # to be consistent with the self._do_transform case (dtype and device) - affine = torch.as_tensor(np.eye(len(sp_size) + 1), device=self.rand_affine.rand_affine_grid.device) + affine = np.eye(len(sp_size) + 1) grid = None if do_resampling: # need to prepare grid grid = self.rand_affine.get_identity_grid(sp_size) @@ -816,7 +815,8 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N # do the transform if do_resampling: d[key] = self.rand_affine.resampler(d[key], grid, mode=mode, padding_mode=padding_mode) - # if not doing transform and and spatial size is unchanged, only need to do numpy/torch conversion + + # if not doing transform and spatial size is unchanged, only need to do convert to torch else: d[key], *_ = convert_data_type( d[key], torch.Tensor, dtype=torch.float32, device=self.rand_affine.resampler.device From 7481a2e14830c5ceaa8b64bccf408b2eaacfb80f Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Tue, 14 Sep 2021 14:57:40 +0100 Subject: [PATCH 09/12] fix Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/spatial/dictionary.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 7671b55724..5109ff986c 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -794,7 +794,9 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N # change image size or do random transform do_resampling = self._do_transform or (sp_size != ensure_tuple(data[self.keys[0]].shape[1:])) - affine = np.eye(len(sp_size) + 1) + affine = torch.as_tensor( + np.eye(len(sp_size) + 1), dtype=torch.float64, device=self.rand_affine.resampler.device + ) grid = None if do_resampling: # need to prepare grid grid = self.rand_affine.get_identity_grid(sp_size) From 4c595539ff559601b276fcbc3115c1fbb5e7a6ce Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Tue, 14 Sep 2021 15:26:34 +0100 Subject: [PATCH 10/12] transforms convert to necessary type Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/croppad/array.py | 18 ++++++++++++++++++ monai/transforms/intensity/array.py | 16 +++++++++++++++- monai/transforms/spatial/array.py | 9 +++++++-- monai/transforms/utility/array.py | 15 ++++++++++++++- 4 files changed, 54 insertions(+), 4 deletions(-) diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py index 4d2a62b390..d3cec35d93 100644 --- a/monai/transforms/croppad/array.py +++ b/monai/transforms/croppad/array.py @@ -421,6 +421,7 @@ def __call__(self, img: Union[np.ndarray, torch.Tensor]): Apply the transform to `img`, assuming `img` is channel-first and slicing doesn't apply to the channel dim. """ + img, *_ = convert_data_type(img, np.ndarray) sd = min(len(self.slices), len(img.shape[1:])) # spatial dims slices = [slice(None)] + self.slices[:sd] return img[tuple(slices)] @@ -449,6 +450,7 @@ def __call__(self, img: np.ndarray): Apply the transform to `img`, assuming `img` is channel-first and slicing doesn't apply to the channel dim. """ + img, *_ = convert_data_type(img, np.ndarray) # type: ignore roi_size = fall_back_tuple(self.roi_size, img.shape[1:]) center = [i // 2 for i in img.shape[1:]] cropper = SpatialCrop(roi_center=center, roi_size=roi_size) @@ -469,6 +471,7 @@ def __init__(self, roi_scale: Union[Sequence[float], float]): self.roi_scale = roi_scale def __call__(self, img: np.ndarray): + img, *_ = convert_data_type(img, np.ndarray) # type: ignore img_size = img.shape[1:] ndim = len(img_size) roi_size = [ceil(r * s) for r, s in zip(ensure_tuple_rep(self.roi_scale, ndim), img_size)] @@ -530,6 +533,7 @@ def __call__(self, img: np.ndarray): Apply the transform to `img`, assuming `img` is channel-first and slicing doesn't apply to the channel dim. """ + img, *_ = convert_data_type(img, np.ndarray) # type: ignore self.randomize(img.shape[1:]) if self._size is None: raise AssertionError @@ -576,6 +580,7 @@ def __call__(self, img: np.ndarray): Apply the transform to `img`, assuming `img` is channel-first and slicing doesn't apply to the channel dim. """ + img, *_ = convert_data_type(img, np.ndarray) # type: ignore img_size = img.shape[1:] ndim = len(img_size) self.roi_size = [ceil(r * s) for r, s in zip(ensure_tuple_rep(self.roi_scale, ndim), img_size)] @@ -645,6 +650,7 @@ def __call__(self, img: np.ndarray) -> List[np.ndarray]: Apply the transform to `img`, assuming `img` is channel-first and cropping doesn't change the channel dim. """ + img, *_ = convert_data_type(img, np.ndarray) # type: ignore return [self.cropper(img) for _ in range(self.num_samples)] @@ -801,12 +807,16 @@ def __call__(self, img: np.ndarray, weight_map: Optional[np.ndarray] = None) -> Returns: A list of image patches """ + img, *_ = convert_data_type(img, np.ndarray) # type: ignore if weight_map is None: weight_map = self.weight_map if weight_map is None: raise ValueError("weight map must be provided for weighted patch sampling.") if img.shape[1:] != weight_map.shape[1:]: raise ValueError(f"image and weight map spatial shape mismatch: {img.shape[1:]} vs {weight_map.shape[1:]}.") + + weight_map, *_ = convert_data_type(weight_map, np.ndarray) # type: ignore + self.randomize(weight_map) _spatial_size = fall_back_tuple(self.spatial_size, weight_map.shape[1:]) results = [] @@ -942,6 +952,9 @@ def __call__( if image is None: image = self.image + image, *_ = convert_data_type(image, np.ndarray) # type: ignore + label, *_ = convert_data_type(label, np.ndarray) # type: ignore + self.randomize(label, fg_indices, bg_indices, image) results: List[np.ndarray] = [] if self.centers is not None: @@ -1075,6 +1088,9 @@ def __call__( if image is None: image = self.image + image, *_ = convert_data_type(image, np.ndarray) # type: ignore + label, *_ = convert_data_type(label, np.ndarray) # type: ignore + self.randomize(label, indices, image) results: List[np.ndarray] = [] if self.centers is not None: @@ -1127,6 +1143,7 @@ def __call__(self, img: np.ndarray, mode: Optional[Union[NumpyPadMode, str]] = N If None, defaults to the ``mode`` in construction. See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html """ + img, *_ = convert_data_type(img, np.ndarray) # type: ignore return self.padder(self.cropper(img), mode=mode) @@ -1161,6 +1178,7 @@ def __call__(self, img: np.ndarray) -> np.ndarray: """ See also: :py:class:`monai.transforms.utils.generate_spatial_bounding_box`. """ + img, *_ = convert_data_type(img, np.ndarray) # type: ignore bbox = [] for channel in range(img.shape[0]): diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index 6c45f0d52b..a1423c8ee5 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -531,6 +531,7 @@ def __call__(self, img: np.ndarray): """ Apply the transform to `img`. """ + img, *_ = convert_data_type(img, np.ndarray) # type: ignore self.randomize(data=img) if not self._do_transform: return img @@ -731,6 +732,7 @@ def __call__(self, img: np.ndarray): """ Apply the transform to `img`. """ + img, *_ = convert_data_type(img, np.ndarray) # type: ignore epsilon = 1e-7 img_min = img.min() img_range = img.max() - img_min @@ -773,6 +775,7 @@ def __call__(self, img: np.ndarray): """ Apply the transform to `img`. """ + img, *_ = convert_data_type(img, np.ndarray) # type: ignore self.randomize() if self.gamma_value is None: raise ValueError("gamma_value is not set.") @@ -910,10 +913,13 @@ def __call__(self, img: np.ndarray, mask_data: Optional[np.ndarray] = None) -> n - ValueError: When ``mask_data`` and ``img`` channels differ and ``mask_data`` is not single channel. """ + img, *_ = convert_data_type(img, np.ndarray) # type: ignore mask_data = self.mask_data if mask_data is None else mask_data if mask_data is None: raise ValueError("must provide the mask_data when initializing the transform or at runtime.") + mask_data, *_ = convert_data_type(mask_data, np.ndarray) # type: ignore + mask_data = np.asarray(self.select_fn(mask_data)) if mask_data.shape[0] != 1 and mask_data.shape[0] != img.shape[0]: raise ValueError( @@ -936,7 +942,7 @@ class SavitzkyGolaySmooth(Transform): or ``'circular'``. Default: ``'zeros'``. See ``torch.nn.Conv1d()`` for more information. """ - backend = [TransformBackends.NUMPY] + backend = [TransformBackends.TORCH] def __init__(self, window_length: int, order: int, axis: int = 1, mode: str = "zeros"): @@ -1000,6 +1006,7 @@ def __call__(self, img: np.ndarray): np.ndarray containing envelope of data in img along the specified axis. """ + img, *_ = convert_data_type(img, np.ndarray) # type: ignore # add one to transform axis because a batch axis will be added at dimension 0 hilbert_transform = HilbertTransform(self.axis + 1, self.n) # convert to Tensor and add Batch axis expected by HilbertTransform @@ -1026,6 +1033,7 @@ def __init__(self, sigma: Union[Sequence[float], float] = 1.0, approx: str = "er self.approx = approx def __call__(self, img: np.ndarray): + img, *_ = convert_data_type(img, np.ndarray) # type: ignore gaussian_filter = GaussianFilter(img.ndim - 1, self.sigma, approx=self.approx) input_data = torch.as_tensor(np.ascontiguousarray(img), dtype=torch.float).unsqueeze(0) return gaussian_filter(input_data).squeeze(0).detach().numpy() @@ -1070,6 +1078,7 @@ def randomize(self, data: Optional[Any] = None) -> None: self.z = self.R.uniform(low=self.sigma_z[0], high=self.sigma_z[1]) def __call__(self, img: np.ndarray): + img, *_ = convert_data_type(img, np.ndarray) # type: ignore self.randomize() if not self._do_transform: return img @@ -1117,6 +1126,7 @@ def __init__( self.approx = approx def __call__(self, img: np.ndarray): + img, *_ = convert_data_type(img, np.ndarray) # type: ignore gaussian_filter1 = GaussianFilter(img.ndim - 1, self.sigma1, approx=self.approx) gaussian_filter2 = GaussianFilter(img.ndim - 1, self.sigma2, approx=self.approx) input_data = torch.as_tensor(np.ascontiguousarray(img), dtype=torch.float).unsqueeze(0) @@ -1183,6 +1193,7 @@ def randomize(self, data: Optional[Any] = None) -> None: self.a = self.R.uniform(low=self.alpha[0], high=self.alpha[1]) def __call__(self, img: np.ndarray): + img, *_ = convert_data_type(img, np.ndarray) # type: ignore self.randomize() if not self._do_transform: return img @@ -1227,6 +1238,7 @@ def randomize(self, data: Optional[Any] = None) -> None: ) def __call__(self, img: np.ndarray) -> np.ndarray: + img, *_ = convert_data_type(img, np.ndarray) # type: ignore self.randomize() if not self._do_transform: return img @@ -1713,6 +1725,7 @@ def _transform_holes(self, img: np.ndarray) -> np.ndarray: raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") def __call__(self, img: np.ndarray): + img, *_ = convert_data_type(img, np.ndarray) # type: ignore self.randomize(img.shape[1:]) if self._do_transform: img = self._transform_holes(img=img) @@ -1871,6 +1884,7 @@ def __init__( self.dtype = dtype def __call__(self, img: np.ndarray, mask: Optional[np.ndarray] = None) -> np.ndarray: + img, *_ = convert_data_type(img, np.ndarray) # type: ignore return equalize_hist( img=img, mask=mask if mask is not None else self.mask, diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index a463c893d7..8083ac071c 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -171,6 +171,7 @@ def __call__( data_array (resampled into `self.pixdim`), original affine, current affine. """ + data_array, *_ = convert_data_type(data_array, np.ndarray) # type: ignore _dtype = dtype or self.dtype or data_array.dtype sr = data_array.ndim - 1 if sr <= 0: @@ -275,6 +276,7 @@ def __call__( data_array (reoriented in `self.axcodes`), original axcodes, current axcodes. """ + data_array, *_ = convert_data_type(data_array, np.ndarray) # type: ignore sr = data_array.ndim - 1 if sr <= 0: raise ValueError("data_array must have at least one spatial dimension.") @@ -392,6 +394,7 @@ def __call__( ValueError: When ``self.spatial_size`` length is less than ``img`` spatial dimensions. """ + img, *_ = convert_data_type(img, np.ndarray) # type: ignore if self.size_mode == "all": input_ndim = img.ndim - 1 # spatial ndim output_ndim = len(ensure_tuple(self.spatial_size)) @@ -642,8 +645,7 @@ def __call__( zoomed = padder(zoomed) zoomed = zoomed[tuple(slice_vec)] - out, *_ = convert_to_dst_type(zoomed, dst=img) - return out + return zoomed class Rotate90(Transform): @@ -1098,6 +1100,8 @@ class RandAffineGrid(Randomizable, Transform): """ + backend = AffineGrid.backend + @deprecated_arg(name="as_tensor_output", since="0.6") def __init__( self, @@ -1930,6 +1934,7 @@ def __call__(self, img: Union[np.ndarray, torch.Tensor]): Args: img: data to be transformed, assuming `img` is channel first. """ + img, *_ = convert_data_type(img, np.ndarray) # type: ignore if max(self.spatial_channels) > img.ndim - 1: raise ValueError( f"input has {img.ndim-1} spatial dimensions, cannot add AddCoordinateChannels channel for " diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 7f94b50044..e53f0e1fe3 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -777,6 +777,9 @@ def __call__( output_shape: expected shape of output indices. if None, use `self.output_shape` instead. """ + label, *_ = convert_data_type(label, np.ndarray) # type: ignore + if image is not None: + image, *_ = convert_data_type(image, np.ndarray) # type: ignore if output_shape is None: output_shape = self.output_shape fg_indices, bg_indices = map_binary_to_indices(label, image, self.image_threshold) @@ -826,6 +829,10 @@ def __call__( output_shape: expected shape of output indices. if None, use `self.output_shape` instead. """ + label, *_ = convert_data_type(label, np.ndarray) # type: ignore + if image is not None: + image, *_ = convert_data_type(image, np.ndarray) # type: ignore + if output_shape is None: output_shape = self.output_shape indices = map_classes_to_indices(label, self.num_classes, image, self.image_threshold) @@ -846,6 +853,7 @@ class ConvertToMultiChannelBasedOnBratsClasses(Transform): """ def __call__(self, img: np.ndarray) -> np.ndarray: + img, *_ = convert_data_type(img, np.ndarray) # type: ignore # if img has channel dim, squeeze it if img.ndim == 4 and img.shape[0] == 1: img = np.squeeze(img, axis=0) @@ -912,6 +920,9 @@ def __call__( if label.shape[0] != 1: raise ValueError("Only supports single channel labels!") + img, *_ = convert_data_type(img, np.ndarray) # type: ignore + label, *_ = convert_data_type(label, np.ndarray) # type: ignore + # Generate extreme points self.randomize(label[0, :]) @@ -948,6 +959,7 @@ def __call__(self, img: torch.Tensor): img: PyTorch Tensor data for the TorchVision transform. """ + img, *_ = convert_data_type(img, torch.Tensor) # type: ignore return self.trans(img) @@ -978,7 +990,7 @@ def __init__(self, orig_labels: Sequence, target_labels: Sequence, dtype: DtypeL self.dtype = dtype def __call__(self, img: np.ndarray): - img = np.asarray(img) + img, *_ = convert_data_type(img, np.ndarray) # type: ignore img_flat = img.flatten() try: out_flat = np.copy(img_flat).astype(self.dtype) @@ -1034,6 +1046,7 @@ def __call__( mask must have the same shape as input `img`. """ + img, *_ = convert_data_type(img, np.ndarray) # type: ignore if meta_data is None: meta_data = {} From 86ef6b0a712f140b2cf55299ec68a09a24b9f65b Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Tue, 14 Sep 2021 15:57:46 +0100 Subject: [PATCH 11/12] fix Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/spatial/array.py | 2 +- monai/transforms/spatial/dictionary.py | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index a463c893d7..2053070028 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -1083,7 +1083,7 @@ def __call__( else: affine = self.affine - if self.device is not None: + if self.device not in (None, torch.device("cpu"), "cpu"): grid, *_ = convert_data_type(grid, torch.Tensor, device=self.device) grid, *_ = convert_data_type(grid, dtype=float) affine, *_ = convert_to_dst_type(affine, grid) diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 5109ff986c..801f4316de 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -790,13 +790,15 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N d = dict(data) self.randomize() + device = self.rand_affine.resampler.device + sp_size = fall_back_tuple(self.rand_affine.spatial_size, data[self.keys[0]].shape[1:]) # change image size or do random transform do_resampling = self._do_transform or (sp_size != ensure_tuple(data[self.keys[0]].shape[1:])) - affine = torch.as_tensor( - np.eye(len(sp_size) + 1), dtype=torch.float64, device=self.rand_affine.resampler.device - ) + affine: NdarrayOrTensor = np.eye(len(sp_size) + 1, dtype=np.float64) + if device not in (None, torch.device("cpu"), "cpu"): + affine, *_ = convert_data_type(affine, torch.Tensor, device=device) grid = None if do_resampling: # need to prepare grid grid = self.rand_affine.get_identity_grid(sp_size) @@ -820,9 +822,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N # if not doing transform and spatial size is unchanged, only need to do convert to torch else: - d[key], *_ = convert_data_type( - d[key], torch.Tensor, dtype=torch.float32, device=self.rand_affine.resampler.device - ) + d[key], *_ = convert_data_type(d[key], torch.Tensor, dtype=torch.float32, device=device) return d From 0c35d67a41ef6acf5e15025e263a233e3758b47b Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Tue, 14 Sep 2021 18:58:02 +0100 Subject: [PATCH 12/12] fixse randzoom Signed-off-by: Wenqi Li --- monai/transforms/spatial/array.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index ffd66e4bf2..c0bb686149 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -645,7 +645,8 @@ def __call__( zoomed = padder(zoomed) zoomed = zoomed[tuple(slice_vec)] - return zoomed + out, *_ = convert_to_dst_type(zoomed, dst=img) + return out class Rotate90(Transform):