Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Maisi morphological funcs #7893

Merged
merged 22 commits into from
Jul 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions docs/source/apps.rst
Original file line number Diff line number Diff line change
Expand Up @@ -261,3 +261,11 @@ FastMRIReader

.. autoclass:: monai.apps.nnunet.nnUNetV2Runner
:members:

`Generative AI`
---------------

`MAISI Utilities`
~~~~~~~~~~~~~~~~~
.. automodule:: monai.apps.generation.maisi.utils.morphological_ops
:members:
10 changes: 10 additions & 0 deletions monai/apps/generation/maisi/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
170 changes: 170 additions & 0 deletions monai/apps/generation/maisi/utils/morphological_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from typing import Sequence

import torch
import torch.nn.functional as F
from torch import Tensor

from monai.config import NdarrayOrTensor
from monai.utils import convert_data_type, convert_to_dst_type, ensure_tuple_rep


def erode(mask: NdarrayOrTensor, filter_size: int | Sequence[int] = 3, pad_value: float = 1.0) -> NdarrayOrTensor:
"""
Erode 2D/3D binary mask.
Args:
mask: input 2D/3D binary mask, [N,C,M,N] or [N,C,M,N,P] torch tensor or ndarray.
filter_size: erosion filter size, has to be odd numbers, default to be 3.
pad_value: the filled value for padding. We need to pad the input before filtering
to keep the output with the same size as input. Usually use default value
and not changed.
Return:
eroded mask, same shape and data type as input.
Example:
.. code-block:: python
# define a naive mask
mask = torch.zeros(3,2,3,3,3)
mask[:,:,1,1,1] = 1.0
filter_size = 3
erode_result = erode(mask, filter_size) # expect torch.zeros(3,2,3,3,3)
dilate_result = dilate(mask, filter_size) # expect torch.ones(3,2,3,3,3)
"""
mask_t, *_ = convert_data_type(mask, torch.Tensor)
res_mask_t = erode_t(mask_t, filter_size=filter_size, pad_value=pad_value)
res_mask: NdarrayOrTensor
res_mask, *_ = convert_to_dst_type(src=res_mask_t, dst=mask)
return res_mask


def dilate(mask: NdarrayOrTensor, filter_size: int | Sequence[int] = 3, pad_value: float = 0.0) -> NdarrayOrTensor:
"""
Dilate 2D/3D binary mask.
Args:
mask: input 2D/3D binary mask, [N,C,M,N] or [N,C,M,N,P] torch tensor or ndarray.
filter_size: dilation filter size, has to be odd numbers, default to be 3.
pad_value: the filled value for padding. We need to pad the input before filtering
to keep the output with the same size as input. Usually use default value
and not changed.
Return:
dilated mask, same shape and data type as input.
Example:
.. code-block:: python
# define a naive mask
mask = torch.zeros(3,2,3,3,3)
mask[:,:,1,1,1] = 1.0
filter_size = 3
erode_result = erode(mask,filter_size) # expect torch.zeros(3,2,3,3,3)
dilate_result = dilate(mask,filter_size) # expect torch.ones(3,2,3,3,3)
"""
mask_t, *_ = convert_data_type(mask, torch.Tensor)
res_mask_t = dilate_t(mask_t, filter_size=filter_size, pad_value=pad_value)
res_mask: NdarrayOrTensor
res_mask, *_ = convert_to_dst_type(src=res_mask_t, dst=mask)
return res_mask


def get_morphological_filter_result_t(mask_t: Tensor, filter_size: int | Sequence[int], pad_value: float) -> Tensor:
"""
Apply a morphological filter to a 2D/3D binary mask tensor.
Args:
mask_t: input 2D/3D binary mask, [N,C,M,N] or [N,C,M,N,P] torch tensor.
filter_size: morphological filter size, has to be odd numbers.
pad_value: the filled value for padding. We need to pad the input before filtering
to keep the output with the same size as input.
Return:
Tensor: Morphological filter result mask, same shape as input.
"""
spatial_dims = len(mask_t.shape) - 2
if spatial_dims not in [2, 3]:
raise ValueError(
f"spatial_dims must be either 2 or 3, "
f"got spatial_dims={spatial_dims} for mask tensor with shape of {mask_t.shape}."
)

# Define the structuring element
filter_size = ensure_tuple_rep(filter_size, spatial_dims)
if any(size % 2 == 0 for size in filter_size):
raise ValueError(f"All dimensions in filter_size must be odd numbers, got {filter_size}.")

structuring_element = torch.ones((mask_t.shape[1], mask_t.shape[1]) + filter_size).to(mask_t.device)

# Pad the input tensor to handle border pixels
# Calculate padding size
pad_size = [size // 2 for size in filter_size for _ in range(2)]

input_padded = F.pad(mask_t.float(), pad_size, mode="constant", value=pad_value)

# Apply filter operation
conv_fn = F.conv2d if spatial_dims == 2 else F.conv3d
output = conv_fn(input_padded, structuring_element, padding=0) / torch.sum(structuring_element[0, ...])

return output


def erode_t(mask_t: Tensor, filter_size: int | Sequence[int] = 3, pad_value: float = 1.0) -> Tensor:
"""
Erode 2D/3D binary mask with data type as torch tensor.
Args:
mask_t: input 2D/3D binary mask, [N,C,M,N] or [N,C,M,N,P] torch tensor.
filter_size: erosion filter size, has to be odd numbers, default to be 3.
pad_value: the filled value for padding. We need to pad the input before filtering
to keep the output with the same size as input. Usually use default value
and not changed.
Return:
Tensor: eroded mask, same shape as input.
"""

output = get_morphological_filter_result_t(mask_t, filter_size, pad_value)

# Set output values based on the minimum value within the structuring element
output = torch.where(torch.abs(output - 1.0) < 1e-7, 1.0, 0.0)

return output


def dilate_t(mask_t: Tensor, filter_size: int | Sequence[int] = 3, pad_value: float = 0.0) -> Tensor:
"""
Dilate 2D/3D binary mask with data type as torch tensor.
Args:
mask_t: input 2D/3D binary mask, [N,C,M,N] or [N,C,M,N,P] torch tensor.
filter_size: dilation filter size, has to be odd numbers, default to be 3.
pad_value: the filled value for padding. We need to pad the input before filtering
to keep the output with the same size as input. Usually use default value
and not changed.
Return:
Tensor: dilated mask, same shape as input.
"""
output = get_morphological_filter_result_t(mask_t, filter_size, pad_value)

# Set output values based on the minimum value within the structuring element
output = torch.where(output > 0, 1.0, 0.0)

return output
102 changes: 102 additions & 0 deletions tests/test_morphological_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import unittest

import torch
from parameterized import parameterized

from monai.apps.generation.maisi.utils.morphological_ops import dilate, erode, get_morphological_filter_result_t
from tests.utils import TEST_NDARRAYS, assert_allclose

TESTS_SHAPE = []
for p in TEST_NDARRAYS:
KumoLiu marked this conversation as resolved.
Show resolved Hide resolved
mask = torch.zeros(1, 1, 5, 5, 5)
filter_size = 3
TESTS_SHAPE.append([{"mask": p(mask), "filter_size": filter_size}, [1, 1, 5, 5, 5]])
mask = torch.zeros(3, 2, 5, 5, 5)
filter_size = 5
TESTS_SHAPE.append([{"mask": p(mask), "filter_size": filter_size}, [3, 2, 5, 5, 5]])
mask = torch.zeros(1, 1, 1, 1, 1)
filter_size = 5
TESTS_SHAPE.append([{"mask": p(mask), "filter_size": filter_size}, [1, 1, 1, 1, 1]])
mask = torch.zeros(1, 1, 1, 1)
filter_size = 5
TESTS_SHAPE.append([{"mask": p(mask), "filter_size": filter_size}, [1, 1, 1, 1]])

TESTS_VALUE_T = []
filter_size = 3
mask = torch.ones(3, 2, 3, 3, 3)
TESTS_VALUE_T.append([{"mask": mask, "filter_size": filter_size, "pad_value": 1.0}, torch.ones(3, 2, 3, 3, 3)])
mask = torch.zeros(3, 2, 3, 3, 3)
TESTS_VALUE_T.append([{"mask": mask, "filter_size": filter_size, "pad_value": 0.0}, torch.zeros(3, 2, 3, 3, 3)])
mask = torch.ones(3, 2, 3, 3)
TESTS_VALUE_T.append([{"mask": mask, "filter_size": filter_size, "pad_value": 1.0}, torch.ones(3, 2, 3, 3)])
mask = torch.zeros(3, 2, 3, 3)
TESTS_VALUE_T.append([{"mask": mask, "filter_size": filter_size, "pad_value": 0.0}, torch.zeros(3, 2, 3, 3)])

TESTS_VALUE = []
for p in TEST_NDARRAYS:
mask = torch.zeros(3, 2, 5, 5, 5)
filter_size = 3
TESTS_VALUE.append(
[{"mask": p(mask), "filter_size": filter_size}, p(torch.zeros(3, 2, 5, 5, 5)), p(torch.zeros(3, 2, 5, 5, 5))]
)
mask = torch.ones(1, 1, 3, 3, 3)
filter_size = 3
TESTS_VALUE.append(
[{"mask": p(mask), "filter_size": filter_size}, p(torch.ones(1, 1, 3, 3, 3)), p(torch.ones(1, 1, 3, 3, 3))]
)
mask = torch.ones(1, 2, 3, 3, 3)
filter_size = 3
TESTS_VALUE.append(
[{"mask": p(mask), "filter_size": filter_size}, p(torch.ones(1, 2, 3, 3, 3)), p(torch.ones(1, 2, 3, 3, 3))]
)
mask = torch.zeros(3, 2, 3, 3, 3)
mask[:, :, 1, 1, 1] = 1.0
filter_size = 3
TESTS_VALUE.append(
[{"mask": p(mask), "filter_size": filter_size}, p(torch.zeros(3, 2, 3, 3, 3)), p(torch.ones(3, 2, 3, 3, 3))]
)
mask = torch.zeros(3, 2, 3, 3)
mask[:, :, 1, 1] = 1.0
filter_size = 3
TESTS_VALUE.append(
[{"mask": p(mask), "filter_size": filter_size}, p(torch.zeros(3, 2, 3, 3)), p(torch.ones(3, 2, 3, 3))]
)


class TestMorph(unittest.TestCase):

@parameterized.expand(TESTS_SHAPE)
def test_shape(self, input_data, expected_result):
result1 = erode(input_data["mask"], input_data["filter_size"])
assert_allclose(result1.shape, expected_result, type_test=False, device_test=False, atol=0.0)

@parameterized.expand(TESTS_VALUE_T)
def test_value_t(self, input_data, expected_result):
result1 = get_morphological_filter_result_t(
input_data["mask"], input_data["filter_size"], input_data["pad_value"]
)
assert_allclose(result1, expected_result, type_test=False, device_test=False, atol=0.0)

@parameterized.expand(TESTS_VALUE)
def test_value(self, input_data, expected_erode_result, expected_dilate_result):
result1 = erode(input_data["mask"], input_data["filter_size"])
assert_allclose(result1, expected_erode_result, type_test=True, device_test=True, atol=0.0)
result2 = dilate(input_data["mask"], input_data["filter_size"])
assert_allclose(result2, expected_dilate_result, type_test=True, device_test=True, atol=0.0)


if __name__ == "__main__":
unittest.main()
Loading