diff --git a/monai/networks/blocks/warp.py b/monai/networks/blocks/warp.py index 56b289e394..60e23f6750 100644 --- a/monai/networks/blocks/warp.py +++ b/monai/networks/blocks/warp.py @@ -4,9 +4,7 @@ from torch import nn from torch.nn import functional as F -from monai.config import USE_COMPILED -from monai.networks.layers import grid_pull -from monai.utils import GridSampleMode, GridSamplePadMode +from monai.utils import GridSamplePadMode class Warp(nn.Module): @@ -17,30 +15,38 @@ class Warp(nn.Module): def __init__( self, spatial_dims: int, - mode: Optional[Union[GridSampleMode, str]] = GridSampleMode.BILINEAR, + mode: int = 1, padding_mode: Optional[Union[GridSamplePadMode, str]] = GridSamplePadMode.ZEROS, ): """ Args: spatial_dims: {2, 3}. number of spatial dimensions - mode: {``"bilinear"``, ``"nearest"``} - Interpolation mode to calculate output values. Defaults to ``"bilinear"``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + mode: interpolation mode to calculate output values, defaults to 1. + Possible values are:: + + - 0 or 'nearest' or InterpolationType.nearest + - 1 or 'linear' or InterpolationType.linear + - 2 or 'quadratic' or InterpolationType.quadratic + - 3 or 'cubic' or InterpolationType.cubic + - 4 or 'fourth' or InterpolationType.fourth + - etc. padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} Padding mode for outside grid values. Defaults to ``"border"``. See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample """ super(Warp, self).__init__() if spatial_dims not in [2, 3]: - raise ValueError(f"got unsupported spatial_dims = {spatial_dims}, only support 2-d and 3-d input") + raise ValueError(f"got unsupported spatial_dims={spatial_dims}, only support 2-d and 3-d input") self.spatial_dims = spatial_dims - self.mode: GridSampleMode = GridSampleMode(mode) + if mode < 0: + raise ValueError(f"do not support negative mode, got mode={mode}") + self.mode = mode self.padding_mode: GridSamplePadMode = GridSamplePadMode(padding_mode) @staticmethod def get_reference_grid(ddf: torch.Tensor) -> torch.Tensor: mesh_points = [torch.arange(0, dim) for dim in ddf.shape[2:]] - grid = torch.stack(torch.meshgrid(*mesh_points[::-1]), dim=0) # (spatial_dims, ...) + grid = torch.stack(torch.meshgrid(*mesh_points), dim=0) # (spatial_dims, ...) grid = torch.stack([grid] * ddf.shape[0], dim=0) # (batch, spatial_dims, ...) grid = grid.to(ddf) return grid @@ -77,27 +83,31 @@ def forward(self, image: torch.Tensor, ddf: torch.Tensor) -> torch.Tensor: grid = self.get_reference_grid(ddf) + ddf grid = grid.permute([0] + list(range(2, 2 + self.spatial_dims)) + [1]) # (batch, ..., self.spatial_dims) - if USE_COMPILED: - _padding_mode = self.padding_mode.value - if _padding_mode == "zeros": - bound = 7 - elif _padding_mode == "border": - bound = 0 - else: - bound = 1 - _interp_mode = self.mode.value - warped_image: torch.Tensor = grid_pull( - image, - grid, - bound=bound, - extrapolate=True, - interpolation=1 if _interp_mode == "bilinear" else _interp_mode, - ) + if self.mode > 1: + raise ValueError(f"{self.mode}-order interpolation not yet implemented.") + # if not USE_COMPILED: + # raise ValueError(f"cannot perform {self.mode}-order interpolation without C compile.") + # _padding_mode = self.padding_mode.value + # if _padding_mode == "zeros": + # bound = 7 + # elif _padding_mode == "border": + # bound = 0 + # else: + # bound = 1 + # warped_image: torch.Tensor = grid_pull( + # image, + # grid, + # bound=bound, + # extrapolate=True, + # interpolation=self.mode, + # ) else: grid = self.normalize_grid(grid) index_ordering: List[int] = list(range(self.spatial_dims - 1, -1, -1)) grid = grid[..., index_ordering] # z, y, x -> x, y, z + _interp_mode = "bilinear" if self.mode == 1 else "nearest" warped_image = F.grid_sample( - image, grid, mode=self.mode.value, padding_mode=self.padding_mode.value, align_corners=True + image, grid, mode=_interp_mode, padding_mode=self.padding_mode.value, align_corners=True ) + return warped_image diff --git a/tests/test_warp.py b/tests/test_warp.py index ba8bc9a994..69ae997e38 100644 --- a/tests/test_warp.py +++ b/tests/test_warp.py @@ -6,53 +6,22 @@ from monai.networks.blocks.warp import Warp -TEST_CASE = [ +LOW_POWER_TEST_CASES = [ [ - {"spatial_dims": 2, "mode": "bilinear", "padding_mode": "zeros"}, + {"spatial_dims": 2, "mode": 0, "padding_mode": "zeros"}, {"image": torch.arange(4).reshape((1, 1, 2, 2)).to(dtype=torch.float), "ddf": torch.zeros(1, 2, 2, 2)}, torch.arange(4).reshape((1, 1, 2, 2)), ], [ - {"spatial_dims": 2, "mode": "nearest", "padding_mode": "zeros"}, - {"image": torch.arange(4).reshape((1, 1, 2, 2)).to(dtype=torch.float), "ddf": torch.zeros(1, 2, 4, 4)}, - torch.tensor([[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 2.0, 3.0, 0.0], [0.0, 0.0, 0.0, 0.0]]) - .unsqueeze(0) - .unsqueeze(0), - ], - [ - {"spatial_dims": 2, "mode": "bilinear", "padding_mode": "border"}, - {"image": torch.arange(4).reshape((1, 1, 2, 2)).to(dtype=torch.float), "ddf": torch.zeros(1, 2, 4, 4)}, - torch.tensor([[0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 1.0, 1.0], [2.0, 2.0, 3, 3.0], [2.0, 2.0, 3.0, 3.0]]) - .unsqueeze(0) - .unsqueeze(0), - ], - [ - {"spatial_dims": 2, "mode": "nearest", "padding_mode": "reflection"}, - {"image": torch.arange(4).reshape((1, 1, 2, 2)).to(dtype=torch.float), "ddf": torch.zeros(1, 2, 4, 4)}, - torch.tensor([[3.0, 2.0, 3.0, 2.0], [1.0, 0.0, 1.0, 0.0], [3.0, 2.0, 3.0, 2.0], [1.0, 0.0, 1.0, 0.0]]) - .unsqueeze(0) - .unsqueeze(0), - ], - [ - {"spatial_dims": 3, "mode": "bilinear", "padding_mode": "zeros"}, - {"image": torch.arange(8).reshape((1, 1, 2, 2, 2)).to(dtype=torch.float), "ddf": torch.zeros(1, 3, 2, 2, 2)}, - torch.arange(8).reshape((1, 1, 2, 2, 2)).to(dtype=torch.float), - ], -] - -TEST_CASES = [ - [ - {"spatial_dims": 2, "mode": "bilinear", "padding_mode": "zeros"}, - {"image": torch.arange(4).reshape((1, 1, 2, 2)).to(dtype=torch.float), "ddf": torch.zeros(1, 2, 2, 2)}, - torch.arange(4).reshape((1, 1, 2, 2)), - ], - [ - {"spatial_dims": 2, "mode": "bilinear", "padding_mode": "zeros"}, + {"spatial_dims": 2, "mode": 1, "padding_mode": "zeros"}, {"image": torch.arange(4).reshape((1, 1, 2, 2)).to(dtype=torch.float), "ddf": torch.ones(1, 2, 2, 2)}, torch.tensor([[[[3, 0], [0, 0]]]]), ], +] + +HIGH_POWER_TEST_CASES = [ [ - {"spatial_dims": 3, "mode": "nearest", "padding_mode": "border"}, + {"spatial_dims": 3, "mode": 2, "padding_mode": "border"}, { "image": torch.arange(8).reshape((1, 1, 2, 2, 2)).to(dtype=torch.float), "ddf": torch.ones(1, 3, 2, 2, 2) * -1, @@ -60,12 +29,16 @@ torch.tensor([[[[[0, 0], [0, 0]], [[0, 0], [0, 0]]]]]), ], [ - {"spatial_dims": 3, "mode": "nearest", "padding_mode": "reflection"}, + {"spatial_dims": 3, "mode": 3, "padding_mode": "reflection"}, {"image": torch.arange(8).reshape((1, 1, 2, 2, 2)).to(dtype=torch.float), "ddf": torch.ones(1, 3, 2, 2, 2)}, torch.tensor([[[[[7, 6], [5, 4]], [[3, 2], [1, 0]]]]]), ], ] +TEST_CASES = LOW_POWER_TEST_CASES +# if USE_COMPILED: +# TEST_CASES += HIGH_POWER_TEST_CASES + class TestWarp(unittest.TestCase): @parameterized.expand(TEST_CASES)