diff --git a/monai/data/image_writer.py b/monai/data/image_writer.py index b467a5ae991..94b31c164c5 100644 --- a/monai/data/image_writer.py +++ b/monai/data/image_writer.py @@ -414,9 +414,12 @@ def set_data_array( spatial_ndim=kwargs.pop("spatial_ndim", 3), contiguous=kwargs.pop("contiguous", True), ) - self.channel_dim = ( - channel_dim if self.data_obj is not None and len(self.data_obj.shape) >= _r else None - ) # channel dim is at the end + if channel_dim is None: + self.channel_dim = -1 # after convert_to_channel_last the last dim is the channel + else: + self.channel_dim = ( + channel_dim if self.data_obj is not None and len(self.data_obj.shape) >= _r else None + ) # channel dim is at the end def set_metadata(self, meta_dict: Mapping | None = None, resample: bool = True, **options): """ diff --git a/tests/test_itk_writer.py b/tests/test_itk_writer.py index 869ec7b9472..931c74bb4be 100644 --- a/tests/test_itk_writer.py +++ b/tests/test_itk_writer.py @@ -52,6 +52,16 @@ def test_rgb(self): np.testing.assert_allclose(output.shape, (5, 5, 3)) np.testing.assert_allclose(output[1, 1], (5, 5, 4)) + def test_no_channel(self): + with tempfile.TemporaryDirectory() as tempdir: + fname = os.path.join(tempdir, "testing.nii.gz") + writer = ITKWriter(output_dtype=np.uint8) + writer.set_data_array(np.arange(48).reshape(3, 4, 4), channel_dim=None) + writer.write(fname) + + output = np.asarray(itk.imread(fname)) + np.testing.assert_allclose(output.shape, (4, 3, 4)) + np.testing.assert_allclose(output[1, 1], (20, 21, 22, 23)) if __name__ == "__main__": unittest.main()