From 3f91409efeda5c6d1713953040f219bfb9efc638 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Tue, 6 Aug 2024 14:48:44 +0800 Subject: [PATCH 01/19] update mlp block Signed-off-by: Yiheng Wang --- monai/networks/blocks/mlp.py | 14 ++++++++++---- tests/test_mlp.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 38 insertions(+), 4 deletions(-) diff --git a/monai/networks/blocks/mlp.py b/monai/networks/blocks/mlp.py index d3510b64d3..5d207920b1 100644 --- a/monai/networks/blocks/mlp.py +++ b/monai/networks/blocks/mlp.py @@ -14,9 +14,10 @@ import torch.nn as nn from monai.networks.layers import get_act_layer +from monai.networks.layers.factories import split_args from monai.utils import look_up_option -SUPPORTED_DROPOUT_MODE = {"vit", "swin"} +SUPPORTED_DROPOUT_MODE = {"vit", "swin", "vista3d"} class MLPBlock(nn.Module): @@ -39,7 +40,7 @@ def __init__( https://github.com/google-research/vision_transformer/blob/main/vit_jax/models.py#L87 "swin" corresponds to one instance as implemented in https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_mlp.py#L23 - + "vista3d" mode does not use dropout. """ @@ -48,15 +49,20 @@ def __init__( if not (0 <= dropout_rate <= 1): raise ValueError("dropout_rate should be between 0 and 1.") mlp_dim = mlp_dim or hidden_size - self.linear1 = nn.Linear(hidden_size, mlp_dim) if act != "GEGLU" else nn.Linear(hidden_size, mlp_dim * 2) + act_name, _ = split_args(act) + self.linear1 = nn.Linear(hidden_size, mlp_dim) if act_name != "GEGLU" else nn.Linear(hidden_size, mlp_dim * 2) self.linear2 = nn.Linear(mlp_dim, hidden_size) self.fn = get_act_layer(act) - self.drop1 = nn.Dropout(dropout_rate) dropout_opt = look_up_option(dropout_mode, SUPPORTED_DROPOUT_MODE) if dropout_opt == "vit": + self.drop1 = nn.Dropout(dropout_rate) self.drop2 = nn.Dropout(dropout_rate) elif dropout_opt == "swin": + self.drop1 = nn.Dropout(dropout_rate) self.drop2 = self.drop1 + elif dropout_opt == "vista3d": + self.drop1 = nn.Identity() + self.drop2 = nn.Identity() else: raise ValueError(f"dropout_mode should be one of {SUPPORTED_DROPOUT_MODE}") diff --git a/tests/test_mlp.py b/tests/test_mlp.py index 54f70d3318..af6eb5b6b8 100644 --- a/tests/test_mlp.py +++ b/tests/test_mlp.py @@ -15,10 +15,12 @@ import numpy as np import torch +import torch.nn as nn from parameterized import parameterized from monai.networks import eval_mode from monai.networks.blocks.mlp import MLPBlock +from monai.networks.layers.factories import split_args TEST_CASE_MLP = [] for dropout_rate in np.linspace(0, 1, 4): @@ -31,6 +33,14 @@ ] TEST_CASE_MLP.append(test_case) +# test different activation layers +TEST_CASE_ACT = [] +for act in ["GELU", "GEGLU", ("GELU", {"approximate": "tanh"}), ("GEGLU", {})]: + TEST_CASE_ACT.append([{"hidden_size": 128, "mlp_dim": 0, "act": act}, (2, 512, 128), (2, 512, 128)]) + +# test different dropout modes +TEST_CASE_DROP = [["vit", nn.Dropout], ["swin", nn.Dropout], ["vista3d", nn.Identity]] + class TestMLPBlock(unittest.TestCase): @@ -45,6 +55,24 @@ def test_ill_arg(self): with self.assertRaises(ValueError): MLPBlock(hidden_size=128, mlp_dim=512, dropout_rate=5.0) + @parameterized.expand(TEST_CASE_ACT) + def test_act(self, input_param, input_shape, expected_shape): + net = MLPBlock(**input_param) + with eval_mode(net): + result = net(torch.randn(input_shape)) + self.assertEqual(result.shape, expected_shape) + act_name, _ = split_args(input_param["act"]) + if act_name == "GEGLU": + self.assertEqual(net.linear1.in_features, net.linear1.out_features // 2) + else: + self.assertEqual(net.linear1.in_features, net.linear1.out_features) + + @parameterized.expand(TEST_CASE_DROP) + def test_dropout_mode(self, dropout_mode, dropout_layer): + net = MLPBlock(hidden_size=128, mlp_dim=512, dropout_rate=0.1, dropout_mode=dropout_mode) + self.assertTrue(isinstance(net.drop1, dropout_layer)) + self.assertTrue(isinstance(net.drop2, dropout_layer)) + if __name__ == "__main__": unittest.main() From d241db5af6604a4b56029f43831cc10cc67ad2aa Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Tue, 6 Aug 2024 15:38:46 +0800 Subject: [PATCH 02/19] add mypy fix Signed-off-by: Yiheng Wang --- monai/networks/blocks/mlp.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/monai/networks/blocks/mlp.py b/monai/networks/blocks/mlp.py index 5d207920b1..8771711d25 100644 --- a/monai/networks/blocks/mlp.py +++ b/monai/networks/blocks/mlp.py @@ -11,6 +11,8 @@ from __future__ import annotations +from typing import Union + import torch.nn as nn from monai.networks.layers import get_act_layer @@ -53,6 +55,10 @@ def __init__( self.linear1 = nn.Linear(hidden_size, mlp_dim) if act_name != "GEGLU" else nn.Linear(hidden_size, mlp_dim * 2) self.linear2 = nn.Linear(mlp_dim, hidden_size) self.fn = get_act_layer(act) + # Use Union[nn.Dropout, nn.Identity] for type annotations + self.drop1: Union[nn.Dropout, nn.Identity] + self.drop2: Union[nn.Dropout, nn.Identity] + dropout_opt = look_up_option(dropout_mode, SUPPORTED_DROPOUT_MODE) if dropout_opt == "vit": self.drop1 = nn.Dropout(dropout_rate) From 1af03f6468de122bc32b69794d4b8f0c60b64460 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Wed, 7 Aug 2024 11:52:14 +0800 Subject: [PATCH 03/19] remove gelu approximate Signed-off-by: Yiheng Wang --- tests/test_mlp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_mlp.py b/tests/test_mlp.py index af6eb5b6b8..a34a3b89ac 100644 --- a/tests/test_mlp.py +++ b/tests/test_mlp.py @@ -35,7 +35,7 @@ # test different activation layers TEST_CASE_ACT = [] -for act in ["GELU", "GEGLU", ("GELU", {"approximate": "tanh"}), ("GEGLU", {})]: +for act in ["GELU", "GEGLU", ("GEGLU", {})]: TEST_CASE_ACT.append([{"hidden_size": 128, "mlp_dim": 0, "act": act}, (2, 512, 128), (2, 512, 128)]) # test different dropout modes From 4dc89e51d3ab700897e8c39fa409e7c260be8d29 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Wed, 7 Aug 2024 13:15:24 +0800 Subject: [PATCH 04/19] free space Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- .github/workflows/pythonapp.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/pythonapp.yml b/.github/workflows/pythonapp.yml index fe04f96a80..7040e32f14 100644 --- a/.github/workflows/pythonapp.yml +++ b/.github/workflows/pythonapp.yml @@ -101,6 +101,7 @@ jobs: python -m pip install --pre -U itk - name: Install the dependencies run: | + find /opt/hostedtoolcache/* -maxdepth 0 ! -name 'Python' -exec rm -rf {} \; python -m pip install --user --upgrade pip wheel python -m pip install torch==1.13.1 torchvision==0.14.1 cat "requirements-dev.txt" From b5400b9ca4796bed2b84ca1fb07dfd7c69c02393 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Wed, 7 Aug 2024 13:28:15 +0800 Subject: [PATCH 05/19] ignore test case type annotation Signed-off-by: Yiheng Wang --- tests/test_mlp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_mlp.py b/tests/test_mlp.py index a34a3b89ac..2598d8877d 100644 --- a/tests/test_mlp.py +++ b/tests/test_mlp.py @@ -35,7 +35,7 @@ # test different activation layers TEST_CASE_ACT = [] -for act in ["GELU", "GEGLU", ("GEGLU", {})]: +for act in ["GELU", "GEGLU", ("GEGLU", {})]: # type: ignore TEST_CASE_ACT.append([{"hidden_size": 128, "mlp_dim": 0, "act": act}, (2, 512, 128), (2, 512, 128)]) # test different dropout modes From c53f038f66111d7763245d3c76bdcab6a5ed747b Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Wed, 7 Aug 2024 13:41:19 +0800 Subject: [PATCH 06/19] try to fix Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- .github/workflows/pythonapp.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/pythonapp.yml b/.github/workflows/pythonapp.yml index 7040e32f14..65f9a4dcf2 100644 --- a/.github/workflows/pythonapp.yml +++ b/.github/workflows/pythonapp.yml @@ -99,9 +99,9 @@ jobs: name: Install itk pre-release (Linux only) run: | python -m pip install --pre -U itk + find /opt/hostedtoolcache/* -maxdepth 0 ! -name 'Python' -exec rm -rf {} \; - name: Install the dependencies run: | - find /opt/hostedtoolcache/* -maxdepth 0 ! -name 'Python' -exec rm -rf {} \; python -m pip install --user --upgrade pip wheel python -m pip install torch==1.13.1 torchvision==0.14.1 cat "requirements-dev.txt" From 90c7888c4251e229cda0387755e9d7ca1086183f Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Wed, 7 Aug 2024 14:30:51 +0800 Subject: [PATCH 07/19] Add utils for vista3d Signed-off-by: Yiheng Wang --- monai/transforms/utils.py | 177 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 177 insertions(+) diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index d8461d927b..a51be1b815 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -22,6 +22,7 @@ import numpy as np import torch +import torch.nn.functional as F import monai from monai.config import DtypeLike, IndexSelection @@ -65,6 +66,8 @@ min_version, optional_import, pytorch_after, + unsqueeze_right, + unsqueeze_left ) from monai.utils.enums import TransformBackends from monai.utils.type_conversion import ( @@ -103,6 +106,10 @@ "generate_spatial_bounding_box", "get_extreme_points", "get_largest_connected_component_mask", + "get_largest_connected_component_mask_point", + "sample_points_from_label", + "erode3d", + "sample" "remove_small_objects", "img_bounds", "in_bounds", @@ -1171,6 +1178,176 @@ def get_largest_connected_component_mask( return convert_to_dst_type(out, dst=img, dtype=out.dtype)[0] +def get_largest_connected_component_mask_point( + img_pos: NdarrayTensor, + img_neg: NdarrayTensor, + pos_val: list=[1, 3], + neg_val: list=[0, 2], + point_coords: None = None, + point_labels: None = None, + margins: int = 3, +) -> NdarrayTensor: + """ + Gets the largest connected component mask of an image that include the point_coords. + Args: + img_pos: [1, B, H, W, D] + point_coords [B, N, 3] + point_labels [B, N] + """ + + img_pos_, *_ = convert_data_type(img_pos, np.ndarray) + img_neg_, *_ = convert_data_type(img_neg, np.ndarray) + label = measure.label + lib = np + + features_pos, num_features = label(img_pos_, connectivity=3, return_num=True) + features_neg, num_features = label(img_neg_, connectivity=3, return_num=True) + + outs = np.zeros_like(img_pos_) + for bs in range(point_coords.shape[0]): + for i, p in enumerate(point_coords[bs]): + if point_labels[bs, i] in pos_val: + features = features_pos + elif point_labels[bs, i] in neg_val: + features = features_neg + else: + # if -1 padding point, skip + continue + for margin in range(margins): + x, y, z = p.round().int().tolist() + l, r = max(x - margin, 0), min(x + margin + 1, features.shape[-3]) + t, d = max(y - margin, 0), min(y + margin + 1, features.shape[-2]) + f, b = max(z - margin, 0), min(z + margin + 1, features.shape[-1]) + if (features[bs, 0, l:r, t:d, f:b] > 0).any(): + index = features[bs, 0, l:r, t:d, f:b].max() + outs[[bs]] += lib.isin(features[[bs]], index) + break + outs[outs > 1] = 1 + return convert_to_dst_type(outs, dst=img_pos, dtype=outs.dtype)[0] + +def convert_points_to_disc(image_size, point, point_label, radius=2, disc=False): + """ + Convert a 3D point coordinates into image mask. The returned mask has the same spatial + size as `image_size` while the batch dimension is the same as point' batch dimension. + The point is converted to a mask ball with radius defined by `radius`. The output + contains two channels each for negative (first channel) and positive points. + Args: + image_size: The output size of th + point: [b, N, 3] + point_label: [b, N], 0 or 2 means negative points, 1 or 3 means postive points. + radius: disc ball radius size + disc: If true, use regular disc other other use gaussian. + """ + if not torch.is_tensor(point): + point = torch.from_numpy(point) + masks = torch.zeros( + [point.shape[0], 2, image_size[0], image_size[1], image_size[2]], + device=point.device, + ) + _array = [torch.arange( + start=0, end=image_size[i], step=1, dtype=torch.float32, device=point.device + ) for i in range(3)] + coord_rows, coord_cols, coord_z = torch.meshgrid(_array[2], _array[1], _array[0]) + # [1, 3, h, w, d] -> [b, 2, 3, h, w, d] + coords = unsqueeze_left(torch.stack((coord_rows, coord_cols, coord_z), dim=0), 6) + coords = coords.repeat(point.shape[0], 2, 1, 1, 1, 1) + for b in range(point.shape[0]): + for n in range(point.shape[1]): + point_bn = unsqueeze_right(point[b, n], 6) + if point_label[b, n] > -1: + channel = 0 if (point_label[b, n] == 0 or point_label[b, n] == 2) else 1 + pow_diff = torch.pow(coords[b, channel] - point_bn[b, n], 2) + if disc: + masks[b, channel] += pow_diff.sum(0) < radius**2 + else: + masks[b, channel] += torch.exp(-pow_diff.sum(0) / (2 * radius**2)) + return masks + +def sample_points_from_label( + labels, label_set=None, max_ppoint=1, max_npoint=0, device="cpu", use_center=False +): + """Sample points from labels. + Args: + labels: [1, 1, H, W, D] + label_set: local index, must match values in labels. + max_ppoint: maximum positive point samples. + max_npoint: maximum negative point samples. + device: returned tensor device. + use_center: whether to sample points from center. + Returns: + point: point coordinates of [B, N, 3]. + point_label: [B, N], always 0 for negative, 1 for positive. + """ + assert labels.shape[0] == 1, "only support batch size 1" + labels = labels[0, 0] + unique_labels = labels.unique().cpu().numpy().tolist() + _point = [] + _point_label = [] + for id in label_set: + if id in unique_labels: + plabels = labels == int(id) + nlabels = ~plabels + _plabels = get_largest_connected_component_mask(erode3d(plabels)) + plabelpoints = torch.nonzero(_plabels).to(device) + if len(plabelpoints) == 0: + plabelpoints = torch.nonzero(plabels).to(device) + nlabelpoints = torch.nonzero(nlabels).to(device) + Np = min(len(plabelpoints), max_ppoint) + Nn = min(len(nlabelpoints), max_npoint) + pad = max_ppoint + max_npoint - Np - Nn + if use_center: + pmean = plabelpoints.float().mean(0) + pdis = ((plabelpoints - pmean) ** 2).sum(-1) + _, sorted_indices = torch.sort(pdis) + else: + sorted_indices = list(range(len(plabelpoints))) + random.shuffle(sorted_indices) + _point.append( + torch.stack([plabelpoints[sorted_indices[i]] for i in Np] + + random.choices(nlabelpoints, k=Nn) + + [torch.tensor([0, 0, 0], device=device)] * pad + ) + ) + _point_label.append( + torch.tensor([1] * Np + [0] * Nn + [-1] * pad).to(device)) + else: + # pad the background labels + _point.append(torch.zeros(max_ppoint + max_npoint, 3).to(device)) + _point_label.append(torch.zeros(max_ppoint + max_npoint).to(device) - 1) + point = torch.stack(_point) + point_label = torch.stack(_point_label) + return point, point_label + +def erode3d(input_tensor, erosion=3): + # Define the structuring element + erosion = ensure_tuple_rep(erosion, 3) + structuring_element = torch.ones(1, 1, erosion[0], erosion[1], erosion[2]).to( + input_tensor.device + ) + + # Pad the input tensor to handle border pixels + input_padded = F.pad( + input_tensor.float().unsqueeze(0).unsqueeze(0), + ( + erosion[2] // 2, + erosion[2] // 2, + erosion[1] // 2, + erosion[1] // 2, + erosion[0] // 2, + erosion[0] // 2, + ), + mode="constant", + value=1.0, + ) + + # Apply erosion operation + output = F.conv3d(input_padded, structuring_element, padding=0) + + # Set output values based on the minimum value within the structuring element + output = torch.where(output == torch.sum(structuring_element), 1.0, 0.0) + + return output.squeeze(0).squeeze(0) + def remove_small_objects( img: NdarrayTensor, From 12a1e66e5dcf674cdb0c68ff724818e238996f92 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Wed, 7 Aug 2024 15:52:51 +0800 Subject: [PATCH 08/19] add convert_points_to_disc only and fix type issue Signed-off-by: Yiheng Wang --- monai/transforms/utils.py | 164 ++++---------------------------------- 1 file changed, 14 insertions(+), 150 deletions(-) diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index a51be1b815..2551f6897c 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -22,7 +22,7 @@ import numpy as np import torch -import torch.nn.functional as F +from torch import Tensor import monai from monai.config import DtypeLike, IndexSelection @@ -66,8 +66,8 @@ min_version, optional_import, pytorch_after, + unsqueeze_left, unsqueeze_right, - unsqueeze_left ) from monai.utils.enums import TransformBackends from monai.utils.type_conversion import ( @@ -106,10 +106,7 @@ "generate_spatial_bounding_box", "get_extreme_points", "get_largest_connected_component_mask", - "get_largest_connected_component_mask_point", - "sample_points_from_label", - "erode3d", - "sample" + "convert_points_to_disc", "remove_small_objects", "img_bounds", "in_bounds", @@ -1178,75 +1175,27 @@ def get_largest_connected_component_mask( return convert_to_dst_type(out, dst=img, dtype=out.dtype)[0] -def get_largest_connected_component_mask_point( - img_pos: NdarrayTensor, - img_neg: NdarrayTensor, - pos_val: list=[1, 3], - neg_val: list=[0, 2], - point_coords: None = None, - point_labels: None = None, - margins: int = 3, -) -> NdarrayTensor: - """ - Gets the largest connected component mask of an image that include the point_coords. - Args: - img_pos: [1, B, H, W, D] - point_coords [B, N, 3] - point_labels [B, N] - """ - - img_pos_, *_ = convert_data_type(img_pos, np.ndarray) - img_neg_, *_ = convert_data_type(img_neg, np.ndarray) - label = measure.label - lib = np - - features_pos, num_features = label(img_pos_, connectivity=3, return_num=True) - features_neg, num_features = label(img_neg_, connectivity=3, return_num=True) - - outs = np.zeros_like(img_pos_) - for bs in range(point_coords.shape[0]): - for i, p in enumerate(point_coords[bs]): - if point_labels[bs, i] in pos_val: - features = features_pos - elif point_labels[bs, i] in neg_val: - features = features_neg - else: - # if -1 padding point, skip - continue - for margin in range(margins): - x, y, z = p.round().int().tolist() - l, r = max(x - margin, 0), min(x + margin + 1, features.shape[-3]) - t, d = max(y - margin, 0), min(y + margin + 1, features.shape[-2]) - f, b = max(z - margin, 0), min(z + margin + 1, features.shape[-1]) - if (features[bs, 0, l:r, t:d, f:b] > 0).any(): - index = features[bs, 0, l:r, t:d, f:b].max() - outs[[bs]] += lib.isin(features[[bs]], index) - break - outs[outs > 1] = 1 - return convert_to_dst_type(outs, dst=img_pos, dtype=outs.dtype)[0] - -def convert_points_to_disc(image_size, point, point_label, radius=2, disc=False): + +def convert_points_to_disc( + image_size: Sequence[int], point: Tensor, point_label: Tensor, radius: int = 2, disc: bool = False +): """ Convert a 3D point coordinates into image mask. The returned mask has the same spatial - size as `image_size` while the batch dimension is the same as point' batch dimension. + size as `image_size` while the batch dimension is the same as 'point' batch dimension. The point is converted to a mask ball with radius defined by `radius`. The output contains two channels each for negative (first channel) and positive points. + Args: - image_size: The output size of th + image_size: The output size of the converted mask. It should be a point: [b, N, 3] point_label: [b, N], 0 or 2 means negative points, 1 or 3 means postive points. radius: disc ball radius size disc: If true, use regular disc other other use gaussian. """ - if not torch.is_tensor(point): - point = torch.from_numpy(point) - masks = torch.zeros( - [point.shape[0], 2, image_size[0], image_size[1], image_size[2]], - device=point.device, - ) - _array = [torch.arange( - start=0, end=image_size[i], step=1, dtype=torch.float32, device=point.device - ) for i in range(3)] + masks = torch.zeros([point.shape[0], 2, image_size[0], image_size[1], image_size[2]], device=point.device) + _array = [ + torch.arange(start=0, end=image_size[i], step=1, dtype=torch.float32, device=point.device) for i in range(3) + ] coord_rows, coord_cols, coord_z = torch.meshgrid(_array[2], _array[1], _array[0]) # [1, 3, h, w, d] -> [b, 2, 3, h, w, d] coords = unsqueeze_left(torch.stack((coord_rows, coord_cols, coord_z), dim=0), 6) @@ -1263,91 +1212,6 @@ def convert_points_to_disc(image_size, point, point_label, radius=2, disc=False) masks[b, channel] += torch.exp(-pow_diff.sum(0) / (2 * radius**2)) return masks -def sample_points_from_label( - labels, label_set=None, max_ppoint=1, max_npoint=0, device="cpu", use_center=False -): - """Sample points from labels. - Args: - labels: [1, 1, H, W, D] - label_set: local index, must match values in labels. - max_ppoint: maximum positive point samples. - max_npoint: maximum negative point samples. - device: returned tensor device. - use_center: whether to sample points from center. - Returns: - point: point coordinates of [B, N, 3]. - point_label: [B, N], always 0 for negative, 1 for positive. - """ - assert labels.shape[0] == 1, "only support batch size 1" - labels = labels[0, 0] - unique_labels = labels.unique().cpu().numpy().tolist() - _point = [] - _point_label = [] - for id in label_set: - if id in unique_labels: - plabels = labels == int(id) - nlabels = ~plabels - _plabels = get_largest_connected_component_mask(erode3d(plabels)) - plabelpoints = torch.nonzero(_plabels).to(device) - if len(plabelpoints) == 0: - plabelpoints = torch.nonzero(plabels).to(device) - nlabelpoints = torch.nonzero(nlabels).to(device) - Np = min(len(plabelpoints), max_ppoint) - Nn = min(len(nlabelpoints), max_npoint) - pad = max_ppoint + max_npoint - Np - Nn - if use_center: - pmean = plabelpoints.float().mean(0) - pdis = ((plabelpoints - pmean) ** 2).sum(-1) - _, sorted_indices = torch.sort(pdis) - else: - sorted_indices = list(range(len(plabelpoints))) - random.shuffle(sorted_indices) - _point.append( - torch.stack([plabelpoints[sorted_indices[i]] for i in Np] - + random.choices(nlabelpoints, k=Nn) - + [torch.tensor([0, 0, 0], device=device)] * pad - ) - ) - _point_label.append( - torch.tensor([1] * Np + [0] * Nn + [-1] * pad).to(device)) - else: - # pad the background labels - _point.append(torch.zeros(max_ppoint + max_npoint, 3).to(device)) - _point_label.append(torch.zeros(max_ppoint + max_npoint).to(device) - 1) - point = torch.stack(_point) - point_label = torch.stack(_point_label) - return point, point_label - -def erode3d(input_tensor, erosion=3): - # Define the structuring element - erosion = ensure_tuple_rep(erosion, 3) - structuring_element = torch.ones(1, 1, erosion[0], erosion[1], erosion[2]).to( - input_tensor.device - ) - - # Pad the input tensor to handle border pixels - input_padded = F.pad( - input_tensor.float().unsqueeze(0).unsqueeze(0), - ( - erosion[2] // 2, - erosion[2] // 2, - erosion[1] // 2, - erosion[1] // 2, - erosion[0] // 2, - erosion[0] // 2, - ), - mode="constant", - value=1.0, - ) - - # Apply erosion operation - output = F.conv3d(input_padded, structuring_element, padding=0) - - # Set output values based on the minimum value within the structuring element - output = torch.where(output == torch.sum(structuring_element), 1.0, 0.0) - - return output.squeeze(0).squeeze(0) - def remove_small_objects( img: NdarrayTensor, From 681e875627b20bd0aa3087a70dbae03253b3cb3e Mon Sep 17 00:00:00 2001 From: Yiheng Wang <68361391+yiheng-wang-nv@users.noreply.github.com> Date: Thu, 8 Aug 2024 10:25:23 +0800 Subject: [PATCH 09/19] Update monai/transforms/utils.py Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Signed-off-by: Yiheng Wang <68361391+yiheng-wang-nv@users.noreply.github.com> --- monai/transforms/utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index 2551f6897c..3ff9a6894e 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -1200,8 +1200,7 @@ def convert_points_to_disc( # [1, 3, h, w, d] -> [b, 2, 3, h, w, d] coords = unsqueeze_left(torch.stack((coord_rows, coord_cols, coord_z), dim=0), 6) coords = coords.repeat(point.shape[0], 2, 1, 1, 1, 1) - for b in range(point.shape[0]): - for n in range(point.shape[1]): + for b, n in np.ndindex(*point.shape[:2]): point_bn = unsqueeze_right(point[b, n], 6) if point_label[b, n] > -1: channel = 0 if (point_label[b, n] == 0 or point_label[b, n] == 2) else 1 From e0695520ac09c786725c3bd889256ba17d867e2b Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Thu, 8 Aug 2024 14:01:03 +0800 Subject: [PATCH 10/19] update more functions and change morphological path Signed-off-by: Yiheng Wang --- monai/transforms/__init__.py | 1 + monai/transforms/utils.py | 75 ++++++++++++++++++- .../utils_morphological_ops.py} | 2 + tests/test_morphological_ops.py | 2 +- tests/test_vista3d_utils.py | 64 ++++++++++++++++ 5 files changed, 139 insertions(+), 5 deletions(-) rename monai/{apps/generation/maisi/utils/morphological_ops.py => transforms/utils_morphological_ops.py} (99%) create mode 100644 tests/test_vista3d_utils.py diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index ef1da2d855..9548443768 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -688,6 +688,7 @@ weighted_patch_samples, zero_margins, ) +from .utils_morphological_ops import dilate, erode from .utils_pytorch_numpy_unification import ( allclose, any_np_pt, diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index 2551f6897c..8238c59b2d 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -31,6 +31,7 @@ from monai.networks.utils import meshgrid_ij from monai.transforms.compose import Compose from monai.transforms.transform import MapTransform, Transform, apply_transform +from monai.transforms.utils_morphological_ops import erode from monai.transforms.utils_pytorch_numpy_unification import ( any_np_pt, ascontiguousarray, @@ -1186,10 +1187,10 @@ def convert_points_to_disc( contains two channels each for negative (first channel) and positive points. Args: - image_size: The output size of the converted mask. It should be a - point: [b, N, 3] - point_label: [b, N], 0 or 2 means negative points, 1 or 3 means postive points. - radius: disc ball radius size + image_size: The output size of the converted mask. It should be a 3D tuple. + point: [B, N, 3], 3D point coordinates. + point_label: [B, N], 0 or 2 means negative points, 1 or 3 means postive points. + radius: disc ball radius size. disc: If true, use regular disc other other use gaussian. """ masks = torch.zeros([point.shape[0], 2, image_size[0], image_size[1], image_size[2]], device=point.device) @@ -1213,6 +1214,72 @@ def convert_points_to_disc( return masks +def sample_points_from_label( + labels: Tensor, + label_set: Sequence[int], + max_ppoint: int = 1, + max_npoint: int = 0, + device: torch.device | str | None = "cpu", + use_center: bool = False, +): + """Sample points from labels. + + Args: + labels: [1, 1, H, W, D] + label_set: local index, must match values in labels. + max_ppoint: maximum positive point samples. + max_npoint: maximum negative point samples. + device: returned tensor device. + use_center: whether to sample points from center. + + Returns: + point: point coordinates of [B, N, 3]. B equals to the length of label_set. + point_label: [B, N], always 0 for negative, 1 for positive. + """ + if not labels.shape[0] == 1: + raise ValueError("labels must have batch size 1.") + + labels = labels[0, 0] + unique_labels = labels.unique().cpu().numpy().tolist() + _point = [] + _point_label = [] + for id in label_set: + if id in unique_labels: + plabels = labels == int(id) + nlabels = ~plabels + _plabels = get_largest_connected_component_mask(erode(plabels.unsqueeze(0).unsqueeze(0))[0, 0]) + plabelpoints = torch.nonzero(_plabels).to(device) + if len(plabelpoints) == 0: + plabelpoints = torch.nonzero(plabels).to(device) + nlabelpoints = torch.nonzero(nlabels).to(device) + num_p = min(len(plabelpoints), max_ppoint) + num_n = min(len(nlabelpoints), max_npoint) + pad = max_ppoint + max_npoint - num_p - num_n + if use_center: + pmean = plabelpoints.float().mean(0) + pdis = ((plabelpoints - pmean) ** 2).sum(-1) + _, sorted_indices = torch.sort(pdis) + else: + sorted_indices = list(range(len(plabelpoints))) + random.shuffle(sorted_indices) + _point.append( + torch.stack( + [plabelpoints[sorted_indices[i]] for i in range(num_p)] + + random.choices(nlabelpoints, k=num_n) + + [torch.tensor([0, 0, 0], device=device)] * pad + ) + ) + _point_label.append(torch.tensor([1] * num_p + [0] * num_n + [-1] * pad).to(device)) + else: + # pad the background labels + _point.append(torch.zeros(max_ppoint + max_npoint, 3).to(device)) + _point_label.append(torch.zeros(max_ppoint + max_npoint).to(device) - 1) + point = torch.stack(_point) + point_label = torch.stack(_point_label) + + return point, point_label + + def remove_small_objects( img: NdarrayTensor, min_size: int = 64, diff --git a/monai/apps/generation/maisi/utils/morphological_ops.py b/monai/transforms/utils_morphological_ops.py similarity index 99% rename from monai/apps/generation/maisi/utils/morphological_ops.py rename to monai/transforms/utils_morphological_ops.py index 14786d60a2..b3134c1865 100644 --- a/monai/apps/generation/maisi/utils/morphological_ops.py +++ b/monai/transforms/utils_morphological_ops.py @@ -20,6 +20,8 @@ from monai.config import NdarrayOrTensor from monai.utils import convert_data_type, convert_to_dst_type, ensure_tuple_rep +__all__ = ["erode", "dilate"] + def erode(mask: NdarrayOrTensor, filter_size: int | Sequence[int] = 3, pad_value: float = 1.0) -> NdarrayOrTensor: """ diff --git a/tests/test_morphological_ops.py b/tests/test_morphological_ops.py index 6f29415759..422e8c4b9d 100644 --- a/tests/test_morphological_ops.py +++ b/tests/test_morphological_ops.py @@ -16,7 +16,7 @@ import torch from parameterized import parameterized -from monai.apps.generation.maisi.utils.morphological_ops import dilate, erode, get_morphological_filter_result_t +from monai.transforms.utils_morphological_ops import dilate, erode, get_morphological_filter_result_t from tests.utils import TEST_NDARRAYS, assert_allclose TESTS_SHAPE = [] diff --git a/tests/test_vista3d_utils.py b/tests/test_vista3d_utils.py new file mode 100644 index 0000000000..9e3bcef539 --- /dev/null +++ b/tests/test_vista3d_utils.py @@ -0,0 +1,64 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.transforms.utils import convert_points_to_disc, sample_points_from_label + +TESTS_SAMPLE_POINTS_FROM_LABEL = [] +for use_center in [True, False]: + labels = torch.zeros(1, 1, 32, 32, 32) + labels[0, 0, 5:10, 5:10, 5:10] = 1 + labels[0, 0, 10:15, 10:15, 10:15] = 3 + labels[0, 0, 20:25, 20:25, 20:25] = 5 + TESTS_SAMPLE_POINTS_FROM_LABEL.append( + [{"labels": labels, "label_set": (1, 3, 5), "use_center": use_center}, (3, 1, 3), (3, 1)] + ) + +TEST_CONVERT_POINTS_TO_DISC = [] +for radius in [1, 2]: + for disc in [True, False]: + image_size = (32, 32, 32) + point = torch.randn(3, 1, 3) + point_label = torch.randint(0, 4, (3, 1)) + expected_shape = (point.shape[0], 2, *image_size) + TEST_CONVERT_POINTS_TO_DISC.append( + [ + {"image_size": image_size, "point": point, "point_label": point_label, "radius": radius, "disc": disc}, + expected_shape, + ] + ) + + +class TestSamplePointsFromLabel(unittest.TestCase): + + @parameterized.expand(TESTS_SAMPLE_POINTS_FROM_LABEL) + def test_shape(self, input_data, expected_point_shape, expected_point_label_shape): + point, point_label = sample_points_from_label(**input_data) + self.assertEqual(point.shape, expected_point_shape) + self.assertEqual(point_label.shape, expected_point_label_shape) + + +class TestConvertPointsToDisc(unittest.TestCase): + + @parameterized.expand(TEST_CONVERT_POINTS_TO_DISC) + def test_shape(self, input_data, expected_shape): + result = convert_points_to_disc(**input_data) + self.assertEqual(result.shape, expected_shape) + + +if __name__ == "__main__": + unittest.main() From dc004cbd947a452259eb2c88d868eb75bf30c442 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Thu, 8 Aug 2024 14:37:11 +0800 Subject: [PATCH 11/19] add get_largest_connected_component_mask_point Signed-off-by: Yiheng Wang --- monai/transforms/utils.py | 75 ++++++++++++++++++++++++++++++++++----- 1 file changed, 66 insertions(+), 9 deletions(-) diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index 298e566c45..3e65ceba3f 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -1177,6 +1177,62 @@ def get_largest_connected_component_mask( return convert_to_dst_type(out, dst=img, dtype=out.dtype)[0] +def get_largest_connected_component_mask_point( + img_pos: NdarrayTensor, + img_neg: NdarrayTensor, + point_coords: NdarrayTensor, + point_labels: NdarrayTensor, + pos_val: Sequence[int] = (1, 3), + neg_val: Sequence[int] = (0, 2), + margins: int = 3, +) -> NdarrayTensor: + """ + Gets the largest connected component mask of an image that include the point_coords. + # TODO: need author to provide more details about this function. Especially about each argument. + Args: + img_pos: [1, B, H, W, D] + img_neg: [1, B, H, W, D] + + point_coords [B, N, 3] + point_labels [B, N] + """ + if not has_measure: + raise RuntimeError("Skimage.measure required.") + + img_pos_, *_ = convert_data_type(img_pos, np.ndarray) + img_neg_, *_ = convert_data_type(img_neg, np.ndarray) + label = measure.label + lib = np + + features_pos, _ = label(img_pos_, connectivity=3, return_num=True) + features_neg, _ = label(img_neg_, connectivity=3, return_num=True) + + pos_val = list(pos_val) + neg_val = list(neg_val) + + outs = np.zeros_like(img_pos_) + for bs in range(point_coords.shape[0]): + for i, p in enumerate(point_coords[bs]): + if point_labels[bs, i] in pos_val: + features = features_pos + elif point_labels[bs, i] in neg_val: + features = features_neg + else: + # if -1 padding point, skip + continue + for margin in range(margins): + x, y, z = p.round().int().tolist() + l, r = max(x - margin, 0), min(x + margin + 1, features.shape[-3]) + t, d = max(y - margin, 0), min(y + margin + 1, features.shape[-2]) + f, b = max(z - margin, 0), min(z + margin + 1, features.shape[-1]) + if (features[bs, 0, l:r, t:d, f:b] > 0).any(): + index = features[bs, 0, l:r, t:d, f:b].max() + outs[[bs]] += lib.isin(features[[bs]], index) + break + outs[outs > 1] = 1 + return convert_to_dst_type(outs, dst=img_pos, dtype=outs.dtype)[0] + + def convert_points_to_disc( image_size: Sequence[int], point: Tensor, point_label: Tensor, radius: int = 2, disc: bool = False ): @@ -1202,14 +1258,14 @@ def convert_points_to_disc( coords = unsqueeze_left(torch.stack((coord_rows, coord_cols, coord_z), dim=0), 6) coords = coords.repeat(point.shape[0], 2, 1, 1, 1, 1) for b, n in np.ndindex(*point.shape[:2]): - point_bn = unsqueeze_right(point[b, n], 6) - if point_label[b, n] > -1: - channel = 0 if (point_label[b, n] == 0 or point_label[b, n] == 2) else 1 - pow_diff = torch.pow(coords[b, channel] - point_bn[b, n], 2) - if disc: - masks[b, channel] += pow_diff.sum(0) < radius**2 - else: - masks[b, channel] += torch.exp(-pow_diff.sum(0) / (2 * radius**2)) + point_bn = unsqueeze_right(point[b, n], 6) + if point_label[b, n] > -1: + channel = 0 if (point_label[b, n] == 0 or point_label[b, n] == 2) else 1 + pow_diff = torch.pow(coords[b, channel] - point_bn[b, n], 2) + if disc: + masks[b, channel] += pow_diff.sum(0) < radius**2 + else: + masks[b, channel] += torch.exp(-pow_diff.sum(0) / (2 * radius**2)) return masks @@ -1257,7 +1313,8 @@ def sample_points_from_label( if use_center: pmean = plabelpoints.float().mean(0) pdis = ((plabelpoints - pmean) ** 2).sum(-1) - _, sorted_indices = torch.sort(pdis) + _, sorted_indices_tensor = torch.sort(pdis) + sorted_indices = sorted_indices_tensor.cpu().tolist() else: sorted_indices = list(range(len(plabelpoints))) random.shuffle(sorted_indices) From 503bdd7d88cbc24fe1fe34ebe09fb5c0e401c1d9 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Thu, 8 Aug 2024 17:07:22 +0800 Subject: [PATCH 12/19] add tests Signed-off-by: Yiheng Wang --- monai/transforms/utils.py | 34 +++++++++++++++++++--------------- tests/test_vista3d_utils.py | 28 +++++++++++++++++++++++++++- 2 files changed, 46 insertions(+), 16 deletions(-) diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index 3e65ceba3f..ec4bdc892b 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -107,6 +107,7 @@ "generate_spatial_bounding_box", "get_extreme_points", "get_largest_connected_component_mask", + "get_largest_connected_component_mask_point", "convert_points_to_disc", "remove_small_objects", "img_bounds", @@ -1178,31 +1179,34 @@ def get_largest_connected_component_mask( def get_largest_connected_component_mask_point( - img_pos: NdarrayTensor, - img_neg: NdarrayTensor, - point_coords: NdarrayTensor, - point_labels: NdarrayTensor, + img_pos: Tensor, + img_neg: Tensor, + point_coords: Tensor, + point_labels: Tensor, pos_val: Sequence[int] = (1, 3), neg_val: Sequence[int] = (0, 2), margins: int = 3, -) -> NdarrayTensor: +) -> Tensor: """ Gets the largest connected component mask of an image that include the point_coords. # TODO: need author to provide more details about this function. Especially about each argument. Args: - img_pos: [1, B, H, W, D] - img_neg: [1, B, H, W, D] + img_pos: [1, 1, H, W, D] + img_neg: [1, 1, H, W, D] - point_coords [B, N, 3] - point_labels [B, N] + point_coords [1, N, 3] + point_labels [1, N] """ - if not has_measure: - raise RuntimeError("Skimage.measure required.") + cucim_skimage, has_cucim = optional_import("cucim.skimage") - img_pos_, *_ = convert_data_type(img_pos, np.ndarray) - img_neg_, *_ = convert_data_type(img_neg, np.ndarray) - label = measure.label - lib = np + use_cp = has_cp and has_cucim and img_pos.device != torch.device("cpu") + if use_cp: + img_pos_ = convert_to_cupy(img_pos.short()) # type: ignore + img_neg_ = convert_to_cupy(img_neg.short()) # type: ignore + label = cucim_skimage.measure.label + lib = cp + else: + raise RuntimeError("Cucim.skimage and GPU device are required.") features_pos, _ = label(img_pos_, connectivity=3, return_num=True) features_neg, _ = label(img_neg_, connectivity=3, return_num=True) diff --git a/tests/test_vista3d_utils.py b/tests/test_vista3d_utils.py index 9e3bcef539..cde690b0f2 100644 --- a/tests/test_vista3d_utils.py +++ b/tests/test_vista3d_utils.py @@ -12,11 +12,22 @@ from __future__ import annotations import unittest +from unittest.case import skipUnless import torch from parameterized import parameterized -from monai.transforms.utils import convert_points_to_disc, sample_points_from_label +from monai.transforms.utils import ( + convert_points_to_disc, + get_largest_connected_component_mask_point, + sample_points_from_label, +) +from monai.utils.module import optional_import +from tests.utils import skip_if_no_cuda + +cp, has_cp = optional_import("cupy") +cucim_skimage, has_cucim = optional_import("cucim.skimage") + TESTS_SAMPLE_POINTS_FROM_LABEL = [] for use_center in [True, False]: @@ -60,5 +71,20 @@ def test_shape(self, input_data, expected_shape): self.assertEqual(result.shape, expected_shape) +@skipUnless(has_cp, "cupy required") +@skipUnless(has_cucim, "cucim required") +class TestGetLargestConnectedComponentMaskPoint(unittest.TestCase): + + @skip_if_no_cuda + def test_shape(self): + shape = (1, 1, 128, 128, 128) + img_pos = torch.randint(0, 2, shape).cuda() + img_neg = torch.randint(0, 2, shape).cuda() + point_coords = torch.randint(0, 32, (1, 1, 3)).cuda() + point_labels = torch.randint(0, 4, (1, 1)).cuda() + mask = get_largest_connected_component_mask_point(img_pos, img_neg, point_coords, point_labels) + self.assertEqual(mask.shape, shape) + + if __name__ == "__main__": unittest.main() From f34bc4242e64f1c4702577244ba659baab19b7a3 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Thu, 8 Aug 2024 17:38:03 +0800 Subject: [PATCH 13/19] adjust tests Signed-off-by: Yiheng Wang --- monai/transforms/utils.py | 3 +++ tests/min_tests.py | 1 + tests/test_vista3d_utils.py | 6 ++++-- 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index ec4bdc892b..85c8112a65 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -1298,6 +1298,9 @@ def sample_points_from_label( if not labels.shape[0] == 1: raise ValueError("labels must have batch size 1.") + if device is None: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + labels = labels[0, 0] unique_labels = labels.unique().cpu().numpy().tolist() _point = [] diff --git a/tests/min_tests.py b/tests/min_tests.py index 3a143df84b..479c4c8dc2 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -209,6 +209,7 @@ def run_testsuit(): "test_zarr_avg_merger", "test_perceptual_loss", "test_ultrasound_confidence_map_transform", + "test_vista3d_utils", ] assert sorted(exclude_cases) == sorted(set(exclude_cases)), f"Duplicated items in {exclude_cases}" diff --git a/tests/test_vista3d_utils.py b/tests/test_vista3d_utils.py index cde690b0f2..2256c61b85 100644 --- a/tests/test_vista3d_utils.py +++ b/tests/test_vista3d_utils.py @@ -22,11 +22,13 @@ get_largest_connected_component_mask_point, sample_points_from_label, ) +from monai.utils import min_version from monai.utils.module import optional_import from tests.utils import skip_if_no_cuda cp, has_cp = optional_import("cupy") cucim_skimage, has_cucim = optional_import("cucim.skimage") +measure, has_measure = optional_import("skimage.measure", "0.14.2", min_version) TESTS_SAMPLE_POINTS_FROM_LABEL = [] @@ -54,6 +56,7 @@ ) +@skipUnless(has_measure or has_cucim, "skimage or cucim required") class TestSamplePointsFromLabel(unittest.TestCase): @parameterized.expand(TESTS_SAMPLE_POINTS_FROM_LABEL) @@ -71,8 +74,7 @@ def test_shape(self, input_data, expected_shape): self.assertEqual(result.shape, expected_shape) -@skipUnless(has_cp, "cupy required") -@skipUnless(has_cucim, "cucim required") +@skipUnless(has_cucim and has_cp, "cucim and cupy required") class TestGetLargestConnectedComponentMaskPoint(unittest.TestCase): @skip_if_no_cuda From 496e0a9457a4a0880374277442c9d3ccf0209f62 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Thu, 8 Aug 2024 17:53:40 +0800 Subject: [PATCH 14/19] update doc Signed-off-by: Yiheng Wang --- docs/source/apps.rst | 8 -------- docs/source/transforms.rst | 3 +++ 2 files changed, 3 insertions(+), 8 deletions(-) diff --git a/docs/source/apps.rst b/docs/source/apps.rst index c6ba8c0b9a..7fa7b9e9ff 100644 --- a/docs/source/apps.rst +++ b/docs/source/apps.rst @@ -261,11 +261,3 @@ FastMRIReader .. autoclass:: monai.apps.nnunet.nnUNetV2Runner :members: - -`Generative AI` ---------------- - -`MAISI Utilities` -~~~~~~~~~~~~~~~~~ -.. automodule:: monai.apps.generation.maisi.utils.morphological_ops - :members: diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index a359821679..637f0873f1 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -2310,6 +2310,9 @@ Utilities .. automodule:: monai.transforms.utils_pytorch_numpy_unification :members: +.. automodule:: monai.transforms.utils_morphological_ops + :members: + By Categories ------------- .. toctree:: From 4a8489a22a204916833508e2927033eb2ebb8720 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Fri, 9 Aug 2024 13:19:44 +0800 Subject: [PATCH 15/19] add tests Signed-off-by: Yiheng Wang --- monai/transforms/utils.py | 48 +++++++++++++++++++------------ tests/test_vista3d_utils.py | 56 +++++++++++++++++++++++++++++++------ 2 files changed, 77 insertions(+), 27 deletions(-) diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index 85c8112a65..8e39ae97bd 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -1179,41 +1179,50 @@ def get_largest_connected_component_mask( def get_largest_connected_component_mask_point( - img_pos: Tensor, - img_neg: Tensor, - point_coords: Tensor, - point_labels: Tensor, + img_pos: NdarrayTensor, + img_neg: NdarrayTensor, + point_coords: NdarrayTensor, + point_labels: NdarrayTensor, pos_val: Sequence[int] = (1, 3), neg_val: Sequence[int] = (0, 2), margins: int = 3, -) -> Tensor: +) -> NdarrayTensor: """ - Gets the largest connected component mask of an image that include the point_coords. - # TODO: need author to provide more details about this function. Especially about each argument. - Args: - img_pos: [1, 1, H, W, D] - img_neg: [1, 1, H, W, D] + Gets the connected component of img_pos and img_neg that include the positive points and + negative points separately. The function is used for combining automatic results with interactive + results in VISTA3D. - point_coords [1, N, 3] - point_labels [1, N] + Args: + img_pos: bool type tensor, shape [B, 1, H, W, D], where B means the foreground masks from a single 3D image. + img_neg: same format as img_pos but corresponds to negative points. + pos_val: positive point label values. + neg_val: negative point label values. + point_coords: the coordinates of each point, shape [B, N, 3], where N means the number of points. + point_labels: the label of each point, shape [B, N]. """ + cucim_skimage, has_cucim = optional_import("cucim.skimage") - use_cp = has_cp and has_cucim and img_pos.device != torch.device("cpu") + use_cp = has_cp and has_cucim and isinstance(img_pos, torch.Tensor) and img_pos.device != torch.device("cpu") if use_cp: img_pos_ = convert_to_cupy(img_pos.short()) # type: ignore img_neg_ = convert_to_cupy(img_neg.short()) # type: ignore label = cucim_skimage.measure.label lib = cp else: - raise RuntimeError("Cucim.skimage and GPU device are required.") + if not has_measure: + raise RuntimeError("skimage.measure required.") + img_pos_, *_ = convert_data_type(img_pos, np.ndarray) + img_neg_, *_ = convert_data_type(img_neg, np.ndarray) + # for skimage.measure.label, the input must be bool type + if img_pos_.dtype != bool or img_neg_.dtype != bool: + raise ValueError("img_pos and img_neg must be bool type.") + label = measure.label + lib = np features_pos, _ = label(img_pos_, connectivity=3, return_num=True) features_neg, _ = label(img_neg_, connectivity=3, return_num=True) - pos_val = list(pos_val) - neg_val = list(neg_val) - outs = np.zeros_like(img_pos_) for bs in range(point_coords.shape[0]): for i, p in enumerate(point_coords[bs]): @@ -1225,7 +1234,10 @@ def get_largest_connected_component_mask_point( # if -1 padding point, skip continue for margin in range(margins): - x, y, z = p.round().int().tolist() + if isinstance(p, np.ndarray): + x, y, z = np.round(p).astype(int).tolist() + else: + x, y, z = p.round().int().tolist() l, r = max(x - margin, 0), min(x + margin + 1, features.shape[-3]) t, d = max(y - margin, 0), min(y + margin + 1, features.shape[-2]) f, b = max(z - margin, 0), min(z + margin + 1, features.shape[-1]) diff --git a/tests/test_vista3d_utils.py b/tests/test_vista3d_utils.py index 2256c61b85..1c0cebb8ab 100644 --- a/tests/test_vista3d_utils.py +++ b/tests/test_vista3d_utils.py @@ -14,6 +14,7 @@ import unittest from unittest.case import skipUnless +import numpy as np import torch from parameterized import parameterized @@ -55,8 +56,40 @@ ] ) +TEST_LCC_MASK_POINT_TORCH = [] +for bs in [1, 2]: + for num_points in [1, 3]: + shape = (bs, 1, 128, 32, 32) + TEST_LCC_MASK_POINT_TORCH.append( + [ + { + "img_pos": torch.randint(0, 2, shape, dtype=torch.bool), + "img_neg": torch.randint(0, 2, shape, dtype=torch.bool), + "point_coords": torch.randint(0, 10, (bs, num_points, 3)), + "point_labels": torch.randint(0, 4, (bs, num_points)), + }, + shape, + ] + ) + +TEST_LCC_MASK_POINT_NP = [] +for bs in [1, 2]: + for num_points in [1, 3]: + shape = (bs, 1, 32, 32, 64) + TEST_LCC_MASK_POINT_NP.append( + [ + { + "img_pos": np.random.randint(0, 2, shape, dtype=bool), + "img_neg": np.random.randint(0, 2, shape, dtype=bool), + "point_coords": np.random.randint(0, 5, (bs, num_points, 3)), + "point_labels": np.random.randint(0, 4, (bs, num_points)), + }, + shape, + ] + ) + -@skipUnless(has_measure or has_cucim, "skimage or cucim required") +@skipUnless(has_measure or cucim_skimage, "skimage or cucim.skimage required") class TestSamplePointsFromLabel(unittest.TestCase): @parameterized.expand(TESTS_SAMPLE_POINTS_FROM_LABEL) @@ -74,17 +107,22 @@ def test_shape(self, input_data, expected_shape): self.assertEqual(result.shape, expected_shape) -@skipUnless(has_cucim and has_cp, "cucim and cupy required") +@skipUnless(has_measure or cucim_skimage, "skimage or cucim.skimage required") class TestGetLargestConnectedComponentMaskPoint(unittest.TestCase): @skip_if_no_cuda - def test_shape(self): - shape = (1, 1, 128, 128, 128) - img_pos = torch.randint(0, 2, shape).cuda() - img_neg = torch.randint(0, 2, shape).cuda() - point_coords = torch.randint(0, 32, (1, 1, 3)).cuda() - point_labels = torch.randint(0, 4, (1, 1)).cuda() - mask = get_largest_connected_component_mask_point(img_pos, img_neg, point_coords, point_labels) + @skipUnless(has_cp and cucim_skimage, "cupy and cucim.skimage required") + @parameterized.expand(TEST_LCC_MASK_POINT_TORCH) + def test_cp_shape(self, input_data, shape): + for key in input_data: + input_data[key] = input_data[key].cuda() + mask = get_largest_connected_component_mask_point(**input_data) + self.assertEqual(mask.shape, shape) + + @skipUnless(has_measure, "skimage required") + @parameterized.expand(TEST_LCC_MASK_POINT_NP) + def test_np_shape(self, input_data, shape): + mask = get_largest_connected_component_mask_point(**input_data) self.assertEqual(mask.shape, shape) From 06a0b84cfbc9a4b57202af7f639505a4314fc15f Mon Sep 17 00:00:00 2001 From: Yiheng Wang <68361391+yiheng-wang-nv@users.noreply.github.com> Date: Fri, 9 Aug 2024 15:19:20 +0800 Subject: [PATCH 16/19] Update monai/transforms/utils.py Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Signed-off-by: Yiheng Wang <68361391+yiheng-wang-nv@users.noreply.github.com> --- monai/transforms/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index 8e39ae97bd..43b2bc2763 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -1263,7 +1263,7 @@ def convert_points_to_disc( point: [B, N, 3], 3D point coordinates. point_label: [B, N], 0 or 2 means negative points, 1 or 3 means postive points. radius: disc ball radius size. - disc: If true, use regular disc other other use gaussian. + disc: If true, use regular disc, other use gaussian. """ masks = torch.zeros([point.shape[0], 2, image_size[0], image_size[1], image_size[2]], device=point.device) _array = [ From b84302e07ae23510148004825284990dfbc46f6d Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Fri, 9 Aug 2024 15:31:10 +0800 Subject: [PATCH 17/19] skip if quick Signed-off-by: Yiheng Wang --- tests/test_vista3d_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_vista3d_utils.py b/tests/test_vista3d_utils.py index 1c0cebb8ab..601a5156f5 100644 --- a/tests/test_vista3d_utils.py +++ b/tests/test_vista3d_utils.py @@ -25,7 +25,7 @@ ) from monai.utils import min_version from monai.utils.module import optional_import -from tests.utils import skip_if_no_cuda +from tests.utils import skip_if_no_cuda, skip_if_quick cp, has_cp = optional_import("cupy") cucim_skimage, has_cucim = optional_import("cucim.skimage") @@ -110,6 +110,7 @@ def test_shape(self, input_data, expected_shape): @skipUnless(has_measure or cucim_skimage, "skimage or cucim.skimage required") class TestGetLargestConnectedComponentMaskPoint(unittest.TestCase): + @skip_if_quick @skip_if_no_cuda @skipUnless(has_cp and cucim_skimage, "cupy and cucim.skimage required") @parameterized.expand(TEST_LCC_MASK_POINT_TORCH) From cacbe6c2665b8573fb862cbb38692711dc7c9226 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Fri, 9 Aug 2024 16:14:27 +0800 Subject: [PATCH 18/19] use to device Signed-off-by: Yiheng Wang --- tests/test_vista3d_utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/test_vista3d_utils.py b/tests/test_vista3d_utils.py index 601a5156f5..a940854d88 100644 --- a/tests/test_vista3d_utils.py +++ b/tests/test_vista3d_utils.py @@ -31,6 +31,8 @@ cucim_skimage, has_cucim = optional_import("cucim.skimage") measure, has_measure = optional_import("skimage.measure", "0.14.2", min_version) +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + TESTS_SAMPLE_POINTS_FROM_LABEL = [] for use_center in [True, False]: @@ -116,7 +118,7 @@ class TestGetLargestConnectedComponentMaskPoint(unittest.TestCase): @parameterized.expand(TEST_LCC_MASK_POINT_TORCH) def test_cp_shape(self, input_data, shape): for key in input_data: - input_data[key] = input_data[key].cuda() + input_data[key] = input_data[key].to(device) mask = get_largest_connected_component_mask_point(**input_data) self.assertEqual(mask.shape, shape) From 7d15188a4b2fd15862566281a6c2e262f338862d Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Fri, 9 Aug 2024 21:22:44 +0800 Subject: [PATCH 19/19] add .float before round Signed-off-by: Yiheng Wang --- monai/transforms/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index 43b2bc2763..e32bf6fc48 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -1237,7 +1237,7 @@ def get_largest_connected_component_mask_point( if isinstance(p, np.ndarray): x, y, z = np.round(p).astype(int).tolist() else: - x, y, z = p.round().int().tolist() + x, y, z = p.float().round().int().tolist() l, r = max(x - margin, 0), min(x + margin + 1, features.shape[-3]) t, d = max(y - margin, 0), min(y + margin + 1, features.shape[-2]) f, b = max(z - margin, 0), min(z + margin + 1, features.shape[-1])