Skip to content

Commit

Permalink
6109 no mutate ratio /user inputs croppad (#6127)
Browse files Browse the repository at this point in the history
Fixes #6109

### Description
- use tuples for user inputs to avoid changes
- enhance the type checks
- fixes issue of `ratios` in `RandCropByLabelClasses `

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [x] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [x] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [x] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: Wenqi Li <wenqil@nvidia.com>
  • Loading branch information
wyli authored Mar 13, 2023
1 parent f754928 commit a8302ec
Show file tree
Hide file tree
Showing 11 changed files with 77 additions and 70 deletions.
2 changes: 1 addition & 1 deletion monai/apps/detection/transforms/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -1071,7 +1071,7 @@ def __init__(
if len(self.image_keys) != len(self.meta_keys):
raise ValueError("meta_keys should have the same length as keys.")
self.meta_key_postfix = ensure_tuple_rep(meta_key_postfix, len(self.image_keys))
self.centers: list[list[int]] | None = None
self.centers: tuple[tuple] | None = None
self.allow_smaller = allow_smaller

def generate_fg_center_boxes_np(self, boxes: NdarrayOrTensor, image_size: Sequence[int]) -> np.ndarray:
Expand Down
2 changes: 1 addition & 1 deletion monai/apps/detection/utils/detector_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def pad_images(
max_spatial_size = compute_divisible_spatial_size(spatial_shape=list(max_spatial_size_t), k=size_divisible)

# allocate memory for the padded images
images = torch.zeros([len(image_sizes), in_channels] + max_spatial_size, dtype=dtype, device=device)
images = torch.zeros([len(image_sizes), in_channels] + list(max_spatial_size), dtype=dtype, device=device)

# Use `SpatialPad` to match sizes, padding in the end will not affect boxes
padder = SpatialPad(spatial_size=max_spatial_size, method="end", mode=mode, **kwargs)
Expand Down
75 changes: 40 additions & 35 deletions monai/transforms/croppad/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,13 +106,13 @@ class Pad(InvertibleTransform, LazyTransform):
backend = [TransformBackends.TORCH, TransformBackends.NUMPY]

def __init__(
self, to_pad: list[tuple[int, int]] | None = None, mode: str = PytorchPadMode.CONSTANT, **kwargs
self, to_pad: tuple[tuple[int, int]] | None = None, mode: str = PytorchPadMode.CONSTANT, **kwargs
) -> None:
self.to_pad = to_pad
self.mode = mode
self.kwargs = kwargs

def compute_pad_width(self, spatial_shape: Sequence[int]) -> list[tuple[int, int]]:
def compute_pad_width(self, spatial_shape: Sequence[int]) -> tuple[tuple[int, int]]:
"""
dynamically compute the pad width according to the spatial shape.
the output is the amount of padding for all dimensions including the channel.
Expand All @@ -123,8 +123,8 @@ def compute_pad_width(self, spatial_shape: Sequence[int]) -> list[tuple[int, int
"""
raise NotImplementedError(f"subclass {self.__class__.__name__} must implement this method.")

def __call__( # type: ignore
self, img: torch.Tensor, to_pad: list[tuple[int, int]] | None = None, mode: str | None = None, **kwargs
def __call__( # type: ignore[override]
self, img: torch.Tensor, to_pad: tuple[tuple[int, int]] | None = None, mode: str | None = None, **kwargs
) -> torch.Tensor:
"""
Args:
Expand All @@ -150,7 +150,7 @@ def __call__( # type: ignore
kwargs_.update(kwargs)

img_t = convert_to_tensor(data=img, track_meta=get_track_meta())
return pad_func(img_t, to_pad_, mode_, self.get_transform_info(), kwargs_) # type: ignore
return pad_func(img_t, to_pad_, mode_, self.get_transform_info(), kwargs_)

def inverse(self, data: MetaTensor) -> MetaTensor:
transform = self.pop_transform(data)
Expand Down Expand Up @@ -200,7 +200,7 @@ def __init__(
self.method: Method = look_up_option(method, Method)
super().__init__(mode=mode, **kwargs)

def compute_pad_width(self, spatial_shape: Sequence[int]) -> list[tuple[int, int]]:
def compute_pad_width(self, spatial_shape: Sequence[int]) -> tuple[tuple[int, int]]:
"""
dynamically compute the pad width according to the spatial shape.
Expand All @@ -213,10 +213,10 @@ def compute_pad_width(self, spatial_shape: Sequence[int]) -> list[tuple[int, int
pad_width = []
for i, sp_i in enumerate(spatial_size):
width = max(sp_i - spatial_shape[i], 0)
pad_width.append((width // 2, width - (width // 2)))
pad_width.append((int(width // 2), int(width - (width // 2))))
else:
pad_width = [(0, max(sp_i - spatial_shape[i], 0)) for i, sp_i in enumerate(spatial_size)]
return [(0, 0)] + pad_width
pad_width = [(0, int(max(sp_i - spatial_shape[i], 0))) for i, sp_i in enumerate(spatial_size)]
return tuple([(0, 0)] + pad_width) # type: ignore


class BorderPad(Pad):
Expand Down Expand Up @@ -249,24 +249,26 @@ def __init__(self, spatial_border: Sequence[int] | int, mode: str = PytorchPadMo
self.spatial_border = spatial_border
super().__init__(mode=mode, **kwargs)

def compute_pad_width(self, spatial_shape: Sequence[int]) -> list[tuple[int, int]]:
def compute_pad_width(self, spatial_shape: Sequence[int]) -> tuple[tuple[int, int]]:
spatial_border = ensure_tuple(self.spatial_border)
if not all(isinstance(b, int) for b in spatial_border):
raise ValueError(f"self.spatial_border must contain only ints, got {spatial_border}.")
spatial_border = tuple(max(0, b) for b in spatial_border)

if len(spatial_border) == 1:
data_pad_width = [(spatial_border[0], spatial_border[0]) for _ in spatial_shape]
data_pad_width = [(int(spatial_border[0]), int(spatial_border[0])) for _ in spatial_shape]
elif len(spatial_border) == len(spatial_shape):
data_pad_width = [(sp, sp) for sp in spatial_border[: len(spatial_shape)]]
data_pad_width = [(int(sp), int(sp)) for sp in spatial_border[: len(spatial_shape)]]
elif len(spatial_border) == len(spatial_shape) * 2:
data_pad_width = [(spatial_border[2 * i], spatial_border[2 * i + 1]) for i in range(len(spatial_shape))]
data_pad_width = [
(int(spatial_border[2 * i]), int(spatial_border[2 * i + 1])) for i in range(len(spatial_shape))
]
else:
raise ValueError(
f"Unsupported spatial_border length: {len(spatial_border)}, available options are "
f"[1, len(spatial_shape)={len(spatial_shape)}, 2*len(spatial_shape)={2*len(spatial_shape)}]."
)
return [(0, 0)] + data_pad_width
return tuple([(0, 0)] + data_pad_width) # type: ignore


class DivisiblePad(Pad):
Expand Down Expand Up @@ -301,7 +303,7 @@ def __init__(
self.method: Method = Method(method)
super().__init__(mode=mode, **kwargs)

def compute_pad_width(self, spatial_shape: Sequence[int]) -> list[tuple[int, int]]:
def compute_pad_width(self, spatial_shape: Sequence[int]) -> tuple[tuple[int, int]]:
new_size = compute_divisible_spatial_size(spatial_shape=spatial_shape, k=self.k)
spatial_pad = SpatialPad(spatial_size=new_size, method=self.method)
return spatial_pad.compute_pad_width(spatial_shape)
Expand All @@ -322,7 +324,7 @@ def compute_slices(
roi_start: Sequence[int] | NdarrayOrTensor | None = None,
roi_end: Sequence[int] | NdarrayOrTensor | None = None,
roi_slices: Sequence[slice] | None = None,
):
) -> tuple[slice]:
"""
Compute the crop slices based on specified `center & size` or `start & end` or `slices`.
Expand All @@ -340,8 +342,8 @@ def compute_slices(

if roi_slices:
if not all(s.step is None or s.step == 1 for s in roi_slices):
raise ValueError("only slice steps of 1/None are currently supported")
return list(roi_slices)
raise ValueError(f"only slice steps of 1/None are currently supported, got {roi_slices}.")
return ensure_tuple(roi_slices) # type: ignore
else:
if roi_center is not None and roi_size is not None:
roi_center_t = convert_to_tensor(data=roi_center, dtype=torch.int16, wrap_sequence=True, device="cpu")
Expand All @@ -363,11 +365,12 @@ def compute_slices(
roi_end_t = torch.maximum(roi_end_t, roi_start_t)
# convert to slices (accounting for 1d)
if roi_start_t.numel() == 1:
return [slice(int(roi_start_t.item()), int(roi_end_t.item()))]
else:
return [slice(int(s), int(e)) for s, e in zip(roi_start_t.tolist(), roi_end_t.tolist())]
return ensure_tuple([slice(int(roi_start_t.item()), int(roi_end_t.item()))]) # type: ignore
return ensure_tuple( # type: ignore
[slice(int(s), int(e)) for s, e in zip(roi_start_t.tolist(), roi_end_t.tolist())]
)

def __call__(self, img: torch.Tensor, slices: tuple[slice, ...]) -> torch.Tensor: # type: ignore
def __call__(self, img: torch.Tensor, slices: tuple[slice, ...]) -> torch.Tensor: # type: ignore[override]
"""
Apply the transform to `img`, assuming `img` is channel-first and
slicing doesn't apply to the channel dim.
Expand All @@ -378,10 +381,10 @@ def __call__(self, img: torch.Tensor, slices: tuple[slice, ...]) -> torch.Tensor
if len(slices_) < sd:
slices_ += [slice(None)] * (sd - len(slices_))
# Add in the channel (no cropping)
slices = tuple([slice(None)] + slices_[:sd])
slices_ = list([slice(None)] + slices_[:sd])

img_t: MetaTensor = convert_to_tensor(data=img, track_meta=get_track_meta())
return crop_func(img_t, slices, self.get_transform_info()) # type: ignore
return crop_func(img_t, tuple(slices_), self.get_transform_info())

def inverse(self, img: MetaTensor) -> MetaTensor:
transform = self.pop_transform(img)
Expand Down Expand Up @@ -429,13 +432,13 @@ def __init__(
roi_center=roi_center, roi_size=roi_size, roi_start=roi_start, roi_end=roi_end, roi_slices=roi_slices
)

def __call__(self, img: torch.Tensor) -> torch.Tensor: # type: ignore
def __call__(self, img: torch.Tensor) -> torch.Tensor: # type: ignore[override]
"""
Apply the transform to `img`, assuming `img` is channel-first and
slicing doesn't apply to the channel dim.
"""
return super().__call__(img=img, slices=self.slices)
return super().__call__(img=img, slices=ensure_tuple(self.slices))


class CenterSpatialCrop(Crop):
Expand All @@ -456,12 +459,12 @@ class CenterSpatialCrop(Crop):
def __init__(self, roi_size: Sequence[int] | int) -> None:
self.roi_size = roi_size

def compute_slices(self, spatial_size: Sequence[int]): # type: ignore
def compute_slices(self, spatial_size: Sequence[int]) -> tuple[slice]: # type: ignore[override]
roi_size = fall_back_tuple(self.roi_size, spatial_size)
roi_center = [i // 2 for i in spatial_size]
return super().compute_slices(roi_center=roi_center, roi_size=roi_size)

def __call__(self, img: torch.Tensor) -> torch.Tensor: # type: ignore
def __call__(self, img: torch.Tensor) -> torch.Tensor: # type: ignore[override]
"""
Apply the transform to `img`, assuming `img` is channel-first and
slicing doesn't apply to the channel dim.
Expand All @@ -486,7 +489,7 @@ class CenterScaleCrop(Crop):
def __init__(self, roi_scale: Sequence[float] | float):
self.roi_scale = roi_scale

def __call__(self, img: torch.Tensor) -> torch.Tensor: # type: ignore
def __call__(self, img: torch.Tensor) -> torch.Tensor: # type: ignore[override]
img_size = img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:]
ndim = len(img_size)
roi_size = [ceil(r * s) for r, s in zip(ensure_tuple_rep(self.roi_scale, ndim), img_size)]
Expand Down Expand Up @@ -771,7 +774,7 @@ def lazy_evaluation(self, _val: bool):
self._lazy_evaluation = _val
self.padder.lazy_evaluation = _val

def compute_bounding_box(self, img: torch.Tensor):
def compute_bounding_box(self, img: torch.Tensor) -> tuple[np.ndarray, np.ndarray]:
"""
Compute the start points and end points of bounding box to crop.
And adjust bounding box coords to be divisible by `k`.
Expand All @@ -794,7 +797,7 @@ def compute_bounding_box(self, img: torch.Tensor):

def crop_pad(
self, img: torch.Tensor, box_start: np.ndarray, box_end: np.ndarray, mode: str | None = None, **pad_kwargs
):
) -> torch.Tensor:
"""
Crop and pad based on the bounding box.
Expand All @@ -817,7 +820,9 @@ def crop_pad(
ret.applied_operations[-1][TraceKeys.EXTRA_INFO]["pad_info"] = ret.applied_operations.pop()
return ret

def __call__(self, img: torch.Tensor, mode: str | None = None, **pad_kwargs): # type: ignore
def __call__( # type: ignore[override]
self, img: torch.Tensor, mode: str | None = None, **pad_kwargs
) -> torch.Tensor:
"""
Apply the transform to `img`, assuming `img` is channel-first and
slicing doesn't change the channel dim.
Expand All @@ -826,7 +831,7 @@ def __call__(self, img: torch.Tensor, mode: str | None = None, **pad_kwargs): #
cropped = self.crop_pad(img, box_start, box_end, mode, **pad_kwargs)

if self.return_coords:
return cropped, box_start, box_end
return cropped, box_start, box_end # type: ignore[return-value]
return cropped

def inverse(self, img: MetaTensor) -> MetaTensor:
Expand Down Expand Up @@ -995,7 +1000,7 @@ def __init__(
self.num_samples = num_samples
self.image = image
self.image_threshold = image_threshold
self.centers: list[list[int]] | None = None
self.centers: tuple[tuple] | None = None
self.fg_indices = fg_indices
self.bg_indices = bg_indices
self.allow_smaller = allow_smaller
Expand Down Expand Up @@ -1173,7 +1178,7 @@ def __init__(
self.num_samples = num_samples
self.image = image
self.image_threshold = image_threshold
self.centers: list[list[int]] | None = None
self.centers: tuple[tuple] | None = None
self.indices = indices
self.allow_smaller = allow_smaller
self.warn = warn
Expand Down
4 changes: 2 additions & 2 deletions monai/transforms/croppad/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -698,9 +698,9 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torc
self.cropper: CropForeground
box_start, box_end = self.cropper.compute_bounding_box(img=d[self.source_key])
if self.start_coord_key is not None:
d[self.start_coord_key] = box_start
d[self.start_coord_key] = box_start # type: ignore
if self.end_coord_key is not None:
d[self.end_coord_key] = box_end
d[self.end_coord_key] = box_end # type: ignore
for key, m in self.key_iterator(d, self.mode):
d[key] = self.cropper.crop_pad(img=d[key], box_start=box_start, box_end=box_end, mode=m)
return d
Expand Down
28 changes: 15 additions & 13 deletions monai/transforms/croppad/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,9 @@ def crop_or_pad_nd(img: torch.Tensor, translation_mat, spatial_size: tuple[int,
return img


def pad_func(img: torch.Tensor, to_pad: list[tuple[int, int]], mode: str, transform_info: dict, kwargs):
def pad_func(
img: torch.Tensor, to_pad: tuple[tuple[int, int]], mode: str, transform_info: dict, kwargs
) -> torch.Tensor:
"""
Functional implementation of padding a MetaTensor. This function operates eagerly or lazily according
to ``transform_info[TraceKeys.LAZY_EVALUATION]`` (default ``False``).
Expand All @@ -166,17 +168,17 @@ def pad_func(img: torch.Tensor, to_pad: list[tuple[int, int]], mode: str, transf
kwargs: other arguments for the `np.pad` or `torch.pad` function.
note that `np.pad` treats channel dimension as the first dimension.
"""
extra_info = {"padded": to_pad, "mode": str(mode)}
extra_info = {"padded": to_pad, "mode": f"{mode}"}
img_size = img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:]
spatial_rank = img.peek_pending_rank() if isinstance(img, MetaTensor) else 3
do_pad = np.asarray(to_pad).any()
if do_pad:
to_pad = list(to_pad)
if len(to_pad) < len(img.shape):
to_pad = list(to_pad) + [(0, 0)] * (len(img.shape) - len(to_pad))
to_shift = [-s[0] for s in to_pad[1:]] # skipping the channel pad
to_pad_list = [(int(p[0]), int(p[1])) for p in to_pad]
if len(to_pad_list) < len(img.shape):
to_pad_list += [(0, 0)] * (len(img.shape) - len(to_pad_list))
to_shift = [-s[0] for s in to_pad_list[1:]] # skipping the channel pad
xform = create_translate(spatial_rank, to_shift)
shape = [d + s + e for d, (s, e) in zip(img_size, to_pad[1:])]
shape = [d + s + e for d, (s, e) in zip(img_size, to_pad_list[1:])]
else:
shape = img_size
xform = torch.eye(int(spatial_rank) + 1, device=torch.device("cpu"), dtype=torch.float64)
Expand All @@ -191,13 +193,13 @@ def pad_func(img: torch.Tensor, to_pad: list[tuple[int, int]], mode: str, transf
)
out = convert_to_tensor(img.as_tensor() if isinstance(img, MetaTensor) else img, track_meta=get_track_meta())
if transform_info.get(TraceKeys.LAZY_EVALUATION, False):
return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else meta_info
out = pad_nd(out, to_pad, mode, **kwargs) if do_pad else out
return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else meta_info # type: ignore
out = pad_nd(out, to_pad_list, mode, **kwargs) if do_pad else out
out = convert_to_tensor(out, track_meta=get_track_meta())
return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out
return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out # type: ignore


def crop_func(img: torch.Tensor, slices: tuple[slice, ...], transform_info: dict):
def crop_func(img: torch.Tensor, slices: tuple[slice, ...], transform_info: dict) -> torch.Tensor:
"""
Functional implementation of cropping a MetaTensor. This function operates eagerly or lazily according
to ``transform_info[TraceKeys.LAZY_EVALUATION]`` (default ``False``).
Expand Down Expand Up @@ -229,6 +231,6 @@ def crop_func(img: torch.Tensor, slices: tuple[slice, ...], transform_info: dict
)
out = convert_to_tensor(img.as_tensor() if isinstance(img, MetaTensor) else img, track_meta=get_track_meta())
if transform_info.get(TraceKeys.LAZY_EVALUATION, False):
return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else meta_info
return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else meta_info # type: ignore
out = out[slices]
return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out
return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out # type: ignore
2 changes: 1 addition & 1 deletion monai/transforms/inverse.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def track_transform_meta(
return data
return out_obj # return with data_t as tensor if get_track_meta() is False

info = transform_info
info = transform_info.copy()
# track the current spatial shape
if orig_size is not None:
info[TraceKeys.ORIG_SIZE] = orig_size
Expand Down
Loading

0 comments on commit a8302ec

Please sign in to comment.