From a8302eca3fc31554d4bd30fae84150766ce13472 Mon Sep 17 00:00:00 2001 From: Wenqi Li <831580+wyli@users.noreply.github.com> Date: Mon, 13 Mar 2023 08:56:36 +0000 Subject: [PATCH] 6109 no mutate ratio /user inputs croppad (#6127) Fixes #6109 ### Description - use tuples for user inputs to avoid changes - enhance the type checks - fixes issue of `ratios` in `RandCropByLabelClasses ` ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Wenqi Li --- monai/apps/detection/transforms/dictionary.py | 2 +- monai/apps/detection/utils/detector_utils.py | 2 +- monai/transforms/croppad/array.py | 75 ++++++++++--------- monai/transforms/croppad/dictionary.py | 4 +- monai/transforms/croppad/functional.py | 28 +++---- monai/transforms/inverse.py | 2 +- monai/transforms/utils.py | 18 ++--- ...est_generate_label_classes_crop_centers.py | 4 +- ...est_generate_pos_neg_label_crop_centers.py | 4 +- tests/test_rand_crop_by_label_classes.py | 4 +- tests/test_rand_crop_by_label_classesd.py | 4 +- 11 files changed, 77 insertions(+), 70 deletions(-) diff --git a/monai/apps/detection/transforms/dictionary.py b/monai/apps/detection/transforms/dictionary.py index 0116044f22..f77c5f4c48 100644 --- a/monai/apps/detection/transforms/dictionary.py +++ b/monai/apps/detection/transforms/dictionary.py @@ -1071,7 +1071,7 @@ def __init__( if len(self.image_keys) != len(self.meta_keys): raise ValueError("meta_keys should have the same length as keys.") self.meta_key_postfix = ensure_tuple_rep(meta_key_postfix, len(self.image_keys)) - self.centers: list[list[int]] | None = None + self.centers: tuple[tuple] | None = None self.allow_smaller = allow_smaller def generate_fg_center_boxes_np(self, boxes: NdarrayOrTensor, image_size: Sequence[int]) -> np.ndarray: diff --git a/monai/apps/detection/utils/detector_utils.py b/monai/apps/detection/utils/detector_utils.py index 7938e1e908..493ce5b216 100644 --- a/monai/apps/detection/utils/detector_utils.py +++ b/monai/apps/detection/utils/detector_utils.py @@ -149,7 +149,7 @@ def pad_images( max_spatial_size = compute_divisible_spatial_size(spatial_shape=list(max_spatial_size_t), k=size_divisible) # allocate memory for the padded images - images = torch.zeros([len(image_sizes), in_channels] + max_spatial_size, dtype=dtype, device=device) + images = torch.zeros([len(image_sizes), in_channels] + list(max_spatial_size), dtype=dtype, device=device) # Use `SpatialPad` to match sizes, padding in the end will not affect boxes padder = SpatialPad(spatial_size=max_spatial_size, method="end", mode=mode, **kwargs) diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py index 318110848b..aa13d54c51 100644 --- a/monai/transforms/croppad/array.py +++ b/monai/transforms/croppad/array.py @@ -106,13 +106,13 @@ class Pad(InvertibleTransform, LazyTransform): backend = [TransformBackends.TORCH, TransformBackends.NUMPY] def __init__( - self, to_pad: list[tuple[int, int]] | None = None, mode: str = PytorchPadMode.CONSTANT, **kwargs + self, to_pad: tuple[tuple[int, int]] | None = None, mode: str = PytorchPadMode.CONSTANT, **kwargs ) -> None: self.to_pad = to_pad self.mode = mode self.kwargs = kwargs - def compute_pad_width(self, spatial_shape: Sequence[int]) -> list[tuple[int, int]]: + def compute_pad_width(self, spatial_shape: Sequence[int]) -> tuple[tuple[int, int]]: """ dynamically compute the pad width according to the spatial shape. the output is the amount of padding for all dimensions including the channel. @@ -123,8 +123,8 @@ def compute_pad_width(self, spatial_shape: Sequence[int]) -> list[tuple[int, int """ raise NotImplementedError(f"subclass {self.__class__.__name__} must implement this method.") - def __call__( # type: ignore - self, img: torch.Tensor, to_pad: list[tuple[int, int]] | None = None, mode: str | None = None, **kwargs + def __call__( # type: ignore[override] + self, img: torch.Tensor, to_pad: tuple[tuple[int, int]] | None = None, mode: str | None = None, **kwargs ) -> torch.Tensor: """ Args: @@ -150,7 +150,7 @@ def __call__( # type: ignore kwargs_.update(kwargs) img_t = convert_to_tensor(data=img, track_meta=get_track_meta()) - return pad_func(img_t, to_pad_, mode_, self.get_transform_info(), kwargs_) # type: ignore + return pad_func(img_t, to_pad_, mode_, self.get_transform_info(), kwargs_) def inverse(self, data: MetaTensor) -> MetaTensor: transform = self.pop_transform(data) @@ -200,7 +200,7 @@ def __init__( self.method: Method = look_up_option(method, Method) super().__init__(mode=mode, **kwargs) - def compute_pad_width(self, spatial_shape: Sequence[int]) -> list[tuple[int, int]]: + def compute_pad_width(self, spatial_shape: Sequence[int]) -> tuple[tuple[int, int]]: """ dynamically compute the pad width according to the spatial shape. @@ -213,10 +213,10 @@ def compute_pad_width(self, spatial_shape: Sequence[int]) -> list[tuple[int, int pad_width = [] for i, sp_i in enumerate(spatial_size): width = max(sp_i - spatial_shape[i], 0) - pad_width.append((width // 2, width - (width // 2))) + pad_width.append((int(width // 2), int(width - (width // 2)))) else: - pad_width = [(0, max(sp_i - spatial_shape[i], 0)) for i, sp_i in enumerate(spatial_size)] - return [(0, 0)] + pad_width + pad_width = [(0, int(max(sp_i - spatial_shape[i], 0))) for i, sp_i in enumerate(spatial_size)] + return tuple([(0, 0)] + pad_width) # type: ignore class BorderPad(Pad): @@ -249,24 +249,26 @@ def __init__(self, spatial_border: Sequence[int] | int, mode: str = PytorchPadMo self.spatial_border = spatial_border super().__init__(mode=mode, **kwargs) - def compute_pad_width(self, spatial_shape: Sequence[int]) -> list[tuple[int, int]]: + def compute_pad_width(self, spatial_shape: Sequence[int]) -> tuple[tuple[int, int]]: spatial_border = ensure_tuple(self.spatial_border) if not all(isinstance(b, int) for b in spatial_border): raise ValueError(f"self.spatial_border must contain only ints, got {spatial_border}.") spatial_border = tuple(max(0, b) for b in spatial_border) if len(spatial_border) == 1: - data_pad_width = [(spatial_border[0], spatial_border[0]) for _ in spatial_shape] + data_pad_width = [(int(spatial_border[0]), int(spatial_border[0])) for _ in spatial_shape] elif len(spatial_border) == len(spatial_shape): - data_pad_width = [(sp, sp) for sp in spatial_border[: len(spatial_shape)]] + data_pad_width = [(int(sp), int(sp)) for sp in spatial_border[: len(spatial_shape)]] elif len(spatial_border) == len(spatial_shape) * 2: - data_pad_width = [(spatial_border[2 * i], spatial_border[2 * i + 1]) for i in range(len(spatial_shape))] + data_pad_width = [ + (int(spatial_border[2 * i]), int(spatial_border[2 * i + 1])) for i in range(len(spatial_shape)) + ] else: raise ValueError( f"Unsupported spatial_border length: {len(spatial_border)}, available options are " f"[1, len(spatial_shape)={len(spatial_shape)}, 2*len(spatial_shape)={2*len(spatial_shape)}]." ) - return [(0, 0)] + data_pad_width + return tuple([(0, 0)] + data_pad_width) # type: ignore class DivisiblePad(Pad): @@ -301,7 +303,7 @@ def __init__( self.method: Method = Method(method) super().__init__(mode=mode, **kwargs) - def compute_pad_width(self, spatial_shape: Sequence[int]) -> list[tuple[int, int]]: + def compute_pad_width(self, spatial_shape: Sequence[int]) -> tuple[tuple[int, int]]: new_size = compute_divisible_spatial_size(spatial_shape=spatial_shape, k=self.k) spatial_pad = SpatialPad(spatial_size=new_size, method=self.method) return spatial_pad.compute_pad_width(spatial_shape) @@ -322,7 +324,7 @@ def compute_slices( roi_start: Sequence[int] | NdarrayOrTensor | None = None, roi_end: Sequence[int] | NdarrayOrTensor | None = None, roi_slices: Sequence[slice] | None = None, - ): + ) -> tuple[slice]: """ Compute the crop slices based on specified `center & size` or `start & end` or `slices`. @@ -340,8 +342,8 @@ def compute_slices( if roi_slices: if not all(s.step is None or s.step == 1 for s in roi_slices): - raise ValueError("only slice steps of 1/None are currently supported") - return list(roi_slices) + raise ValueError(f"only slice steps of 1/None are currently supported, got {roi_slices}.") + return ensure_tuple(roi_slices) # type: ignore else: if roi_center is not None and roi_size is not None: roi_center_t = convert_to_tensor(data=roi_center, dtype=torch.int16, wrap_sequence=True, device="cpu") @@ -363,11 +365,12 @@ def compute_slices( roi_end_t = torch.maximum(roi_end_t, roi_start_t) # convert to slices (accounting for 1d) if roi_start_t.numel() == 1: - return [slice(int(roi_start_t.item()), int(roi_end_t.item()))] - else: - return [slice(int(s), int(e)) for s, e in zip(roi_start_t.tolist(), roi_end_t.tolist())] + return ensure_tuple([slice(int(roi_start_t.item()), int(roi_end_t.item()))]) # type: ignore + return ensure_tuple( # type: ignore + [slice(int(s), int(e)) for s, e in zip(roi_start_t.tolist(), roi_end_t.tolist())] + ) - def __call__(self, img: torch.Tensor, slices: tuple[slice, ...]) -> torch.Tensor: # type: ignore + def __call__(self, img: torch.Tensor, slices: tuple[slice, ...]) -> torch.Tensor: # type: ignore[override] """ Apply the transform to `img`, assuming `img` is channel-first and slicing doesn't apply to the channel dim. @@ -378,10 +381,10 @@ def __call__(self, img: torch.Tensor, slices: tuple[slice, ...]) -> torch.Tensor if len(slices_) < sd: slices_ += [slice(None)] * (sd - len(slices_)) # Add in the channel (no cropping) - slices = tuple([slice(None)] + slices_[:sd]) + slices_ = list([slice(None)] + slices_[:sd]) img_t: MetaTensor = convert_to_tensor(data=img, track_meta=get_track_meta()) - return crop_func(img_t, slices, self.get_transform_info()) # type: ignore + return crop_func(img_t, tuple(slices_), self.get_transform_info()) def inverse(self, img: MetaTensor) -> MetaTensor: transform = self.pop_transform(img) @@ -429,13 +432,13 @@ def __init__( roi_center=roi_center, roi_size=roi_size, roi_start=roi_start, roi_end=roi_end, roi_slices=roi_slices ) - def __call__(self, img: torch.Tensor) -> torch.Tensor: # type: ignore + def __call__(self, img: torch.Tensor) -> torch.Tensor: # type: ignore[override] """ Apply the transform to `img`, assuming `img` is channel-first and slicing doesn't apply to the channel dim. """ - return super().__call__(img=img, slices=self.slices) + return super().__call__(img=img, slices=ensure_tuple(self.slices)) class CenterSpatialCrop(Crop): @@ -456,12 +459,12 @@ class CenterSpatialCrop(Crop): def __init__(self, roi_size: Sequence[int] | int) -> None: self.roi_size = roi_size - def compute_slices(self, spatial_size: Sequence[int]): # type: ignore + def compute_slices(self, spatial_size: Sequence[int]) -> tuple[slice]: # type: ignore[override] roi_size = fall_back_tuple(self.roi_size, spatial_size) roi_center = [i // 2 for i in spatial_size] return super().compute_slices(roi_center=roi_center, roi_size=roi_size) - def __call__(self, img: torch.Tensor) -> torch.Tensor: # type: ignore + def __call__(self, img: torch.Tensor) -> torch.Tensor: # type: ignore[override] """ Apply the transform to `img`, assuming `img` is channel-first and slicing doesn't apply to the channel dim. @@ -486,7 +489,7 @@ class CenterScaleCrop(Crop): def __init__(self, roi_scale: Sequence[float] | float): self.roi_scale = roi_scale - def __call__(self, img: torch.Tensor) -> torch.Tensor: # type: ignore + def __call__(self, img: torch.Tensor) -> torch.Tensor: # type: ignore[override] img_size = img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:] ndim = len(img_size) roi_size = [ceil(r * s) for r, s in zip(ensure_tuple_rep(self.roi_scale, ndim), img_size)] @@ -771,7 +774,7 @@ def lazy_evaluation(self, _val: bool): self._lazy_evaluation = _val self.padder.lazy_evaluation = _val - def compute_bounding_box(self, img: torch.Tensor): + def compute_bounding_box(self, img: torch.Tensor) -> tuple[np.ndarray, np.ndarray]: """ Compute the start points and end points of bounding box to crop. And adjust bounding box coords to be divisible by `k`. @@ -794,7 +797,7 @@ def compute_bounding_box(self, img: torch.Tensor): def crop_pad( self, img: torch.Tensor, box_start: np.ndarray, box_end: np.ndarray, mode: str | None = None, **pad_kwargs - ): + ) -> torch.Tensor: """ Crop and pad based on the bounding box. @@ -817,7 +820,9 @@ def crop_pad( ret.applied_operations[-1][TraceKeys.EXTRA_INFO]["pad_info"] = ret.applied_operations.pop() return ret - def __call__(self, img: torch.Tensor, mode: str | None = None, **pad_kwargs): # type: ignore + def __call__( # type: ignore[override] + self, img: torch.Tensor, mode: str | None = None, **pad_kwargs + ) -> torch.Tensor: """ Apply the transform to `img`, assuming `img` is channel-first and slicing doesn't change the channel dim. @@ -826,7 +831,7 @@ def __call__(self, img: torch.Tensor, mode: str | None = None, **pad_kwargs): # cropped = self.crop_pad(img, box_start, box_end, mode, **pad_kwargs) if self.return_coords: - return cropped, box_start, box_end + return cropped, box_start, box_end # type: ignore[return-value] return cropped def inverse(self, img: MetaTensor) -> MetaTensor: @@ -995,7 +1000,7 @@ def __init__( self.num_samples = num_samples self.image = image self.image_threshold = image_threshold - self.centers: list[list[int]] | None = None + self.centers: tuple[tuple] | None = None self.fg_indices = fg_indices self.bg_indices = bg_indices self.allow_smaller = allow_smaller @@ -1173,7 +1178,7 @@ def __init__( self.num_samples = num_samples self.image = image self.image_threshold = image_threshold - self.centers: list[list[int]] | None = None + self.centers: tuple[tuple] | None = None self.indices = indices self.allow_smaller = allow_smaller self.warn = warn diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index 1aa710d018..ab4ce28941 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -698,9 +698,9 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torc self.cropper: CropForeground box_start, box_end = self.cropper.compute_bounding_box(img=d[self.source_key]) if self.start_coord_key is not None: - d[self.start_coord_key] = box_start + d[self.start_coord_key] = box_start # type: ignore if self.end_coord_key is not None: - d[self.end_coord_key] = box_end + d[self.end_coord_key] = box_end # type: ignore for key, m in self.key_iterator(d, self.mode): d[key] = self.cropper.crop_pad(img=d[key], box_start=box_start, box_end=box_end, mode=m) return d diff --git a/monai/transforms/croppad/functional.py b/monai/transforms/croppad/functional.py index adb11edea6..fa95958bd5 100644 --- a/monai/transforms/croppad/functional.py +++ b/monai/transforms/croppad/functional.py @@ -147,7 +147,9 @@ def crop_or_pad_nd(img: torch.Tensor, translation_mat, spatial_size: tuple[int, return img -def pad_func(img: torch.Tensor, to_pad: list[tuple[int, int]], mode: str, transform_info: dict, kwargs): +def pad_func( + img: torch.Tensor, to_pad: tuple[tuple[int, int]], mode: str, transform_info: dict, kwargs +) -> torch.Tensor: """ Functional implementation of padding a MetaTensor. This function operates eagerly or lazily according to ``transform_info[TraceKeys.LAZY_EVALUATION]`` (default ``False``). @@ -166,17 +168,17 @@ def pad_func(img: torch.Tensor, to_pad: list[tuple[int, int]], mode: str, transf kwargs: other arguments for the `np.pad` or `torch.pad` function. note that `np.pad` treats channel dimension as the first dimension. """ - extra_info = {"padded": to_pad, "mode": str(mode)} + extra_info = {"padded": to_pad, "mode": f"{mode}"} img_size = img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:] spatial_rank = img.peek_pending_rank() if isinstance(img, MetaTensor) else 3 do_pad = np.asarray(to_pad).any() if do_pad: - to_pad = list(to_pad) - if len(to_pad) < len(img.shape): - to_pad = list(to_pad) + [(0, 0)] * (len(img.shape) - len(to_pad)) - to_shift = [-s[0] for s in to_pad[1:]] # skipping the channel pad + to_pad_list = [(int(p[0]), int(p[1])) for p in to_pad] + if len(to_pad_list) < len(img.shape): + to_pad_list += [(0, 0)] * (len(img.shape) - len(to_pad_list)) + to_shift = [-s[0] for s in to_pad_list[1:]] # skipping the channel pad xform = create_translate(spatial_rank, to_shift) - shape = [d + s + e for d, (s, e) in zip(img_size, to_pad[1:])] + shape = [d + s + e for d, (s, e) in zip(img_size, to_pad_list[1:])] else: shape = img_size xform = torch.eye(int(spatial_rank) + 1, device=torch.device("cpu"), dtype=torch.float64) @@ -191,13 +193,13 @@ def pad_func(img: torch.Tensor, to_pad: list[tuple[int, int]], mode: str, transf ) out = convert_to_tensor(img.as_tensor() if isinstance(img, MetaTensor) else img, track_meta=get_track_meta()) if transform_info.get(TraceKeys.LAZY_EVALUATION, False): - return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else meta_info - out = pad_nd(out, to_pad, mode, **kwargs) if do_pad else out + return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else meta_info # type: ignore + out = pad_nd(out, to_pad_list, mode, **kwargs) if do_pad else out out = convert_to_tensor(out, track_meta=get_track_meta()) - return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out + return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out # type: ignore -def crop_func(img: torch.Tensor, slices: tuple[slice, ...], transform_info: dict): +def crop_func(img: torch.Tensor, slices: tuple[slice, ...], transform_info: dict) -> torch.Tensor: """ Functional implementation of cropping a MetaTensor. This function operates eagerly or lazily according to ``transform_info[TraceKeys.LAZY_EVALUATION]`` (default ``False``). @@ -229,6 +231,6 @@ def crop_func(img: torch.Tensor, slices: tuple[slice, ...], transform_info: dict ) out = convert_to_tensor(img.as_tensor() if isinstance(img, MetaTensor) else img, track_meta=get_track_meta()) if transform_info.get(TraceKeys.LAZY_EVALUATION, False): - return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else meta_info + return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else meta_info # type: ignore out = out[slices] - return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out + return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out # type: ignore diff --git a/monai/transforms/inverse.py b/monai/transforms/inverse.py index ab74a9813b..7ac4e572d9 100644 --- a/monai/transforms/inverse.py +++ b/monai/transforms/inverse.py @@ -190,7 +190,7 @@ def track_transform_meta( return data return out_obj # return with data_t as tensor if get_track_meta() is False - info = transform_info + info = transform_info.copy() # track the current spatial shape if orig_size is not None: info[TraceKeys.ORIG_SIZE] = orig_size diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index 111bfa1102..6db45a6fae 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -409,7 +409,7 @@ def weighted_patch_samples( s = tuple(slice(w // 2, m - w + w // 2) if m > w else slice(m // 2, m // 2 + 1) for w, m in zip(win_size, img_size)) v = w[s] # weight map in the 'valid' mode v_size = v.shape - v = ravel(v) + v = ravel(v) # always copy if (v < 0).any(): v -= v.min() # shifting to non-negative v = cumsum(v) @@ -430,7 +430,7 @@ def correct_crop_centers( spatial_size: Sequence[int] | int, label_spatial_shape: Sequence[int], allow_smaller: bool = False, -): +) -> tuple[Any]: """ Utility to correct the crop center if the crop size and centers are not compatible with the image size. @@ -466,7 +466,7 @@ def correct_crop_centers( for c, v_s, v_e in zip(centers, valid_start, valid_end): center_i = min(max(c, v_s), v_e - 1) valid_centers.append(int(center_i)) - return valid_centers + return ensure_tuple(valid_centers) # type: ignore def generate_pos_neg_label_crop_centers( @@ -478,7 +478,7 @@ def generate_pos_neg_label_crop_centers( bg_indices: NdarrayOrTensor, rand_state: np.random.RandomState | None = None, allow_smaller: bool = False, -) -> list[list[int]]: +) -> tuple[tuple]: """ Generate valid sample locations based on the label with option for specifying foreground ratio Valid: samples sitting entirely within image, expected input shape: [C, H, W, D] or [C, H, W] @@ -524,7 +524,7 @@ def generate_pos_neg_label_crop_centers( # shift center to range of valid centers centers.append(correct_crop_centers(center, spatial_size, label_spatial_shape, allow_smaller)) - return centers + return ensure_tuple(centers) # type: ignore def generate_label_classes_crop_centers( @@ -536,7 +536,7 @@ def generate_label_classes_crop_centers( rand_state: np.random.RandomState | None = None, allow_smaller: bool = False, warn: bool = True, -) -> list[list[int]]: +) -> tuple[tuple]: """ Generate valid sample locations based on the specified ratios of label classes. Valid: samples sitting entirely within image, expected input shape: [C, H, W, D] or [C, H, W] @@ -560,7 +560,7 @@ def generate_label_classes_crop_centers( if num_samples < 1: raise ValueError(f"num_samples must be an int number and greater than 0, got {num_samples}.") - ratios_: list[float | int] = ([1] * len(indices)) if ratios is None else ratios + ratios_: list[float | int] = list(ensure_tuple([1] * len(indices) if ratios is None else ratios)) if len(ratios_) != len(indices): raise ValueError( f"random crop ratios must match the number of indices of classes, got {len(ratios_)} and {len(indices)}." @@ -584,7 +584,7 @@ def generate_label_classes_crop_centers( # shift center to range of valid centers centers.append(correct_crop_centers(center, spatial_size, label_spatial_shape, allow_smaller)) - return centers + return ensure_tuple(centers) # type: ignore def create_grid( @@ -1397,7 +1397,7 @@ def compute_divisible_spatial_size(spatial_shape: Sequence[int], k: Sequence[int new_dim = int(np.ceil(dim / k_d) * k_d) if k_d > 0 else dim new_size.append(new_dim) - return new_size + return tuple(new_size) def equalize_hist( diff --git a/tests/test_generate_label_classes_crop_centers.py b/tests/test_generate_label_classes_crop_centers.py index 5a4a7140a3..c276171bd5 100644 --- a/tests/test_generate_label_classes_crop_centers.py +++ b/tests/test_generate_label_classes_crop_centers.py @@ -28,7 +28,7 @@ "label_spatial_shape": [3, 3, 3], "indices": [[3, 12, 21], [1, 9, 18]], }, - list, + tuple, 2, 3, ] @@ -41,7 +41,7 @@ "label_spatial_shape": [3, 3, 3], "indices": [[3, 12, 21], [1, 9, 18]], }, - list, + tuple, 1, 3, ] diff --git a/tests/test_generate_pos_neg_label_crop_centers.py b/tests/test_generate_pos_neg_label_crop_centers.py index 8c1729ef29..13b7b728b4 100644 --- a/tests/test_generate_pos_neg_label_crop_centers.py +++ b/tests/test_generate_pos_neg_label_crop_centers.py @@ -30,7 +30,7 @@ "fg_indices": [1, 9, 18], "bg_indices": [3, 12, 21], }, - list, + tuple, 2, 3, ], @@ -43,7 +43,7 @@ "fg_indices": [], "bg_indices": [3, 12, 21], }, - list, + tuple, 2, 3, ], diff --git a/tests/test_rand_crop_by_label_classes.py b/tests/test_rand_crop_by_label_classes.py index 20a3876ed0..f31aa947bb 100644 --- a/tests/test_rand_crop_by_label_classes.py +++ b/tests/test_rand_crop_by_label_classes.py @@ -109,7 +109,7 @@ "label": None, "num_classes": 2, "spatial_size": [4, 4, 4], - "ratios": [1, 1], + "ratios": (1, 1), # test no assignment "num_samples": 2, "image": p(np.random.randint(0, 2, size=[3, 3, 3, 3])), "image_threshold": 0, @@ -117,7 +117,7 @@ }, { "img": p(np.random.randint(0, 2, size=[3, 3, 3, 3])), - "label": p(np.random.randint(0, 2, size=[1, 3, 3, 3])), + "label": p(np.random.randint(0, 1, size=[1, 3, 3, 3])), "image": p(np.random.randint(0, 2, size=[3, 3, 3, 3])), }, list, diff --git a/tests/test_rand_crop_by_label_classesd.py b/tests/test_rand_crop_by_label_classesd.py index bcd5577e16..8061f03dff 100644 --- a/tests/test_rand_crop_by_label_classesd.py +++ b/tests/test_rand_crop_by_label_classesd.py @@ -77,7 +77,7 @@ "label_key": "label", "num_classes": 2, "spatial_size": [4, 4, 2], - "ratios": [1, 1], + "ratios": (1, 1), # test no assignment "num_samples": 2, "image_key": "image", "image_threshold": 0, @@ -86,7 +86,7 @@ { "img": p(np.random.randint(0, 2, size=[3, 3, 3, 3])), "image": p(np.random.randint(0, 2, size=[3, 3, 3, 3])), - "label": p(np.random.randint(0, 2, size=[1, 3, 3, 3])), + "label": p(np.random.randint(0, 1, size=[1, 3, 3, 3])), }, list, (3, 3, 3, 2),