forked from Project-MONAI/MONAI
-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
2583 Add DatasetCalculator (Project-MONAI#2616)
* Add DatasetCalculator Signed-off-by: Yiheng Wang <vennw@nvidia.com> * update docstring Signed-off-by: Yiheng Wang <vennw@nvidia.com> * use multiprocessing Signed-off-by: Yiheng Wang <vennw@nvidia.com> * update to use dataset and other places Signed-off-by: Yiheng Wang <vennw@nvidia.com> * update to support array return Signed-off-by: Yiheng Wang <vennw@nvidia.com> * update with new testcases and change name Signed-off-by: Yiheng Wang <vennw@nvidia.com> * update min test Signed-off-by: Yiheng Wang <vennw@nvidia.com> * update unittest Signed-off-by: Yiheng Wang <vennw@nvidia.com> * fix vstack error Signed-off-by: Yiheng Wang <vennw@nvidia.com>
- Loading branch information
1 parent
f26f115
commit 86e2a06
Showing
5 changed files
with
278 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,182 @@ | ||
# Copyright 2020 - 2021 MONAI Consortium | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from itertools import chain | ||
from typing import List, Optional | ||
|
||
import numpy as np | ||
import torch | ||
|
||
from monai.data.dataloader import DataLoader | ||
from monai.data.dataset import Dataset | ||
|
||
|
||
class DatasetSummary: | ||
""" | ||
This class provides a way to calculate a reasonable output voxel spacing according to | ||
the input dataset. The achieved values can used to resample the input in 3d segmentation tasks | ||
(like using as the `pixdim` parameter in `monai.transforms.Spacingd`). | ||
In addition, it also supports to count the mean, std, min and max intensities of the input, | ||
and these statistics are helpful for image normalization | ||
(like using in `monai.transforms.ScaleIntensityRanged` and `monai.transforms.NormalizeIntensityd`). | ||
The algorithm for calculation refers to: | ||
`Automated Design of Deep Learning Methods for Biomedical Image Segmentation <https://arxiv.org/abs/1904.08128>`_. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
dataset: Dataset, | ||
image_key: Optional[str] = "image", | ||
label_key: Optional[str] = "label", | ||
meta_key_postfix: str = "meta_dict", | ||
num_workers: int = 0, | ||
**kwargs, | ||
): | ||
""" | ||
Args: | ||
dataset: dataset from which to load the data. | ||
image_key: key name of images (default: ``image``). | ||
label_key: key name of labels (default: ``label``). | ||
meta_key_postfix: use `{image_key}_{meta_key_postfix}` to fetch the meta data from dict, | ||
the meta data is a dictionary object (default: ``meta_dict``). | ||
num_workers: how many subprocesses to use for data loading. | ||
``0`` means that the data will be loaded in the main process (default: ``0``). | ||
kwargs: other parameters (except batch_size) for DataLoader (this class forces to use ``batch_size=1``). | ||
""" | ||
|
||
self.data_loader = DataLoader(dataset=dataset, batch_size=1, num_workers=num_workers, **kwargs) | ||
|
||
self.image_key = image_key | ||
self.label_key = label_key | ||
if image_key: | ||
self.meta_key = "{}_{}".format(image_key, meta_key_postfix) | ||
self.all_meta_data: List = [] | ||
|
||
def collect_meta_data(self): | ||
""" | ||
This function is used to collect the meta data for all images of the dataset. | ||
""" | ||
if not self.meta_key: | ||
raise ValueError("To collect meta data for the dataset, `meta_key` should exist.") | ||
|
||
for data in self.data_loader: | ||
self.all_meta_data.append(data[self.meta_key]) | ||
|
||
def get_target_spacing(self, spacing_key: str = "pixdim", anisotropic_threshold: int = 3, percentile: float = 10.0): | ||
""" | ||
Calculate the target spacing according to all spacings. | ||
If the target spacing is very anisotropic, | ||
decrease the spacing value of the maximum axis according to percentile. | ||
So far, this function only supports NIFTI images which store spacings in headers with key "pixdim". After loading | ||
with `monai.DataLoader`, "pixdim" is in the form of `torch.Tensor` with size `(batch_size, 8)`. | ||
Args: | ||
spacing_key: key of spacing in meta data (default: ``pixdim``). | ||
anisotropic_threshold: threshold to decide if the target spacing is anisotropic (default: ``3``). | ||
percentile: for anisotropic target spacing, use the percentile of all spacings of the anisotropic axis to | ||
replace that axis. | ||
""" | ||
if len(self.all_meta_data) == 0: | ||
self.collect_meta_data() | ||
if spacing_key not in self.all_meta_data[0]: | ||
raise ValueError("The provided spacing_key is not in self.all_meta_data.") | ||
|
||
all_spacings = torch.cat([data[spacing_key][:, 1:4] for data in self.all_meta_data], dim=0).numpy() | ||
|
||
target_spacing = np.median(all_spacings, axis=0) | ||
if max(target_spacing) / min(target_spacing) >= anisotropic_threshold: | ||
largest_axis = np.argmax(target_spacing) | ||
target_spacing[largest_axis] = np.percentile(all_spacings[:, largest_axis], percentile) | ||
|
||
output = list(target_spacing) | ||
|
||
return tuple(output) | ||
|
||
def calculate_statistics(self, foreground_threshold: int = 0): | ||
""" | ||
This function is used to calculate the maximum, minimum, mean and standard deviation of intensities of | ||
the input dataset. | ||
Args: | ||
foreground_threshold: the threshold to distinguish if a voxel belongs to foreground, this parameter | ||
is used to select the foreground of images for calculation. Normally, `label > 0` means the corresponding | ||
voxel belongs to foreground, thus if you need to calculate the statistics for whole images, you can set | ||
the threshold to ``-1`` (default: ``0``). | ||
""" | ||
voxel_sum = torch.as_tensor(0.0) | ||
voxel_square_sum = torch.as_tensor(0.0) | ||
voxel_max, voxel_min = [], [] | ||
voxel_ct = 0 | ||
|
||
for data in self.data_loader: | ||
if self.image_key and self.label_key: | ||
image, label = data[self.image_key], data[self.label_key] | ||
else: | ||
image, label = data | ||
|
||
voxel_max.append(image.max().item()) | ||
voxel_min.append(image.min().item()) | ||
|
||
image_foreground = image[torch.where(label > foreground_threshold)] | ||
voxel_ct += len(image_foreground) | ||
voxel_sum += image_foreground.sum() | ||
voxel_square_sum += torch.square(image_foreground).sum() | ||
|
||
self.data_max, self.data_min = max(voxel_max), min(voxel_min) | ||
self.data_mean = (voxel_sum / voxel_ct).item() | ||
self.data_std = (torch.sqrt(voxel_square_sum / voxel_ct - self.data_mean ** 2)).item() | ||
|
||
def calculate_percentiles( | ||
self, | ||
foreground_threshold: int = 0, | ||
sampling_flag: bool = True, | ||
interval: int = 10, | ||
min_percentile: float = 0.5, | ||
max_percentile: float = 99.5, | ||
): | ||
""" | ||
This function is used to calculate the percentiles of intensities (and median) of the input dataset. To get | ||
the required values, all voxels need to be accumulated. To reduce the memory used, this function can be set | ||
to accumulate only a part of the voxels. | ||
Args: | ||
foreground_threshold: the threshold to distinguish if a voxel belongs to foreground, this parameter | ||
is used to select the foreground of images for calculation. Normally, `label > 0` means the corresponding | ||
voxel belongs to foreground, thus if you need to calculate the statistics for whole images, you can set | ||
the threshold to ``-1`` (default: ``0``). | ||
sampling_flag: whether to sample only a part of the voxels (default: ``True``). | ||
interval: the sampling interval for accumulating voxels (default: ``10``). | ||
min_percentile: minimal percentile (default: ``0.5``). | ||
max_percentile: maximal percentile (default: ``99.5``). | ||
""" | ||
all_intensities = [] | ||
for data in self.data_loader: | ||
if self.image_key and self.label_key: | ||
image, label = data[self.image_key], data[self.label_key] | ||
else: | ||
image, label = data | ||
|
||
intensities = image[torch.where(label > foreground_threshold)].tolist() | ||
if sampling_flag: | ||
intensities = intensities[::interval] | ||
all_intensities.append(intensities) | ||
|
||
all_intensities = list(chain(*all_intensities)) | ||
self.data_min_percentile, self.data_max_percentile = np.percentile( | ||
all_intensities, [min_percentile, max_percentile] | ||
) | ||
self.data_median = np.median(all_intensities) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
# Copyright 2020 - 2021 MONAI Consortium | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import glob | ||
import os | ||
import tempfile | ||
import unittest | ||
|
||
import nibabel as nib | ||
import numpy as np | ||
|
||
from monai.data import Dataset, DatasetSummary, create_test_image_3d | ||
from monai.transforms import LoadImaged | ||
from monai.utils import set_determinism | ||
|
||
|
||
class TestDatasetSummary(unittest.TestCase): | ||
def test_spacing_intensity(self): | ||
set_determinism(seed=0) | ||
with tempfile.TemporaryDirectory() as tempdir: | ||
|
||
for i in range(5): | ||
im, seg = create_test_image_3d(32, 32, 32, num_seg_classes=1, num_objs=3, rad_max=6, channel_dim=0) | ||
n = nib.Nifti1Image(im, np.eye(4)) | ||
nib.save(n, os.path.join(tempdir, f"img{i:d}.nii.gz")) | ||
n = nib.Nifti1Image(seg, np.eye(4)) | ||
nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz")) | ||
|
||
train_images = sorted(glob.glob(os.path.join(tempdir, "img*.nii.gz"))) | ||
train_labels = sorted(glob.glob(os.path.join(tempdir, "seg*.nii.gz"))) | ||
data_dicts = [ | ||
{"image": image_name, "label": label_name} for image_name, label_name in zip(train_images, train_labels) | ||
] | ||
|
||
dataset = Dataset(data=data_dicts, transform=LoadImaged(keys=["image", "label"])) | ||
|
||
calculator = DatasetSummary(dataset, num_workers=4) | ||
|
||
target_spacing = calculator.get_target_spacing() | ||
self.assertEqual(target_spacing, (1.0, 1.0, 1.0)) | ||
calculator.calculate_statistics() | ||
np.testing.assert_allclose(calculator.data_mean, 0.892599, rtol=1e-5, atol=1e-5) | ||
np.testing.assert_allclose(calculator.data_std, 0.131731, rtol=1e-5, atol=1e-5) | ||
calculator.calculate_percentiles(sampling_flag=True, interval=2) | ||
self.assertEqual(calculator.data_max_percentile, 1.0) | ||
np.testing.assert_allclose(calculator.data_min_percentile, 0.556411, rtol=1e-5, atol=1e-5) | ||
|
||
def test_anisotropic_spacing(self): | ||
with tempfile.TemporaryDirectory() as tempdir: | ||
|
||
pixdims = [ | ||
[1.0, 1.0, 5.0], | ||
[1.0, 1.0, 4.0], | ||
[1.0, 1.0, 4.5], | ||
[1.0, 1.0, 2.0], | ||
[1.0, 1.0, 1.0], | ||
] | ||
for i in range(5): | ||
im, seg = create_test_image_3d(32, 32, 32, num_seg_classes=1, num_objs=3, rad_max=6, channel_dim=0) | ||
n = nib.Nifti1Image(im, np.eye(4)) | ||
n.header["pixdim"][1:4] = pixdims[i] | ||
nib.save(n, os.path.join(tempdir, f"img{i:d}.nii.gz")) | ||
n = nib.Nifti1Image(seg, np.eye(4)) | ||
n.header["pixdim"][1:4] = pixdims[i] | ||
nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz")) | ||
|
||
train_images = sorted(glob.glob(os.path.join(tempdir, "img*.nii.gz"))) | ||
train_labels = sorted(glob.glob(os.path.join(tempdir, "seg*.nii.gz"))) | ||
data_dicts = [ | ||
{"image": image_name, "label": label_name} for image_name, label_name in zip(train_images, train_labels) | ||
] | ||
|
||
dataset = Dataset(data=data_dicts, transform=LoadImaged(keys=["image", "label"])) | ||
|
||
calculator = DatasetSummary(dataset, num_workers=4) | ||
|
||
target_spacing = calculator.get_target_spacing(anisotropic_threshold=4.0, percentile=20.0) | ||
np.testing.assert_allclose(target_spacing, (1.0, 1.0, 1.8)) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |