Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

1452 adjust-warp-layer-grid-pull-user-case #1470

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 37 additions & 27 deletions monai/networks/blocks/warp.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@
from torch import nn
from torch.nn import functional as F

from monai.config import USE_COMPILED
from monai.networks.layers import grid_pull
from monai.utils import GridSampleMode, GridSamplePadMode
from monai.utils import GridSamplePadMode


class Warp(nn.Module):
Expand All @@ -17,30 +15,38 @@ class Warp(nn.Module):
def __init__(
self,
spatial_dims: int,
mode: Optional[Union[GridSampleMode, str]] = GridSampleMode.BILINEAR,
mode: int = 1,
padding_mode: Optional[Union[GridSamplePadMode, str]] = GridSamplePadMode.ZEROS,
):
"""
Args:
spatial_dims: {2, 3}. number of spatial dimensions
mode: {``"bilinear"``, ``"nearest"``}
Interpolation mode to calculate output values. Defaults to ``"bilinear"``.
See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample
mode: interpolation mode to calculate output values, defaults to 1.
Possible values are::

- 0 or 'nearest' or InterpolationType.nearest
- 1 or 'linear' or InterpolationType.linear
- 2 or 'quadratic' or InterpolationType.quadratic
- 3 or 'cubic' or InterpolationType.cubic
- 4 or 'fourth' or InterpolationType.fourth
- etc.
padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``}
Padding mode for outside grid values. Defaults to ``"border"``.
See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample
"""
super(Warp, self).__init__()
if spatial_dims not in [2, 3]:
raise ValueError(f"got unsupported spatial_dims = {spatial_dims}, only support 2-d and 3-d input")
raise ValueError(f"got unsupported spatial_dims={spatial_dims}, only support 2-d and 3-d input")
self.spatial_dims = spatial_dims
self.mode: GridSampleMode = GridSampleMode(mode)
if mode < 0:
raise ValueError(f"do not support negative mode, got mode={mode}")
self.mode = mode
self.padding_mode: GridSamplePadMode = GridSamplePadMode(padding_mode)

@staticmethod
def get_reference_grid(ddf: torch.Tensor) -> torch.Tensor:
mesh_points = [torch.arange(0, dim) for dim in ddf.shape[2:]]
grid = torch.stack(torch.meshgrid(*mesh_points[::-1]), dim=0) # (spatial_dims, ...)
grid = torch.stack(torch.meshgrid(*mesh_points), dim=0) # (spatial_dims, ...)
grid = torch.stack([grid] * ddf.shape[0], dim=0) # (batch, spatial_dims, ...)
grid = grid.to(ddf)
return grid
Expand Down Expand Up @@ -77,27 +83,31 @@ def forward(self, image: torch.Tensor, ddf: torch.Tensor) -> torch.Tensor:
grid = self.get_reference_grid(ddf) + ddf
grid = grid.permute([0] + list(range(2, 2 + self.spatial_dims)) + [1]) # (batch, ..., self.spatial_dims)

if USE_COMPILED:
_padding_mode = self.padding_mode.value
if _padding_mode == "zeros":
bound = 7
elif _padding_mode == "border":
bound = 0
else:
bound = 1
_interp_mode = self.mode.value
warped_image: torch.Tensor = grid_pull(
image,
grid,
bound=bound,
extrapolate=True,
interpolation=1 if _interp_mode == "bilinear" else _interp_mode,
)
if self.mode > 1:
raise ValueError(f"{self.mode}-order interpolation not yet implemented.")
# if not USE_COMPILED:
# raise ValueError(f"cannot perform {self.mode}-order interpolation without C compile.")
# _padding_mode = self.padding_mode.value
# if _padding_mode == "zeros":
# bound = 7
# elif _padding_mode == "border":
# bound = 0
# else:
# bound = 1
# warped_image: torch.Tensor = grid_pull(
# image,
# grid,
# bound=bound,
# extrapolate=True,
# interpolation=self.mode,
# )
else:
grid = self.normalize_grid(grid)
index_ordering: List[int] = list(range(self.spatial_dims - 1, -1, -1))
grid = grid[..., index_ordering] # z, y, x -> x, y, z
_interp_mode = "bilinear" if self.mode == 1 else "nearest"
warped_image = F.grid_sample(
image, grid, mode=self.mode.value, padding_mode=self.padding_mode.value, align_corners=True
image, grid, mode=_interp_mode, padding_mode=self.padding_mode.value, align_corners=True
)

return warped_image
51 changes: 12 additions & 39 deletions tests/test_warp.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,66 +6,39 @@

from monai.networks.blocks.warp import Warp

TEST_CASE = [
LOW_POWER_TEST_CASES = [
[
{"spatial_dims": 2, "mode": "bilinear", "padding_mode": "zeros"},
{"spatial_dims": 2, "mode": 0, "padding_mode": "zeros"},
{"image": torch.arange(4).reshape((1, 1, 2, 2)).to(dtype=torch.float), "ddf": torch.zeros(1, 2, 2, 2)},
torch.arange(4).reshape((1, 1, 2, 2)),
],
[
{"spatial_dims": 2, "mode": "nearest", "padding_mode": "zeros"},
{"image": torch.arange(4).reshape((1, 1, 2, 2)).to(dtype=torch.float), "ddf": torch.zeros(1, 2, 4, 4)},
torch.tensor([[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 2.0, 3.0, 0.0], [0.0, 0.0, 0.0, 0.0]])
.unsqueeze(0)
.unsqueeze(0),
],
[
{"spatial_dims": 2, "mode": "bilinear", "padding_mode": "border"},
{"image": torch.arange(4).reshape((1, 1, 2, 2)).to(dtype=torch.float), "ddf": torch.zeros(1, 2, 4, 4)},
torch.tensor([[0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 1.0, 1.0], [2.0, 2.0, 3, 3.0], [2.0, 2.0, 3.0, 3.0]])
.unsqueeze(0)
.unsqueeze(0),
],
[
{"spatial_dims": 2, "mode": "nearest", "padding_mode": "reflection"},
{"image": torch.arange(4).reshape((1, 1, 2, 2)).to(dtype=torch.float), "ddf": torch.zeros(1, 2, 4, 4)},
torch.tensor([[3.0, 2.0, 3.0, 2.0], [1.0, 0.0, 1.0, 0.0], [3.0, 2.0, 3.0, 2.0], [1.0, 0.0, 1.0, 0.0]])
.unsqueeze(0)
.unsqueeze(0),
],
[
{"spatial_dims": 3, "mode": "bilinear", "padding_mode": "zeros"},
{"image": torch.arange(8).reshape((1, 1, 2, 2, 2)).to(dtype=torch.float), "ddf": torch.zeros(1, 3, 2, 2, 2)},
torch.arange(8).reshape((1, 1, 2, 2, 2)).to(dtype=torch.float),
],
]

TEST_CASES = [
[
{"spatial_dims": 2, "mode": "bilinear", "padding_mode": "zeros"},
{"image": torch.arange(4).reshape((1, 1, 2, 2)).to(dtype=torch.float), "ddf": torch.zeros(1, 2, 2, 2)},
torch.arange(4).reshape((1, 1, 2, 2)),
],
[
{"spatial_dims": 2, "mode": "bilinear", "padding_mode": "zeros"},
{"spatial_dims": 2, "mode": 1, "padding_mode": "zeros"},
{"image": torch.arange(4).reshape((1, 1, 2, 2)).to(dtype=torch.float), "ddf": torch.ones(1, 2, 2, 2)},
torch.tensor([[[[3, 0], [0, 0]]]]),
],
]

HIGH_POWER_TEST_CASES = [
[
{"spatial_dims": 3, "mode": "nearest", "padding_mode": "border"},
{"spatial_dims": 3, "mode": 2, "padding_mode": "border"},
{
"image": torch.arange(8).reshape((1, 1, 2, 2, 2)).to(dtype=torch.float),
"ddf": torch.ones(1, 3, 2, 2, 2) * -1,
},
torch.tensor([[[[[0, 0], [0, 0]], [[0, 0], [0, 0]]]]]),
],
[
{"spatial_dims": 3, "mode": "nearest", "padding_mode": "reflection"},
{"spatial_dims": 3, "mode": 3, "padding_mode": "reflection"},
{"image": torch.arange(8).reshape((1, 1, 2, 2, 2)).to(dtype=torch.float), "ddf": torch.ones(1, 3, 2, 2, 2)},
torch.tensor([[[[[7, 6], [5, 4]], [[3, 2], [1, 0]]]]]),
],
]

TEST_CASES = LOW_POWER_TEST_CASES
# if USE_COMPILED:
# TEST_CASES += HIGH_POWER_TEST_CASES


class TestWarp(unittest.TestCase):
@parameterized.expand(TEST_CASES)
Expand Down