-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fixes # . ### Description Maisi morphological funcs ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Can-Zhao <canz@nvidia.com> Signed-off-by: Can Zhao <69829124+Can-Zhao@users.noreply.github.com> Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
- Loading branch information
1 parent
15d0771
commit 410109a
Showing
4 changed files
with
290 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
# Copyright (c) MONAI Consortium | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,170 @@ | ||
# Copyright (c) MONAI Consortium | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from __future__ import annotations | ||
|
||
from typing import Sequence | ||
|
||
import torch | ||
import torch.nn.functional as F | ||
from torch import Tensor | ||
|
||
from monai.config import NdarrayOrTensor | ||
from monai.utils import convert_data_type, convert_to_dst_type, ensure_tuple_rep | ||
|
||
|
||
def erode(mask: NdarrayOrTensor, filter_size: int | Sequence[int] = 3, pad_value: float = 1.0) -> NdarrayOrTensor: | ||
""" | ||
Erode 2D/3D binary mask. | ||
Args: | ||
mask: input 2D/3D binary mask, [N,C,M,N] or [N,C,M,N,P] torch tensor or ndarray. | ||
filter_size: erosion filter size, has to be odd numbers, default to be 3. | ||
pad_value: the filled value for padding. We need to pad the input before filtering | ||
to keep the output with the same size as input. Usually use default value | ||
and not changed. | ||
Return: | ||
eroded mask, same shape and data type as input. | ||
Example: | ||
.. code-block:: python | ||
# define a naive mask | ||
mask = torch.zeros(3,2,3,3,3) | ||
mask[:,:,1,1,1] = 1.0 | ||
filter_size = 3 | ||
erode_result = erode(mask, filter_size) # expect torch.zeros(3,2,3,3,3) | ||
dilate_result = dilate(mask, filter_size) # expect torch.ones(3,2,3,3,3) | ||
""" | ||
mask_t, *_ = convert_data_type(mask, torch.Tensor) | ||
res_mask_t = erode_t(mask_t, filter_size=filter_size, pad_value=pad_value) | ||
res_mask: NdarrayOrTensor | ||
res_mask, *_ = convert_to_dst_type(src=res_mask_t, dst=mask) | ||
return res_mask | ||
|
||
|
||
def dilate(mask: NdarrayOrTensor, filter_size: int | Sequence[int] = 3, pad_value: float = 0.0) -> NdarrayOrTensor: | ||
""" | ||
Dilate 2D/3D binary mask. | ||
Args: | ||
mask: input 2D/3D binary mask, [N,C,M,N] or [N,C,M,N,P] torch tensor or ndarray. | ||
filter_size: dilation filter size, has to be odd numbers, default to be 3. | ||
pad_value: the filled value for padding. We need to pad the input before filtering | ||
to keep the output with the same size as input. Usually use default value | ||
and not changed. | ||
Return: | ||
dilated mask, same shape and data type as input. | ||
Example: | ||
.. code-block:: python | ||
# define a naive mask | ||
mask = torch.zeros(3,2,3,3,3) | ||
mask[:,:,1,1,1] = 1.0 | ||
filter_size = 3 | ||
erode_result = erode(mask,filter_size) # expect torch.zeros(3,2,3,3,3) | ||
dilate_result = dilate(mask,filter_size) # expect torch.ones(3,2,3,3,3) | ||
""" | ||
mask_t, *_ = convert_data_type(mask, torch.Tensor) | ||
res_mask_t = dilate_t(mask_t, filter_size=filter_size, pad_value=pad_value) | ||
res_mask: NdarrayOrTensor | ||
res_mask, *_ = convert_to_dst_type(src=res_mask_t, dst=mask) | ||
return res_mask | ||
|
||
|
||
def get_morphological_filter_result_t(mask_t: Tensor, filter_size: int | Sequence[int], pad_value: float) -> Tensor: | ||
""" | ||
Apply a morphological filter to a 2D/3D binary mask tensor. | ||
Args: | ||
mask_t: input 2D/3D binary mask, [N,C,M,N] or [N,C,M,N,P] torch tensor. | ||
filter_size: morphological filter size, has to be odd numbers. | ||
pad_value: the filled value for padding. We need to pad the input before filtering | ||
to keep the output with the same size as input. | ||
Return: | ||
Tensor: Morphological filter result mask, same shape as input. | ||
""" | ||
spatial_dims = len(mask_t.shape) - 2 | ||
if spatial_dims not in [2, 3]: | ||
raise ValueError( | ||
f"spatial_dims must be either 2 or 3, " | ||
f"got spatial_dims={spatial_dims} for mask tensor with shape of {mask_t.shape}." | ||
) | ||
|
||
# Define the structuring element | ||
filter_size = ensure_tuple_rep(filter_size, spatial_dims) | ||
if any(size % 2 == 0 for size in filter_size): | ||
raise ValueError(f"All dimensions in filter_size must be odd numbers, got {filter_size}.") | ||
|
||
structuring_element = torch.ones((mask_t.shape[1], mask_t.shape[1]) + filter_size).to(mask_t.device) | ||
|
||
# Pad the input tensor to handle border pixels | ||
# Calculate padding size | ||
pad_size = [size // 2 for size in filter_size for _ in range(2)] | ||
|
||
input_padded = F.pad(mask_t.float(), pad_size, mode="constant", value=pad_value) | ||
|
||
# Apply filter operation | ||
conv_fn = F.conv2d if spatial_dims == 2 else F.conv3d | ||
output = conv_fn(input_padded, structuring_element, padding=0) / torch.sum(structuring_element[0, ...]) | ||
|
||
return output | ||
|
||
|
||
def erode_t(mask_t: Tensor, filter_size: int | Sequence[int] = 3, pad_value: float = 1.0) -> Tensor: | ||
""" | ||
Erode 2D/3D binary mask with data type as torch tensor. | ||
Args: | ||
mask_t: input 2D/3D binary mask, [N,C,M,N] or [N,C,M,N,P] torch tensor. | ||
filter_size: erosion filter size, has to be odd numbers, default to be 3. | ||
pad_value: the filled value for padding. We need to pad the input before filtering | ||
to keep the output with the same size as input. Usually use default value | ||
and not changed. | ||
Return: | ||
Tensor: eroded mask, same shape as input. | ||
""" | ||
|
||
output = get_morphological_filter_result_t(mask_t, filter_size, pad_value) | ||
|
||
# Set output values based on the minimum value within the structuring element | ||
output = torch.where(torch.abs(output - 1.0) < 1e-7, 1.0, 0.0) | ||
|
||
return output | ||
|
||
|
||
def dilate_t(mask_t: Tensor, filter_size: int | Sequence[int] = 3, pad_value: float = 0.0) -> Tensor: | ||
""" | ||
Dilate 2D/3D binary mask with data type as torch tensor. | ||
Args: | ||
mask_t: input 2D/3D binary mask, [N,C,M,N] or [N,C,M,N,P] torch tensor. | ||
filter_size: dilation filter size, has to be odd numbers, default to be 3. | ||
pad_value: the filled value for padding. We need to pad the input before filtering | ||
to keep the output with the same size as input. Usually use default value | ||
and not changed. | ||
Return: | ||
Tensor: dilated mask, same shape as input. | ||
""" | ||
output = get_morphological_filter_result_t(mask_t, filter_size, pad_value) | ||
|
||
# Set output values based on the minimum value within the structuring element | ||
output = torch.where(output > 0, 1.0, 0.0) | ||
|
||
return output |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
# Copyright (c) MONAI Consortium | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from __future__ import annotations | ||
|
||
import unittest | ||
|
||
import torch | ||
from parameterized import parameterized | ||
|
||
from monai.apps.generation.maisi.utils.morphological_ops import dilate, erode, get_morphological_filter_result_t | ||
from tests.utils import TEST_NDARRAYS, assert_allclose | ||
|
||
TESTS_SHAPE = [] | ||
for p in TEST_NDARRAYS: | ||
mask = torch.zeros(1, 1, 5, 5, 5) | ||
filter_size = 3 | ||
TESTS_SHAPE.append([{"mask": p(mask), "filter_size": filter_size}, [1, 1, 5, 5, 5]]) | ||
mask = torch.zeros(3, 2, 5, 5, 5) | ||
filter_size = 5 | ||
TESTS_SHAPE.append([{"mask": p(mask), "filter_size": filter_size}, [3, 2, 5, 5, 5]]) | ||
mask = torch.zeros(1, 1, 1, 1, 1) | ||
filter_size = 5 | ||
TESTS_SHAPE.append([{"mask": p(mask), "filter_size": filter_size}, [1, 1, 1, 1, 1]]) | ||
mask = torch.zeros(1, 1, 1, 1) | ||
filter_size = 5 | ||
TESTS_SHAPE.append([{"mask": p(mask), "filter_size": filter_size}, [1, 1, 1, 1]]) | ||
|
||
TESTS_VALUE_T = [] | ||
filter_size = 3 | ||
mask = torch.ones(3, 2, 3, 3, 3) | ||
TESTS_VALUE_T.append([{"mask": mask, "filter_size": filter_size, "pad_value": 1.0}, torch.ones(3, 2, 3, 3, 3)]) | ||
mask = torch.zeros(3, 2, 3, 3, 3) | ||
TESTS_VALUE_T.append([{"mask": mask, "filter_size": filter_size, "pad_value": 0.0}, torch.zeros(3, 2, 3, 3, 3)]) | ||
mask = torch.ones(3, 2, 3, 3) | ||
TESTS_VALUE_T.append([{"mask": mask, "filter_size": filter_size, "pad_value": 1.0}, torch.ones(3, 2, 3, 3)]) | ||
mask = torch.zeros(3, 2, 3, 3) | ||
TESTS_VALUE_T.append([{"mask": mask, "filter_size": filter_size, "pad_value": 0.0}, torch.zeros(3, 2, 3, 3)]) | ||
|
||
TESTS_VALUE = [] | ||
for p in TEST_NDARRAYS: | ||
mask = torch.zeros(3, 2, 5, 5, 5) | ||
filter_size = 3 | ||
TESTS_VALUE.append( | ||
[{"mask": p(mask), "filter_size": filter_size}, p(torch.zeros(3, 2, 5, 5, 5)), p(torch.zeros(3, 2, 5, 5, 5))] | ||
) | ||
mask = torch.ones(1, 1, 3, 3, 3) | ||
filter_size = 3 | ||
TESTS_VALUE.append( | ||
[{"mask": p(mask), "filter_size": filter_size}, p(torch.ones(1, 1, 3, 3, 3)), p(torch.ones(1, 1, 3, 3, 3))] | ||
) | ||
mask = torch.ones(1, 2, 3, 3, 3) | ||
filter_size = 3 | ||
TESTS_VALUE.append( | ||
[{"mask": p(mask), "filter_size": filter_size}, p(torch.ones(1, 2, 3, 3, 3)), p(torch.ones(1, 2, 3, 3, 3))] | ||
) | ||
mask = torch.zeros(3, 2, 3, 3, 3) | ||
mask[:, :, 1, 1, 1] = 1.0 | ||
filter_size = 3 | ||
TESTS_VALUE.append( | ||
[{"mask": p(mask), "filter_size": filter_size}, p(torch.zeros(3, 2, 3, 3, 3)), p(torch.ones(3, 2, 3, 3, 3))] | ||
) | ||
mask = torch.zeros(3, 2, 3, 3) | ||
mask[:, :, 1, 1] = 1.0 | ||
filter_size = 3 | ||
TESTS_VALUE.append( | ||
[{"mask": p(mask), "filter_size": filter_size}, p(torch.zeros(3, 2, 3, 3)), p(torch.ones(3, 2, 3, 3))] | ||
) | ||
|
||
|
||
class TestMorph(unittest.TestCase): | ||
|
||
@parameterized.expand(TESTS_SHAPE) | ||
def test_shape(self, input_data, expected_result): | ||
result1 = erode(input_data["mask"], input_data["filter_size"]) | ||
assert_allclose(result1.shape, expected_result, type_test=False, device_test=False, atol=0.0) | ||
|
||
@parameterized.expand(TESTS_VALUE_T) | ||
def test_value_t(self, input_data, expected_result): | ||
result1 = get_morphological_filter_result_t( | ||
input_data["mask"], input_data["filter_size"], input_data["pad_value"] | ||
) | ||
assert_allclose(result1, expected_result, type_test=False, device_test=False, atol=0.0) | ||
|
||
@parameterized.expand(TESTS_VALUE) | ||
def test_value(self, input_data, expected_erode_result, expected_dilate_result): | ||
result1 = erode(input_data["mask"], input_data["filter_size"]) | ||
assert_allclose(result1, expected_erode_result, type_test=True, device_test=True, atol=0.0) | ||
result2 = dilate(input_data["mask"], input_data["filter_size"]) | ||
assert_allclose(result2, expected_dilate_result, type_test=True, device_test=True, atol=0.0) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |