Skip to content

Commit

Permalink
update channel_dim None case
Browse files Browse the repository at this point in the history
Signed-off-by: Wenqi Li <wenqil@nvidia.com>
  • Loading branch information
wyli committed Mar 15, 2023
1 parent 2a4cb8d commit b0ca392
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 3 deletions.
9 changes: 6 additions & 3 deletions monai/data/image_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,9 +414,12 @@ def set_data_array(
spatial_ndim=kwargs.pop("spatial_ndim", 3),
contiguous=kwargs.pop("contiguous", True),
)
self.channel_dim = (
channel_dim if self.data_obj is not None and len(self.data_obj.shape) >= _r else None
) # channel dim is at the end
if channel_dim is None:
self.channel_dim = -1 # after convert_to_channel_last the last dim is the channel
else:
self.channel_dim = (
channel_dim if self.data_obj is not None and len(self.data_obj.shape) >= _r else None
) # channel dim is at the end

def set_metadata(self, meta_dict: Mapping | None = None, resample: bool = True, **options):
"""
Expand Down
11 changes: 11 additions & 0 deletions tests/test_itk_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,17 @@ def test_rgb(self):
np.testing.assert_allclose(output.shape, (5, 5, 3))
np.testing.assert_allclose(output[1, 1], (5, 5, 4))

def test_no_channel(self):
with tempfile.TemporaryDirectory() as tempdir:
fname = os.path.join(tempdir, "testing.nii.gz")
writer = ITKWriter(output_dtype=np.uint8)
writer.set_data_array(np.arange(48).reshape(3, 4, 4), channel_dim=None)
writer.write(fname)

output = np.asarray(itk.imread(fname))
np.testing.assert_allclose(output.shape, (4, 3, 4))
np.testing.assert_allclose(output[1, 1], (20, 21, 22, 23))


if __name__ == "__main__":
unittest.main()

0 comments on commit b0ca392

Please sign in to comment.