Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enhance affinegrid to use torch backend #2969

Merged
merged 23 commits into from
Sep 20, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions monai/data/png_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def write_png(

"""
if not isinstance(data, np.ndarray):
raise AssertionError("input data must be numpy array.")
raise ValueError("input data must be numpy array.")
if len(data.shape) == 3 and data.shape[2] == 1: # PIL Image can't save image with 1 channel
data = data.squeeze(2)
if output_spatial_shape is not None:
Expand All @@ -59,11 +59,11 @@ def write_png(
_min, _max = np.min(data), np.max(data)
if len(data.shape) == 3:
data = np.moveaxis(data, -1, 0) # to channel first
data = xform(data)
data = xform(data) # type: ignore
data = np.moveaxis(data, 0, -1)
else: # (H, W)
data = np.expand_dims(data, 0) # make a channel
data = xform(data)[0] # first channel
data = xform(data)[0] # type: ignore
if mode != InterpolateMode.NEAREST:
data = np.clip(data, _min, _max) # type: ignore

Expand Down
35 changes: 18 additions & 17 deletions monai/transforms/croppad/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,10 +128,7 @@ def __call__(
# all zeros, skip padding
return img
mode = convert_pad_mode(dst=img, mode=mode or self.mode).value
if isinstance(img, torch.Tensor):
pad = self._pt_pad
else:
pad = self._np_pad # type: ignore
pad = self._pt_pad if isinstance(img, torch.Tensor) else self._np_pad
return pad(img, self.to_pad, mode, **self.kwargs) # type: ignore


Expand Down Expand Up @@ -449,15 +446,16 @@ class CenterSpatialCrop(Transform):
the spatial size of output data will be [32, 40, 40].
"""

backend = SpatialCrop.backend

def __init__(self, roi_size: Union[Sequence[int], int]) -> None:
self.roi_size = roi_size

def __call__(self, img: np.ndarray):
def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
"""
Apply the transform to `img`, assuming `img` is channel-first and
slicing doesn't apply to the channel dim.
"""
img, *_ = convert_data_type(img, np.ndarray) # type: ignore
roi_size = fall_back_tuple(self.roi_size, img.shape[1:])
center = [i // 2 for i in img.shape[1:]]
cropper = SpatialCrop(roi_center=center, roi_size=roi_size)
Expand All @@ -474,11 +472,12 @@ class CenterScaleCrop(Transform):

"""

backend = CenterSpatialCrop.backend

def __init__(self, roi_scale: Union[Sequence[float], float]):
self.roi_scale = roi_scale

def __call__(self, img: np.ndarray):
img, *_ = convert_data_type(img, np.ndarray) # type: ignore
def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
img_size = img.shape[1:]
ndim = len(img_size)
roi_size = [ceil(r * s) for r, s in zip(ensure_tuple_rep(self.roi_scale, ndim), img_size)]
Expand Down Expand Up @@ -510,6 +509,8 @@ class RandSpatialCrop(Randomizable, Transform):
if True, the actual size is sampled from `randint(roi_size, max_roi_size + 1)`.
"""

backend = CenterSpatialCrop.backend

def __init__(
self,
roi_size: Union[Sequence[int], int],
Expand All @@ -535,15 +536,14 @@ def randomize(self, img_size: Sequence[int]) -> None:
valid_size = get_valid_patch_size(img_size, self._size)
self._slices = (slice(None),) + get_random_patch(img_size, valid_size, self.R)

def __call__(self, img: np.ndarray):
def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
"""
Apply the transform to `img`, assuming `img` is channel-first and
slicing doesn't apply to the channel dim.
"""
img, *_ = convert_data_type(img, np.ndarray) # type: ignore
self.randomize(img.shape[1:])
if self._size is None:
raise AssertionError
raise RuntimeError("self._size not specified.")
if self.random_center:
return img[self._slices]
cropper = CenterSpatialCrop(self._size)
Expand Down Expand Up @@ -582,12 +582,11 @@ def __init__(
self.roi_scale = roi_scale
self.max_roi_scale = max_roi_scale

def __call__(self, img: np.ndarray):
def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
"""
Apply the transform to `img`, assuming `img` is channel-first and
slicing doesn't apply to the channel dim.
"""
img, *_ = convert_data_type(img, np.ndarray) # type: ignore
img_size = img.shape[1:]
ndim = len(img_size)
self.roi_size = [ceil(r * s) for r, s in zip(ensure_tuple_rep(self.roi_scale, ndim), img_size)]
Expand Down Expand Up @@ -629,6 +628,8 @@ class RandSpatialCropSamples(Randomizable, Transform):

"""

backend = RandScaleCrop.backend

def __init__(
self,
roi_size: Union[Sequence[int], int],
Expand All @@ -652,12 +653,11 @@ def set_random_state(
def randomize(self, data: Optional[Any] = None) -> None:
pass

def __call__(self, img: np.ndarray) -> List[np.ndarray]:
def __call__(self, img: NdarrayOrTensor) -> List[NdarrayOrTensor]:
"""
Apply the transform to `img`, assuming `img` is channel-first and
cropping doesn't change the channel dim.
"""
img, *_ = convert_data_type(img, np.ndarray) # type: ignore
return [self.cropper(img) for _ in range(self.num_samples)]


Expand Down Expand Up @@ -1128,6 +1128,8 @@ class ResizeWithPadOrCrop(Transform):

"""

backend = list(set(SpatialPad.backend) & set(CenterSpatialCrop.backend))

def __init__(
self,
spatial_size: Union[Sequence[int], int],
Expand All @@ -1138,7 +1140,7 @@ def __init__(
self.padder = SpatialPad(spatial_size=spatial_size, method=method, mode=mode, **np_kwargs)
self.cropper = CenterSpatialCrop(roi_size=spatial_size)

def __call__(self, img: np.ndarray, mode: Optional[Union[NumpyPadMode, str]] = None) -> np.ndarray:
def __call__(self, img: NdarrayOrTensor, mode: Optional[Union[NumpyPadMode, str]] = None) -> NdarrayOrTensor:
"""
Args:
img: data to pad or crop, assuming `img` is channel-first and
Expand All @@ -1149,7 +1151,6 @@ def __call__(self, img: np.ndarray, mode: Optional[Union[NumpyPadMode, str]] = N
If None, defaults to the ``mode`` in construction.
See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html
"""
img, *_ = convert_data_type(img, np.ndarray) # type: ignore
return self.padder(self.cropper(img), mode=mode) # type: ignore


Expand Down
32 changes: 22 additions & 10 deletions monai/transforms/croppad/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,13 +416,15 @@ class CenterSpatialCropd(MapTransform, InvertibleTransform):
allow_missing_keys: don't raise exception if key is missing.
"""

backend = CenterSpatialCrop.backend

def __init__(
self, keys: KeysCollection, roi_size: Union[Sequence[int], int], allow_missing_keys: bool = False
) -> None:
super().__init__(keys, allow_missing_keys)
self.cropper = CenterSpatialCrop(roi_size)

def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = dict(data)
for key in self.key_iterator(d):
orig_size = d[key].shape[1:]
Expand Down Expand Up @@ -466,13 +468,15 @@ class CenterScaleCropd(MapTransform, InvertibleTransform):
allow_missing_keys: don't raise exception if key is missing.
"""

backend = CenterSpatialCrop.backend

def __init__(
self, keys: KeysCollection, roi_scale: Union[Sequence[float], float], allow_missing_keys: bool = False
) -> None:
super().__init__(keys, allow_missing_keys=allow_missing_keys)
self.roi_scale = roi_scale

def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = dict(data)
# use the spatial size of first image to scale, expect all images have the same spatial size
img_size = data[self.keys[0]].shape[1:]
Expand Down Expand Up @@ -537,6 +541,8 @@ class RandSpatialCropd(Randomizable, MapTransform, InvertibleTransform):
allow_missing_keys: don't raise exception if key is missing.
"""

backend = CenterSpatialCrop.backend

def __init__(
self,
keys: KeysCollection,
Expand Down Expand Up @@ -565,11 +571,11 @@ def randomize(self, img_size: Sequence[int]) -> None:
valid_size = get_valid_patch_size(img_size, self._size)
self._slices = (slice(None),) + get_random_patch(img_size, valid_size, self.R)

def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = dict(data)
self.randomize(d[self.keys[0]].shape[1:]) # image shape from the first data key
if self._size is None:
raise AssertionError
raise RuntimeError("self._size not specified.")
for key in self.key_iterator(d):
if self.random_center:
self.push_transform(d, key, {"slices": [(i.start, i.stop) for i in self._slices[1:]]}) # type: ignore
Expand Down Expand Up @@ -638,6 +644,8 @@ class RandScaleCropd(RandSpatialCropd):
allow_missing_keys: don't raise exception if key is missing.
"""

backend = RandSpatialCropd.backend

def __init__(
self,
keys: KeysCollection,
Expand All @@ -659,7 +667,7 @@ def __init__(
self.roi_scale = roi_scale
self.max_roi_scale = max_roi_scale

def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
img_size = data[self.keys[0]].shape[1:]
ndim = len(img_size)
self.roi_size = [ceil(r * s) for r, s in zip(ensure_tuple_rep(self.roi_scale, ndim), img_size)]
Expand Down Expand Up @@ -723,6 +731,8 @@ class RandSpatialCropSamplesd(Randomizable, MapTransform, InvertibleTransform):

"""

backend = RandSpatialCropd.backend

def __init__(
self,
keys: KeysCollection,
Expand Down Expand Up @@ -755,7 +765,7 @@ def set_random_state(
def randomize(self, data: Optional[Any] = None) -> None:
pass

def __call__(self, data: Mapping[Hashable, np.ndarray]) -> List[Dict[Hashable, np.ndarray]]:
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> List[Dict[Hashable, NdarrayOrTensor]]:
ret = []
for i in range(self.num_samples):
d = dict(data)
Expand All @@ -765,14 +775,14 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> List[Dict[Hashable, n
cropped = self.cropper(d)
# self.cropper will have added RandSpatialCropd to the list. Change to RandSpatialCropSamplesd
for key in self.key_iterator(cropped):
cropped[str(key) + InverseKeys.KEY_SUFFIX][-1][InverseKeys.CLASS_NAME] = self.__class__.__name__
cropped[str(key) + InverseKeys.KEY_SUFFIX][-1][InverseKeys.ID] = id(self)
cropped[str(key) + InverseKeys.KEY_SUFFIX][-1][InverseKeys.CLASS_NAME] = self.__class__.__name__ # type: ignore
cropped[str(key) + InverseKeys.KEY_SUFFIX][-1][InverseKeys.ID] = id(self) # type: ignore
# add `patch_index` to the meta data
for key, meta_key, meta_key_postfix in self.key_iterator(d, self.meta_keys, self.meta_key_postfix):
meta_key = meta_key or f"{key}_{meta_key_postfix}"
if meta_key not in cropped:
cropped[meta_key] = {} # type: ignore
cropped[meta_key][Key.PATCH_INDEX] = i
cropped[meta_key][Key.PATCH_INDEX] = i # type: ignore
ret.append(cropped)
return ret

Expand Down Expand Up @@ -1377,6 +1387,8 @@ class ResizeWithPadOrCropd(MapTransform, InvertibleTransform):

"""

backend = ResizeWithPadOrCrop.backend

def __init__(
self,
keys: KeysCollection,
Expand All @@ -1390,7 +1402,7 @@ def __init__(
self.mode = ensure_tuple_rep(mode, len(self.keys))
self.padcropper = ResizeWithPadOrCrop(spatial_size=spatial_size, method=method, **np_kwargs)

def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = dict(data)
for key, m in self.key_iterator(d, self.mode):
orig_size = d[key].shape[1:]
Expand Down
Loading