Skip to content

Commit 1b170d8

Browse files
Add deepgrow dataset (#1581)
* Add deepgrow dataset Signed-off-by: YuanTingHsieh <yuantinghsieh@gmail.com> * Fix CI/CD issue Signed-off-by: YuanTingHsieh <yuantinghsieh@gmail.com> * Fix issues based on review Signed-off-by: YuanTingHsieh <yuantinghsieh@gmail.com>
1 parent 3f16c21 commit 1b170d8

File tree

3 files changed

+339
-0
lines changed

3 files changed

+339
-0
lines changed

monai/apps/deepgrow/dataset.py

+281
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,281 @@
1+
# Copyright 2020 - 2021 MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
import logging
13+
import os
14+
from typing import Dict, List
15+
16+
import numpy as np
17+
18+
from monai.transforms import AsChannelFirstd, Compose, LoadImaged, Orientationd, Spacingd
19+
from monai.utils import GridSampleMode
20+
21+
22+
def create_dataset(
23+
datalist,
24+
output_dir: str,
25+
dimension,
26+
pixdim,
27+
image_key: str = "image",
28+
label_key: str = "label",
29+
base_dir=None,
30+
limit: int = 0,
31+
relative_path: bool = False,
32+
transforms=None,
33+
) -> List[Dict]:
34+
"""
35+
Utility to pre-process and create dataset list for Deepgrow training over on existing one.
36+
The input data list is normally a list of images and labels (3D volume) that needs pre-processing
37+
for Deepgrow training pipeline.
38+
39+
Args:
40+
datalist: A list of data dictionary. Each entry should at least contain 'image_key': <image filename>.
41+
For example, typical input data can be a list of dictionaries::
42+
43+
[{'image': <image filename>, 'label': <label filename>}]
44+
45+
output_dir: target directory to store the training data for Deepgrow Training
46+
pixdim: output voxel spacing.
47+
dimension: dimension for Deepgrow training. It can be 2 or 3.
48+
image_key: image key in input datalist. Defaults to 'image'.
49+
label_key: label key in input datalist. Defaults to 'label'.
50+
base_dir: base directory in case related path is used for the keys in datalist. Defaults to None.
51+
limit: limit number of inputs for pre-processing. Defaults to 0 (no limit).
52+
relative_path: output keys values should be based on relative path. Defaults to False.
53+
transforms: explicit transforms to execute operations on input data.
54+
55+
Raises:
56+
ValueError: When ``dimension`` is not one of [2, 3]
57+
ValueError: When ``datalist`` is Empty
58+
59+
Returns:
60+
A new datalist that contains path to the images/labels after pre-processing.
61+
62+
Example::
63+
64+
datalist = create_dataset(
65+
datalist=[{'image': 'img1.nii', 'label': 'label1.nii'}],
66+
base_dir=None,
67+
output_dir=output_2d,
68+
dimension=2,
69+
image_key='image',
70+
label_key='label',
71+
pixdim=(1.0, 1.0),
72+
limit=0,
73+
relative_path=True
74+
)
75+
76+
print(datalist[0]["image"], datalist[0]["label"])
77+
"""
78+
79+
if dimension not in [2, 3]:
80+
raise ValueError("Dimension can be only 2 or 3 as Deepgrow supports only 2D/3D Training")
81+
82+
if not len(datalist):
83+
raise ValueError("Input datalist is empty")
84+
85+
transforms = _default_transforms(image_key, label_key, pixdim) if transforms is None else transforms
86+
new_datalist = []
87+
for idx in range(len(datalist)):
88+
if limit and idx >= limit:
89+
break
90+
91+
image = datalist[idx][image_key]
92+
label = datalist[idx].get(label_key, None)
93+
if base_dir:
94+
image = os.path.join(base_dir, image)
95+
label = os.path.join(base_dir, label) if label else None
96+
97+
image = os.path.abspath(image)
98+
label = os.path.abspath(label) if label else None
99+
100+
logging.info("Image: {}; Label: {}".format(image, label if label else None))
101+
data = transforms({image_key: image, label_key: label})
102+
if dimension == 2:
103+
data = _save_data_2d(
104+
vol_idx=idx,
105+
vol_image=data[image_key],
106+
vol_label=data[label_key],
107+
dataset_dir=output_dir,
108+
relative_path=relative_path,
109+
)
110+
else:
111+
data = _save_data_3d(
112+
vol_idx=idx,
113+
vol_image=data[image_key],
114+
vol_label=data[label_key],
115+
dataset_dir=output_dir,
116+
relative_path=relative_path,
117+
)
118+
new_datalist.extend(data)
119+
return new_datalist
120+
121+
122+
def _default_transforms(image_key, label_key, pixdim):
123+
keys = [image_key] if label_key is None else [image_key, label_key]
124+
mode = [GridSampleMode.BILINEAR, GridSampleMode.NEAREST] if len(keys) == 2 else [GridSampleMode.BILINEAR]
125+
return Compose(
126+
[
127+
LoadImaged(keys=keys),
128+
AsChannelFirstd(keys=keys),
129+
Spacingd(keys=keys, pixdim=pixdim, mode=mode),
130+
Orientationd(keys=keys, axcodes="RAS"),
131+
]
132+
)
133+
134+
135+
def _save_data_2d(vol_idx, vol_image, vol_label, dataset_dir, relative_path):
136+
data_list = []
137+
138+
if len(vol_image.shape) == 4:
139+
logging.info(
140+
"4D-Image, pick only first series; Image: {}; Label: {}".format(
141+
vol_image.shape, vol_label.shape if vol_label else None
142+
)
143+
)
144+
vol_image = vol_image[0]
145+
vol_image = np.moveaxis(vol_image, -1, 0)
146+
147+
image_count = 0
148+
label_count = 0
149+
unique_labels_count = 0
150+
for sid in range(vol_image.shape[0]):
151+
image = vol_image[sid, ...]
152+
label = vol_label[sid, ...] if vol_label is not None else None
153+
154+
if vol_label is not None and np.sum(label) == 0:
155+
continue
156+
157+
image_file_prefix = "vol_idx_{:0>4d}_slice_{:0>3d}".format(vol_idx, sid)
158+
image_file = os.path.join(dataset_dir, "images", image_file_prefix)
159+
image_file += ".npy"
160+
161+
os.makedirs(os.path.join(dataset_dir, "images"), exist_ok=True)
162+
np.save(image_file, image)
163+
image_count += 1
164+
165+
# Test Data
166+
if vol_label is None:
167+
data_list.append(
168+
{
169+
"image": image_file.replace(dataset_dir + os.pathsep, "") if relative_path else image_file,
170+
}
171+
)
172+
continue
173+
174+
# For all Labels
175+
unique_labels = np.unique(label.flatten())
176+
unique_labels = unique_labels[unique_labels != 0]
177+
unique_labels_count = max(unique_labels_count, len(unique_labels))
178+
179+
for idx in unique_labels:
180+
label_file_prefix = "{}_region_{:0>2d}".format(image_file_prefix, int(idx))
181+
label_file = os.path.join(dataset_dir, "labels", label_file_prefix)
182+
label_file += ".npy"
183+
184+
os.makedirs(os.path.join(dataset_dir, "labels"), exist_ok=True)
185+
curr_label = (label == idx).astype(np.float32)
186+
np.save(label_file, curr_label)
187+
188+
label_count += 1
189+
data_list.append(
190+
{
191+
"image": image_file.replace(dataset_dir + os.pathsep, "") if relative_path else image_file,
192+
"label": label_file.replace(dataset_dir + os.pathsep, "") if relative_path else label_file,
193+
"region": int(idx),
194+
}
195+
)
196+
197+
if unique_labels_count >= 20:
198+
logging.warning(f"Unique labels {unique_labels_count} exceeds 20. Please check if this is correct.")
199+
200+
logging.info(
201+
"{} => Image Shape: {} => {}; Label Shape: {} => {}; Unique Labels: {}".format(
202+
vol_idx,
203+
vol_image.shape,
204+
image_count,
205+
vol_label.shape if vol_label is not None else None,
206+
label_count,
207+
unique_labels_count,
208+
)
209+
)
210+
return data_list
211+
212+
213+
def _save_data_3d(vol_idx, vol_image, vol_label, dataset_dir, relative_path):
214+
data_list = []
215+
216+
if len(vol_image.shape) == 4:
217+
logging.info(
218+
"4D-Image, pick only first series; Image: {}; Label: {}".format(
219+
vol_image.shape, vol_label.shape if vol_label else None
220+
)
221+
)
222+
vol_image = vol_image[0]
223+
vol_image = np.moveaxis(vol_image, -1, 0)
224+
225+
image_count = 0
226+
label_count = 0
227+
unique_labels_count = 0
228+
229+
image_file_prefix = "vol_idx_{:0>4d}".format(vol_idx)
230+
image_file = os.path.join(dataset_dir, "images", image_file_prefix)
231+
image_file += ".npy"
232+
233+
os.makedirs(os.path.join(dataset_dir, "images"), exist_ok=True)
234+
np.save(image_file, vol_image)
235+
image_count += 1
236+
237+
# Test Data
238+
if vol_label is None:
239+
data_list.append(
240+
{
241+
"image": image_file.replace(dataset_dir + os.pathsep, "") if relative_path else image_file,
242+
}
243+
)
244+
else:
245+
# For all Labels
246+
unique_labels = np.unique(vol_label.flatten())
247+
unique_labels = unique_labels[unique_labels != 0]
248+
unique_labels_count = max(unique_labels_count, len(unique_labels))
249+
250+
for idx in unique_labels:
251+
label_file_prefix = "{}_region_{:0>2d}".format(image_file_prefix, int(idx))
252+
label_file = os.path.join(dataset_dir, "labels", label_file_prefix)
253+
label_file += ".npy"
254+
255+
curr_label = (vol_label == idx).astype(np.float32)
256+
os.makedirs(os.path.join(dataset_dir, "labels"), exist_ok=True)
257+
np.save(label_file, curr_label)
258+
259+
label_count += 1
260+
data_list.append(
261+
{
262+
"image": image_file.replace(dataset_dir + os.pathsep, "") if relative_path else image_file,
263+
"label": label_file.replace(dataset_dir + os.pathsep, "") if relative_path else label_file,
264+
"region": int(idx),
265+
}
266+
)
267+
268+
if unique_labels_count >= 20:
269+
logging.warning(f"Unique labels {unique_labels_count} exceeds 20. Please check if this is correct.")
270+
271+
logging.info(
272+
"{} => Image Shape: {} => {}; Label Shape: {} => {}; Unique Labels: {}".format(
273+
vol_idx,
274+
vol_image.shape,
275+
image_count,
276+
vol_label.shape if vol_label is not None else None,
277+
label_count,
278+
unique_labels_count,
279+
)
280+
)
281+
return data_list

tests/min_tests.py

+1
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ def run_testsuit():
104104
"test_handler_metrics_saver_dist",
105105
"test_evenly_divisible_all_gather_dist",
106106
"test_handler_classification_saver_dist",
107+
"test_deepgrow_dataset",
107108
]
108109
assert sorted(exclude_cases) == sorted(set(exclude_cases)), f"Duplicated items in {exclude_cases}"
109110

tests/test_deepgrow_dataset.py

+57
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# Copyright 2020 - 2021 MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
import os
13+
import tempfile
14+
import unittest
15+
16+
import nibabel as nib
17+
import numpy as np
18+
19+
from monai.apps.deepgrow.dataset import create_dataset
20+
21+
22+
class TestCreateDataset(unittest.TestCase):
23+
def _create_data(self, tempdir):
24+
affine = np.eye(4)
25+
image = np.random.randint(0, 2, size=(128, 128, 40))
26+
image_file = os.path.join(tempdir, "image1.nii.gz")
27+
nib.save(nib.Nifti1Image(image, affine), image_file)
28+
29+
label = np.zeros((128, 128, 40))
30+
label[0][1][0] = 1
31+
label[0][1][1] = 1
32+
label[0][0][2] = 1
33+
label[0][1][2] = 1
34+
label_file = os.path.join(tempdir, "label1.nii.gz")
35+
nib.save(nib.Nifti1Image(label, affine), label_file)
36+
37+
return [{"image": image_file, "label": label_file}]
38+
39+
def test_create_dataset_2d(self):
40+
with tempfile.TemporaryDirectory() as tempdir:
41+
datalist = self._create_data(tempdir)
42+
output_dir = os.path.join(tempdir, "2d")
43+
deepgrow_datalist = create_dataset(datalist=datalist, output_dir=output_dir, dimension=2, pixdim=(1, 1))
44+
self.assertEqual(len(deepgrow_datalist), 3)
45+
self.assertEqual(deepgrow_datalist[0]["region"], 1)
46+
47+
def test_create_dataset_3d(self):
48+
with tempfile.TemporaryDirectory() as tempdir:
49+
datalist = self._create_data(tempdir)
50+
output_dir = os.path.join(tempdir, "3d")
51+
deepgrow_datalist = create_dataset(datalist=datalist, output_dir=output_dir, dimension=3, pixdim=(1, 1, 1))
52+
self.assertEqual(len(deepgrow_datalist), 1)
53+
self.assertEqual(deepgrow_datalist[0]["region"], 1)
54+
55+
56+
if __name__ == "__main__":
57+
unittest.main()

0 commit comments

Comments
 (0)