Skip to content

Commit

Permalink
6136 6146 update the default writer flag (#6147)
Browse files Browse the repository at this point in the history
Fixes #6136 
fixes #6146 


### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [x] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [x] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [x] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: Wenqi Li <wenqil@nvidia.com>
  • Loading branch information
wyli authored Mar 16, 2023
1 parent 678b512 commit 66d0478
Show file tree
Hide file tree
Showing 10 changed files with 99 additions and 33 deletions.
24 changes: 16 additions & 8 deletions monai/data/image_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,13 +373,14 @@ class ITKWriter(ImageWriter):
output_dtype: DtypeLike = None
channel_dim: int | None

def __init__(self, output_dtype: DtypeLike = np.float32, affine_lps_to_ras: bool = True, **kwargs):
def __init__(self, output_dtype: DtypeLike = np.float32, affine_lps_to_ras: bool | None = True, **kwargs):
"""
Args:
output_dtype: output data type.
affine_lps_to_ras: whether to convert the affine matrix from "LPS" to "RAS". Defaults to ``True``.
Set to ``True`` to be consistent with ``NibabelWriter``,
otherwise the affine matrix is assumed already in the ITK convention.
Set to ``None`` to use ``data_array.meta[MetaKeys.SPACE]`` to determine the flag.
kwargs: keyword arguments passed to ``ImageWriter``.
The constructor will create ``self.output_dtype`` internally.
Expand All @@ -406,17 +407,20 @@ def set_data_array(
kwargs: keyword arguments passed to ``self.convert_to_channel_last``,
currently support ``spatial_ndim`` and ``contiguous``, defauting to ``3`` and ``False`` respectively.
"""
_r = len(data_array.shape)
n_chns = data_array.shape[channel_dim] if channel_dim is not None else 0
self.data_obj = self.convert_to_channel_last(
data=data_array,
channel_dim=channel_dim,
squeeze_end_dims=squeeze_end_dims,
spatial_ndim=kwargs.pop("spatial_ndim", 3),
contiguous=kwargs.pop("contiguous", True),
)
self.channel_dim = (
channel_dim if self.data_obj is not None and len(self.data_obj.shape) >= _r else None
) # channel dim is at the end
self.channel_dim = -1 # in most cases, the data is set to channel last
if squeeze_end_dims and n_chns <= 1: # num_channel==1 squeezed
self.channel_dim = None
if not squeeze_end_dims and n_chns < 1: # originally no channel and convert_to_channel_last added a channel
self.channel_dim = None
self.data_obj = self.data_obj[..., 0]

def set_metadata(self, meta_dict: Mapping | None = None, resample: bool = True, **options):
"""
Expand Down Expand Up @@ -478,7 +482,7 @@ def create_backend_obj(
channel_dim: int | None = 0,
affine: NdarrayOrTensor | None = None,
dtype: DtypeLike = np.float32,
affine_lps_to_ras: bool = True,
affine_lps_to_ras: bool | None = True,
**kwargs,
):
"""
Expand All @@ -492,14 +496,18 @@ def create_backend_obj(
affine_lps_to_ras: whether to convert the affine matrix from "LPS" to "RAS". Defaults to ``True``.
Set to ``True`` to be consistent with ``NibabelWriter``,
otherwise the affine matrix is assumed already in the ITK convention.
Set to ``None`` to use ``data_array.meta[MetaKeys.SPACE]`` to determine the flag.
kwargs: keyword arguments. Current `itk.GetImageFromArray` will read ``ttype`` from this dictionary.
see also:
- https://github.com/InsightSoftwareConsortium/ITK/blob/v5.2.1/Wrapping/Generators/Python/itk/support/extras.py#L389
"""
if isinstance(data_array, MetaTensor) and data_array.meta.get(MetaKeys.SPACE, SpaceKeys.LPS) != SpaceKeys.LPS:
affine_lps_to_ras = False # do the converting from LPS to RAS only if the space type is currently LPS.
if isinstance(data_array, MetaTensor) and affine_lps_to_ras is None:
affine_lps_to_ras = (
data_array.meta.get(MetaKeys.SPACE, SpaceKeys.LPS) != SpaceKeys.LPS
) # do the converting from LPS to RAS only if the space type is currently LPS.
data_array = super().create_backend_obj(data_array)
_is_vec = channel_dim is not None
if _is_vec:
Expand Down
19 changes: 0 additions & 19 deletions monai/networks/layers/filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,10 @@ class BilateralFilter(torch.autograd.Function):
Args:
input: input tensor.
spatial_sigma: the standard deviation of the spatial blur. Higher values can
hurt performance when not using the approximate method (see fast approx).
color_sigma: the standard deviation of the color blur. Lower values preserve
edges better whilst higher values tend to a simple gaussian spatial blur.
fast approx: This flag chooses between two implementations. The approximate method may
produce artifacts in some scenarios whereas the exact solution may be intolerably
slow for high spatial standard deviations.
Expand Down Expand Up @@ -76,9 +73,7 @@ class PHLFilter(torch.autograd.Function):
Args:
input: input tensor to be filtered.
features: feature tensor used to filter the input.
sigmas: the standard deviations of each feature in the filter.
Returns:
Expand Down Expand Up @@ -114,13 +109,9 @@ class TrainableBilateralFilterFunction(torch.autograd.Function):
Args:
input: input tensor to be filtered.
sigma x: trainable standard deviation of the spatial filter kernel in x direction.
sigma y: trainable standard deviation of the spatial filter kernel in y direction.
sigma z: trainable standard deviation of the spatial filter kernel in z direction.
color sigma: trainable standard deviation of the intensity range kernel. This filter
parameter determines the degree of edge preservation.
Expand Down Expand Up @@ -200,11 +191,9 @@ class TrainableBilateralFilter(torch.nn.Module):
Args:
input: input tensor to be filtered.
spatial_sigma: tuple (sigma_x, sigma_y, sigma_z) initializing the trainable standard
deviations of the spatial filter kernels. Tuple length must equal the number of
spatial input dimensions.
color_sigma: trainable standard deviation of the intensity range kernel. This filter
parameter determines the degree of edge preservation.
Expand Down Expand Up @@ -280,15 +269,10 @@ class TrainableJointBilateralFilterFunction(torch.autograd.Function):
Args:
input: input tensor to be filtered.
guide: guidance image tensor to be used during filtering.
sigma x: trainable standard deviation of the spatial filter kernel in x direction.
sigma y: trainable standard deviation of the spatial filter kernel in y direction.
sigma z: trainable standard deviation of the spatial filter kernel in z direction.
color sigma: trainable standard deviation of the intensity range kernel. This filter
parameter determines the degree of edge preservation.
Expand Down Expand Up @@ -373,13 +357,10 @@ class TrainableJointBilateralFilter(torch.nn.Module):
Args:
input: input tensor to be filtered.
guide: guidance image tensor to be used during filtering.
spatial_sigma: tuple (sigma_x, sigma_y, sigma_z) initializing the trainable standard
deviations of the spatial filter kernels. Tuple length must equal the number of
spatial input dimensions.
color_sigma: trainable standard deviation of the intensity range kernel. This filter
parameter determines the degree of edge preservation.
Expand Down
13 changes: 11 additions & 2 deletions monai/transforms/io/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
PydicomReader,
)
from monai.data.meta_tensor import MetaTensor
from monai.data.utils import is_no_channel
from monai.transforms.transform import Transform
from monai.transforms.utility.array import EnsureChannelFirst
from monai.utils import GridSamplePadMode
Expand Down Expand Up @@ -440,6 +441,7 @@ def set_options(self, init_kwargs=None, data_kwargs=None, meta_kwargs=None, writ
self.meta_kwargs.update(meta_kwargs)
if write_kwargs is not None:
self.write_kwargs.update(write_kwargs)
return self

def __call__(self, img: torch.Tensor | np.ndarray, meta_data: dict | None = None):
"""
Expand All @@ -450,8 +452,15 @@ def __call__(self, img: torch.Tensor | np.ndarray, meta_data: dict | None = None
meta_data = img.meta if isinstance(img, MetaTensor) else meta_data
kw = self.fname_formatter(meta_data, self)
filename = self.folder_layout.filename(**kw)
if meta_data and len(ensure_tuple(meta_data.get("spatial_shape", ()))) == len(img.shape):
self.data_kwargs["channel_dim"] = None
if meta_data:
meta_spatial_shape = ensure_tuple(meta_data.get("spatial_shape", ()))
if len(meta_spatial_shape) >= len(img.shape):
self.data_kwargs["channel_dim"] = None
elif is_no_channel(self.data_kwargs.get("channel_dim")):
warnings.warn(
f"data shape {img.shape} (with spatial shape {meta_spatial_shape}) "
f"but SaveImage `channel_dim` is set to {self.data_kwargs.get('channel_dim')} no channel."
)

err = []
for writer_cls in self.writers:
Expand Down
1 change: 1 addition & 0 deletions monai/transforms/io/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,7 @@ def __init__(

def set_options(self, init_kwargs=None, data_kwargs=None, meta_kwargs=None, write_kwargs=None):
self.saver.set_options(init_kwargs, data_kwargs, meta_kwargs, write_kwargs)
return self

def __call__(self, data):
d = dict(data)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_auto3dseg_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@


@skip_if_quick
@SkipIfBeforePyTorchVersion((1, 10, 0))
@SkipIfBeforePyTorchVersion((1, 13, 0))
@unittest.skipIf(not has_tb, "no tensorboard summary writer")
class TestEnsembleBuilder(unittest.TestCase):
def setUp(self) -> None:
Expand Down
2 changes: 1 addition & 1 deletion tests/test_image_rw.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def nrrd_rw(self, test_data, reader, writer, dtype, resample=True):
filepath = f"testfile_{ndim}d"
saver = SaveImage(
output_dir=self.test_dir, output_ext=output_ext, resample=resample, separate_folder=False, writer=writer
)
).set_options(init_kwargs={"affine_lps_to_ras": True})
test_data = MetaTensor(
p(test_data), meta={"filename_or_obj": f"{filepath}{output_ext}", "spatial_shape": test_data.shape}
)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_integration_autorunner.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@


@skip_if_quick
@SkipIfBeforePyTorchVersion((1, 9, 1))
@SkipIfBeforePyTorchVersion((1, 13, 0))
@unittest.skipIf(not has_tb, "no tensorboard summary writer")
class TestAutoRunner(unittest.TestCase):
def setUp(self) -> None:
Expand Down
2 changes: 1 addition & 1 deletion tests/test_integration_gpu_customization.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@


@skip_if_quick
@SkipIfBeforePyTorchVersion((1, 9, 1))
@SkipIfBeforePyTorchVersion((1, 13, 0))
@unittest.skipIf(not has_tb, "no tensorboard summary writer")
class TestEnsembleGpuCustomization(unittest.TestCase):
def setUp(self) -> None:
Expand Down
11 changes: 11 additions & 0 deletions tests/test_itk_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,17 @@ def test_rgb(self):
np.testing.assert_allclose(output.shape, (5, 5, 3))
np.testing.assert_allclose(output[1, 1], (5, 5, 4))

def test_no_channel(self):
with tempfile.TemporaryDirectory() as tempdir:
fname = os.path.join(tempdir, "testing.nii.gz")
writer = ITKWriter(output_dtype=np.uint8)
writer.set_data_array(np.arange(48).reshape(3, 4, 4), channel_dim=None)
writer.write(fname)

output = np.asarray(itk.imread(fname))
np.testing.assert_allclose(output.shape, (4, 4, 3))
np.testing.assert_allclose(output[1, 1], (5, 21, 37))


if __name__ == "__main__":
unittest.main()
56 changes: 56 additions & 0 deletions tests/testing_data/integration_answers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,62 @@
import numpy as np

EXPECTED_ANSWERS = [
{ # test answers for PyTorch 2.0
"integration_segmentation_3d": {
"losses": [
0.5430086106061935,
0.47010003924369814,
0.4453376233577728,
0.451901963353157,
0.4398456811904907,
0.43450237810611725,
],
"best_metric": 0.9329540133476257,
"infer_metric": 0.9330471754074097,
"output_sums": [
0.14212507078546172,
0.15199039602949577,
0.15133471939291526,
0.13967984811021827,
0.18831614355832332,
0.1694076821827231,
0.14663931509271658,
0.16788710637623733,
0.1569452710008219,
0.17907130698392254,
0.16244092698688475,
0.1679350345855819,
0.14437674754879065,
0.11355098478396568,
0.161660275855964,
0.20082478187698194,
0.17575491677668853,
0.0974593860605401,
0.19366775441539907,
0.20293016863409002,
0.19610441127101647,
0.20812173772459808,
0.16184212006067655,
0.13185211452732482,
0.14824716961304257,
0.14229818359602905,
0.23141282114085215,
0.1609268635938338,
0.14825300029123678,
0.10286266811772046,
0.11873484714087054,
0.1296615212510262,
0.11386621034856693,
0.15203351148564773,
0.16300823766585265,
0.1936726544485426,
0.2227251185536394,
0.18067789917505797,
0.19005874127683337,
0.07462121515702229,
],
}
},
{ # test answers for PyTorch 1.12.1
"integration_classification_2d": {
"losses": [0.776835828070428, 0.1615355300011149, 0.07492854832938523, 0.04591309238865877],
Expand Down

0 comments on commit 66d0478

Please sign in to comment.