Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

4636 4637 backward compatible types #4638

Merged
merged 10 commits into from
Jul 7, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 30 additions & 1 deletion monai/data/meta_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@
from copy import deepcopy
from typing import Any, Sequence

import numpy as np
import torch

from monai.config.type_definitions import NdarrayTensor
from monai.data.meta_obj import MetaObj, get_track_meta
from monai.data.utils import affine_to_spacing, decollate_batch, list_data_collate, remove_extra_metadata
from monai.utils import look_up_option
from monai.utils.enums import PostFix
from monai.utils.type_conversion import convert_to_tensor
from monai.utils.type_conversion import convert_data_type, convert_to_tensor

__all__ = ["MetaTensor"]

Expand Down Expand Up @@ -307,6 +309,33 @@ def as_dict(self, key: str) -> dict:
PostFix.transforms(key): deepcopy(self.applied_operations),
}

def astype(self, dtype, device=None, *unused_args, **unused_kwargs):
"""
Cast to ``dtype``, sharing data whenever possible.

Args:
dtype: dtypes such as np.float32, torch.float, "np.float32", float.
device: the device if `dtype` is a torch data type.
unused_args: additional args (currently unused).
unused_kwargs: additional kwargs (currently unused).

Returns:
data array instance
"""
if isinstance(dtype, str):
mod_str, *dtype = dtype.split(".", 1)
dtype = mod_str if not dtype else dtype[0]
else:
mod_str = getattr(dtype, "__module__", "torch")
mod_str = look_up_option(mod_str, {"torch", "numpy", "np"}, default="numpy")
if mod_str == "torch":
out_type = torch.Tensor
elif mod_str in ("numpy", "np"):
out_type = np.ndarray
else:
out_type = None
return convert_data_type(self, output_type=out_type, device=device, dtype=dtype, wrap_sequence=True)[0]

@property
def affine(self) -> torch.Tensor:
"""Get the affine."""
Expand Down
18 changes: 16 additions & 2 deletions monai/transforms/utility/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,8 @@ class EnsureType(Transform):
device: for Tensor data type, specify the target device.
wrap_sequence: if `False`, then lists will recursively call this function, default to `True`.
E.g., if `False`, `[1, 2]` -> `[tensor(1), tensor(2)]`, if `True`, then `[1, 2]` -> `tensor([1, 2])`.
track_meta: whether to convert to `MetaTensor` when `data_type` is "tensor".
If False, the output data type will be `torch.Tensor`. Default to the return value of ``get_track_meta``.

"""

Expand All @@ -446,11 +448,16 @@ def __init__(
dtype: Optional[Union[DtypeLike, torch.dtype]] = None,
device: Optional[torch.device] = None,
wrap_sequence: bool = True,
track_meta: Optional[bool] = None,
) -> None:
self.data_type = look_up_option(data_type.lower(), {"tensor", "numpy"})
self.dtype = dtype
self.device = device
self.wrap_sequence = wrap_sequence
if track_meta is None:
self.track_meta = get_track_meta()
else:
self.track_meta = bool(track_meta)

def __call__(self, data: NdarrayOrTensor):
"""
Expand All @@ -461,10 +468,17 @@ def __call__(self, data: NdarrayOrTensor):
if applicable and `wrap_sequence=False`.

"""
output_type = torch.Tensor if self.data_type == "tensor" else np.ndarray
if self.data_type == "tensor":
output_type = MetaTensor if self.track_meta else torch.Tensor
else:
output_type = np.ndarray # type: ignore
out: NdarrayOrTensor
out, *_ = convert_data_type(
data=data, output_type=output_type, dtype=self.dtype, device=self.device, wrap_sequence=self.wrap_sequence
data=data,
output_type=output_type, # type: ignore
dtype=self.dtype,
device=self.device,
wrap_sequence=self.wrap_sequence,
)
return out

Expand Down
22 changes: 8 additions & 14 deletions monai/transforms/utility/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
)
from monai.transforms.utils import extreme_points_to_image, get_extreme_points
from monai.transforms.utils_pytorch_numpy_unification import concatenate
from monai.utils import convert_to_numpy, deprecated, deprecated_arg, ensure_tuple, ensure_tuple_rep
from monai.utils import deprecated, deprecated_arg, ensure_tuple, ensure_tuple_rep
from monai.utils.enums import PostFix, TraceKeys, TransformBackends
from monai.utils.type_conversion import convert_to_dst_type

Expand Down Expand Up @@ -519,7 +519,7 @@ def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, Nd
return d


class EnsureTyped(MapTransform, InvertibleTransform):
class EnsureTyped(MapTransform):
"""
Dictionary-based wrapper of :py:class:`monai.transforms.EnsureType`.

Expand All @@ -541,6 +541,7 @@ def __init__(
dtype: Union[DtypeLike, torch.dtype] = None,
device: Optional[torch.device] = None,
wrap_sequence: bool = True,
track_meta: Optional[bool] = None,
allow_missing_keys: bool = False,
) -> None:
"""
Expand All @@ -552,28 +553,21 @@ def __init__(
device: for Tensor data type, specify the target device.
wrap_sequence: if `False`, then lists will recursively call this function, default to `True`.
E.g., if `False`, `[1, 2]` -> `[tensor(1), tensor(2)]`, if `True`, then `[1, 2]` -> `tensor([1, 2])`.
track_meta: whether to convert to `MetaTensor` when `data_type` is "tensor".
If False, the output data type will be `torch.Tensor`. Default to the return value of `get_track_meta`.
allow_missing_keys: don't raise exception if key is missing.
"""
super().__init__(keys, allow_missing_keys)
self.converter = EnsureType(data_type=data_type, dtype=dtype, device=device, wrap_sequence=wrap_sequence)
self.converter = EnsureType(
data_type=data_type, dtype=dtype, device=device, wrap_sequence=wrap_sequence, track_meta=track_meta
)

def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = dict(data)
for key in self.key_iterator(d):
self.push_transform(d, key)
d[key] = self.converter(d[key])
return d

def inverse(self, data: Mapping[Hashable, Any]) -> Dict[Hashable, Any]:
d = deepcopy(dict(data))
for key in self.key_iterator(d):
# FIXME: currently, only convert tensor data to numpy array or scalar number,
# need to also invert numpy array but it's not easy to determine the previous data type
d[key] = convert_to_numpy(d[key])
# Remove the applied transform
self.pop_transform(d, key)
return d

wyli marked this conversation as resolved.
Show resolved Hide resolved

class ToNumpyd(MapTransform):
"""
Expand Down
4 changes: 2 additions & 2 deletions monai/utils/type_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,8 +239,8 @@ def convert_data_type(

wyli marked this conversation as resolved.
Show resolved Hide resolved
Args:
data: data to be converted
output_type: `torch.Tensor` or `np.ndarray` (if `None`, unchanged)
device: if output is `torch.Tensor`, select device (if `None`, unchanged)
output_type: `monai.data.MetaTensor`, `torch.Tensor`, or `np.ndarray` (if `None`, unchanged)
device: if output is `MetaTensor` or `torch.Tensor`, select device (if `None`, unchanged)
dtype: dtype of output data. Converted to correct library type (e.g.,
`np.float32` is converted to `torch.float32` if output type is `torch.Tensor`).
If left blank, it remains unchanged.
Expand Down
7 changes: 4 additions & 3 deletions tests/test_ensure_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import numpy as np
import torch

from monai.data import MetaTensor
from monai.transforms import EnsureType
from tests.utils import assert_allclose

Expand Down Expand Up @@ -59,9 +60,9 @@ def test_string(self):

def test_list_tuple(self):
for dtype in ("tensor", "numpy"):
result = EnsureType(data_type=dtype, wrap_sequence=False)([[1, 2], [3, 4]])
result = EnsureType(data_type=dtype, wrap_sequence=False, track_meta=True)([[1, 2], [3, 4]])
self.assertTrue(isinstance(result, list))
self.assertTrue(isinstance(result[0][1], torch.Tensor if dtype == "tensor" else np.ndarray))
self.assertTrue(isinstance(result[0][1], MetaTensor if dtype == "tensor" else np.ndarray))
torch.testing.assert_allclose(result[1][0], torch.as_tensor(3))
# tuple of numpy arrays
result = EnsureType(data_type=dtype, wrap_sequence=False)((np.array([1, 2]), np.array([3, 4])))
Expand All @@ -77,7 +78,7 @@ def test_dict(self):
"extra": None,
}
for dtype in ("tensor", "numpy"):
result = EnsureType(data_type=dtype)(test_data)
result = EnsureType(data_type=dtype, track_meta=False)(test_data)
self.assertTrue(isinstance(result, dict))
self.assertTrue(isinstance(result["img"], torch.Tensor if dtype == "tensor" else np.ndarray))
torch.testing.assert_allclose(result["img"], torch.as_tensor([1.0, 2.0]))
Expand Down
7 changes: 5 additions & 2 deletions tests/test_ensure_typed.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import numpy as np
import torch

from monai.data import MetaTensor
from monai.transforms import EnsureTyped
from tests.utils import assert_allclose

Expand Down Expand Up @@ -61,9 +62,11 @@ def test_string(self):

def test_list_tuple(self):
for dtype in ("tensor", "numpy"):
result = EnsureTyped(keys="data", data_type=dtype, wrap_sequence=False)({"data": [[1, 2], [3, 4]]})["data"]
result = EnsureTyped(keys="data", data_type=dtype, wrap_sequence=False, track_meta=True)(
{"data": [[1, 2], [3, 4]]}
)["data"]
self.assertTrue(isinstance(result, list))
self.assertTrue(isinstance(result[0][1], torch.Tensor if dtype == "tensor" else np.ndarray))
self.assertTrue(isinstance(result[0][1], MetaTensor if dtype == "tensor" else np.ndarray))
torch.testing.assert_allclose(result[1][0], torch.as_tensor(3))
# tuple of numpy arrays
result = EnsureTyped(keys="data", data_type=dtype, wrap_sequence=False)(
Expand Down
11 changes: 9 additions & 2 deletions tests/test_meta_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from multiprocessing.reduction import ForkingPickler
from typing import Optional, Union

import numpy as np
import torch
import torch.multiprocessing
from parameterized import parameterized
Expand Down Expand Up @@ -433,6 +434,14 @@ def test_str(self):
for s in (s1, s2):
self.assertEqual(s, expected_out)

def test_astype(self):
t = MetaTensor([1.0], affine=torch.tensor(1), meta={"fname": "filename"})
for np_types in ("float32", "np.float32", "numpy.float32", np.float32, float, "int", np.compat.long):
self.assertIsInstance(t.astype(np_types), np.ndarray)
for pt_types in ("torch.float", torch.float, "torch.float64"):
self.assertIsInstance(t.astype(pt_types), torch.Tensor)
self.assertIsInstance(t.astype("torch.float", device="cpu"), torch.Tensor)

def test_transforms(self):
key = "im"
_, im = self.get_im()
Expand All @@ -441,7 +450,6 @@ def test_transforms(self):
data = {key: im, PostFix.meta(key): {"affine": torch.eye(4)}}

# apply one at a time
is_meta = isinstance(im, MetaTensor)
for i, _tr in enumerate(tr.transforms):
data = _tr(data)
is_meta = isinstance(_tr, (ToMetaTensord, BorderPadd, DivisiblePadd))
Expand All @@ -458,7 +466,6 @@ def test_transforms(self):
self.assertEqual(n_applied, i + 1)

# inverse one at a time
is_meta = isinstance(im, MetaTensor)
for i, _tr in enumerate(tr.transforms[::-1]):
data = _tr.inverse(data)
is_meta = isinstance(_tr, (FromMetaTensord, BorderPadd, DivisiblePadd))
Expand Down