diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 2eb6c447c6..add47e27ca 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -334,11 +334,15 @@ class ToTensor(Transform): backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + def __init__(self, device: Optional[torch.device] = None) -> None: + super().__init__() + self.device = device + def __call__(self, img: NdarrayOrTensor) -> torch.Tensor: """ Apply the transform to `img` and make it contiguous. """ - return convert_to_tensor(img, wrap_sequence=True) # type: ignore + return convert_to_tensor(img, wrap_sequence=True, device=self.device) # type: ignore class EnsureType(Transform): @@ -399,8 +403,6 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ Apply the transform to `img` and make it contiguous. """ - if isinstance(img, torch.Tensor): - img = img.detach().cpu().numpy() return cp.ascontiguousarray(cp.asarray(img)) # type: ignore diff --git a/monai/utils/type_conversion.py b/monai/utils/type_conversion.py index 3688b02d26..47b48aa2b8 100644 --- a/monai/utils/type_conversion.py +++ b/monai/utils/type_conversion.py @@ -83,7 +83,7 @@ def get_dtype(data: Any): return type(data) -def convert_to_tensor(data, wrap_sequence: bool = False): +def convert_to_tensor(data, wrap_sequence: bool = False, device: Optional[torch.device] = None): """ Utility to convert the input data to a PyTorch Tensor. If passing a dictionary, list or tuple, recursively check every item and convert it to PyTorch Tensor. @@ -97,26 +97,26 @@ def convert_to_tensor(data, wrap_sequence: bool = False): """ if isinstance(data, torch.Tensor): - return data.contiguous() + return data.contiguous().to(device) if isinstance(data, np.ndarray): # skip array of string classes and object, refer to: # https://github.com/pytorch/pytorch/blob/v1.9.0/torch/utils/data/_utils/collate.py#L13 if re.search(r"[SaUO]", data.dtype.str) is None: # numpy array with 0 dims is also sequence iterable, # `ascontiguousarray` will add 1 dim if img has no dim, so we only apply on data with dims - return torch.as_tensor(data if data.ndim == 0 else np.ascontiguousarray(data)) + return torch.as_tensor(data if data.ndim == 0 else np.ascontiguousarray(data), device=device) elif has_cp and isinstance(data, cp_ndarray): - return torch.as_tensor(data) + return torch.as_tensor(data, device=device) elif isinstance(data, (float, int, bool)): - return torch.as_tensor(data) + return torch.as_tensor(data, device=device) elif isinstance(data, Sequence) and wrap_sequence: - return torch.as_tensor(data) + return torch.as_tensor(data, device=device) elif isinstance(data, list): - return [convert_to_tensor(i) for i in data] + return [convert_to_tensor(i, device=device) for i in data] elif isinstance(data, tuple): - return tuple(convert_to_tensor(i) for i in data) + return tuple(convert_to_tensor(i, device=device) for i in data) elif isinstance(data, dict): - return {k: convert_to_tensor(v) for k, v in data.items()} + return {k: convert_to_tensor(v, device=device) for k, v in data.items()} return data