Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

1452 warp layer #1463

Merged
merged 4 commits into from
Jan 19, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/source/networks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,10 @@ Blocks
.. autoclass:: LocalNetFeatureExtractorBlock
:members:

`Warp`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: Warp
:members:

Layers
------
Expand Down
1 change: 1 addition & 0 deletions monai/networks/blocks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,4 @@
SEResNeXtBottleneck,
)
from .upsample import SubpixelUpsample, Subpixelupsample, SubpixelUpSample, Upsample, UpSample
from .warp import Warp
103 changes: 103 additions & 0 deletions monai/networks/blocks/warp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
from typing import List, Optional, Union

import torch
from torch import nn
from torch.nn import functional as F

from monai.config import USE_COMPILED
from monai.networks.layers import grid_pull
from monai.utils import GridSampleMode, GridSamplePadMode


class Warp(nn.Module):
"""
Warp an image with given DDF.
"""

def __init__(
self,
spatial_dims: int,
mode: Optional[Union[GridSampleMode, str]] = GridSampleMode.BILINEAR,
padding_mode: Optional[Union[GridSamplePadMode, str]] = GridSamplePadMode.ZEROS,
):
"""
Args:
spatial_dims: {2, 3}. number of spatial dimensions
mode: {``"bilinear"``, ``"nearest"``}
Interpolation mode to calculate output values. Defaults to ``"bilinear"``.
See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample
padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``}
Padding mode for outside grid values. Defaults to ``"border"``.
See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample
"""
super(Warp, self).__init__()
if spatial_dims not in [2, 3]:
raise ValueError(f"got unsupported spatial_dims = {spatial_dims}, only support 2-d and 3-d input")
self.spatial_dims = spatial_dims
self.mode: GridSampleMode = GridSampleMode(mode)
self.padding_mode: GridSamplePadMode = GridSamplePadMode(padding_mode)

@staticmethod
def get_reference_grid(ddf: torch.Tensor) -> torch.Tensor:
mesh_points = [torch.arange(0, dim) for dim in ddf.shape[2:]]
grid = torch.stack(torch.meshgrid(*mesh_points[::-1]), dim=0) # (spatial_dims, ...)
grid = torch.stack([grid] * ddf.shape[0], dim=0) # (batch, spatial_dims, ...)
grid = grid.to(ddf)
return grid

@staticmethod
def normalize_grid(grid: torch.Tensor) -> torch.Tensor:
# (batch, ..., self.spatial_dims)
for i, dim in enumerate(grid.shape[1:-1]):
grid[..., i] = grid[..., i] * 2 / (dim - 1) - 1
return grid

def forward(self, image: torch.Tensor, ddf: torch.Tensor) -> torch.Tensor:
"""
Args:
image: Tensor in shape (batch, num_channels, H, W[, D])
ddf: Tensor in the same spatial size as image, in shape (batch, spatial_dims, H, W[, D])

Returns:
warped_image in the same shape as image (batch, num_channels, H, W[, D])
"""
if len(image.shape) != 2 + self.spatial_dims:
raise ValueError(f"expecting {self.spatial_dims + 2}-d input, " f"got input in shape {image.shape}")
if len(ddf.shape) != 2 + self.spatial_dims or ddf.shape[1] != self.spatial_dims:
raise ValueError(
f"expecting {self.spatial_dims + 2}-d ddf with {self.spatial_dims} channels, "
f"got ddf in shape {ddf.shape}"
)
if image.shape[0] != ddf.shape[0] or image.shape[2:] != ddf.shape[2:]:
raise ValueError(
"expecting image and ddf of same batch size and spatial size, "
f"got image of shape {image.shape}, ddf of shape {ddf.shape}"
)

grid = self.get_reference_grid(ddf) + ddf
grid = grid.permute([0] + list(range(2, 2 + self.spatial_dims)) + [1]) # (batch, ..., self.spatial_dims)

if USE_COMPILED:
_padding_mode = self.padding_mode.value
if _padding_mode == "zeros":
bound = 7
elif _padding_mode == "border":
bound = 0
else:
bound = 1
_interp_mode = self.mode.value
warped_image: torch.Tensor = grid_pull(
image,
grid,
bound=bound,
extrapolate=True,
interpolation=1 if _interp_mode == "bilinear" else _interp_mode,
)
else:
grid = self.normalize_grid(grid)
index_ordering: List[int] = list(range(self.spatial_dims - 1, -1, -1))
grid = grid[..., index_ordering] # z, y, x -> x, y, z
warped_image = F.grid_sample(
image, grid, mode=self.mode.value, padding_mode=self.padding_mode.value, align_corners=True
)
return warped_image
96 changes: 96 additions & 0 deletions tests/test_warp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import unittest

import numpy as np
import torch
from parameterized import parameterized

from monai.networks.blocks.warp import Warp

TEST_CASE = [
[
{"spatial_dims": 2, "mode": "bilinear", "padding_mode": "zeros"},
{"image": torch.arange(4).reshape((1, 1, 2, 2)).to(dtype=torch.float), "ddf": torch.zeros(1, 2, 2, 2)},
torch.arange(4).reshape((1, 1, 2, 2)),
],
[
{"spatial_dims": 2, "mode": "nearest", "padding_mode": "zeros"},
{"image": torch.arange(4).reshape((1, 1, 2, 2)).to(dtype=torch.float), "ddf": torch.zeros(1, 2, 4, 4)},
torch.tensor([[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 2.0, 3.0, 0.0], [0.0, 0.0, 0.0, 0.0]])
.unsqueeze(0)
.unsqueeze(0),
],
[
{"spatial_dims": 2, "mode": "bilinear", "padding_mode": "border"},
{"image": torch.arange(4).reshape((1, 1, 2, 2)).to(dtype=torch.float), "ddf": torch.zeros(1, 2, 4, 4)},
torch.tensor([[0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 1.0, 1.0], [2.0, 2.0, 3, 3.0], [2.0, 2.0, 3.0, 3.0]])
.unsqueeze(0)
.unsqueeze(0),
],
[
{"spatial_dims": 2, "mode": "nearest", "padding_mode": "reflection"},
{"image": torch.arange(4).reshape((1, 1, 2, 2)).to(dtype=torch.float), "ddf": torch.zeros(1, 2, 4, 4)},
torch.tensor([[3.0, 2.0, 3.0, 2.0], [1.0, 0.0, 1.0, 0.0], [3.0, 2.0, 3.0, 2.0], [1.0, 0.0, 1.0, 0.0]])
.unsqueeze(0)
.unsqueeze(0),
],
[
{"spatial_dims": 3, "mode": "bilinear", "padding_mode": "zeros"},
{"image": torch.arange(8).reshape((1, 1, 2, 2, 2)).to(dtype=torch.float), "ddf": torch.zeros(1, 3, 2, 2, 2)},
torch.arange(8).reshape((1, 1, 2, 2, 2)).to(dtype=torch.float),
],
]

TEST_CASES = [
[
{"spatial_dims": 2, "mode": "bilinear", "padding_mode": "zeros"},
{"image": torch.arange(4).reshape((1, 1, 2, 2)).to(dtype=torch.float), "ddf": torch.zeros(1, 2, 2, 2)},
torch.arange(4).reshape((1, 1, 2, 2)),
],
[
{"spatial_dims": 2, "mode": "bilinear", "padding_mode": "zeros"},
{"image": torch.arange(4).reshape((1, 1, 2, 2)).to(dtype=torch.float), "ddf": torch.ones(1, 2, 2, 2)},
torch.tensor([[[[3, 0], [0, 0]]]]),
],
[
{"spatial_dims": 3, "mode": "nearest", "padding_mode": "border"},
{
"image": torch.arange(8).reshape((1, 1, 2, 2, 2)).to(dtype=torch.float),
"ddf": torch.ones(1, 3, 2, 2, 2) * -1,
},
torch.tensor([[[[[0, 0], [0, 0]], [[0, 0], [0, 0]]]]]),
],
[
{"spatial_dims": 3, "mode": "nearest", "padding_mode": "reflection"},
{"image": torch.arange(8).reshape((1, 1, 2, 2, 2)).to(dtype=torch.float), "ddf": torch.ones(1, 3, 2, 2, 2)},
torch.tensor([[[[[7, 6], [5, 4]], [[3, 2], [1, 0]]]]]),
],
]


class TestWarp(unittest.TestCase):
@parameterized.expand(TEST_CASES)
def test_resample(self, input_param, input_data, expected_val):
warp_layer = Warp(**input_param)
result = warp_layer(**input_data)
np.testing.assert_allclose(result.cpu().numpy(), expected_val.cpu().numpy(), rtol=1e-4, atol=1e-4)

def test_ill_shape(self):
warp_layer = Warp(spatial_dims=2)
with self.assertRaisesRegex(ValueError, ""):
warp_layer(
image=torch.arange(4).reshape((1, 1, 1, 2, 2)).to(dtype=torch.float), ddf=torch.zeros(1, 2, 2, 2)
)
with self.assertRaisesRegex(ValueError, ""):
warp_layer(
image=torch.arange(4).reshape((1, 1, 2, 2)).to(dtype=torch.float), ddf=torch.zeros(1, 2, 1, 2, 2)
)
with self.assertRaisesRegex(ValueError, ""):
warp_layer(image=torch.arange(4).reshape((1, 1, 2, 2)).to(dtype=torch.float), ddf=torch.zeros(1, 2, 3, 3))

def test_ill_opts(self):
with self.assertRaisesRegex(ValueError, ""):
Warp(spatial_dims=4)


if __name__ == "__main__":
unittest.main()