diff --git a/docs/source/data.rst b/docs/source/data.rst index a5c3509fc9c..022f7877d1c 100644 --- a/docs/source/data.rst +++ b/docs/source/data.rst @@ -182,6 +182,10 @@ DistributedWeightedRandomSampler ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: monai.data.DistributedWeightedRandomSampler +DatasetSummary +~~~~~~~~~~~~~~ +.. autoclass:: monai.data.DatasetSummary + Decathlon Datalist ~~~~~~~~~~~~~~~~~~ .. autofunction:: monai.data.load_decathlon_datalist diff --git a/monai/data/__init__.py b/monai/data/__init__.py index af42627f5f1..fca170335b3 100644 --- a/monai/data/__init__.py +++ b/monai/data/__init__.py @@ -23,6 +23,7 @@ SmartCacheDataset, ZipDataset, ) +from .dataset_summary import DatasetSummary from .decathlon_datalist import load_decathlon_datalist, load_decathlon_properties from .grid_dataset import GridPatchDataset, PatchDataset, PatchIter from .image_dataset import ImageDataset diff --git a/monai/data/dataset_summary.py b/monai/data/dataset_summary.py new file mode 100644 index 00000000000..a8598eb6c88 --- /dev/null +++ b/monai/data/dataset_summary.py @@ -0,0 +1,182 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from itertools import chain +from typing import List, Optional + +import numpy as np +import torch + +from monai.data.dataloader import DataLoader +from monai.data.dataset import Dataset + + +class DatasetSummary: + """ + This class provides a way to calculate a reasonable output voxel spacing according to + the input dataset. The achieved values can used to resample the input in 3d segmentation tasks + (like using as the `pixdim` parameter in `monai.transforms.Spacingd`). + In addition, it also supports to count the mean, std, min and max intensities of the input, + and these statistics are helpful for image normalization + (like using in `monai.transforms.ScaleIntensityRanged` and `monai.transforms.NormalizeIntensityd`). + + The algorithm for calculation refers to: + `Automated Design of Deep Learning Methods for Biomedical Image Segmentation `_. + + """ + + def __init__( + self, + dataset: Dataset, + image_key: Optional[str] = "image", + label_key: Optional[str] = "label", + meta_key_postfix: str = "meta_dict", + num_workers: int = 0, + **kwargs, + ): + """ + Args: + dataset: dataset from which to load the data. + image_key: key name of images (default: ``image``). + label_key: key name of labels (default: ``label``). + meta_key_postfix: use `{image_key}_{meta_key_postfix}` to fetch the meta data from dict, + the meta data is a dictionary object (default: ``meta_dict``). + num_workers: how many subprocesses to use for data loading. + ``0`` means that the data will be loaded in the main process (default: ``0``). + kwargs: other parameters (except batch_size) for DataLoader (this class forces to use ``batch_size=1``). + + """ + + self.data_loader = DataLoader(dataset=dataset, batch_size=1, num_workers=num_workers, **kwargs) + + self.image_key = image_key + self.label_key = label_key + if image_key: + self.meta_key = "{}_{}".format(image_key, meta_key_postfix) + self.all_meta_data: List = [] + + def collect_meta_data(self): + """ + This function is used to collect the meta data for all images of the dataset. + """ + if not self.meta_key: + raise ValueError("To collect meta data for the dataset, `meta_key` should exist.") + + for data in self.data_loader: + self.all_meta_data.append(data[self.meta_key]) + + def get_target_spacing(self, spacing_key: str = "pixdim", anisotropic_threshold: int = 3, percentile: float = 10.0): + """ + Calculate the target spacing according to all spacings. + If the target spacing is very anisotropic, + decrease the spacing value of the maximum axis according to percentile. + So far, this function only supports NIFTI images which store spacings in headers with key "pixdim". After loading + with `monai.DataLoader`, "pixdim" is in the form of `torch.Tensor` with size `(batch_size, 8)`. + + Args: + spacing_key: key of spacing in meta data (default: ``pixdim``). + anisotropic_threshold: threshold to decide if the target spacing is anisotropic (default: ``3``). + percentile: for anisotropic target spacing, use the percentile of all spacings of the anisotropic axis to + replace that axis. + + """ + if len(self.all_meta_data) == 0: + self.collect_meta_data() + if spacing_key not in self.all_meta_data[0]: + raise ValueError("The provided spacing_key is not in self.all_meta_data.") + + all_spacings = torch.cat([data[spacing_key][:, 1:4] for data in self.all_meta_data], dim=0).numpy() + + target_spacing = np.median(all_spacings, axis=0) + if max(target_spacing) / min(target_spacing) >= anisotropic_threshold: + largest_axis = np.argmax(target_spacing) + target_spacing[largest_axis] = np.percentile(all_spacings[:, largest_axis], percentile) + + output = list(target_spacing) + + return tuple(output) + + def calculate_statistics(self, foreground_threshold: int = 0): + """ + This function is used to calculate the maximum, minimum, mean and standard deviation of intensities of + the input dataset. + + Args: + foreground_threshold: the threshold to distinguish if a voxel belongs to foreground, this parameter + is used to select the foreground of images for calculation. Normally, `label > 0` means the corresponding + voxel belongs to foreground, thus if you need to calculate the statistics for whole images, you can set + the threshold to ``-1`` (default: ``0``). + + """ + voxel_sum = torch.as_tensor(0.0) + voxel_square_sum = torch.as_tensor(0.0) + voxel_max, voxel_min = [], [] + voxel_ct = 0 + + for data in self.data_loader: + if self.image_key and self.label_key: + image, label = data[self.image_key], data[self.label_key] + else: + image, label = data + + voxel_max.append(image.max().item()) + voxel_min.append(image.min().item()) + + image_foreground = image[torch.where(label > foreground_threshold)] + voxel_ct += len(image_foreground) + voxel_sum += image_foreground.sum() + voxel_square_sum += torch.square(image_foreground).sum() + + self.data_max, self.data_min = max(voxel_max), min(voxel_min) + self.data_mean = (voxel_sum / voxel_ct).item() + self.data_std = (torch.sqrt(voxel_square_sum / voxel_ct - self.data_mean ** 2)).item() + + def calculate_percentiles( + self, + foreground_threshold: int = 0, + sampling_flag: bool = True, + interval: int = 10, + min_percentile: float = 0.5, + max_percentile: float = 99.5, + ): + """ + This function is used to calculate the percentiles of intensities (and median) of the input dataset. To get + the required values, all voxels need to be accumulated. To reduce the memory used, this function can be set + to accumulate only a part of the voxels. + + Args: + foreground_threshold: the threshold to distinguish if a voxel belongs to foreground, this parameter + is used to select the foreground of images for calculation. Normally, `label > 0` means the corresponding + voxel belongs to foreground, thus if you need to calculate the statistics for whole images, you can set + the threshold to ``-1`` (default: ``0``). + sampling_flag: whether to sample only a part of the voxels (default: ``True``). + interval: the sampling interval for accumulating voxels (default: ``10``). + min_percentile: minimal percentile (default: ``0.5``). + max_percentile: maximal percentile (default: ``99.5``). + + """ + all_intensities = [] + for data in self.data_loader: + if self.image_key and self.label_key: + image, label = data[self.image_key], data[self.label_key] + else: + image, label = data + + intensities = image[torch.where(label > foreground_threshold)].tolist() + if sampling_flag: + intensities = intensities[::interval] + all_intensities.append(intensities) + + all_intensities = list(chain(*all_intensities)) + self.data_min_percentile, self.data_max_percentile = np.percentile( + all_intensities, [min_percentile, max_percentile] + ) + self.data_median = np.median(all_intensities) diff --git a/tests/min_tests.py b/tests/min_tests.py index 1cd54f35d09..1f53569cd98 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -111,6 +111,7 @@ def run_testsuit(): "test_handler_metrics_saver", "test_handler_metrics_saver_dist", "test_handler_classification_saver_dist", + "test_dataset_summary", "test_deepgrow_transforms", "test_deepgrow_interaction", "test_deepgrow_dataset", diff --git a/tests/test_dataset_summary.py b/tests/test_dataset_summary.py new file mode 100644 index 00000000000..5307bc7e66f --- /dev/null +++ b/tests/test_dataset_summary.py @@ -0,0 +1,90 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import glob +import os +import tempfile +import unittest + +import nibabel as nib +import numpy as np + +from monai.data import Dataset, DatasetSummary, create_test_image_3d +from monai.transforms import LoadImaged +from monai.utils import set_determinism + + +class TestDatasetSummary(unittest.TestCase): + def test_spacing_intensity(self): + set_determinism(seed=0) + with tempfile.TemporaryDirectory() as tempdir: + + for i in range(5): + im, seg = create_test_image_3d(32, 32, 32, num_seg_classes=1, num_objs=3, rad_max=6, channel_dim=0) + n = nib.Nifti1Image(im, np.eye(4)) + nib.save(n, os.path.join(tempdir, f"img{i:d}.nii.gz")) + n = nib.Nifti1Image(seg, np.eye(4)) + nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz")) + + train_images = sorted(glob.glob(os.path.join(tempdir, "img*.nii.gz"))) + train_labels = sorted(glob.glob(os.path.join(tempdir, "seg*.nii.gz"))) + data_dicts = [ + {"image": image_name, "label": label_name} for image_name, label_name in zip(train_images, train_labels) + ] + + dataset = Dataset(data=data_dicts, transform=LoadImaged(keys=["image", "label"])) + + calculator = DatasetSummary(dataset, num_workers=4) + + target_spacing = calculator.get_target_spacing() + self.assertEqual(target_spacing, (1.0, 1.0, 1.0)) + calculator.calculate_statistics() + np.testing.assert_allclose(calculator.data_mean, 0.892599, rtol=1e-5, atol=1e-5) + np.testing.assert_allclose(calculator.data_std, 0.131731, rtol=1e-5, atol=1e-5) + calculator.calculate_percentiles(sampling_flag=True, interval=2) + self.assertEqual(calculator.data_max_percentile, 1.0) + np.testing.assert_allclose(calculator.data_min_percentile, 0.556411, rtol=1e-5, atol=1e-5) + + def test_anisotropic_spacing(self): + with tempfile.TemporaryDirectory() as tempdir: + + pixdims = [ + [1.0, 1.0, 5.0], + [1.0, 1.0, 4.0], + [1.0, 1.0, 4.5], + [1.0, 1.0, 2.0], + [1.0, 1.0, 1.0], + ] + for i in range(5): + im, seg = create_test_image_3d(32, 32, 32, num_seg_classes=1, num_objs=3, rad_max=6, channel_dim=0) + n = nib.Nifti1Image(im, np.eye(4)) + n.header["pixdim"][1:4] = pixdims[i] + nib.save(n, os.path.join(tempdir, f"img{i:d}.nii.gz")) + n = nib.Nifti1Image(seg, np.eye(4)) + n.header["pixdim"][1:4] = pixdims[i] + nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz")) + + train_images = sorted(glob.glob(os.path.join(tempdir, "img*.nii.gz"))) + train_labels = sorted(glob.glob(os.path.join(tempdir, "seg*.nii.gz"))) + data_dicts = [ + {"image": image_name, "label": label_name} for image_name, label_name in zip(train_images, train_labels) + ] + + dataset = Dataset(data=data_dicts, transform=LoadImaged(keys=["image", "label"])) + + calculator = DatasetSummary(dataset, num_workers=4) + + target_spacing = calculator.get_target_spacing(anisotropic_threshold=4.0, percentile=20.0) + np.testing.assert_allclose(target_spacing, (1.0, 1.0, 1.8)) + + +if __name__ == "__main__": + unittest.main()