Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

6136 6146 update the default writer flag #6147

Merged
merged 13 commits into from
Mar 16, 2023
22 changes: 15 additions & 7 deletions monai/data/image_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,13 +373,14 @@ class ITKWriter(ImageWriter):
output_dtype: DtypeLike = None
channel_dim: int | None

def __init__(self, output_dtype: DtypeLike = np.float32, affine_lps_to_ras: bool = True, **kwargs):
def __init__(self, output_dtype: DtypeLike = np.float32, affine_lps_to_ras: bool | None = True, **kwargs):
"""
Args:
output_dtype: output data type.
affine_lps_to_ras: whether to convert the affine matrix from "LPS" to "RAS". Defaults to ``True``.
Set to ``True`` to be consistent with ``NibabelWriter``,
otherwise the affine matrix is assumed already in the ITK convention.
Set to ``None`` to use ``data_array.meta[MetaKeys.SPACE]`` to determine the flag.
kwargs: keyword arguments passed to ``ImageWriter``.

The constructor will create ``self.output_dtype`` internally.
Expand Down Expand Up @@ -414,9 +415,12 @@ def set_data_array(
spatial_ndim=kwargs.pop("spatial_ndim", 3),
contiguous=kwargs.pop("contiguous", True),
)
self.channel_dim = (
channel_dim if self.data_obj is not None and len(self.data_obj.shape) >= _r else None
) # channel dim is at the end
if channel_dim is None:
self.channel_dim = -1 # after convert_to_channel_last the last dim is the channel
KumoLiu marked this conversation as resolved.
Show resolved Hide resolved
else:
self.channel_dim = (
channel_dim if self.data_obj is not None and len(self.data_obj.shape) >= _r else None
) # channel dim is at the end

def set_metadata(self, meta_dict: Mapping | None = None, resample: bool = True, **options):
"""
Expand Down Expand Up @@ -478,7 +482,7 @@ def create_backend_obj(
channel_dim: int | None = 0,
affine: NdarrayOrTensor | None = None,
dtype: DtypeLike = np.float32,
affine_lps_to_ras: bool = True,
affine_lps_to_ras: bool | None = True,
**kwargs,
):
"""
Expand All @@ -492,14 +496,18 @@ def create_backend_obj(
affine_lps_to_ras: whether to convert the affine matrix from "LPS" to "RAS". Defaults to ``True``.
Set to ``True`` to be consistent with ``NibabelWriter``,
otherwise the affine matrix is assumed already in the ITK convention.
Set to ``None`` to use ``data_array.meta[MetaKeys.SPACE]`` to determine the flag.
kwargs: keyword arguments. Current `itk.GetImageFromArray` will read ``ttype`` from this dictionary.

see also:

- https://github.com/InsightSoftwareConsortium/ITK/blob/v5.2.1/Wrapping/Generators/Python/itk/support/extras.py#L389

"""
if isinstance(data_array, MetaTensor) and data_array.meta.get(MetaKeys.SPACE, SpaceKeys.LPS) != SpaceKeys.LPS:
affine_lps_to_ras = False # do the converting from LPS to RAS only if the space type is currently LPS.
if isinstance(data_array, MetaTensor) and affine_lps_to_ras is None:
affine_lps_to_ras = (
data_array.meta.get(MetaKeys.SPACE, SpaceKeys.LPS) != SpaceKeys.LPS
) # do the converting from LPS to RAS only if the space type is currently LPS.
data_array = super().create_backend_obj(data_array)
_is_vec = channel_dim is not None
if _is_vec:
Expand Down
21 changes: 2 additions & 19 deletions monai/networks/layers/filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,10 @@ class BilateralFilter(torch.autograd.Function):

Args:
input: input tensor.

spatial_sigma: the standard deviation of the spatial blur. Higher values can
hurt performance when not using the approximate method (see fast approx).

color_sigma: the standard deviation of the color blur. Lower values preserve
edges better whilst higher values tend to a simple gaussian spatial blur.

fast approx: This flag chooses between two implementations. The approximate method may
produce artifacts in some scenarios whereas the exact solution may be intolerably
slow for high spatial standard deviations.
Expand All @@ -49,6 +46,7 @@ class BilateralFilter(torch.autograd.Function):

@staticmethod
def forward(ctx, input, spatial_sigma=5, color_sigma=0.5, fast_approx=True):
"""autograd forward"""
ctx.ss = spatial_sigma
ctx.cs = color_sigma
ctx.fa = fast_approx
Expand All @@ -57,6 +55,7 @@ def forward(ctx, input, spatial_sigma=5, color_sigma=0.5, fast_approx=True):

@staticmethod
def backward(ctx, grad_output):
"""autograd backward"""
spatial_sigma, color_sigma, fast_approx = ctx.ss, ctx.cs, ctx.fa
grad_input = _C.bilateral_filter(grad_output, spatial_sigma, color_sigma, fast_approx)
return grad_input, None, None, None
Expand All @@ -74,9 +73,7 @@ class PHLFilter(torch.autograd.Function):

Args:
input: input tensor to be filtered.

features: feature tensor used to filter the input.

sigmas: the standard deviations of each feature in the filter.

Returns:
Expand Down Expand Up @@ -112,13 +109,9 @@ class TrainableBilateralFilterFunction(torch.autograd.Function):

Args:
input: input tensor to be filtered.

sigma x: trainable standard deviation of the spatial filter kernel in x direction.

sigma y: trainable standard deviation of the spatial filter kernel in y direction.

sigma z: trainable standard deviation of the spatial filter kernel in z direction.

color sigma: trainable standard deviation of the intensity range kernel. This filter
parameter determines the degree of edge preservation.

Expand Down Expand Up @@ -198,11 +191,9 @@ class TrainableBilateralFilter(torch.nn.Module):

Args:
input: input tensor to be filtered.

spatial_sigma: tuple (sigma_x, sigma_y, sigma_z) initializing the trainable standard
deviations of the spatial filter kernels. Tuple length must equal the number of
spatial input dimensions.

color_sigma: trainable standard deviation of the intensity range kernel. This filter
parameter determines the degree of edge preservation.

Expand Down Expand Up @@ -278,15 +269,10 @@ class TrainableJointBilateralFilterFunction(torch.autograd.Function):

Args:
input: input tensor to be filtered.

guide: guidance image tensor to be used during filtering.

sigma x: trainable standard deviation of the spatial filter kernel in x direction.

sigma y: trainable standard deviation of the spatial filter kernel in y direction.

sigma z: trainable standard deviation of the spatial filter kernel in z direction.

color sigma: trainable standard deviation of the intensity range kernel. This filter
parameter determines the degree of edge preservation.

Expand Down Expand Up @@ -371,13 +357,10 @@ class TrainableJointBilateralFilter(torch.nn.Module):

Args:
input: input tensor to be filtered.

guide: guidance image tensor to be used during filtering.

spatial_sigma: tuple (sigma_x, sigma_y, sigma_z) initializing the trainable standard
deviations of the spatial filter kernels. Tuple length must equal the number of
spatial input dimensions.

color_sigma: trainable standard deviation of the intensity range kernel. This filter
parameter determines the degree of edge preservation.

Expand Down
1 change: 1 addition & 0 deletions monai/transforms/io/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,7 @@ def set_options(self, init_kwargs=None, data_kwargs=None, meta_kwargs=None, writ
self.meta_kwargs.update(meta_kwargs)
if write_kwargs is not None:
self.write_kwargs.update(write_kwargs)
return self

def __call__(self, img: torch.Tensor | np.ndarray, meta_data: dict | None = None):
"""
Expand Down
1 change: 1 addition & 0 deletions monai/transforms/io/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,7 @@ def __init__(

def set_options(self, init_kwargs=None, data_kwargs=None, meta_kwargs=None, write_kwargs=None):
self.saver.set_options(init_kwargs, data_kwargs, meta_kwargs, write_kwargs)
return self

def __call__(self, data):
d = dict(data)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_image_rw.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def nrrd_rw(self, test_data, reader, writer, dtype, resample=True):
filepath = f"testfile_{ndim}d"
saver = SaveImage(
output_dir=self.test_dir, output_ext=output_ext, resample=resample, separate_folder=False, writer=writer
)
).set_options(init_kwargs={"affine_lps_to_ras": True})
test_data = MetaTensor(
p(test_data), meta={"filename_or_obj": f"{filepath}{output_ext}", "spatial_shape": test_data.shape}
)
Expand Down
11 changes: 11 additions & 0 deletions tests/test_itk_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,17 @@ def test_rgb(self):
np.testing.assert_allclose(output.shape, (5, 5, 3))
np.testing.assert_allclose(output[1, 1], (5, 5, 4))

def test_no_channel(self):
with tempfile.TemporaryDirectory() as tempdir:
fname = os.path.join(tempdir, "testing.nii.gz")
writer = ITKWriter(output_dtype=np.uint8)
writer.set_data_array(np.arange(48).reshape(3, 4, 4), channel_dim=None)
writer.write(fname)

output = np.asarray(itk.imread(fname))
np.testing.assert_allclose(output.shape, (4, 3, 4))
np.testing.assert_allclose(output[1, 1], (20, 21, 22, 23))


if __name__ == "__main__":
unittest.main()