Skip to content

Commit

Permalink
2231 Fixes tutorial 353 (#2954)
Browse files Browse the repository at this point in the history
* fixes tutorial 353

Signed-off-by: Wenqi Li <wenqil@nvidia.com>

* adding type tests

Signed-off-by: Wenqi Li <wenqil@nvidia.com>

* improves type checks

Signed-off-by: Wenqi Li <wenqil@nvidia.com>

* fixes flake8

Signed-off-by: Wenqi Li <wenqil@nvidia.com>

* fixes as channel first

Signed-off-by: Wenqi Li <wenqil@nvidia.com>

* type test option

Signed-off-by: Wenqi Li <wenqil@nvidia.com>

* ndarray suuport

Signed-off-by: Wenqi Li <wenqil@nvidia.com>

* fixes unit tests

Signed-off-by: Wenqi Li <wenqil@nvidia.com>

 update

Signed-off-by: Wenqi Li <wenqil@nvidia.com>

* bash option for windows test

Signed-off-by: Wenqi Li <wenqil@nvidia.com>

* fixes unit tests

Signed-off-by: Wenqi Li <wenqil@nvidia.com>

* enhance norm intensity tests

Signed-off-by: Wenqi Li <wenqil@nvidia.com>

* fixes merge tests

Signed-off-by: Wenqi Li <wenqil@nvidia.com>
  • Loading branch information
wyli authored Sep 15, 2021
1 parent 2f4b582 commit 0f17aa9
Show file tree
Hide file tree
Showing 48 changed files with 195 additions and 233 deletions.
1 change: 1 addition & 0 deletions .github/workflows/pythonapp.yml
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ jobs:
python -c 'import torch; print(torch.__version__); print(torch.rand(5,3))'
python -c "import monai; monai.config.print_config()"
./runtests.sh --min
shell: bash
env:
QUICKTEST: True

Expand Down
35 changes: 17 additions & 18 deletions monai/transforms/spatial/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,7 +471,7 @@ def __call__(
padding_mode: Optional[Union[GridSamplePadMode, str]] = None,
align_corners: Optional[bool] = None,
dtype: Union[DtypeLike, torch.dtype] = None,
) -> torch.Tensor:
) -> NdarrayOrTensor:
"""
Args:
img: channel first array, must have shape: [chns, H, W] or [chns, H, W, D].
Expand Down Expand Up @@ -526,13 +526,11 @@ def __call__(
align_corners=self.align_corners if align_corners is None else align_corners,
reverse_indexing=True,
)
output: torch.Tensor = xform(
img_t.unsqueeze(0),
transform_t,
spatial_size=output_shape,
)
output: torch.Tensor = xform(img_t.unsqueeze(0), transform_t, spatial_size=output_shape).squeeze(0)
self._rotation_matrix = transform
return output.squeeze(0).detach().float()
out: NdarrayOrTensor
out, *_ = convert_to_dst_type(output, dst=img, dtype=output.dtype)
return out

def get_rotation_matrix(self) -> Optional[np.ndarray]:
"""
Expand Down Expand Up @@ -799,7 +797,7 @@ def __call__(
padding_mode: Optional[Union[GridSamplePadMode, str]] = None,
align_corners: Optional[bool] = None,
dtype: Union[DtypeLike, torch.dtype] = None,
) -> torch.Tensor:
) -> NdarrayOrTensor:
"""
Args:
img: channel first array, must have shape 2D: (nchannels, H, W), or 3D: (nchannels, H, W, D).
Expand Down Expand Up @@ -1290,7 +1288,7 @@ def __call__(
grid: Optional[NdarrayOrTensor] = None,
mode: Optional[Union[GridSampleMode, str]] = None,
padding_mode: Optional[Union[GridSamplePadMode, str]] = None,
) -> torch.Tensor:
) -> NdarrayOrTensor:
"""
Args:
img: shape must be (num_channels, H, W[, D]).
Expand Down Expand Up @@ -1344,8 +1342,9 @@ def __call__(
padding_mode=self.padding_mode.value if padding_mode is None else GridSamplePadMode(padding_mode).value,
align_corners=True,
)[0]

return out
out_val: NdarrayOrTensor
out_val, *_ = convert_to_dst_type(out, dst=img, dtype=out.dtype)
return out_val


class Affine(Transform):
Expand Down Expand Up @@ -1425,7 +1424,7 @@ def __call__(
spatial_size: Optional[Union[Sequence[int], int]] = None,
mode: Optional[Union[GridSampleMode, str]] = None,
padding_mode: Optional[Union[GridSamplePadMode, str]] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, NdarrayOrTensor]]:
) -> Union[NdarrayOrTensor, Tuple[NdarrayOrTensor, NdarrayOrTensor]]:
"""
Args:
img: shape must be (num_channels, H, W[, D]),
Expand Down Expand Up @@ -1589,7 +1588,7 @@ def __call__(
spatial_size: Optional[Union[Sequence[int], int]] = None,
mode: Optional[Union[GridSampleMode, str]] = None,
padding_mode: Optional[Union[GridSamplePadMode, str]] = None,
) -> torch.Tensor:
) -> NdarrayOrTensor:
"""
Args:
img: shape must be (num_channels, H, W[, D]),
Expand All @@ -1615,7 +1614,7 @@ def __call__(
grid = self.get_identity_grid(sp_size)
if self._do_transform:
grid = self.rand_affine_grid(grid=grid)
out: torch.Tensor = self.resampler(
out: NdarrayOrTensor = self.resampler(
img=img, grid=grid, mode=mode or self.mode, padding_mode=padding_mode or self.padding_mode
)
return out
Expand Down Expand Up @@ -1727,7 +1726,7 @@ def __call__(
spatial_size: Optional[Union[Tuple[int, int], int]] = None,
mode: Optional[Union[GridSampleMode, str]] = None,
padding_mode: Optional[Union[GridSamplePadMode, str]] = None,
) -> torch.Tensor:
) -> NdarrayOrTensor:
"""
Args:
img: shape must be (num_channels, H, W),
Expand Down Expand Up @@ -1756,7 +1755,7 @@ def __call__(
grid = CenterSpatialCrop(roi_size=sp_size)(grid[0])
else:
grid = create_grid(spatial_size=sp_size)
out: torch.Tensor = self.resampler(
out: NdarrayOrTensor = self.resampler(
img, grid, mode=mode or self.mode, padding_mode=padding_mode or self.padding_mode
)
return out
Expand Down Expand Up @@ -1877,7 +1876,7 @@ def __call__(
spatial_size: Optional[Union[Tuple[int, int, int], int]] = None,
mode: Optional[Union[GridSampleMode, str]] = None,
padding_mode: Optional[Union[GridSamplePadMode, str]] = None,
) -> torch.Tensor:
) -> NdarrayOrTensor:
"""
Args:
img: shape must be (num_channels, H, W, D),
Expand All @@ -1902,7 +1901,7 @@ def __call__(
offset = torch.as_tensor(self.rand_offset, device=self.device).unsqueeze(0)
grid[:3] += gaussian(offset)[0] * self.magnitude
grid = self.rand_affine_grid(grid=grid)
out: torch.Tensor = self.resampler(
out: NdarrayOrTensor = self.resampler(
img, grid, mode=mode or self.mode, padding_mode=padding_mode or self.padding_mode
)
return out
Expand Down
10 changes: 1 addition & 9 deletions monai/transforms/spatial/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -820,10 +820,6 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N
if do_resampling:
d[key] = self.rand_affine.resampler(d[key], grid, mode=mode, padding_mode=padding_mode)

# if not doing transform and spatial size is unchanged, only need to do convert to torch
else:
d[key], *_ = convert_data_type(d[key], torch.Tensor, dtype=torch.float32, device=device)

return d

def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
Expand Down Expand Up @@ -1442,10 +1438,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N
self.randomize()
d = dict(data)
angle: Union[Sequence[float], float] = self.x if d[self.keys[0]].ndim == 3 else (self.x, self.y, self.z)
rotator = Rotate(
angle=angle,
keep_size=self.keep_size,
)
rotator = Rotate(angle=angle, keep_size=self.keep_size)
for key, mode, padding_mode, align_corners, dtype in self.key_iterator(
d, self.mode, self.padding_mode, self.align_corners, self.dtype
):
Expand All @@ -1460,7 +1453,6 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N
)
rot_mat = rotator.get_rotation_matrix()
else:
d[key], *_ = convert_data_type(d[key], torch.Tensor)
rot_mat = np.eye(d[key].ndim)
self.push_transform(
d,
Expand Down
14 changes: 10 additions & 4 deletions monai/utils/type_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,11 +248,14 @@ def convert_data_type(
return data, orig_type, orig_device


def convert_to_dst_type(src: Any, dst: NdarrayOrTensor) -> Tuple[NdarrayOrTensor, type, Optional[torch.device]]:
def convert_to_dst_type(
src: Any, dst: NdarrayOrTensor, dtype: Optional[Union[DtypeLike, torch.dtype]] = None
) -> Tuple[NdarrayOrTensor, type, Optional[torch.device]]:
"""
If `dst` is `torch.Tensor` or its subclass, convert `src` to `torch.Tensor` with the same data type as `dst`,
if `dst` is `numpy.ndarray` or its subclass, convert to `numpy.ndarray` with the same data type as `dst`,
If `dst` is an instance of `torch.Tensor` or its subclass, convert `src` to `torch.Tensor` with the same data type as `dst`,
if `dst` is an instance of `numpy.ndarray` or its subclass, convert to `numpy.ndarray` with the same data type as `dst`,
otherwise, convert to the type of `dst` directly.
`dtype` is an optional argument if the target `dtype` is different from the original `dst`'s data type.
See Also:
:func:`convert_data_type`
Expand All @@ -261,11 +264,14 @@ def convert_to_dst_type(src: Any, dst: NdarrayOrTensor) -> Tuple[NdarrayOrTensor
if isinstance(dst, torch.Tensor):
device = dst.device

if dtype is None:
dtype = dst.dtype

output_type: Any
if isinstance(dst, torch.Tensor):
output_type = torch.Tensor
elif isinstance(dst, np.ndarray):
output_type = np.ndarray
else:
output_type = type(dst)
return convert_data_type(data=src, output_type=output_type, device=device, dtype=dst.dtype)
return convert_data_type(data=src, output_type=output_type, device=device, dtype=dtype)
2 changes: 1 addition & 1 deletion tests/test_affine_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def test_affine_grid(self, input_param, input_data, expected_val):
result, _ = g(**input_data)
if "device" in input_data:
self.assertEqual(result.device, input_data[device])
assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4)
assert_allclose(result, expected_val, type_test=False, rtol=1e-4, atol=1e-4)


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion tests/test_as_channel_first.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def test_value(self, in_type, input_param, expected_shape):
if isinstance(test_data, torch.Tensor):
test_data = test_data.cpu().numpy()
expected = np.moveaxis(test_data, input_param["channel_dim"], 0)
assert_allclose(expected, result)
assert_allclose(result, expected, type_test=False)


if __name__ == "__main__":
Expand Down
4 changes: 2 additions & 2 deletions tests/test_ensure_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def test_array_input(self):
if dtype == "NUMPY":
self.assertTrue(result.dtype == np.float32)
self.assertTrue(isinstance(result, torch.Tensor if dtype == "tensor" else np.ndarray))
assert_allclose(result, test_data)
assert_allclose(result, test_data, type_test=False)
self.assertTupleEqual(result.shape, (2, 2))

def test_single_input(self):
Expand All @@ -43,7 +43,7 @@ def test_single_input(self):
if isinstance(test_data, bool):
self.assertFalse(result)
else:
assert_allclose(result, test_data)
assert_allclose(result, test_data, type_test=False)
self.assertEqual(result.ndim, 0)

def test_string(self):
Expand Down
4 changes: 2 additions & 2 deletions tests/test_ensure_typed.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def test_array_input(self):
if dtype == "NUMPY":
self.assertTrue(result.dtype == np.float32)
self.assertTrue(isinstance(result, torch.Tensor if dtype == "tensor" else np.ndarray))
assert_allclose(result, test_data)
assert_allclose(result, test_data, type_test=False)
self.assertTupleEqual(result.shape, (2, 2))

def test_single_input(self):
Expand All @@ -48,7 +48,7 @@ def test_single_input(self):
if isinstance(test_data, bool):
self.assertFalse(result)
else:
assert_allclose(result, test_data)
assert_allclose(result, test_data, type_test=False)
self.assertEqual(result.ndim, 0)

def test_string(self):
Expand Down
6 changes: 2 additions & 4 deletions tests/test_flip.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,10 @@ def test_correct_results(self, _, spatial_axis):
for p in TEST_NDARRAYS:
im = p(self.imt[0])
flip = Flip(spatial_axis=spatial_axis)
expected = []
for channel in self.imt[0]:
expected.append(np.flip(channel, spatial_axis))
expected = [np.flip(channel, spatial_axis) for channel in self.imt[0]]
expected = np.stack(expected)
result = flip(im)
assert_allclose(expected, result)
assert_allclose(result, p(expected))


if __name__ == "__main__":
Expand Down
6 changes: 2 additions & 4 deletions tests/test_flipd.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,10 @@ def test_invalid_cases(self, _, spatial_axis, raises):
def test_correct_results(self, _, spatial_axis):
for p in TEST_NDARRAYS:
flip = Flipd(keys="img", spatial_axis=spatial_axis)
expected = []
for channel in self.imt[0]:
expected.append(np.flip(channel, spatial_axis))
expected = [np.flip(channel, spatial_axis) for channel in self.imt[0]]
expected = np.stack(expected)
result = flip({"img": p(self.imt[0])})["img"]
assert_allclose(expected, result)
assert_allclose(result, p(expected))


if __name__ == "__main__":
Expand Down
5 changes: 1 addition & 4 deletions tests/test_inverse_collation.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,10 +115,7 @@ def tearDown(self):

@parameterized.expand(TESTS_2D + TESTS_3D)
def test_collation(self, _, transform, collate_fn, ndim):
if ndim == 3:
data = self.data_3d
else:
data = self.data_2d
data = self.data_3d if ndim == 3 else self.data_2d
if collate_fn:
modified_transform = transform
else:
Expand Down
2 changes: 1 addition & 1 deletion tests/test_label_to_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def test_value(self, argments, image, expected_data):
self.assertEqual(type(result), type(image))
if isinstance(result, torch.Tensor):
self.assertEqual(result.device, image.device)
assert_allclose(result, expected_data)
assert_allclose(result, expected_data, type_test=False)


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion tests/test_label_to_maskd.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def test_value(self, argments, input_data, expected_data):
self.assertEqual(type(r), type(i))
if isinstance(r, torch.Tensor):
self.assertEqual(r.device, i.device)
assert_allclose(r, expected_data)
assert_allclose(r, expected_data, type_test=False)


if __name__ == "__main__":
Expand Down
Loading

0 comments on commit 0f17aa9

Please sign in to comment.