Skip to content

Commit

Permalink
Track applied operations in image filter (Project-MONAI#7395)
Browse files Browse the repository at this point in the history
Fixes Project-MONAI#7394

### Description

When ImageFilter is in the transformation sequence it didn't pass the
applied_operations.
Now it is passed when present.

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [x] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: axel.vlaminck <axel.vlaminck@gmail.com>
Signed-off-by: Yu0610 <612410030@alum.ccu.edu.tw>
  • Loading branch information
vlaminckaxel authored and Yu0610 committed Apr 11, 2024
1 parent b28f8b5 commit 790afa5
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 3 deletions.
11 changes: 8 additions & 3 deletions monai/transforms/utility/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1562,17 +1562,22 @@ def __init__(self, filter: str | NdarrayOrTensor | nn.Module, filter_size: int |
self.filter_size = filter_size
self.additional_args_for_filter = kwargs

def __call__(self, img: NdarrayOrTensor, meta_dict: dict | None = None) -> NdarrayOrTensor:
def __call__(
self, img: NdarrayOrTensor, meta_dict: dict | None = None, applied_operations: list | None = None
) -> NdarrayOrTensor:
"""
Args:
img: torch tensor data to apply filter to with shape: [channels, height, width[, depth]]
meta_dict: An optional dictionary with metadata
applied_operations: An optional list of operations that have been applied to the data
Returns:
A MetaTensor with the same shape as `img` and identical metadata
"""
if isinstance(img, MetaTensor):
meta_dict = img.meta
applied_operations = img.applied_operations

img_, prev_type, device = convert_data_type(img, torch.Tensor)
ndim = img_.ndim - 1 # assumes channel first format

Expand All @@ -1582,8 +1587,8 @@ def __call__(self, img: NdarrayOrTensor, meta_dict: dict | None = None) -> Ndarr
self.filter = ApplyFilter(self.filter)

img_ = self._apply_filter(img_)
if meta_dict:
img_ = MetaTensor(img_, meta=meta_dict)
if meta_dict is not None or applied_operations is not None:
img_ = MetaTensor(img_, meta=meta_dict, applied_operations=applied_operations)
else:
img_, *_ = convert_data_type(img_, prev_type, device)
return img_
Expand Down
16 changes: 16 additions & 0 deletions tests/test_image_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import torch
from parameterized import parameterized

from monai.data.meta_tensor import MetaTensor
from monai.networks.layers.simplelayers import GaussianFilter
from monai.transforms import ImageFilter, ImageFilterd, RandImageFilter, RandImageFilterd

Expand Down Expand Up @@ -115,6 +116,21 @@ def test_call_3d(self, filter_name):
out_tensor = filter(SAMPLE_IMAGE_3D)
self.assertEqual(out_tensor.shape[1:], SAMPLE_IMAGE_3D.shape[1:])

def test_pass_applied_operations(self):
"Test that applied operations are passed through"
applied_operations = ["op1", "op2"]
image = MetaTensor(SAMPLE_IMAGE_2D, applied_operations=applied_operations)
filter = ImageFilter(SUPPORTED_FILTERS[0], 3, **ADDITIONAL_ARGUMENTS)
out_tensor = filter(image)
self.assertEqual(out_tensor.applied_operations, applied_operations)

def test_pass_empty_metadata_dict(self):
"Test that applied operations are passed through"
image = MetaTensor(SAMPLE_IMAGE_2D, meta={})
filter = ImageFilter(SUPPORTED_FILTERS[0], 3, **ADDITIONAL_ARGUMENTS)
out_tensor = filter(image)
self.assertTrue(isinstance(out_tensor, MetaTensor))


class TestImageFilterDict(unittest.TestCase):
@parameterized.expand(SUPPORTED_FILTERS)
Expand Down

0 comments on commit 790afa5

Please sign in to comment.