Skip to content

Commit

Permalink
1452 adjust-warp-layer-grid-pull-user-case (#1470)
Browse files Browse the repository at this point in the history
* 1442 use pull-grid only for above linear interpolation

Signed-off-by: kate-sann5100 <yiwen.li@st-annes.ox.ac.uk>
  • Loading branch information
kate-sann5100 authored Jan 28, 2021
1 parent 8207e1e commit f75b67a
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 66 deletions.
64 changes: 37 additions & 27 deletions monai/networks/blocks/warp.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@
from torch import nn
from torch.nn import functional as F

from monai.config import USE_COMPILED
from monai.networks.layers import grid_pull
from monai.utils import GridSampleMode, GridSamplePadMode
from monai.utils import GridSamplePadMode


class Warp(nn.Module):
Expand All @@ -17,30 +15,38 @@ class Warp(nn.Module):
def __init__(
self,
spatial_dims: int,
mode: Optional[Union[GridSampleMode, str]] = GridSampleMode.BILINEAR,
mode: int = 1,
padding_mode: Optional[Union[GridSamplePadMode, str]] = GridSamplePadMode.ZEROS,
):
"""
Args:
spatial_dims: {2, 3}. number of spatial dimensions
mode: {``"bilinear"``, ``"nearest"``}
Interpolation mode to calculate output values. Defaults to ``"bilinear"``.
See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample
mode: interpolation mode to calculate output values, defaults to 1.
Possible values are::
- 0 or 'nearest' or InterpolationType.nearest
- 1 or 'linear' or InterpolationType.linear
- 2 or 'quadratic' or InterpolationType.quadratic
- 3 or 'cubic' or InterpolationType.cubic
- 4 or 'fourth' or InterpolationType.fourth
- etc.
padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``}
Padding mode for outside grid values. Defaults to ``"border"``.
See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample
"""
super(Warp, self).__init__()
if spatial_dims not in [2, 3]:
raise ValueError(f"got unsupported spatial_dims = {spatial_dims}, only support 2-d and 3-d input")
raise ValueError(f"got unsupported spatial_dims={spatial_dims}, only support 2-d and 3-d input")
self.spatial_dims = spatial_dims
self.mode: GridSampleMode = GridSampleMode(mode)
if mode < 0:
raise ValueError(f"do not support negative mode, got mode={mode}")
self.mode = mode
self.padding_mode: GridSamplePadMode = GridSamplePadMode(padding_mode)

@staticmethod
def get_reference_grid(ddf: torch.Tensor) -> torch.Tensor:
mesh_points = [torch.arange(0, dim) for dim in ddf.shape[2:]]
grid = torch.stack(torch.meshgrid(*mesh_points[::-1]), dim=0) # (spatial_dims, ...)
grid = torch.stack(torch.meshgrid(*mesh_points), dim=0) # (spatial_dims, ...)
grid = torch.stack([grid] * ddf.shape[0], dim=0) # (batch, spatial_dims, ...)
grid = grid.to(ddf)
return grid
Expand Down Expand Up @@ -77,27 +83,31 @@ def forward(self, image: torch.Tensor, ddf: torch.Tensor) -> torch.Tensor:
grid = self.get_reference_grid(ddf) + ddf
grid = grid.permute([0] + list(range(2, 2 + self.spatial_dims)) + [1]) # (batch, ..., self.spatial_dims)

if USE_COMPILED:
_padding_mode = self.padding_mode.value
if _padding_mode == "zeros":
bound = 7
elif _padding_mode == "border":
bound = 0
else:
bound = 1
_interp_mode = self.mode.value
warped_image: torch.Tensor = grid_pull(
image,
grid,
bound=bound,
extrapolate=True,
interpolation=1 if _interp_mode == "bilinear" else _interp_mode,
)
if self.mode > 1:
raise ValueError(f"{self.mode}-order interpolation not yet implemented.")
# if not USE_COMPILED:
# raise ValueError(f"cannot perform {self.mode}-order interpolation without C compile.")
# _padding_mode = self.padding_mode.value
# if _padding_mode == "zeros":
# bound = 7
# elif _padding_mode == "border":
# bound = 0
# else:
# bound = 1
# warped_image: torch.Tensor = grid_pull(
# image,
# grid,
# bound=bound,
# extrapolate=True,
# interpolation=self.mode,
# )
else:
grid = self.normalize_grid(grid)
index_ordering: List[int] = list(range(self.spatial_dims - 1, -1, -1))
grid = grid[..., index_ordering] # z, y, x -> x, y, z
_interp_mode = "bilinear" if self.mode == 1 else "nearest"
warped_image = F.grid_sample(
image, grid, mode=self.mode.value, padding_mode=self.padding_mode.value, align_corners=True
image, grid, mode=_interp_mode, padding_mode=self.padding_mode.value, align_corners=True
)

return warped_image
51 changes: 12 additions & 39 deletions tests/test_warp.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,66 +6,39 @@

from monai.networks.blocks.warp import Warp

TEST_CASE = [
LOW_POWER_TEST_CASES = [
[
{"spatial_dims": 2, "mode": "bilinear", "padding_mode": "zeros"},
{"spatial_dims": 2, "mode": 0, "padding_mode": "zeros"},
{"image": torch.arange(4).reshape((1, 1, 2, 2)).to(dtype=torch.float), "ddf": torch.zeros(1, 2, 2, 2)},
torch.arange(4).reshape((1, 1, 2, 2)),
],
[
{"spatial_dims": 2, "mode": "nearest", "padding_mode": "zeros"},
{"image": torch.arange(4).reshape((1, 1, 2, 2)).to(dtype=torch.float), "ddf": torch.zeros(1, 2, 4, 4)},
torch.tensor([[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 2.0, 3.0, 0.0], [0.0, 0.0, 0.0, 0.0]])
.unsqueeze(0)
.unsqueeze(0),
],
[
{"spatial_dims": 2, "mode": "bilinear", "padding_mode": "border"},
{"image": torch.arange(4).reshape((1, 1, 2, 2)).to(dtype=torch.float), "ddf": torch.zeros(1, 2, 4, 4)},
torch.tensor([[0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 1.0, 1.0], [2.0, 2.0, 3, 3.0], [2.0, 2.0, 3.0, 3.0]])
.unsqueeze(0)
.unsqueeze(0),
],
[
{"spatial_dims": 2, "mode": "nearest", "padding_mode": "reflection"},
{"image": torch.arange(4).reshape((1, 1, 2, 2)).to(dtype=torch.float), "ddf": torch.zeros(1, 2, 4, 4)},
torch.tensor([[3.0, 2.0, 3.0, 2.0], [1.0, 0.0, 1.0, 0.0], [3.0, 2.0, 3.0, 2.0], [1.0, 0.0, 1.0, 0.0]])
.unsqueeze(0)
.unsqueeze(0),
],
[
{"spatial_dims": 3, "mode": "bilinear", "padding_mode": "zeros"},
{"image": torch.arange(8).reshape((1, 1, 2, 2, 2)).to(dtype=torch.float), "ddf": torch.zeros(1, 3, 2, 2, 2)},
torch.arange(8).reshape((1, 1, 2, 2, 2)).to(dtype=torch.float),
],
]

TEST_CASES = [
[
{"spatial_dims": 2, "mode": "bilinear", "padding_mode": "zeros"},
{"image": torch.arange(4).reshape((1, 1, 2, 2)).to(dtype=torch.float), "ddf": torch.zeros(1, 2, 2, 2)},
torch.arange(4).reshape((1, 1, 2, 2)),
],
[
{"spatial_dims": 2, "mode": "bilinear", "padding_mode": "zeros"},
{"spatial_dims": 2, "mode": 1, "padding_mode": "zeros"},
{"image": torch.arange(4).reshape((1, 1, 2, 2)).to(dtype=torch.float), "ddf": torch.ones(1, 2, 2, 2)},
torch.tensor([[[[3, 0], [0, 0]]]]),
],
]

HIGH_POWER_TEST_CASES = [
[
{"spatial_dims": 3, "mode": "nearest", "padding_mode": "border"},
{"spatial_dims": 3, "mode": 2, "padding_mode": "border"},
{
"image": torch.arange(8).reshape((1, 1, 2, 2, 2)).to(dtype=torch.float),
"ddf": torch.ones(1, 3, 2, 2, 2) * -1,
},
torch.tensor([[[[[0, 0], [0, 0]], [[0, 0], [0, 0]]]]]),
],
[
{"spatial_dims": 3, "mode": "nearest", "padding_mode": "reflection"},
{"spatial_dims": 3, "mode": 3, "padding_mode": "reflection"},
{"image": torch.arange(8).reshape((1, 1, 2, 2, 2)).to(dtype=torch.float), "ddf": torch.ones(1, 3, 2, 2, 2)},
torch.tensor([[[[[7, 6], [5, 4]], [[3, 2], [1, 0]]]]]),
],
]

TEST_CASES = LOW_POWER_TEST_CASES
# if USE_COMPILED:
# TEST_CASES += HIGH_POWER_TEST_CASES


class TestWarp(unittest.TestCase):
@parameterized.expand(TEST_CASES)
Expand Down

0 comments on commit f75b67a

Please sign in to comment.