From c4e23dfce4afbbe74e27c62086b3576ea0d2d020 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Tue, 8 Feb 2022 14:26:04 +0000 Subject: [PATCH] update saveimage and writer selector Signed-off-by: Wenqi Li --- docs/source/data.rst | 8 + monai/data/__init__.py | 11 +- monai/data/image_reader.py | 4 +- monai/data/image_writer.py | 79 ++++++++- monai/data/nifti_saver.py | 5 + monai/data/nifti_writer.py | 7 +- monai/data/png_saver.py | 6 +- monai/data/png_writer.py | 6 +- monai/transforms/io/array.py | 199 ++++++++++++---------- monai/transforms/io/dictionary.py | 97 +++++------ tests/min_tests.py | 2 - tests/test_handler_segmentation_saver.py | 8 +- tests/test_integration_segmentation_3d.py | 17 +- tests/test_nifti_saver.py | 111 ------------ tests/test_png_saver.py | 76 --------- tests/test_save_image.py | 12 +- tests/test_save_imaged.py | 12 +- 17 files changed, 319 insertions(+), 341 deletions(-) delete mode 100644 tests/test_nifti_saver.py delete mode 100644 tests/test_png_saver.py diff --git a/docs/source/data.rst b/docs/source/data.rst index f2377e29720..2bdf401c7f8 100644 --- a/docs/source/data.rst +++ b/docs/source/data.rst @@ -153,6 +153,14 @@ WSIReader Image writer ------------ +resolve_writer +~~~~~~~~~~~~~~ +.. autofunction:: resolve_writer + +register_writer +~~~~~~~~~~~~~~~ +.. autofunction:: register_writer + ImageWriter ~~~~~~~~~~~ .. autoclass:: ImageWriter diff --git a/monai/data/__init__.py b/monai/data/__init__.py index 86630ae495b..bed194d2f44 100644 --- a/monai/data/__init__.py +++ b/monai/data/__init__.py @@ -35,7 +35,16 @@ from .grid_dataset import GridPatchDataset, PatchDataset, PatchIter from .image_dataset import ImageDataset from .image_reader import ImageReader, ITKReader, NibabelReader, NumpyReader, PILReader, WSIReader -from .image_writer import ImageWriter, ITKWriter, NibabelWriter, PILWriter, logger +from .image_writer import ( + SUPPORTED_WRITERS, + ImageWriter, + ITKWriter, + NibabelWriter, + PILWriter, + logger, + register_writer, + resolve_writer, +) from .iterable_dataset import CSVIterableDataset, IterableDataset, ShuffleBuffer from .nifti_saver import NiftiSaver from .nifti_writer import write_nifti diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index 9f0e3f32cf6..0be7feb1e5e 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -18,12 +18,10 @@ from torch.utils.data._utils.collate import np_str_obj_array_pattern from monai.config import DtypeLike, KeysCollection, PathLike -from monai.data.utils import correct_nifti_header_if_necessary +from monai.data.utils import correct_nifti_header_if_necessary, is_supported_format from monai.transforms.utility.array import EnsureChannelFirst from monai.utils import ensure_tuple, ensure_tuple_rep, optional_import, require_pkg -from .utils import is_supported_format - if TYPE_CHECKING: import itk import nibabel as nib diff --git a/monai/data/image_writer.py b/monai/data/image_writer.py index 375063e3976..62ffc6c0722 100644 --- a/monai/data/image_writer.py +++ b/monai/data/image_writer.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Mapping, Optional, Sequence, Union +from typing import TYPE_CHECKING, Dict, Mapping, Optional, Sequence, Union import numpy as np @@ -22,6 +22,7 @@ GridSampleMode, GridSamplePadMode, InterpolateMode, + OptionalImportError, convert_data_type, look_up_option, optional_import, @@ -41,7 +42,69 @@ PILImage, _ = optional_import("PIL.Image") -__all__ = ["ImageWriter", "ITKWriter", "NibabelWriter", "PILWriter", "logger"] +__all__ = [ + "ImageWriter", + "ITKWriter", + "NibabelWriter", + "PILWriter", + "SUPPORTED_WRITERS", + "register_writer", + "resolve_writer", + "logger", +] + +SUPPORTED_WRITERS: Dict = {} + + +def register_writer(ext_name, *im_writer): + """ + Register ``ImageWriter``, so that writing a file with filename extension ``ext_name`` + could be resolved to a tuple of potentially appropriate ``ImageWriter``. + The customised writers could be registered by: + + .. code-block:: python + + from monai.data import image_writer + # `MyWriter` must implement `ImageWriter` interface + image_writer.register_writer(".nii", MyWriter) + + Args: + ext_name: the filename extension of the image. + As an indexing key, it will be converted to a lower case string. + im_writer: one or multiple ImageWriter classes with high priority ones first. + """ + fmt = f"{ext_name}".lower() + existing = look_up_option(fmt, SUPPORTED_WRITERS, default=()) + all_writers = im_writer + existing + SUPPORTED_WRITERS[fmt] = all_writers + + +def resolve_writer(ext_name, error_if_not_found=True) -> Sequence: + """ + Resolves to a tuple of available ``ImageWriter`` in ``SUPPORTED_WRITERS`` + according to the filename extension key ``ext_name``. + + Args: + ext_name: the filename extension of the image. + As an indexing key it will be converted to a lower case string. + error_if_not_found: whether to raise an error if no suitable image writer is found. + if True , raise an ``OptionalImportError``, otherwise return an empty tuple. Default is ``True``. + """ + if not SUPPORTED_WRITERS: + init() + fmt = f"{ext_name}".lower() + avail_writers = [] + for _writer in look_up_option(fmt, SUPPORTED_WRITERS, default=SUPPORTED_WRITERS["*"]): + try: + _writer() # this triggers `monai.utils.module.require_pkg` to check the system availability + avail_writers.append(_writer) + except OptionalImportError: + pass + if not avail_writers and error_if_not_found: + raise OptionalImportError(f"No ImageWriter backend found for {fmt}.") + writer_tuple = ensure_tuple(avail_writers) + SUPPORTED_WRITERS[fmt] = writer_tuple + return writer_tuple class ImageWriter: @@ -716,3 +779,15 @@ def create_backend_obj( data = np.moveaxis(data, 0, 1) return PILImage.fromarray(data, mode=kwargs.pop("image_mode", None)) + + +def init(): + """ + Initialize the image writer modules according to the filename extension. + """ + for ext in (".png", ".jpg", ".jpeg", ".bmp", ".tiff", ".tif"): + register_writer(ext, PILWriter) # TODO: test 16-bit + for ext in (".nii.gz", ".nii"): + register_writer(ext, NibabelWriter, ITKWriter) + register_writer(".nrrd", ITKWriter, NibabelWriter) + register_writer("*", ITKWriter, NibabelWriter) diff --git a/monai/data/nifti_saver.py b/monai/data/nifti_saver.py index a5acdd032ed..3fdc0aa3e87 100644 --- a/monai/data/nifti_saver.py +++ b/monai/data/nifti_saver.py @@ -19,8 +19,10 @@ from monai.data.utils import create_file_basename from monai.utils import GridSampleMode, GridSamplePadMode from monai.utils import ImageMetaKey as Key +from monai.utils import deprecated +@deprecated(since="0.8", msg_suffix="use monai.transforms.SaveImage instead.") class NiftiSaver: """ Save the data as NIfTI file, it can support single data content or a batch of data. @@ -32,6 +34,9 @@ class NiftiSaver: Note: image should include channel dimension: [B],C,H,W,[D]. + .. deprecated:: 0.8 + Use :py:class:`monai.transforms.SaveImage` instead. + """ def __init__( diff --git a/monai/data/nifti_writer.py b/monai/data/nifti_writer.py index b658121e49b..8a6172955f9 100644 --- a/monai/data/nifti_writer.py +++ b/monai/data/nifti_writer.py @@ -19,12 +19,13 @@ from monai.data.utils import compute_shape_offset, to_affine_nd from monai.networks.layers import AffineTransform from monai.transforms.utils_pytorch_numpy_unification import allclose -from monai.utils import GridSampleMode, GridSamplePadMode, optional_import +from monai.utils import GridSampleMode, GridSamplePadMode, deprecated, optional_import from monai.utils.type_conversion import convert_data_type nib, _ = optional_import("nibabel") +@deprecated(since="0.8", msg_suffix="use monai.data.NibabelWriter instead.") def write_nifti( data: NdarrayOrTensor, file_name: str, @@ -98,6 +99,10 @@ def write_nifti( dtype: data type for resampling computation. Defaults to ``np.float64`` for best precision. If None, use the data type of input data. output_dtype: data type for saving data. Defaults to ``np.float32``. + + .. deprecated:: 0.8 + Use :py:meth:`monai.data.NibabelWriter` instead. + """ data, *_ = convert_data_type(data, np.ndarray) affine, *_ = convert_data_type(affine, np.ndarray) diff --git a/monai/data/png_saver.py b/monai/data/png_saver.py index a83a560e9fb..9a1ade0efa4 100644 --- a/monai/data/png_saver.py +++ b/monai/data/png_saver.py @@ -18,9 +18,10 @@ from monai.data.png_writer import write_png from monai.data.utils import create_file_basename from monai.utils import ImageMetaKey as Key -from monai.utils import InterpolateMode, look_up_option +from monai.utils import InterpolateMode, deprecated, look_up_option +@deprecated(since="0.8", msg_suffix="use monai.transforms.SaveImage instead.") class PNGSaver: """ Save the data as png file, it can support single data content or a batch of data. @@ -30,6 +31,9 @@ class PNGSaver: where the input image name is extracted from the provided meta data dictionary. If no meta data provided, use index from 0 as the filename prefix. + .. deprecated:: 0.8 + Use :py:class:`monai.transforms.SaveImage` instead. + """ def __init__( diff --git a/monai/data/png_writer.py b/monai/data/png_writer.py index 7fcdb7fdb03..5d055369237 100644 --- a/monai/data/png_writer.py +++ b/monai/data/png_writer.py @@ -14,11 +14,12 @@ import numpy as np from monai.transforms.spatial.array import Resize -from monai.utils import InterpolateMode, ensure_tuple_rep, look_up_option, optional_import +from monai.utils import InterpolateMode, deprecated, ensure_tuple_rep, look_up_option, optional_import Image, _ = optional_import("PIL", name="Image") +@deprecated(since="0.8", msg_suffix="use monai.data.PILWriter instead.") def write_png( data: np.ndarray, file_name: str, @@ -46,6 +47,9 @@ def write_png( Raises: ValueError: When ``scale`` is not one of [255, 65535]. + .. deprecated:: 0.8 + Use :py:meth:`monai.data.PILWriter` instead. + """ if not isinstance(data, np.ndarray): raise ValueError("input data must be numpy array.") diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index 0b8b7ba156a..5b9fdbc3c35 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -24,9 +24,9 @@ import torch from monai.config import DtypeLike, PathLike +from monai.data import image_writer +from monai.data.folder_layout import FolderLayout from monai.data.image_reader import ImageReader, ITKReader, NibabelReader, NumpyReader, PILReader -from monai.data.nifti_saver import NiftiSaver -from monai.data.png_saver import PNGSaver from monai.transforms.transform import Transform from monai.utils import GridSampleMode, GridSamplePadMode from monai.utils import ImageMetaKey as Key @@ -82,7 +82,7 @@ class LoadImage(Transform): - User-specified reader in the constructor of `LoadImage`. - Readers from the last to the first in the registered list. - Current default readers: (nii, nii.gz -> NibabelReader), (png, jpg, bmp -> PILReader), - (npz, npy -> NumpyReader), (dcm, DICOM series and others -> ITKReader). + (npz, npy -> NumpyReader), (DICOM file -> ITKReader). See also: @@ -112,7 +112,7 @@ def __init__(self, reader=None, image_only: bool = False, dtype: DtypeLike = np. or a tuple of two elements containing the data array, and the meta data in a dictionary format otherwise. - If `reader` is specified, the loader will attempt to use the specified readers and the default supported readers. This might introduce overheads when handling the exceptions of trying the incompatible loaders. - In this case, it is therefore recommended to set the most appropriate reader as + In this case, it is therefore recommended setting the most appropriate reader as the last item of the `reader` parameter. """ @@ -227,69 +227,59 @@ def __call__(self, filename: Union[Sequence[PathLike], PathLike], reader: Option class SaveImage(Transform): """ - Save transformed data into files, support NIfTI and PNG formats. - It can work for both numpy array and PyTorch Tensor in both preprocessing transform - chain and postprocessing transform chain. - The name of saved file will be `{input_image_name}_{output_postfix}{output_ext}`, - where the input image name is extracted from the provided meta data dictionary. - If no meta data provided, use index from 0 as the filename prefix. - It can also save a list of PyTorch Tensor or numpy array without `batch dim`. + Save the image (in the form of torch tensor or numpy ndarray) and metadata dictionary into files. - Note: image should be channel-first shape: [C,H,W,[D]]. + The name of saved file will be `{input_image_name}_{output_postfix}{output_ext}`, + where the `input_image_name` is extracted from the provided metadata dictionary. + If no metadata provided, a running index starting from 0 will be used as the filename prefix. Args: output_dir: output image directory. output_postfix: a string appended to all output file names, default to `trans`. - output_ext: output file extension name, available extensions: `.nii.gz`, `.nii`, `.png`. - resample: whether to resample before saving the data array. - if saving PNG format image, based on the `spatial_shape` from metadata. - if saving NIfTI format image, based on the `original_affine` from metadata. - mode: This option is used when ``resample = True``. Defaults to ``"nearest"``. - - - NIfTI files {``"bilinear"``, ``"nearest"``} - Interpolation mode to calculate output values. - See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html - - PNG files {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``, ``"area"``} - The interpolation mode. - See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html - - padding_mode: This option is used when ``resample = True``. Defaults to ``"border"``. + output_ext: output file extension name. + output_dtype: data type for saving data. Defaults to ``np.float32``. + resample: whether to resample image (if needed) before saving the data array, + based on the `spatial_shape` (and `original_affine`) from metadata. + mode: This option is used when ``resample=True``. Defaults to ``"nearest"``. + Depending on the writers, the possible options are - - NIfTI files {``"zeros"``, ``"border"``, ``"reflection"``} - Padding mode for outside grid values. - See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html - - PNG files - This option is ignored. + - {``"bilinear"``, ``"nearest"``, ``"bicubic"``}. + See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + - {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``, ``"area"``}. + See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate + padding_mode: This option is used when ``resample = True``. Defaults to ``"border"``. + Possible options are {``"zeros"``, ``"border"``, ``"reflection"``} + See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample scale: {``255``, ``65535``} postprocess data by clipping to [0, 1] and scaling - [0, 255] (uint8) or [0, 65535] (uint16). Default is None to disable scaling. - it's used for PNG format only. + [0, 255] (uint8) or [0, 65535] (uint16). Default is `None` (no scaling). dtype: data type during resampling computation. Defaults to ``np.float64`` for best precision. if None, use the data type of input data. To be compatible with other modules, - the output data type is always ``np.float32``. - it's used for NIfTI format only. - output_dtype: data type for saving data. Defaults to ``np.float32``. - it's used for NIfTI format only. squeeze_end_dims: if True, any trailing singleton dimensions will be removed (after the channel has been moved to the end). So if input is (C,H,W,D), this will be altered to (H,W,D,C), and - then if C==1, it will be saved as (H,W,D). If D also ==1, it will be saved as (H,W). If false, + then if C==1, it will be saved as (H,W,D). If D is also 1, it will be saved as (H,W). If `false`, image will always be saved as (H,W,D,C). - it's used for NIfTI format only. data_root_dir: if not empty, it specifies the beginning parts of the input file's - absolute path. it's used to compute `input_file_rel_path`, the relative path to the file from + absolute path. It's used to compute `input_file_rel_path`, the relative path to the file from `data_root_dir` to preserve folder structure when saving in case there are files in different - folders with the same file names. for example: - input_file_name: /foo/bar/test1/image.nii, - output_postfix: seg - output_ext: nii.gz - output_dir: /output, - data_root_dir: /foo/bar, - output will be: /output/test1/image/image_seg.nii.gz - separate_folder: whether to save every file in a separate folder, for example: if input filename is - `image.nii`, postfix is `seg` and folder_path is `output`, if `True`, save as: - `output/image/image_seg.nii`, if `False`, save as `output/image_seg.nii`. default to `True`. - print_log: whether to print log about the saved file path, etc. default to `True`. - + folders with the same file names. For example, with the following inputs: + + - input_file_name: `/foo/bar/test1/image.nii` + - output_postfix: `seg` + - output_ext: `.nii.gz` + - output_dir: `/output` + - data_root_dir: `/foo/bar` + + The output will be: /output/test1/image/image_seg.nii.gz + + separate_folder: whether to save every file in a separate folder. For example: for the input filename + `image.nii`, postfix `seg` and folder_path `output`, if `separate_folder=True`, it will be saved as: + `output/image/image_seg.nii`, if `False`, saving as `output/image_seg.nii`. Default to `True`. + print_log: whether to print logs when saving. Default to `True`. + output_format: an optional string to specify the output image writer. + see also: `monai.data.image_writer.SUPPORTED_WRITERS`. + writer: a customised image writer to save data arrays. + if `None`, use the default writer from `monai.data.image_writer` according to `output_ext`. """ def __init__( @@ -297,55 +287,90 @@ def __init__( output_dir: PathLike = "./", output_postfix: str = "trans", output_ext: str = ".nii.gz", + output_dtype: DtypeLike = np.float32, resample: bool = True, mode: Union[GridSampleMode, InterpolateMode, str] = "nearest", padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER, scale: Optional[int] = None, dtype: DtypeLike = np.float64, - output_dtype: DtypeLike = np.float32, squeeze_end_dims: bool = True, data_root_dir: PathLike = "", separate_folder: bool = True, print_log: bool = True, + output_format: Optional[str] = None, + writer: Optional[image_writer.ImageWriter] = None, ) -> None: - self.saver: Union[NiftiSaver, PNGSaver] - if output_ext in {".nii.gz", ".nii"}: - self.saver = NiftiSaver( - output_dir=output_dir, - output_postfix=output_postfix, - output_ext=output_ext, - resample=resample, - mode=GridSampleMode(mode), - padding_mode=padding_mode, - dtype=dtype, - output_dtype=output_dtype, - squeeze_end_dims=squeeze_end_dims, - data_root_dir=data_root_dir, - separate_folder=separate_folder, - print_log=print_log, - ) - elif output_ext == ".png": - self.saver = PNGSaver( - output_dir=output_dir, - output_postfix=output_postfix, - output_ext=output_ext, - resample=resample, - mode=InterpolateMode(mode), - scale=scale, - data_root_dir=data_root_dir, - separate_folder=separate_folder, - print_log=print_log, - ) - else: - raise ValueError(f"unsupported output extension: {output_ext}.") + self.folder_layout = FolderLayout( + output_dir=output_dir, + postfix=output_postfix, + extension=output_ext, + parent=separate_folder, + makedirs=True, + data_root_dir=data_root_dir, + ) + + self.output_ext = output_ext.lower() + self.writers = image_writer.resolve_writer(output_format or self.output_ext) if writer is None else (writer,) + + _output_dtype = output_dtype + if self.output_ext == ".png" and _output_dtype not in (np.uint8, np.uint16): + _output_dtype = np.uint8 + if self.output_ext == ".dcm" and _output_dtype not in (np.uint8, np.uint16): + _output_dtype = np.uint8 + self.init_kwargs = {"output_dtype": _output_dtype, "scale": scale} + self.data_kwargs = {"squeeze_end_dims": squeeze_end_dims, "channel_dim": 0} + self.meta_kwargs = {"resample": resample, "mode": mode, "padding_mode": padding_mode, "dtype": dtype} + self.write_kwargs = {"verbose": print_log} + self._data_index = 0 + + def set_options(self, init_kwargs=None, data_kwargs=None, meta_kwargs=None, write_kwargs=None): + """ + Set the options for the underlying writer by updating `kwargs` dictionaries. + + The arguments correspond to the following usage: + + - `writer = ImageWriter(**init_kwargs)` + - `writer.set_data_array(array, **data_kwargs)` + - `writer.set_metadata(meta_data, **meta_kwargs)` + - `writer.write(filename, **write_kwargs)` + + """ + if init_kwargs is not None: + self.init_kwargs.update(init_kwargs) + if data_kwargs is not None: + self.data_kwargs.update(data_kwargs) + if meta_kwargs is not None: + self.meta_kwargs.update(meta_kwargs) + if write_kwargs is not None: + self.write_kwargs.update(write_kwargs) def __call__(self, img: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] = None): """ Args: - img: target data content that save into file. + img: target data content that save into file. The image should be channel-first, shape: `[C,H,W,[D]]`. meta_data: key-value pairs of meta_data corresponding to the data. - """ - self.saver.save(img, meta_data) + subject = meta_data[Key.FILENAME_OR_OBJ] if meta_data else str(self._data_index) + patch_index = meta_data.get(Key.PATCH_INDEX, None) if meta_data else None + filename = self.folder_layout.filename(subject=f"{subject}", idx=patch_index) - return img + for writer_cls in self.writers: + try: + writer_obj = writer_cls(**self.init_kwargs) + writer_obj.set_data_array(data_array=img, **self.data_kwargs) + writer_obj.set_metadata(meta_dict=meta_data, **self.meta_kwargs) + writer_obj.write(filename, **self.write_kwargs) + except Exception as e: + logging.getLogger(self.__class__.__name__).exception(e, exc_info=True) + logging.getLogger(self.__class__.__name__).info( + f"{writer_cls.__class__.__name__}: unable to write {filename}." + ) + else: + self._data_index += 1 + return img + raise RuntimeError( + f"cannot find a suitable writer for {filename}.\n" + " Please install the reader libraries, see also the installation instructions:\n" + " https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies.\n" + f" The current registered: {self.writers}.\n" + ) diff --git a/monai/transforms/io/dictionary.py b/monai/transforms/io/dictionary.py index 071db4b5b2b..67c7bd05882 100644 --- a/monai/transforms/io/dictionary.py +++ b/monai/transforms/io/dictionary.py @@ -21,6 +21,7 @@ import numpy as np from monai.config import DtypeLike, KeysCollection +from monai.data import image_writer from monai.data.image_reader import ImageReader from monai.transforms.io.array import LoadImage, SaveImage from monai.transforms.transform import MapTransform @@ -150,68 +151,61 @@ class SaveImaged(MapTransform): Args: keys: keys of the corresponding items to be transformed. See also: :py:class:`monai.transforms.compose.MapTransform` - meta_keys: explicitly indicate the key of the corresponding meta data dictionary. - for example, for data with key `image`, the metadata by default is in `image_meta_dict`. - the meta data is a dictionary object which contains: filename, original_shape, etc. - it can be a sequence of string, map to the `keys`. - if None, will try to construct meta_keys by `key_{meta_key_postfix}`. - meta_key_postfix: if meta_keys is None and `key_{postfix}` was used to store the metadata in `LoadImaged`. - need the key to extract metadata to save images, default is `meta_dict`. - for example, for data with key `image`, the metadata by default is in `image_meta_dict`. - the meta data is a dictionary object which contains: filename, affine, original_shape, etc. - if no corresponding metadata, set to `None`. + meta_keys: explicitly indicate the key of the corresponding metadata dictionary. + For example, for data with key `image`, the metadata by default is in `image_meta_dict`. + The metadata is a dictionary contains values such as filename, original_shape. + This argument can be a sequence of string, map to the `keys`. + If `None`, will try to construct meta_keys by `key_{meta_key_postfix}`. + meta_key_postfix: if `meta_keys` is `None`, use `key_{meta_key_postfix}` to retrieve the metadict. output_dir: output image directory. output_postfix: a string appended to all output file names, default to `trans`. output_ext: output file extension name, available extensions: `.nii.gz`, `.nii`, `.png`. - resample: whether to resample before saving the data array. - if saving PNG format image, based on the `spatial_shape` from metadata. - if saving NIfTI format image, based on the `original_affine` from metadata. - mode: This option is used when ``resample = True``. Defaults to ``"nearest"``. - - - NIfTI files {``"bilinear"``, ``"nearest"``} - Interpolation mode to calculate output values. - See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html - - PNG files {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``, ``"area"``} - The interpolation mode. - See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html - - padding_mode: This option is used when ``resample = True``. Defaults to ``"border"``. + output_dtype: data type for saving data. Defaults to ``np.float32``. + resample: whether to resample image (if needed) before saving the data array, + based on the `spatial_shape` (and `original_affine`) from metadata. + mode: This option is used when ``resample=True``. Defaults to ``"nearest"``. + Depending on the writers, the possible options are: - - NIfTI files {``"zeros"``, ``"border"``, ``"reflection"``} - Padding mode for outside grid values. - See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html - - PNG files - This option is ignored. + - {``"bilinear"``, ``"nearest"``, ``"bicubic"``}. + See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + - {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``, ``"area"``}. + See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate + padding_mode: This option is used when ``resample = True``. Defaults to ``"border"``. + Possible options are {``"zeros"``, ``"border"``, ``"reflection"``} + See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample scale: {``255``, ``65535``} postprocess data by clipping to [0, 1] and scaling - [0, 255] (uint8) or [0, 65535] (uint16). Default is None to disable scaling. - it's used for PNG format only. + [0, 255] (uint8) or [0, 65535] (uint16). Default is `None` (no scaling). dtype: data type during resampling computation. Defaults to ``np.float64`` for best precision. if None, use the data type of input data. To be compatible with other modules, - the output data type is always ``np.float32``. - it's used for NIfTI format only. output_dtype: data type for saving data. Defaults to ``np.float32``. it's used for NIfTI format only. allow_missing_keys: don't raise exception if key is missing. squeeze_end_dims: if True, any trailing singleton dimensions will be removed (after the channel has been moved to the end). So if input is (C,H,W,D), this will be altered to (H,W,D,C), and - then if C==1, it will be saved as (H,W,D). If D also ==1, it will be saved as (H,W). If false, + then if C==1, it will be saved as (H,W,D). If D is also 1, it will be saved as (H,W). If `false`, image will always be saved as (H,W,D,C). - it's used for NIfTI format only. data_root_dir: if not empty, it specifies the beginning parts of the input file's - absolute path. it's used to compute `input_file_rel_path`, the relative path to the file from + absolute path. It's used to compute `input_file_rel_path`, the relative path to the file from `data_root_dir` to preserve folder structure when saving in case there are files in different - folders with the same file names. for example: - input_file_name: /foo/bar/test1/image.nii, - output_postfix: seg - output_ext: nii.gz - output_dir: /output, - data_root_dir: /foo/bar, - output will be: /output/test1/image/image_seg.nii.gz - separate_folder: whether to save every file in a separate folder, for example: if input filename is - `image.nii`, postfix is `seg` and folder_path is `output`, if `True`, save as: - `output/image/image_seg.nii`, if `False`, save as `output/image_seg.nii`. default to `True`. - print_log: whether to print log about the saved file path, etc. default to `True`. + folders with the same file names. For example, with the following inputs: + + - input_file_name: `/foo/bar/test1/image.nii` + - output_postfix: `seg` + - output_ext: `.nii.gz` + - output_dir: `/output` + - data_root_dir: `/foo/bar` + + The output will be: /output/test1/image/image_seg.nii.gz + + separate_folder: whether to save every file in a separate folder. For example: for the input filename + `image.nii`, postfix `seg` and folder_path `output`, if `separate_folder=True`, it will be saved as: + `output/image/image_seg.nii`, if `False`, saving as `output/image_seg.nii`. Default to `True`. + print_log: whether to print logs when saving. Default to `True`. + output_format: an optional string to specify the output image writer. + see also: `monai.data.image_writer.SUPPORTED_WRITERS`. + writer: a customised image writer to save data arrays. + if `None`, use the default writer from `monai.data.image_writer` according to `output_ext`. """ @@ -234,11 +228,13 @@ def __init__( data_root_dir: str = "", separate_folder: bool = True, print_log: bool = True, + output_format: Optional[str] = None, + writer: Optional[image_writer.ImageWriter] = None, ) -> None: super().__init__(keys, allow_missing_keys) self.meta_keys = ensure_tuple_rep(meta_keys, len(self.keys)) self.meta_key_postfix = ensure_tuple_rep(meta_key_postfix, len(self.keys)) - self._saver = SaveImage( + self.saver = SaveImage( output_dir=output_dir, output_postfix=output_postfix, output_ext=output_ext, @@ -252,15 +248,20 @@ def __init__( data_root_dir=data_root_dir, separate_folder=separate_folder, print_log=print_log, + output_format=output_format, + writer=writer, ) + def set_options(self, init_kwargs=None, data_kwargs=None, meta_kwargs=None, write_kwargs=None): + self.saver.set_options(init_kwargs, data_kwargs, meta_kwargs, write_kwargs) + def __call__(self, data): d = dict(data) for key, meta_key, meta_key_postfix in self.key_iterator(d, self.meta_keys, self.meta_key_postfix): if meta_key is None and meta_key_postfix is not None: meta_key = f"{key}_{meta_key_postfix}" meta_data = d[meta_key] if meta_key is not None else None - self._saver(img=d[key], meta_data=meta_data) + self.saver(img=d[key], meta_data=meta_data) return d diff --git a/tests/min_tests.py b/tests/min_tests.py index 090167c4b11..cfb70387034 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -110,7 +110,6 @@ def run_testsuit(): "test_mlp", "test_nifti_header_revise", "test_nifti_rw", - "test_nifti_saver", "test_occlusion_sensitivity", "test_orientation", "test_orientationd", @@ -120,7 +119,6 @@ def run_testsuit(): "test_pil_reader", "test_plot_2d_or_3d_image", "test_png_rw", - "test_png_saver", "test_prepare_batch_default", "test_prepare_batch_extra_input", "test_rand_rotate", diff --git a/tests/test_handler_segmentation_saver.py b/tests/test_handler_segmentation_saver.py index 3632a98cfcb..ee6566f6cbb 100644 --- a/tests/test_handler_segmentation_saver.py +++ b/tests/test_handler_segmentation_saver.py @@ -39,7 +39,9 @@ def _train_func(engine, batch): engine = Engine(_train_func) # set up testing handler - saver = SegmentationSaver(output_dir=tempdir, output_postfix="seg", output_ext=output_ext, scale=255) + saver = SegmentationSaver( + output_dir=tempdir, output_postfix="seg", output_ext=output_ext, scale=255, output_dtype=np.uint8 + ) saver.attach(engine) data = [ @@ -65,7 +67,9 @@ def _train_func(engine, batch): engine = Engine(_train_func) # set up testing handler - saver = SegmentationSaver(output_dir=tempdir, output_postfix="seg", output_ext=output_ext, scale=255) + saver = SegmentationSaver( + output_dir=tempdir, output_postfix="seg", output_ext=output_ext, scale=255, output_dtype=np.uint8 + ) saver.attach(engine) data = [ diff --git a/tests/test_integration_segmentation_3d.py b/tests/test_integration_segmentation_3d.py index 718c9291fbe..5c273d0a461 100644 --- a/tests/test_integration_segmentation_3d.py +++ b/tests/test_integration_segmentation_3d.py @@ -21,7 +21,7 @@ from torch.utils.tensorboard import SummaryWriter import monai -from monai.data import NiftiSaver, create_test_image_3d, decollate_batch +from monai.data import create_test_image_3d, decollate_batch from monai.inferers import sliding_window_inference from monai.metrics import DiceMetric from monai.networks import eval_mode @@ -34,6 +34,7 @@ LoadImaged, RandCropByPosNegLabeld, RandRotate90d, + SaveImage, ScaleIntensityd, Spacingd, ToTensor, @@ -213,17 +214,25 @@ def run_inference_test(root_dir, device="cuda:0"): with eval_mode(model): # resampling with align_corners=True or dtype=float64 will generate # slight different results between PyTorch 1.5 an 1.6 - saver = NiftiSaver(output_dir=os.path.join(root_dir, "output"), dtype=np.float32) + saver = SaveImage( + output_dir=os.path.join(root_dir, "output"), + dtype=np.float32, + output_ext=".nii.gz", + output_postfix="seg", + mode="bilinear", + ) for val_data in val_loader: val_images, val_labels = val_data["img"].to(device), val_data["seg"].to(device) # define sliding window size and batch size for windows inference sw_batch_size, roi_size = 4, (96, 96, 96) val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model) - # decollate prediction into a list and execute post processing for every item + # decollate prediction into a list val_outputs = [val_post_tran(i) for i in decollate_batch(val_outputs)] + val_meta = decollate_batch(val_data[PostFix.meta("img")]) # compute metrics dice_metric(y_pred=val_outputs, y=val_labels) - saver.save_batch(val_outputs, val_data[PostFix.meta("img")]) + for img, meta in zip(val_outputs, val_meta): # save a decollated batch of files + saver(img, meta) return dice_metric.aggregate().item() diff --git a/tests/test_nifti_saver.py b/tests/test_nifti_saver.py deleted file mode 100644 index 6855a590412..00000000000 --- a/tests/test_nifti_saver.py +++ /dev/null @@ -1,111 +0,0 @@ -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import tempfile -import unittest -from pathlib import Path - -import numpy as np -import torch - -from monai.data import NiftiSaver -from monai.transforms import LoadImage - - -class TestNiftiSaver(unittest.TestCase): - def test_saved_content(self): - with tempfile.TemporaryDirectory() as tempdir: - - saver = NiftiSaver(output_dir=Path(tempdir), output_postfix="seg", output_ext=".nii.gz") - - meta_data = {"filename_or_obj": ["testfile" + str(i) + ".nii" for i in range(8)]} - saver.save_batch(torch.zeros(8, 1, 2, 2), meta_data) - for i in range(8): - filepath = os.path.join("testfile" + str(i), "testfile" + str(i) + "_seg.nii.gz") - self.assertTrue(os.path.exists(os.path.join(tempdir, filepath))) - - def test_saved_resize_content(self): - with tempfile.TemporaryDirectory() as tempdir: - - saver = NiftiSaver(output_dir=tempdir, output_postfix="seg", output_ext=".nii.gz", dtype=np.float32) - - meta_data = { - "filename_or_obj": ["testfile" + str(i) + ".nii" for i in range(8)], - "affine": [np.diag(np.ones(4)) * 5] * 8, - "original_affine": [np.diag(np.ones(4)) * 1.0] * 8, - } - saver.save_batch(torch.randint(0, 255, (8, 8, 2, 2)), meta_data) - for i in range(8): - filepath = os.path.join("testfile" + str(i), "testfile" + str(i) + "_seg.nii.gz") - self.assertTrue(os.path.exists(os.path.join(tempdir, filepath))) - - def test_saved_3d_resize_content(self): - with tempfile.TemporaryDirectory() as tempdir: - - saver = NiftiSaver(output_dir=tempdir, output_postfix="seg", output_ext=".nii.gz", dtype=np.float32) - - meta_data = { - "filename_or_obj": ["testfile" + str(i) + ".nii.gz" for i in range(8)], - "spatial_shape": [(10, 10, 2)] * 8, - "affine": [np.diag(np.ones(4)) * 5] * 8, - "original_affine": [np.diag(np.ones(4)) * 1.0] * 8, - } - saver.save_batch(torch.randint(0, 255, (8, 8, 1, 2, 2)), meta_data) - for i in range(8): - filepath = os.path.join("testfile" + str(i), "testfile" + str(i) + "_seg.nii.gz") - self.assertTrue(os.path.exists(os.path.join(tempdir, filepath))) - - def test_saved_3d_no_resize_content(self): - with tempfile.TemporaryDirectory() as tempdir: - - saver = NiftiSaver( - output_dir=tempdir, output_postfix="seg", output_ext=".nii.gz", dtype=np.float32, resample=False - ) - - meta_data = { - "filename_or_obj": ["testfile" + str(i) + ".nii.gz" for i in range(8)], - "spatial_shape": [(10, 10, 2)] * 8, - "affine": [np.diag(np.ones(4)) * 5] * 8, - "original_affine": [np.diag(np.ones(4)) * 1.0] * 8, - } - saver.save_batch(torch.randint(0, 255, (8, 8, 1, 2, 2)), meta_data) - for i in range(8): - filepath = os.path.join(tempdir, "testfile" + str(i), "testfile" + str(i) + "_seg.nii.gz") - img, _ = LoadImage("nibabelreader")(filepath) - self.assertEqual(img.shape, (1, 2, 2, 8)) - - def test_squeeze_end_dims(self): - with tempfile.TemporaryDirectory() as tempdir: - - for squeeze_end_dims in [False, True]: - - saver = NiftiSaver( - output_dir=tempdir, - output_postfix="", - output_ext=".nii.gz", - dtype=np.float32, - squeeze_end_dims=squeeze_end_dims, - ) - - fname = "testfile_squeeze" - meta_data = {"filename_or_obj": fname} - - # 2d image w channel - saver.save(torch.randint(0, 255, (1, 2, 2)), meta_data) - - im, meta = LoadImage()(os.path.join(tempdir, fname, fname + ".nii.gz")) - self.assertTrue(im.ndim == 2 if squeeze_end_dims else 4) - self.assertTrue(meta["dim"][0] == im.ndim) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_png_saver.py b/tests/test_png_saver.py deleted file mode 100644 index d8327186439..00000000000 --- a/tests/test_png_saver.py +++ /dev/null @@ -1,76 +0,0 @@ -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import tempfile -import unittest -from pathlib import Path - -import torch - -from monai.data import PNGSaver - - -class TestPNGSaver(unittest.TestCase): - def test_saved_content(self): - with tempfile.TemporaryDirectory() as tempdir: - - saver = PNGSaver(output_dir=tempdir, output_postfix="seg", output_ext=".png", scale=255) - - meta_data = {"filename_or_obj": ["testfile" + str(i) + ".jpg" for i in range(8)]} - saver.save_batch(torch.randint(1, 200, (8, 1, 2, 2)), meta_data) - for i in range(8): - filepath = os.path.join("testfile" + str(i), "testfile" + str(i) + "_seg.png") - self.assertTrue(os.path.exists(os.path.join(tempdir, filepath))) - - def test_saved_content_three_channel(self): - with tempfile.TemporaryDirectory() as tempdir: - - saver = PNGSaver(output_dir=Path(tempdir), output_postfix="seg", output_ext=".png", scale=255) - - meta_data = {"filename_or_obj": ["testfile" + str(i) + ".jpg" for i in range(8)]} - saver.save_batch(torch.randint(1, 200, (8, 3, 2, 2)), meta_data) - for i in range(8): - filepath = os.path.join("testfile" + str(i), "testfile" + str(i) + "_seg.png") - self.assertTrue(os.path.exists(os.path.join(tempdir, filepath))) - - def test_saved_content_spatial_size(self): - with tempfile.TemporaryDirectory() as tempdir: - - saver = PNGSaver(output_dir=tempdir, output_postfix="seg", output_ext=".png", scale=255) - - meta_data = { - "filename_or_obj": ["testfile" + str(i) + ".jpg" for i in range(8)], - "spatial_shape": [(4, 4) for i in range(8)], - } - saver.save_batch(torch.randint(1, 200, (8, 1, 2, 2)), meta_data) - for i in range(8): - filepath = os.path.join("testfile" + str(i), "testfile" + str(i) + "_seg.png") - self.assertTrue(os.path.exists(os.path.join(tempdir, filepath))) - - def test_saved_specified_root(self): - with tempfile.TemporaryDirectory() as tempdir: - - saver = PNGSaver( - output_dir=tempdir, output_postfix="seg", output_ext=".png", scale=255, data_root_dir="test" - ) - - meta_data = { - "filename_or_obj": [os.path.join("test", "testfile" + str(i), "image" + ".jpg") for i in range(8)] - } - saver.save_batch(torch.randint(1, 200, (8, 1, 2, 2)), meta_data) - for i in range(8): - filepath = os.path.join("testfile" + str(i), "image", "image" + "_seg.png") - self.assertTrue(os.path.exists(os.path.join(tempdir, filepath))) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_save_image.py b/tests/test_save_image.py index d3671cf8309..7c703b5220a 100644 --- a/tests/test_save_image.py +++ b/tests/test_save_image.py @@ -13,6 +13,7 @@ import tempfile import unittest +import numpy as np import torch from parameterized import parameterized @@ -22,9 +23,18 @@ TEST_CASE_2 = [torch.randint(0, 255, (1, 2, 3, 4)), None, ".nii.gz", False] +TEST_CASE_3 = [torch.randint(0, 255, (1, 2, 3, 4)), {"filename_or_obj": "testfile0.nrrd"}, ".nrrd", False] + +TEST_CASE_4 = [ + np.random.randint(0, 255, (3, 2, 4, 5), dtype=np.uint8), + {"filename_or_obj": "testfile0.dcm"}, + ".dcm", + False, +] + class TestSaveImage(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4]) def test_saved_content(self, test_data, meta_data, output_ext, resample): with tempfile.TemporaryDirectory() as tempdir: trans = SaveImage( diff --git a/tests/test_save_imaged.py b/tests/test_save_imaged.py index 6f0bb4c2ba2..a6988683e55 100644 --- a/tests/test_save_imaged.py +++ b/tests/test_save_imaged.py @@ -35,9 +35,19 @@ False, ] +TEST_CASE_3 = [ + { + "img": torch.randint(0, 255, (1, 2, 3, 4)), + PostFix.meta("img"): {"filename_or_obj": "testfile0.nrrd"}, + "patch_index": 6, + }, + ".nrrd", + False, +] + class TestSaveImaged(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) def test_saved_content(self, test_data, output_ext, resample): with tempfile.TemporaryDirectory() as tempdir: trans = SaveImaged(