Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

2231 Fixes tutorial 353 #2954

Merged
merged 14 commits into from
Sep 15, 2021
1 change: 1 addition & 0 deletions .github/workflows/pythonapp.yml
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ jobs:
python -c 'import torch; print(torch.__version__); print(torch.rand(5,3))'
python -c "import monai; monai.config.print_config()"
./runtests.sh --min
shell: bash
env:
QUICKTEST: True

Expand Down
35 changes: 17 additions & 18 deletions monai/transforms/spatial/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,7 +471,7 @@ def __call__(
padding_mode: Optional[Union[GridSamplePadMode, str]] = None,
align_corners: Optional[bool] = None,
dtype: Union[DtypeLike, torch.dtype] = None,
) -> torch.Tensor:
) -> NdarrayOrTensor:
"""
Args:
img: channel first array, must have shape: [chns, H, W] or [chns, H, W, D].
Expand Down Expand Up @@ -526,13 +526,11 @@ def __call__(
align_corners=self.align_corners if align_corners is None else align_corners,
reverse_indexing=True,
)
output: torch.Tensor = xform(
img_t.unsqueeze(0),
transform_t,
spatial_size=output_shape,
)
output: torch.Tensor = xform(img_t.unsqueeze(0), transform_t, spatial_size=output_shape).squeeze(0)
self._rotation_matrix = transform
return output.squeeze(0).detach().float()
out: NdarrayOrTensor
out, *_ = convert_to_dst_type(output, dst=img, dtype=output.dtype)
return out

def get_rotation_matrix(self) -> Optional[np.ndarray]:
"""
Expand Down Expand Up @@ -799,7 +797,7 @@ def __call__(
padding_mode: Optional[Union[GridSamplePadMode, str]] = None,
align_corners: Optional[bool] = None,
dtype: Union[DtypeLike, torch.dtype] = None,
) -> torch.Tensor:
) -> NdarrayOrTensor:
"""
Args:
img: channel first array, must have shape 2D: (nchannels, H, W), or 3D: (nchannels, H, W, D).
Expand Down Expand Up @@ -1290,7 +1288,7 @@ def __call__(
grid: Optional[NdarrayOrTensor] = None,
mode: Optional[Union[GridSampleMode, str]] = None,
padding_mode: Optional[Union[GridSamplePadMode, str]] = None,
) -> torch.Tensor:
) -> NdarrayOrTensor:
"""
Args:
img: shape must be (num_channels, H, W[, D]).
Expand Down Expand Up @@ -1344,8 +1342,9 @@ def __call__(
padding_mode=self.padding_mode.value if padding_mode is None else GridSamplePadMode(padding_mode).value,
align_corners=True,
)[0]

return out
out_val: NdarrayOrTensor
out_val, *_ = convert_to_dst_type(out, dst=img, dtype=out.dtype)
wyli marked this conversation as resolved.
Show resolved Hide resolved
return out_val


class Affine(Transform):
Expand Down Expand Up @@ -1425,7 +1424,7 @@ def __call__(
spatial_size: Optional[Union[Sequence[int], int]] = None,
mode: Optional[Union[GridSampleMode, str]] = None,
padding_mode: Optional[Union[GridSamplePadMode, str]] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, NdarrayOrTensor]]:
) -> Union[NdarrayOrTensor, Tuple[NdarrayOrTensor, NdarrayOrTensor]]:
"""
Args:
img: shape must be (num_channels, H, W[, D]),
Expand Down Expand Up @@ -1589,7 +1588,7 @@ def __call__(
spatial_size: Optional[Union[Sequence[int], int]] = None,
mode: Optional[Union[GridSampleMode, str]] = None,
padding_mode: Optional[Union[GridSamplePadMode, str]] = None,
) -> torch.Tensor:
) -> NdarrayOrTensor:
"""
Args:
img: shape must be (num_channels, H, W[, D]),
Expand All @@ -1615,7 +1614,7 @@ def __call__(
grid = self.get_identity_grid(sp_size)
if self._do_transform:
grid = self.rand_affine_grid(grid=grid)
out: torch.Tensor = self.resampler(
out: NdarrayOrTensor = self.resampler(
img=img, grid=grid, mode=mode or self.mode, padding_mode=padding_mode or self.padding_mode
)
return out
Expand Down Expand Up @@ -1727,7 +1726,7 @@ def __call__(
spatial_size: Optional[Union[Tuple[int, int], int]] = None,
mode: Optional[Union[GridSampleMode, str]] = None,
padding_mode: Optional[Union[GridSamplePadMode, str]] = None,
) -> torch.Tensor:
) -> NdarrayOrTensor:
"""
Args:
img: shape must be (num_channels, H, W),
Expand Down Expand Up @@ -1756,7 +1755,7 @@ def __call__(
grid = CenterSpatialCrop(roi_size=sp_size)(grid[0])
else:
grid = create_grid(spatial_size=sp_size)
out: torch.Tensor = self.resampler(
out: NdarrayOrTensor = self.resampler(
img, grid, mode=mode or self.mode, padding_mode=padding_mode or self.padding_mode
)
return out
Expand Down Expand Up @@ -1877,7 +1876,7 @@ def __call__(
spatial_size: Optional[Union[Tuple[int, int, int], int]] = None,
mode: Optional[Union[GridSampleMode, str]] = None,
padding_mode: Optional[Union[GridSamplePadMode, str]] = None,
) -> torch.Tensor:
) -> NdarrayOrTensor:
"""
Args:
img: shape must be (num_channels, H, W, D),
Expand All @@ -1902,7 +1901,7 @@ def __call__(
offset = torch.as_tensor(self.rand_offset, device=self.device).unsqueeze(0)
grid[:3] += gaussian(offset)[0] * self.magnitude
grid = self.rand_affine_grid(grid=grid)
out: torch.Tensor = self.resampler(
out: NdarrayOrTensor = self.resampler(
img, grid, mode=mode or self.mode, padding_mode=padding_mode or self.padding_mode
)
return out
Expand Down
10 changes: 1 addition & 9 deletions monai/transforms/spatial/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -820,10 +820,6 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N
if do_resampling:
d[key] = self.rand_affine.resampler(d[key], grid, mode=mode, padding_mode=padding_mode)

# if not doing transform and spatial size is unchanged, only need to do convert to torch
else:
d[key], *_ = convert_data_type(d[key], torch.Tensor, dtype=torch.float32, device=device)

return d

def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
Expand Down Expand Up @@ -1442,10 +1438,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N
self.randomize()
d = dict(data)
angle: Union[Sequence[float], float] = self.x if d[self.keys[0]].ndim == 3 else (self.x, self.y, self.z)
rotator = Rotate(
angle=angle,
keep_size=self.keep_size,
)
rotator = Rotate(angle=angle, keep_size=self.keep_size)
for key, mode, padding_mode, align_corners, dtype in self.key_iterator(
d, self.mode, self.padding_mode, self.align_corners, self.dtype
):
Expand All @@ -1460,7 +1453,6 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N
)
rot_mat = rotator.get_rotation_matrix()
else:
d[key], *_ = convert_data_type(d[key], torch.Tensor)
rot_mat = np.eye(d[key].ndim)
self.push_transform(
d,
Expand Down
14 changes: 10 additions & 4 deletions monai/utils/type_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,11 +247,14 @@ def convert_data_type(
return data, orig_type, orig_device


def convert_to_dst_type(src: Any, dst: NdarrayOrTensor) -> Tuple[NdarrayOrTensor, type, Optional[torch.device]]:
def convert_to_dst_type(
src: Any, dst: NdarrayOrTensor, dtype: Optional[Union[DtypeLike, torch.dtype]] = None
) -> Tuple[NdarrayOrTensor, type, Optional[torch.device]]:
"""
If `dst` is `torch.Tensor` or its subclass, convert `src` to `torch.Tensor` with the same data type as `dst`,
if `dst` is `numpy.ndarray` or its subclass, convert to `numpy.ndarray` with the same data type as `dst`,
If `dst` is an instance of `torch.Tensor` or its subclass, convert `src` to `torch.Tensor` with the same data type as `dst`,
if `dst` is an instance of `numpy.ndarray` or its subclass, convert to `numpy.ndarray` with the same data type as `dst`,
otherwise, convert to the type of `dst` directly.
`dtype` is an optional argument if the target `dtype` is different from the original `dst`'s data type.

See Also:
:func:`convert_data_type`
Expand All @@ -260,11 +263,14 @@ def convert_to_dst_type(src: Any, dst: NdarrayOrTensor) -> Tuple[NdarrayOrTensor
if isinstance(dst, torch.Tensor):
device = dst.device

if dtype is None:
dtype = dst.dtype

output_type: Any
if isinstance(dst, torch.Tensor):
output_type = torch.Tensor
elif isinstance(dst, np.ndarray):
output_type = np.ndarray
else:
output_type = type(dst)
return convert_data_type(data=src, output_type=output_type, device=device, dtype=dst.dtype)
return convert_data_type(data=src, output_type=output_type, device=device, dtype=dtype)
2 changes: 1 addition & 1 deletion tests/test_affine_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def test_affine_grid(self, input_param, input_data, expected_val):
result, _ = g(**input_data)
if "device" in input_data:
self.assertEqual(result.device, input_data[device])
assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4)
assert_allclose(result, expected_val, type_test=False, rtol=1e-4, atol=1e-4)


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion tests/test_as_channel_first.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def test_value(self, in_type, input_param, expected_shape):
if isinstance(test_data, torch.Tensor):
test_data = test_data.cpu().numpy()
expected = np.moveaxis(test_data, input_param["channel_dim"], 0)
assert_allclose(expected, result)
assert_allclose(result, expected, type_test=False)


if __name__ == "__main__":
Expand Down
4 changes: 2 additions & 2 deletions tests/test_ensure_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def test_array_input(self):
for dtype in ("tensor", "NUMPY"):
result = EnsureType(data_type=dtype)(test_data)
self.assertTrue(isinstance(result, torch.Tensor if dtype == "tensor" else np.ndarray))
assert_allclose(result, test_data)
assert_allclose(result, test_data, type_test=False)
self.assertTupleEqual(result.shape, (2, 2))

def test_single_input(self):
Expand All @@ -41,7 +41,7 @@ def test_single_input(self):
if isinstance(test_data, bool):
self.assertFalse(result)
else:
assert_allclose(result, test_data)
assert_allclose(result, test_data, type_test=False)
self.assertEqual(result.ndim, 0)

def test_string(self):
Expand Down
4 changes: 2 additions & 2 deletions tests/test_ensure_typed.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def test_array_input(self):
for dtype in ("tensor", "NUMPY"):
result = EnsureTyped(keys="data", data_type=dtype)({"data": test_data})["data"]
self.assertTrue(isinstance(result, torch.Tensor if dtype == "tensor" else np.ndarray))
assert_allclose(result, test_data)
assert_allclose(result, test_data, type_test=False)
self.assertTupleEqual(result.shape, (2, 2))

def test_single_input(self):
Expand All @@ -41,7 +41,7 @@ def test_single_input(self):
if isinstance(test_data, bool):
self.assertFalse(result)
else:
assert_allclose(result, test_data)
assert_allclose(result, test_data, type_test=False)
self.assertEqual(result.ndim, 0)

def test_string(self):
Expand Down
6 changes: 2 additions & 4 deletions tests/test_flip.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,10 @@ def test_correct_results(self, _, spatial_axis):
for p in TEST_NDARRAYS:
im = p(self.imt[0])
flip = Flip(spatial_axis=spatial_axis)
expected = []
for channel in self.imt[0]:
expected.append(np.flip(channel, spatial_axis))
expected = [np.flip(channel, spatial_axis) for channel in self.imt[0]]
expected = np.stack(expected)
result = flip(im)
assert_allclose(expected, result)
assert_allclose(result, p(expected))


if __name__ == "__main__":
Expand Down
6 changes: 2 additions & 4 deletions tests/test_flipd.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,10 @@ def test_invalid_cases(self, _, spatial_axis, raises):
def test_correct_results(self, _, spatial_axis):
for p in TEST_NDARRAYS:
flip = Flipd(keys="img", spatial_axis=spatial_axis)
expected = []
for channel in self.imt[0]:
expected.append(np.flip(channel, spatial_axis))
expected = [np.flip(channel, spatial_axis) for channel in self.imt[0]]
expected = np.stack(expected)
result = flip({"img": p(self.imt[0])})["img"]
assert_allclose(expected, result)
assert_allclose(result, p(expected))


if __name__ == "__main__":
Expand Down
5 changes: 1 addition & 4 deletions tests/test_inverse_collation.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,10 +115,7 @@ def tearDown(self):

@parameterized.expand(TESTS_2D + TESTS_3D)
def test_collation(self, _, transform, collate_fn, ndim):
if ndim == 3:
data = self.data_3d
else:
data = self.data_2d
data = self.data_3d if ndim == 3 else self.data_2d
if collate_fn:
modified_transform = transform
else:
Expand Down
2 changes: 1 addition & 1 deletion tests/test_label_to_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def test_value(self, argments, image, expected_data):
self.assertEqual(type(result), type(image))
if isinstance(result, torch.Tensor):
self.assertEqual(result.device, image.device)
assert_allclose(result, expected_data)
assert_allclose(result, expected_data, type_test=False)


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion tests/test_label_to_maskd.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def test_value(self, argments, input_data, expected_data):
self.assertEqual(type(r), type(i))
if isinstance(r, torch.Tensor):
self.assertEqual(r.device, i.device)
assert_allclose(r, expected_data)
assert_allclose(r, expected_data, type_test=False)


if __name__ == "__main__":
Expand Down
Loading