diff --git a/monai/data/folder_layout.py b/monai/data/folder_layout.py index d8ce162c278..9a9689a7c7b 100644 --- a/monai/data/folder_layout.py +++ b/monai/data/folder_layout.py @@ -29,7 +29,7 @@ class FolderLayout: layout = FolderLayout( output_dir="/test_run_1/", postfix="seg", - extension=".nii", + extension="nii", makedirs=False) layout.filename(subject="Sub-A", idx="00", modality="T1") # return value: "/test_run_1/Sub-A_seg_00_modality-T1.nii" @@ -95,5 +95,6 @@ def filename(self, subject: PathLike = "subject", idx=None, **kwargs): for k, v in kwargs.items(): full_name += f"_{k}-{v}" if self.ext is not None: - full_name += f"{self.ext}" + ext = f"{self.ext}" + full_name += f"{ext}" if ext.startswith(".") else f".{ext}" return full_name diff --git a/monai/data/image_writer.py b/monai/data/image_writer.py index 6394e1c8a0e..074f4e22cf5 100644 --- a/monai/data/image_writer.py +++ b/monai/data/image_writer.py @@ -56,7 +56,7 @@ SUPPORTED_WRITERS: Dict = {} -def register_writer(ext_name, *im_writer): +def register_writer(ext_name, *im_writers): """ Register ``ImageWriter``, so that writing a file with filename extension ``ext_name`` could be resolved to a tuple of potentially appropriate ``ImageWriter``. @@ -64,18 +64,20 @@ def register_writer(ext_name, *im_writer): .. code-block:: python - from monai.data import image_writer + from monai.data import register_writer # `MyWriter` must implement `ImageWriter` interface - image_writer.register_writer(".nii", MyWriter) + register_writer("nii", MyWriter) Args: ext_name: the filename extension of the image. As an indexing key, it will be converted to a lower case string. - im_writer: one or multiple ImageWriter classes with high priority ones first. + im_writers: one or multiple ImageWriter classes with high priority ones first. """ fmt = f"{ext_name}".lower() + if fmt.startswith("."): + fmt = fmt[1:] existing = look_up_option(fmt, SUPPORTED_WRITERS, default=()) - all_writers = im_writer + existing + all_writers = im_writers + existing SUPPORTED_WRITERS[fmt] = all_writers @@ -93,8 +95,11 @@ def resolve_writer(ext_name, error_if_not_found=True) -> Sequence: if not SUPPORTED_WRITERS: init() fmt = f"{ext_name}".lower() + if fmt.startswith("."): + fmt = fmt[1:] avail_writers = [] - for _writer in look_up_option(fmt, SUPPORTED_WRITERS, default=SUPPORTED_WRITERS["*"]): + default_writers = SUPPORTED_WRITERS.get("*", ()) + for _writer in look_up_option(fmt, SUPPORTED_WRITERS, default=default_writers): try: _writer() # this triggers `monai.utils.module.require_pkg` to check the system availability avail_writers.append(_writer) @@ -788,9 +793,9 @@ def init(): """ Initialize the image writer modules according to the filename extension. """ - for ext in (".png", ".jpg", ".jpeg", ".bmp", ".tiff", ".tif"): + for ext in ("png", "jpg", "jpeg", "bmp", "tiff", "tif"): register_writer(ext, PILWriter) # TODO: test 16-bit - for ext in (".nii.gz", ".nii"): + for ext in ("nii.gz", "nii"): register_writer(ext, NibabelWriter, ITKWriter) - register_writer(".nrrd", ITKWriter, NibabelWriter) - register_writer("*", ITKWriter, NibabelWriter) + register_writer("nrrd", ITKWriter, NibabelWriter) + register_writer("*", ITKWriter, NibabelWriter, ITKWriter) diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index 4f4fbaf464e..46460292b00 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -326,7 +326,7 @@ def __init__( def set_options(self, init_kwargs=None, data_kwargs=None, meta_kwargs=None, write_kwargs=None): """ - Set the options for the underlying writer by updating `kwargs` dictionaries. + Set the options for the underlying writer by updating the `self.*_kwargs` dictionaries. The arguments correspond to the following usage: diff --git a/tests/min_tests.py b/tests/min_tests.py index 192d390ded2..426650eb040 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -111,6 +111,7 @@ def run_testsuit(): "test_mlp", "test_nifti_header_revise", "test_nifti_rw", + "test_nifti_saver", "test_occlusion_sensitivity", "test_orientation", "test_orientationd", @@ -120,6 +121,7 @@ def run_testsuit(): "test_pil_reader", "test_plot_2d_or_3d_image", "test_png_rw", + "test_png_saver", "test_prepare_batch_default", "test_prepare_batch_extra_input", "test_rand_rotate", diff --git a/tests/test_nifti_saver.py b/tests/test_nifti_saver.py new file mode 100644 index 00000000000..6855a590412 --- /dev/null +++ b/tests/test_nifti_saver.py @@ -0,0 +1,111 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import tempfile +import unittest +from pathlib import Path + +import numpy as np +import torch + +from monai.data import NiftiSaver +from monai.transforms import LoadImage + + +class TestNiftiSaver(unittest.TestCase): + def test_saved_content(self): + with tempfile.TemporaryDirectory() as tempdir: + + saver = NiftiSaver(output_dir=Path(tempdir), output_postfix="seg", output_ext=".nii.gz") + + meta_data = {"filename_or_obj": ["testfile" + str(i) + ".nii" for i in range(8)]} + saver.save_batch(torch.zeros(8, 1, 2, 2), meta_data) + for i in range(8): + filepath = os.path.join("testfile" + str(i), "testfile" + str(i) + "_seg.nii.gz") + self.assertTrue(os.path.exists(os.path.join(tempdir, filepath))) + + def test_saved_resize_content(self): + with tempfile.TemporaryDirectory() as tempdir: + + saver = NiftiSaver(output_dir=tempdir, output_postfix="seg", output_ext=".nii.gz", dtype=np.float32) + + meta_data = { + "filename_or_obj": ["testfile" + str(i) + ".nii" for i in range(8)], + "affine": [np.diag(np.ones(4)) * 5] * 8, + "original_affine": [np.diag(np.ones(4)) * 1.0] * 8, + } + saver.save_batch(torch.randint(0, 255, (8, 8, 2, 2)), meta_data) + for i in range(8): + filepath = os.path.join("testfile" + str(i), "testfile" + str(i) + "_seg.nii.gz") + self.assertTrue(os.path.exists(os.path.join(tempdir, filepath))) + + def test_saved_3d_resize_content(self): + with tempfile.TemporaryDirectory() as tempdir: + + saver = NiftiSaver(output_dir=tempdir, output_postfix="seg", output_ext=".nii.gz", dtype=np.float32) + + meta_data = { + "filename_or_obj": ["testfile" + str(i) + ".nii.gz" for i in range(8)], + "spatial_shape": [(10, 10, 2)] * 8, + "affine": [np.diag(np.ones(4)) * 5] * 8, + "original_affine": [np.diag(np.ones(4)) * 1.0] * 8, + } + saver.save_batch(torch.randint(0, 255, (8, 8, 1, 2, 2)), meta_data) + for i in range(8): + filepath = os.path.join("testfile" + str(i), "testfile" + str(i) + "_seg.nii.gz") + self.assertTrue(os.path.exists(os.path.join(tempdir, filepath))) + + def test_saved_3d_no_resize_content(self): + with tempfile.TemporaryDirectory() as tempdir: + + saver = NiftiSaver( + output_dir=tempdir, output_postfix="seg", output_ext=".nii.gz", dtype=np.float32, resample=False + ) + + meta_data = { + "filename_or_obj": ["testfile" + str(i) + ".nii.gz" for i in range(8)], + "spatial_shape": [(10, 10, 2)] * 8, + "affine": [np.diag(np.ones(4)) * 5] * 8, + "original_affine": [np.diag(np.ones(4)) * 1.0] * 8, + } + saver.save_batch(torch.randint(0, 255, (8, 8, 1, 2, 2)), meta_data) + for i in range(8): + filepath = os.path.join(tempdir, "testfile" + str(i), "testfile" + str(i) + "_seg.nii.gz") + img, _ = LoadImage("nibabelreader")(filepath) + self.assertEqual(img.shape, (1, 2, 2, 8)) + + def test_squeeze_end_dims(self): + with tempfile.TemporaryDirectory() as tempdir: + + for squeeze_end_dims in [False, True]: + + saver = NiftiSaver( + output_dir=tempdir, + output_postfix="", + output_ext=".nii.gz", + dtype=np.float32, + squeeze_end_dims=squeeze_end_dims, + ) + + fname = "testfile_squeeze" + meta_data = {"filename_or_obj": fname} + + # 2d image w channel + saver.save(torch.randint(0, 255, (1, 2, 2)), meta_data) + + im, meta = LoadImage()(os.path.join(tempdir, fname, fname + ".nii.gz")) + self.assertTrue(im.ndim == 2 if squeeze_end_dims else 4) + self.assertTrue(meta["dim"][0] == im.ndim) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_png_saver.py b/tests/test_png_saver.py new file mode 100644 index 00000000000..d8327186439 --- /dev/null +++ b/tests/test_png_saver.py @@ -0,0 +1,76 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import tempfile +import unittest +from pathlib import Path + +import torch + +from monai.data import PNGSaver + + +class TestPNGSaver(unittest.TestCase): + def test_saved_content(self): + with tempfile.TemporaryDirectory() as tempdir: + + saver = PNGSaver(output_dir=tempdir, output_postfix="seg", output_ext=".png", scale=255) + + meta_data = {"filename_or_obj": ["testfile" + str(i) + ".jpg" for i in range(8)]} + saver.save_batch(torch.randint(1, 200, (8, 1, 2, 2)), meta_data) + for i in range(8): + filepath = os.path.join("testfile" + str(i), "testfile" + str(i) + "_seg.png") + self.assertTrue(os.path.exists(os.path.join(tempdir, filepath))) + + def test_saved_content_three_channel(self): + with tempfile.TemporaryDirectory() as tempdir: + + saver = PNGSaver(output_dir=Path(tempdir), output_postfix="seg", output_ext=".png", scale=255) + + meta_data = {"filename_or_obj": ["testfile" + str(i) + ".jpg" for i in range(8)]} + saver.save_batch(torch.randint(1, 200, (8, 3, 2, 2)), meta_data) + for i in range(8): + filepath = os.path.join("testfile" + str(i), "testfile" + str(i) + "_seg.png") + self.assertTrue(os.path.exists(os.path.join(tempdir, filepath))) + + def test_saved_content_spatial_size(self): + with tempfile.TemporaryDirectory() as tempdir: + + saver = PNGSaver(output_dir=tempdir, output_postfix="seg", output_ext=".png", scale=255) + + meta_data = { + "filename_or_obj": ["testfile" + str(i) + ".jpg" for i in range(8)], + "spatial_shape": [(4, 4) for i in range(8)], + } + saver.save_batch(torch.randint(1, 200, (8, 1, 2, 2)), meta_data) + for i in range(8): + filepath = os.path.join("testfile" + str(i), "testfile" + str(i) + "_seg.png") + self.assertTrue(os.path.exists(os.path.join(tempdir, filepath))) + + def test_saved_specified_root(self): + with tempfile.TemporaryDirectory() as tempdir: + + saver = PNGSaver( + output_dir=tempdir, output_postfix="seg", output_ext=".png", scale=255, data_root_dir="test" + ) + + meta_data = { + "filename_or_obj": [os.path.join("test", "testfile" + str(i), "image" + ".jpg") for i in range(8)] + } + saver.save_batch(torch.randint(1, 200, (8, 1, 2, 2)), meta_data) + for i in range(8): + filepath = os.path.join("testfile" + str(i), "image", "image" + "_seg.png") + self.assertTrue(os.path.exists(os.path.join(tempdir, filepath))) + + +if __name__ == "__main__": + unittest.main()