Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add dtype to ToCupy #2950

Merged
merged 6 commits into from
Sep 15, 2021
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 19 additions & 4 deletions monai/transforms/utility/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,15 @@
map_classes_to_indices,
)
from monai.transforms.utils_pytorch_numpy_unification import in1d, moveaxis
from monai.utils import convert_to_numpy, convert_to_tensor, ensure_tuple, look_up_option, min_version, optional_import
from monai.utils import (
convert_to_cupy,
convert_to_numpy,
convert_to_tensor,
ensure_tuple,
look_up_option,
min_version,
optional_import,
)
from monai.utils.enums import TransformBackends
from monai.utils.misc import is_module_ver_at_least
from monai.utils.type_conversion import convert_data_type
Expand Down Expand Up @@ -393,15 +401,22 @@ def __call__(self, img: NdarrayOrTensor) -> np.ndarray:
class ToCupy(Transform):
"""
Converts the input data to CuPy array, can support list or tuple of numbers, NumPy and PyTorch Tensor.

Args:
dtype: data type specifier. It is inferred from the input by default.
"""

backend = [TransformBackends.TORCH, TransformBackends.NUMPY]

def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
def __init__(self, dtype=None) -> None:
super().__init__()
self.dtype = dtype

def __call__(self, data: NdarrayOrTensor):
"""
Apply the transform to `img` and make it contiguous.
Create a CuPy array from `data` and make it contiguous
"""
return cp.ascontiguousarray(cp.asarray(img)) # type: ignore
return convert_to_cupy(data, self.dtype)


class ToPIL(Transform):
Expand Down
16 changes: 8 additions & 8 deletions monai/transforms/utility/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,19 +549,19 @@ def __call__(self, data: Mapping[Hashable, Any]) -> Dict[Hashable, Any]:
class ToCupyd(MapTransform):
"""
Dictionary-based wrapper of :py:class:`monai.transforms.ToCupy`.

Args:
keys: keys of the corresponding items to be transformed.
See also: :py:class:`monai.transforms.compose.MapTransform`
allow_missing_keys: don't raise exception if key is missing.
dtype: data type specifier. It is inferred from the input by default.
"""

backend = ToCupy.backend

def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> None:
"""
Args:
keys: keys of the corresponding items to be transformed.
See also: :py:class:`monai.transforms.compose.MapTransform`
allow_missing_keys: don't raise exception if key is missing.
"""
def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False, dtype=None) -> None:
super().__init__(keys, allow_missing_keys)
self.converter = ToCupy()
self.converter = ToCupy(dtype=dtype)

def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = dict(data)
Expand Down
1 change: 1 addition & 0 deletions monai/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
from .state_cacher import StateCacher
from .type_conversion import (
convert_data_type,
convert_to_cupy,
convert_to_dst_type,
convert_to_numpy,
convert_to_tensor,
Expand Down
43 changes: 43 additions & 0 deletions monai/utils/type_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"get_equivalent_dtype",
"convert_data_type",
"get_dtype",
"convert_to_cupy",
"convert_to_numpy",
"convert_to_tensor",
"convert_to_dst_type",
Expand Down Expand Up @@ -154,6 +155,42 @@ def convert_to_numpy(data, wrap_sequence: bool = False):
return data


def convert_to_cupy(data, dtype, wrap_sequence: bool = True):
"""
Utility to convert the input data to a cupy array. If passing a dictionary, list or tuple,
recursively check every item and convert it to cupy array.

Args:
data: input data can be PyTorch Tensor, numpy array, cupy array, list, dictionary, int, float, bool, str, etc.
Tensor, numpy array, cupy array, float, int, bool are converted to cupy arrays

for dictionary, list or tuple, convert every item to a numpy array if applicable.
wrap_sequence: if `False`, then lists will recursively call this function. E.g., `[1, 2]` -> `[array(1), array(2)]`.
If `True`, then `[1, 2]` -> `array([1, 2])`.
"""

# direct calls
if isinstance(data, (cp_ndarray, np.ndarray, torch.Tensor, float, int, bool)):
data = cp.asarray(data, dtype)
# recursive calls
elif isinstance(data, Sequence) and wrap_sequence:
return cp.asarray(data)
bhashemian marked this conversation as resolved.
Show resolved Hide resolved
bhashemian marked this conversation as resolved.
Show resolved Hide resolved
elif isinstance(data, list):
return [convert_to_cupy(i, dtype) for i in data]
elif isinstance(data, tuple):
return tuple(convert_to_cupy(i, dtype) for i in data)
elif isinstance(data, dict):
return {k: convert_to_cupy(v, dtype) for k, v in data.items()}
# make it contiguous
if isinstance(data, cp.ndarray):
if data.ndim > 0:
data = cp.ascontiguousarray(data)
else:
raise ValueError(f"The input data type [{type(data)}] cannot be converted into cupy arrays!")

return data


def convert_data_type(
data: Any,
output_type: Optional[type] = None,
Expand All @@ -178,6 +215,8 @@ def convert_data_type(
orig_type = torch.Tensor
elif isinstance(data, np.ndarray):
orig_type = np.ndarray
elif has_cp and isinstance(data, cp.ndarray):
orig_type = cp.ndarray
else:
orig_type = type(data)

Expand All @@ -199,6 +238,10 @@ def convert_data_type(
data = convert_to_numpy(data)
if data is not None and dtype != data.dtype:
data = data.astype(dtype)
elif has_cp and output_type is cp.ndarray:
if data is not None:
data = convert_to_cupy(data, dtype)

else:
raise ValueError(f"Unsupported output type: {output_type}")
return data, orig_type, orig_device
Expand Down
54 changes: 43 additions & 11 deletions tests/test_to_cupy.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,49 +22,81 @@
cp, has_cp = optional_import("cupy")


@skipUnless(has_cp, "CuPy is required.")
class TestToCupy(unittest.TestCase):
@skipUnless(has_cp, "CuPy is required.")
def test_cupy_input(self):
test_data = cp.array([[1, 2], [3, 4]])
test_data = cp.array([[1, 2], [3, 4]], dtype=cp.float32)
bhashemian marked this conversation as resolved.
Show resolved Hide resolved
test_data = cp.rot90(test_data)
self.assertFalse(test_data.flags["C_CONTIGUOUS"])
result = ToCupy()(test_data)
self.assertTrue(result.dtype == cp.float32)
self.assertTrue(isinstance(result, cp.ndarray))
self.assertTrue(result.flags["C_CONTIGUOUS"])
cp.testing.assert_allclose(result, test_data)

def test_cupy_input_dtype(self):
test_data = cp.array([[1, 2], [3, 4]], dtype=cp.float32)
test_data = cp.rot90(test_data)
self.assertFalse(test_data.flags["C_CONTIGUOUS"])
result = ToCupy(cp.uint8)(test_data)
self.assertTrue(result.dtype == cp.uint8)
self.assertTrue(isinstance(result, cp.ndarray))
self.assertTrue(result.flags["C_CONTIGUOUS"])
cp.testing.assert_allclose(result, test_data)

@skipUnless(has_cp, "CuPy is required.")
def test_numpy_input(self):
test_data = np.array([[1, 2], [3, 4]])
test_data = np.array([[1, 2], [3, 4]], dtype=np.float32)
test_data = np.rot90(test_data)
self.assertFalse(test_data.flags["C_CONTIGUOUS"])
result = ToCupy()(test_data)
self.assertTrue(result.dtype == cp.float32)
self.assertTrue(isinstance(result, cp.ndarray))
self.assertTrue(result.flags["C_CONTIGUOUS"])
cp.testing.assert_allclose(result, test_data)

def test_numpy_input_dtype(self):
test_data = np.array([[1, 2], [3, 4]], dtype=np.float32)
test_data = np.rot90(test_data)
self.assertFalse(test_data.flags["C_CONTIGUOUS"])
result = ToCupy(np.uint8)(test_data)
self.assertTrue(result.dtype == cp.uint8)
self.assertTrue(isinstance(result, cp.ndarray))
self.assertTrue(result.flags["C_CONTIGUOUS"])
cp.testing.assert_allclose(result, test_data)

@skipUnless(has_cp, "CuPy is required.")
def test_tensor_input(self):
test_data = torch.tensor([[1, 2], [3, 4]])
test_data = torch.tensor([[1, 2], [3, 4]], dtype=torch.float32)
test_data = test_data.rot90()
self.assertFalse(test_data.is_contiguous())
result = ToCupy()(test_data)
self.assertTrue(result.dtype == cp.float32)
self.assertTrue(isinstance(result, cp.ndarray))
self.assertTrue(result.flags["C_CONTIGUOUS"])
cp.testing.assert_allclose(result, test_data.numpy())
cp.testing.assert_allclose(result, test_data)

@skipUnless(has_cp, "CuPy is required.")
@skip_if_no_cuda
def test_tensor_cuda_input(self):
test_data = torch.tensor([[1, 2], [3, 4]]).cuda()
test_data = torch.tensor([[1, 2], [3, 4]], dtype=torch.float32).cuda()
test_data = test_data.rot90()
self.assertFalse(test_data.is_contiguous())
result = ToCupy()(test_data)
self.assertTrue(result.dtype == cp.float32)
self.assertTrue(isinstance(result, cp.ndarray))
self.assertTrue(result.flags["C_CONTIGUOUS"])
cp.testing.assert_allclose(result, test_data.cpu().numpy())
cp.testing.assert_allclose(result, test_data)

@skip_if_no_cuda
def test_tensor_cuda_input_dtype(self):
test_data = torch.tensor([[1, 2], [3, 4]], dtype=torch.uint8).cuda()
test_data = test_data.rot90()
self.assertFalse(test_data.is_contiguous())

result = ToCupy(dtype="float32")(test_data)
self.assertTrue(result.dtype == cp.float32)
self.assertTrue(isinstance(result, cp.ndarray))
self.assertTrue(result.flags["C_CONTIGUOUS"])
cp.testing.assert_allclose(result, test_data)

@skipUnless(has_cp, "CuPy is required.")
def test_list_tuple(self):
test_data = [[1, 2], [3, 4]]
result = ToCupy()(test_data)
Expand Down