Skip to content

Commit

Permalink
update based on comments
Browse files Browse the repository at this point in the history
Signed-off-by: Wenqi Li <wenqil@nvidia.com>
  • Loading branch information
wyli committed Feb 10, 2022
1 parent 9292690 commit 67826e8
Show file tree
Hide file tree
Showing 6 changed files with 208 additions and 13 deletions.
5 changes: 3 additions & 2 deletions monai/data/folder_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class FolderLayout:
layout = FolderLayout(
output_dir="/test_run_1/",
postfix="seg",
extension=".nii",
extension="nii",
makedirs=False)
layout.filename(subject="Sub-A", idx="00", modality="T1")
# return value: "/test_run_1/Sub-A_seg_00_modality-T1.nii"
Expand Down Expand Up @@ -95,5 +95,6 @@ def filename(self, subject: PathLike = "subject", idx=None, **kwargs):
for k, v in kwargs.items():
full_name += f"_{k}-{v}"
if self.ext is not None:
full_name += f"{self.ext}"
ext = f"{self.ext}"
full_name += f"{ext}" if ext.startswith(".") else f".{ext}"
return full_name
25 changes: 15 additions & 10 deletions monai/data/image_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,26 +56,28 @@
SUPPORTED_WRITERS: Dict = {}


def register_writer(ext_name, *im_writer):
def register_writer(ext_name, *im_writers):
"""
Register ``ImageWriter``, so that writing a file with filename extension ``ext_name``
could be resolved to a tuple of potentially appropriate ``ImageWriter``.
The customised writers could be registered by:
.. code-block:: python
from monai.data import image_writer
from monai.data import register_writer
# `MyWriter` must implement `ImageWriter` interface
image_writer.register_writer(".nii", MyWriter)
register_writer("nii", MyWriter)
Args:
ext_name: the filename extension of the image.
As an indexing key, it will be converted to a lower case string.
im_writer: one or multiple ImageWriter classes with high priority ones first.
im_writers: one or multiple ImageWriter classes with high priority ones first.
"""
fmt = f"{ext_name}".lower()
if fmt.startswith("."):
fmt = fmt[1:]
existing = look_up_option(fmt, SUPPORTED_WRITERS, default=())
all_writers = im_writer + existing
all_writers = im_writers + existing
SUPPORTED_WRITERS[fmt] = all_writers


Expand All @@ -93,8 +95,11 @@ def resolve_writer(ext_name, error_if_not_found=True) -> Sequence:
if not SUPPORTED_WRITERS:
init()
fmt = f"{ext_name}".lower()
if fmt.startswith("."):
fmt = fmt[1:]
avail_writers = []
for _writer in look_up_option(fmt, SUPPORTED_WRITERS, default=SUPPORTED_WRITERS["*"]):
default_writers = SUPPORTED_WRITERS.get("*", ())
for _writer in look_up_option(fmt, SUPPORTED_WRITERS, default=default_writers):
try:
_writer() # this triggers `monai.utils.module.require_pkg` to check the system availability
avail_writers.append(_writer)
Expand Down Expand Up @@ -788,9 +793,9 @@ def init():
"""
Initialize the image writer modules according to the filename extension.
"""
for ext in (".png", ".jpg", ".jpeg", ".bmp", ".tiff", ".tif"):
for ext in ("png", "jpg", "jpeg", "bmp", "tiff", "tif"):
register_writer(ext, PILWriter) # TODO: test 16-bit
for ext in (".nii.gz", ".nii"):
for ext in ("nii.gz", "nii"):
register_writer(ext, NibabelWriter, ITKWriter)
register_writer(".nrrd", ITKWriter, NibabelWriter)
register_writer("*", ITKWriter, NibabelWriter)
register_writer("nrrd", ITKWriter, NibabelWriter)
register_writer("*", ITKWriter, NibabelWriter, ITKWriter)
2 changes: 1 addition & 1 deletion monai/transforms/io/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ def __init__(

def set_options(self, init_kwargs=None, data_kwargs=None, meta_kwargs=None, write_kwargs=None):
"""
Set the options for the underlying writer by updating `kwargs` dictionaries.
Set the options for the underlying writer by updating the `self.*_kwargs` dictionaries.
The arguments correspond to the following usage:
Expand Down
2 changes: 2 additions & 0 deletions tests/min_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def run_testsuit():
"test_mlp",
"test_nifti_header_revise",
"test_nifti_rw",
"test_nifti_saver",
"test_occlusion_sensitivity",
"test_orientation",
"test_orientationd",
Expand All @@ -120,6 +121,7 @@ def run_testsuit():
"test_pil_reader",
"test_plot_2d_or_3d_image",
"test_png_rw",
"test_png_saver",
"test_prepare_batch_default",
"test_prepare_batch_extra_input",
"test_rand_rotate",
Expand Down
111 changes: 111 additions & 0 deletions tests/test_nifti_saver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import tempfile
import unittest
from pathlib import Path

import numpy as np
import torch

from monai.data import NiftiSaver
from monai.transforms import LoadImage


class TestNiftiSaver(unittest.TestCase):
def test_saved_content(self):
with tempfile.TemporaryDirectory() as tempdir:

saver = NiftiSaver(output_dir=Path(tempdir), output_postfix="seg", output_ext=".nii.gz")

meta_data = {"filename_or_obj": ["testfile" + str(i) + ".nii" for i in range(8)]}
saver.save_batch(torch.zeros(8, 1, 2, 2), meta_data)
for i in range(8):
filepath = os.path.join("testfile" + str(i), "testfile" + str(i) + "_seg.nii.gz")
self.assertTrue(os.path.exists(os.path.join(tempdir, filepath)))

def test_saved_resize_content(self):
with tempfile.TemporaryDirectory() as tempdir:

saver = NiftiSaver(output_dir=tempdir, output_postfix="seg", output_ext=".nii.gz", dtype=np.float32)

meta_data = {
"filename_or_obj": ["testfile" + str(i) + ".nii" for i in range(8)],
"affine": [np.diag(np.ones(4)) * 5] * 8,
"original_affine": [np.diag(np.ones(4)) * 1.0] * 8,
}
saver.save_batch(torch.randint(0, 255, (8, 8, 2, 2)), meta_data)
for i in range(8):
filepath = os.path.join("testfile" + str(i), "testfile" + str(i) + "_seg.nii.gz")
self.assertTrue(os.path.exists(os.path.join(tempdir, filepath)))

def test_saved_3d_resize_content(self):
with tempfile.TemporaryDirectory() as tempdir:

saver = NiftiSaver(output_dir=tempdir, output_postfix="seg", output_ext=".nii.gz", dtype=np.float32)

meta_data = {
"filename_or_obj": ["testfile" + str(i) + ".nii.gz" for i in range(8)],
"spatial_shape": [(10, 10, 2)] * 8,
"affine": [np.diag(np.ones(4)) * 5] * 8,
"original_affine": [np.diag(np.ones(4)) * 1.0] * 8,
}
saver.save_batch(torch.randint(0, 255, (8, 8, 1, 2, 2)), meta_data)
for i in range(8):
filepath = os.path.join("testfile" + str(i), "testfile" + str(i) + "_seg.nii.gz")
self.assertTrue(os.path.exists(os.path.join(tempdir, filepath)))

def test_saved_3d_no_resize_content(self):
with tempfile.TemporaryDirectory() as tempdir:

saver = NiftiSaver(
output_dir=tempdir, output_postfix="seg", output_ext=".nii.gz", dtype=np.float32, resample=False
)

meta_data = {
"filename_or_obj": ["testfile" + str(i) + ".nii.gz" for i in range(8)],
"spatial_shape": [(10, 10, 2)] * 8,
"affine": [np.diag(np.ones(4)) * 5] * 8,
"original_affine": [np.diag(np.ones(4)) * 1.0] * 8,
}
saver.save_batch(torch.randint(0, 255, (8, 8, 1, 2, 2)), meta_data)
for i in range(8):
filepath = os.path.join(tempdir, "testfile" + str(i), "testfile" + str(i) + "_seg.nii.gz")
img, _ = LoadImage("nibabelreader")(filepath)
self.assertEqual(img.shape, (1, 2, 2, 8))

def test_squeeze_end_dims(self):
with tempfile.TemporaryDirectory() as tempdir:

for squeeze_end_dims in [False, True]:

saver = NiftiSaver(
output_dir=tempdir,
output_postfix="",
output_ext=".nii.gz",
dtype=np.float32,
squeeze_end_dims=squeeze_end_dims,
)

fname = "testfile_squeeze"
meta_data = {"filename_or_obj": fname}

# 2d image w channel
saver.save(torch.randint(0, 255, (1, 2, 2)), meta_data)

im, meta = LoadImage()(os.path.join(tempdir, fname, fname + ".nii.gz"))
self.assertTrue(im.ndim == 2 if squeeze_end_dims else 4)
self.assertTrue(meta["dim"][0] == im.ndim)


if __name__ == "__main__":
unittest.main()
76 changes: 76 additions & 0 deletions tests/test_png_saver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import tempfile
import unittest
from pathlib import Path

import torch

from monai.data import PNGSaver


class TestPNGSaver(unittest.TestCase):
def test_saved_content(self):
with tempfile.TemporaryDirectory() as tempdir:

saver = PNGSaver(output_dir=tempdir, output_postfix="seg", output_ext=".png", scale=255)

meta_data = {"filename_or_obj": ["testfile" + str(i) + ".jpg" for i in range(8)]}
saver.save_batch(torch.randint(1, 200, (8, 1, 2, 2)), meta_data)
for i in range(8):
filepath = os.path.join("testfile" + str(i), "testfile" + str(i) + "_seg.png")
self.assertTrue(os.path.exists(os.path.join(tempdir, filepath)))

def test_saved_content_three_channel(self):
with tempfile.TemporaryDirectory() as tempdir:

saver = PNGSaver(output_dir=Path(tempdir), output_postfix="seg", output_ext=".png", scale=255)

meta_data = {"filename_or_obj": ["testfile" + str(i) + ".jpg" for i in range(8)]}
saver.save_batch(torch.randint(1, 200, (8, 3, 2, 2)), meta_data)
for i in range(8):
filepath = os.path.join("testfile" + str(i), "testfile" + str(i) + "_seg.png")
self.assertTrue(os.path.exists(os.path.join(tempdir, filepath)))

def test_saved_content_spatial_size(self):
with tempfile.TemporaryDirectory() as tempdir:

saver = PNGSaver(output_dir=tempdir, output_postfix="seg", output_ext=".png", scale=255)

meta_data = {
"filename_or_obj": ["testfile" + str(i) + ".jpg" for i in range(8)],
"spatial_shape": [(4, 4) for i in range(8)],
}
saver.save_batch(torch.randint(1, 200, (8, 1, 2, 2)), meta_data)
for i in range(8):
filepath = os.path.join("testfile" + str(i), "testfile" + str(i) + "_seg.png")
self.assertTrue(os.path.exists(os.path.join(tempdir, filepath)))

def test_saved_specified_root(self):
with tempfile.TemporaryDirectory() as tempdir:

saver = PNGSaver(
output_dir=tempdir, output_postfix="seg", output_ext=".png", scale=255, data_root_dir="test"
)

meta_data = {
"filename_or_obj": [os.path.join("test", "testfile" + str(i), "image" + ".jpg") for i in range(8)]
}
saver.save_batch(torch.randint(1, 200, (8, 1, 2, 2)), meta_data)
for i in range(8):
filepath = os.path.join("testfile" + str(i), "image", "image" + "_seg.png")
self.assertTrue(os.path.exists(os.path.join(tempdir, filepath)))


if __name__ == "__main__":
unittest.main()

0 comments on commit 67826e8

Please sign in to comment.