Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Maisi morphological funcs #7893

Merged
merged 22 commits into from
Jul 2, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
166 changes: 166 additions & 0 deletions monai/apps/generation/maisi/utils/morphological_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from typing import Sequence

import torch
import torch.nn.functional as F
from torch import Tensor

from monai.config import NdarrayOrTensor
from monai.utils import convert_data_type, convert_to_dst_type, ensure_tuple_rep


def erode(mask: NdarrayOrTensor, filter_size: int | Sequence[int] = 3, pad_value: float = 1.0) -> NdarrayOrTensor:
"""
Erode 2D/3D binary mask.

Args:
mask: input 2D/3D binary mask, [N,C,M,N] or [N,C,M,N,P] torch tensor or ndarray.
filter_size: erosion filter size, has to be odd numbers, default to be 3.
pad_value: the filled value for padding. We need to pad the input before filtering to keep the output with the same size as input. Usually use default value and not changed.

Return:
eroded mask, [N,C,M,N] or [N,C,M,N,P] torch tensor or ndarray.

Example:

.. code-block:: python

# define a naive mask
mask = torch.zeros(3,2,3,3,3)
mask[:,:,1,1,1] = 1.0
filter_size = 3
erode_result = erode(mask,filter_size) # expect torch.zeros(3,2,3,3,3)
dilate_result = dilate(mask,filter_size) # expect torch.ones(3,2,3,3,3)

"""
mask_t, *_ = convert_data_type(mask, torch.Tensor)
res_mask_t = erode_t(mask_t, filter_size=filter_size, pad_value=pad_value)
res_mask: NdarrayOrTensor
res_mask, *_ = convert_to_dst_type(src=res_mask_t, dst=mask)
return res_mask


def dilate(mask: NdarrayOrTensor, filter_size: int | Sequence[int] = 3, pad_value: float = 0.0) -> NdarrayOrTensor:
"""
Dilate 2D/3D binary mask.

Args:
mask: input 2D/3D binary mask, [N,C,M,N] or [N,C,M,N,P] torch tensor or ndarray.
filter_size: dilation filter size, has to be odd numbers, default to be 3.
pad_value: the filled value for padding. We need to pad the input before filtering to keep the output with the same size as input. Usually use default value and not changed.

Return:
dilated mask, [N,C,M,N] or [N,C,M,N,P] torch tensor or ndarray.

Example:

.. code-block:: python

# define a naive mask
mask = torch.zeros(3,2,3,3,3)
mask[:,:,1,1,1] = 1.0
filter_size = 3
erode_result = erode(mask,filter_size) # expect torch.zeros(3,2,3,3,3)
dilate_result = dilate(mask,filter_size) # expect torch.ones(3,2,3,3,3)
"""
mask_t, *_ = convert_data_type(mask, torch.Tensor)
res_mask_t = dilate_t(mask_t, filter_size=filter_size, pad_value=pad_value)
res_mask: NdarrayOrTensor
res_mask, *_ = convert_to_dst_type(src=res_mask_t, dst=mask)
return res_mask


def get_morphological_filter_result_t(mask_t: Tensor, filter_size: int | Sequence[int], pad_value: float) -> Tensor:
"""
Get morphological filter result for 2D/3D mask with data type as torch tensor.
Can-Zhao marked this conversation as resolved.
Show resolved Hide resolved

Args:
mask_t: input 2D/3D binary mask, [N,C,M,N] or [N,C,M,N,P] torch tensor.
filter_size: morphological filter size, has to be odd numbers.
pad_value: the filled value for padding. We need to pad the input before filtering to keep the output with the same size as input.

Return:
Morphological filter result mask, [N,C,M,N] or [N,C,M,N,P] torch tensor
KumoLiu marked this conversation as resolved.
Show resolved Hide resolved
"""
spatial_dims = len(mask_t.shape) - 2
if spatial_dims not in [2, 3]:
raise ValueError(
f"spatial_dims must be either 2 or 3, yet got spatial_dims={spatial_dims} for mask tensor with shape of {mask_t.shape}."
KumoLiu marked this conversation as resolved.
Show resolved Hide resolved
)

# Define the structuring element
filter_size = ensure_tuple_rep(filter_size, spatial_dims)
if any(size % 2 == 0 for size in filter_size):
raise ValueError(f"All dimensions in filter_size must be odd numbers, yet got {filter_size}.")
Can-Zhao marked this conversation as resolved.
Show resolved Hide resolved

filter_shape = [mask_t.shape[1], mask_t.shape[1]] + list(filter_size)
structuring_element = torch.ones(filter_shape).to(mask_t.device)
KumoLiu marked this conversation as resolved.
Show resolved Hide resolved

# Pad the input tensor to handle border pixels
# Calculate padding size
pad_size = []
for size in filter_size:
pad_size.extend([size // 2, size // 2])
KumoLiu marked this conversation as resolved.
Show resolved Hide resolved

input_padded = F.pad(mask_t.float(), pad_size, mode="constant", value=pad_value)

# Apply filter operation
if spatial_dims == 2:
output = F.conv2d(input_padded, structuring_element, padding=0) / torch.sum(structuring_element[0, ...])
if spatial_dims == 3:
output = F.conv3d(input_padded, structuring_element, padding=0) / torch.sum(structuring_element[0, ...])
KumoLiu marked this conversation as resolved.
Show resolved Hide resolved

return output


def erode_t(mask_t: Tensor, filter_size: int | Sequence[int] = 3, pad_value: float = 1.0) -> Tensor:
"""
Erode 2D/3D binary mask with data type as torch tensor.

Args:
mask_t: input 2D/3D binary mask, [N,C,M,N] or [N,C,M,N,P] torch tensor.
filter_size: erosion filter size, has to be odd numbers, default to be 3.
pad_value: the filled value for padding. We need to pad the input before filtering to keep the output with the same size as input. Usually use default value and not changed.

Return:
eroded mask, [N,C,M,N] or [N,C,M,N,P] torch tensor
KumoLiu marked this conversation as resolved.
Show resolved Hide resolved
"""

output = get_morphological_filter_result_t(mask_t, filter_size, pad_value)

# Set output values based on the minimum value within the structuring element
output = torch.where(output == 1.0, 1.0, 0.0)
Can-Zhao marked this conversation as resolved.
Show resolved Hide resolved

return output


def dilate_t(mask_t: Tensor, filter_size: int | Sequence[int] = 3, pad_value: float = 0.0) -> Tensor:
"""
Dilate 2D/3D binary mask with data type as torch tensor.

Args:
mask_t: input 2D/3D binary mask, [N,C,M,N] or [N,C,M,N,P] torch tensor.
filter_size: dilation filter size, has to be odd numbers, default to be 3.
pad_value: the filled value for padding. We need to pad the input before filtering to keep the output with the same size as input. Usually use default value and not changed.

Return:
dilated mask, [N,C,M,N] or [N,C,M,N,P] torch tensor
KumoLiu marked this conversation as resolved.
Show resolved Hide resolved
"""
output = get_morphological_filter_result_t(mask_t, filter_size, pad_value)

# Set output values based on the minimum value within the structuring element
output = torch.where(output > 0, 1.0, 0.0)

return output
92 changes: 92 additions & 0 deletions tests/test_morphological_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import unittest

import torch
from parameterized import parameterized

from monai.apps.generation.maisi.utils import morphological_ops
from tests.utils import TEST_NDARRAYS, assert_allclose

TESTS_SHAPE = []
for p in TEST_NDARRAYS:
KumoLiu marked this conversation as resolved.
Show resolved Hide resolved
mask = torch.zeros(1, 1, 5, 5, 5)
filter_size = 3
TESTS_SHAPE.append([{"mask": p(mask), "filter_size": filter_size}, [1, 1, 5, 5, 5]])
mask = torch.zeros(3, 2, 5, 5, 5)
filter_size = 5
TESTS_SHAPE.append([{"mask": p(mask), "filter_size": filter_size}, [3, 2, 5, 5, 5]])
mask = torch.zeros(1, 1, 1, 1, 1)
filter_size = 5
TESTS_SHAPE.append([{"mask": p(mask), "filter_size": filter_size}, [1, 1, 1, 1, 1]])

TESTS_VALUE_T = []
mask = torch.ones(3, 2, 3, 3, 3)
filter_size = 3
TESTS_VALUE_T.append([{"mask": mask, "filter_size": filter_size, "pad_value": 1.0}, torch.ones(3, 2, 3, 3, 3)])
mask = torch.zeros(3, 2, 3, 3, 3)
filter_size = 3
TESTS_VALUE_T.append([{"mask": mask, "filter_size": filter_size, "pad_value": 0.0}, torch.zeros(3, 2, 3, 3, 3)])

TESTS_VALUE = []
for p in TEST_NDARRAYS:
mask = torch.zeros(3, 2, 5, 5, 5)
filter_size = 3
TESTS_VALUE.append(
[{"mask": p(mask), "filter_size": filter_size}, p(torch.zeros(3, 2, 5, 5, 5)), p(torch.zeros(3, 2, 5, 5, 5))]
)
mask = torch.ones(1, 1, 3, 3, 3)
filter_size = 3
TESTS_VALUE.append(
[{"mask": p(mask), "filter_size": filter_size}, p(torch.ones(1, 1, 3, 3, 3)), p(torch.ones(1, 1, 3, 3, 3))]
)
mask = torch.ones(1, 2, 3, 3, 3)
filter_size = 3
TESTS_VALUE.append(
[{"mask": p(mask), "filter_size": filter_size}, p(torch.ones(1, 2, 3, 3, 3)), p(torch.ones(1, 2, 3, 3, 3))]
)
mask = torch.zeros(3, 2, 3, 3, 3)
mask[:, :, 1, 1, 1] = 1.0
filter_size = 3
TESTS_VALUE.append(
[{"mask": p(mask), "filter_size": filter_size}, p(torch.zeros(3, 2, 3, 3, 3)), p(torch.ones(3, 2, 3, 3, 3))]
)


class TestMorph(unittest.TestCase):

@parameterized.expand(TESTS_SHAPE)
def test_shape(self, input_data, expected_result):
result1 = morphological_ops.erode(input_data["mask"], input_data["filter_size"])
assert_allclose(result1.shape, expected_result, type_test=False, device_test=False, atol=0.0)

@parameterized.expand(TESTS_VALUE_T)
def test_value_t(self, input_data, expected_result):
result1 = morphological_ops.get_morphological_filter_result_t(
input_data["mask"], input_data["filter_size"], input_data["pad_value"]
)
# result1 = morphological_ops.erode(input_data["mask"],input_data["filter_size"])
# assert_allclose(result1, expected_erode_result, type_test=True, device_test=True, atol=0.0)
assert_allclose(result1, expected_result, type_test=False, device_test=False, atol=0.0)

@parameterized.expand(TESTS_VALUE)
def test_value(self, input_data, expected_erode_result, expected_dilate_result):
result1 = morphological_ops.erode(input_data["mask"], input_data["filter_size"])
assert_allclose(result1, expected_erode_result, type_test=True, device_test=True, atol=0.0)
result2 = morphological_ops.dilate(input_data["mask"], input_data["filter_size"])
assert_allclose(result2, expected_dilate_result, type_test=True, device_test=True, atol=0.0)


if __name__ == "__main__":
unittest.main()
Loading