Skip to content

Commit

Permalink
Maisi morphological funcs (#7893)
Browse files Browse the repository at this point in the history
Fixes # .

### Description

Maisi morphological funcs

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [x] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [x] In-line docstrings updated.
- [x] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: Can-Zhao <canz@nvidia.com>
Signed-off-by: Can Zhao <69829124+Can-Zhao@users.noreply.github.com>
Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Jul 2, 2024
1 parent 15d0771 commit 410109a
Show file tree
Hide file tree
Showing 4 changed files with 290 additions and 0 deletions.
8 changes: 8 additions & 0 deletions docs/source/apps.rst
Original file line number Diff line number Diff line change
Expand Up @@ -261,3 +261,11 @@ FastMRIReader

.. autoclass:: monai.apps.nnunet.nnUNetV2Runner
:members:

`Generative AI`
---------------

`MAISI Utilities`
~~~~~~~~~~~~~~~~~
.. automodule:: monai.apps.generation.maisi.utils.morphological_ops
:members:
10 changes: 10 additions & 0 deletions monai/apps/generation/maisi/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
170 changes: 170 additions & 0 deletions monai/apps/generation/maisi/utils/morphological_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from typing import Sequence

import torch
import torch.nn.functional as F
from torch import Tensor

from monai.config import NdarrayOrTensor
from monai.utils import convert_data_type, convert_to_dst_type, ensure_tuple_rep


def erode(mask: NdarrayOrTensor, filter_size: int | Sequence[int] = 3, pad_value: float = 1.0) -> NdarrayOrTensor:
"""
Erode 2D/3D binary mask.
Args:
mask: input 2D/3D binary mask, [N,C,M,N] or [N,C,M,N,P] torch tensor or ndarray.
filter_size: erosion filter size, has to be odd numbers, default to be 3.
pad_value: the filled value for padding. We need to pad the input before filtering
to keep the output with the same size as input. Usually use default value
and not changed.
Return:
eroded mask, same shape and data type as input.
Example:
.. code-block:: python
# define a naive mask
mask = torch.zeros(3,2,3,3,3)
mask[:,:,1,1,1] = 1.0
filter_size = 3
erode_result = erode(mask, filter_size) # expect torch.zeros(3,2,3,3,3)
dilate_result = dilate(mask, filter_size) # expect torch.ones(3,2,3,3,3)
"""
mask_t, *_ = convert_data_type(mask, torch.Tensor)
res_mask_t = erode_t(mask_t, filter_size=filter_size, pad_value=pad_value)
res_mask: NdarrayOrTensor
res_mask, *_ = convert_to_dst_type(src=res_mask_t, dst=mask)
return res_mask


def dilate(mask: NdarrayOrTensor, filter_size: int | Sequence[int] = 3, pad_value: float = 0.0) -> NdarrayOrTensor:
"""
Dilate 2D/3D binary mask.
Args:
mask: input 2D/3D binary mask, [N,C,M,N] or [N,C,M,N,P] torch tensor or ndarray.
filter_size: dilation filter size, has to be odd numbers, default to be 3.
pad_value: the filled value for padding. We need to pad the input before filtering
to keep the output with the same size as input. Usually use default value
and not changed.
Return:
dilated mask, same shape and data type as input.
Example:
.. code-block:: python
# define a naive mask
mask = torch.zeros(3,2,3,3,3)
mask[:,:,1,1,1] = 1.0
filter_size = 3
erode_result = erode(mask,filter_size) # expect torch.zeros(3,2,3,3,3)
dilate_result = dilate(mask,filter_size) # expect torch.ones(3,2,3,3,3)
"""
mask_t, *_ = convert_data_type(mask, torch.Tensor)
res_mask_t = dilate_t(mask_t, filter_size=filter_size, pad_value=pad_value)
res_mask: NdarrayOrTensor
res_mask, *_ = convert_to_dst_type(src=res_mask_t, dst=mask)
return res_mask


def get_morphological_filter_result_t(mask_t: Tensor, filter_size: int | Sequence[int], pad_value: float) -> Tensor:
"""
Apply a morphological filter to a 2D/3D binary mask tensor.
Args:
mask_t: input 2D/3D binary mask, [N,C,M,N] or [N,C,M,N,P] torch tensor.
filter_size: morphological filter size, has to be odd numbers.
pad_value: the filled value for padding. We need to pad the input before filtering
to keep the output with the same size as input.
Return:
Tensor: Morphological filter result mask, same shape as input.
"""
spatial_dims = len(mask_t.shape) - 2
if spatial_dims not in [2, 3]:
raise ValueError(
f"spatial_dims must be either 2 or 3, "
f"got spatial_dims={spatial_dims} for mask tensor with shape of {mask_t.shape}."
)

# Define the structuring element
filter_size = ensure_tuple_rep(filter_size, spatial_dims)
if any(size % 2 == 0 for size in filter_size):
raise ValueError(f"All dimensions in filter_size must be odd numbers, got {filter_size}.")

structuring_element = torch.ones((mask_t.shape[1], mask_t.shape[1]) + filter_size).to(mask_t.device)

# Pad the input tensor to handle border pixels
# Calculate padding size
pad_size = [size // 2 for size in filter_size for _ in range(2)]

input_padded = F.pad(mask_t.float(), pad_size, mode="constant", value=pad_value)

# Apply filter operation
conv_fn = F.conv2d if spatial_dims == 2 else F.conv3d
output = conv_fn(input_padded, structuring_element, padding=0) / torch.sum(structuring_element[0, ...])

return output


def erode_t(mask_t: Tensor, filter_size: int | Sequence[int] = 3, pad_value: float = 1.0) -> Tensor:
"""
Erode 2D/3D binary mask with data type as torch tensor.
Args:
mask_t: input 2D/3D binary mask, [N,C,M,N] or [N,C,M,N,P] torch tensor.
filter_size: erosion filter size, has to be odd numbers, default to be 3.
pad_value: the filled value for padding. We need to pad the input before filtering
to keep the output with the same size as input. Usually use default value
and not changed.
Return:
Tensor: eroded mask, same shape as input.
"""

output = get_morphological_filter_result_t(mask_t, filter_size, pad_value)

# Set output values based on the minimum value within the structuring element
output = torch.where(torch.abs(output - 1.0) < 1e-7, 1.0, 0.0)

return output


def dilate_t(mask_t: Tensor, filter_size: int | Sequence[int] = 3, pad_value: float = 0.0) -> Tensor:
"""
Dilate 2D/3D binary mask with data type as torch tensor.
Args:
mask_t: input 2D/3D binary mask, [N,C,M,N] or [N,C,M,N,P] torch tensor.
filter_size: dilation filter size, has to be odd numbers, default to be 3.
pad_value: the filled value for padding. We need to pad the input before filtering
to keep the output with the same size as input. Usually use default value
and not changed.
Return:
Tensor: dilated mask, same shape as input.
"""
output = get_morphological_filter_result_t(mask_t, filter_size, pad_value)

# Set output values based on the minimum value within the structuring element
output = torch.where(output > 0, 1.0, 0.0)

return output
102 changes: 102 additions & 0 deletions tests/test_morphological_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import unittest

import torch
from parameterized import parameterized

from monai.apps.generation.maisi.utils.morphological_ops import dilate, erode, get_morphological_filter_result_t
from tests.utils import TEST_NDARRAYS, assert_allclose

TESTS_SHAPE = []
for p in TEST_NDARRAYS:
mask = torch.zeros(1, 1, 5, 5, 5)
filter_size = 3
TESTS_SHAPE.append([{"mask": p(mask), "filter_size": filter_size}, [1, 1, 5, 5, 5]])
mask = torch.zeros(3, 2, 5, 5, 5)
filter_size = 5
TESTS_SHAPE.append([{"mask": p(mask), "filter_size": filter_size}, [3, 2, 5, 5, 5]])
mask = torch.zeros(1, 1, 1, 1, 1)
filter_size = 5
TESTS_SHAPE.append([{"mask": p(mask), "filter_size": filter_size}, [1, 1, 1, 1, 1]])
mask = torch.zeros(1, 1, 1, 1)
filter_size = 5
TESTS_SHAPE.append([{"mask": p(mask), "filter_size": filter_size}, [1, 1, 1, 1]])

TESTS_VALUE_T = []
filter_size = 3
mask = torch.ones(3, 2, 3, 3, 3)
TESTS_VALUE_T.append([{"mask": mask, "filter_size": filter_size, "pad_value": 1.0}, torch.ones(3, 2, 3, 3, 3)])
mask = torch.zeros(3, 2, 3, 3, 3)
TESTS_VALUE_T.append([{"mask": mask, "filter_size": filter_size, "pad_value": 0.0}, torch.zeros(3, 2, 3, 3, 3)])
mask = torch.ones(3, 2, 3, 3)
TESTS_VALUE_T.append([{"mask": mask, "filter_size": filter_size, "pad_value": 1.0}, torch.ones(3, 2, 3, 3)])
mask = torch.zeros(3, 2, 3, 3)
TESTS_VALUE_T.append([{"mask": mask, "filter_size": filter_size, "pad_value": 0.0}, torch.zeros(3, 2, 3, 3)])

TESTS_VALUE = []
for p in TEST_NDARRAYS:
mask = torch.zeros(3, 2, 5, 5, 5)
filter_size = 3
TESTS_VALUE.append(
[{"mask": p(mask), "filter_size": filter_size}, p(torch.zeros(3, 2, 5, 5, 5)), p(torch.zeros(3, 2, 5, 5, 5))]
)
mask = torch.ones(1, 1, 3, 3, 3)
filter_size = 3
TESTS_VALUE.append(
[{"mask": p(mask), "filter_size": filter_size}, p(torch.ones(1, 1, 3, 3, 3)), p(torch.ones(1, 1, 3, 3, 3))]
)
mask = torch.ones(1, 2, 3, 3, 3)
filter_size = 3
TESTS_VALUE.append(
[{"mask": p(mask), "filter_size": filter_size}, p(torch.ones(1, 2, 3, 3, 3)), p(torch.ones(1, 2, 3, 3, 3))]
)
mask = torch.zeros(3, 2, 3, 3, 3)
mask[:, :, 1, 1, 1] = 1.0
filter_size = 3
TESTS_VALUE.append(
[{"mask": p(mask), "filter_size": filter_size}, p(torch.zeros(3, 2, 3, 3, 3)), p(torch.ones(3, 2, 3, 3, 3))]
)
mask = torch.zeros(3, 2, 3, 3)
mask[:, :, 1, 1] = 1.0
filter_size = 3
TESTS_VALUE.append(
[{"mask": p(mask), "filter_size": filter_size}, p(torch.zeros(3, 2, 3, 3)), p(torch.ones(3, 2, 3, 3))]
)


class TestMorph(unittest.TestCase):

@parameterized.expand(TESTS_SHAPE)
def test_shape(self, input_data, expected_result):
result1 = erode(input_data["mask"], input_data["filter_size"])
assert_allclose(result1.shape, expected_result, type_test=False, device_test=False, atol=0.0)

@parameterized.expand(TESTS_VALUE_T)
def test_value_t(self, input_data, expected_result):
result1 = get_morphological_filter_result_t(
input_data["mask"], input_data["filter_size"], input_data["pad_value"]
)
assert_allclose(result1, expected_result, type_test=False, device_test=False, atol=0.0)

@parameterized.expand(TESTS_VALUE)
def test_value(self, input_data, expected_erode_result, expected_dilate_result):
result1 = erode(input_data["mask"], input_data["filter_size"])
assert_allclose(result1, expected_erode_result, type_test=True, device_test=True, atol=0.0)
result2 = dilate(input_data["mask"], input_data["filter_size"])
assert_allclose(result2, expected_dilate_result, type_test=True, device_test=True, atol=0.0)


if __name__ == "__main__":
unittest.main()

0 comments on commit 410109a

Please sign in to comment.