Skip to content

Commit 064788f

Browse files
committed
Add deepgrow dataset
1 parent af1ffd6 commit 064788f

File tree

2 files changed

+326
-0
lines changed

2 files changed

+326
-0
lines changed

monai/apps/deepgrow/dataset.py

+271
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,271 @@
1+
# Copyright 2020 - 2021 MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
import logging
13+
import os
14+
from typing import Dict, List
15+
16+
import numpy as np
17+
18+
from monai.transforms import AsChannelFirstd, Compose, LoadImaged, Orientationd, Spacingd
19+
from monai.utils import GridSampleMode
20+
21+
22+
def create_dataset(
23+
datalist,
24+
output_dir: str,
25+
dimension,
26+
pixdim,
27+
image_key: str = "image",
28+
label_key: str = "label",
29+
base_dir: str = None,
30+
limit: int = 0,
31+
relative_path: bool = False,
32+
transforms=None,
33+
) -> List[Dict]:
34+
"""
35+
Utility to pre-process and create dataset list for Deepgrow training over on existing one.
36+
The input data list is normally a list of images and labels (3D volume) that needs pre-processing
37+
for Deepgrow training pipeline.
38+
39+
Args:
40+
datalist: A generic dataset with a length property which normally contains a list of data dictionary.
41+
For example, typical input data can be a list of dictionaries::
42+
43+
[{'image': 'img1.nii', 'label': 'label1.nii'}]
44+
45+
output_dir: target directory to store the training data for Deepgrow Training
46+
pixdim: output voxel spacing.
47+
dimension: dimension for Deepgrow training. It can be 2 or 3.
48+
image_key: image key in input datalist. Defaults to 'image'.
49+
label_key: label key in input datalist. Defaults to 'label'.
50+
base_dir: base directory in case related path is used for the keys in datalist. Defaults to None.
51+
limit: limit number of inputs for pre-processing. Defaults to 0 (no limit).
52+
relative_path: output keys values should be based on relative path. Defaults to False.
53+
transforms: explicit transforms to execute operations on input data.
54+
55+
Raises:
56+
ValueError: When ``dimension`` is not one of [2, 3]
57+
ValueError: When ``datalist`` is Empty
58+
59+
Returns:
60+
A new datalist that contains path to the images/labels after pre-processing.
61+
62+
Example::
63+
64+
datalist = create_dataset(
65+
datalist=[{'image': 'img1.nii', 'label': 'label1.nii'}],
66+
base_dir=None,
67+
output_dir=output_2d,
68+
dimension=2,
69+
image_key='image',
70+
label_key='label',
71+
pixdim=(1.0, 1.0),
72+
limit=0,
73+
relative_path=True
74+
)
75+
76+
print(datalist[0]["image"], datalist[0]["label"])
77+
"""
78+
79+
if dimension not in [2, 3]:
80+
raise ValueError("Dimension can be only 2 or 3 as Deepgrow supports only 2D/3D Training")
81+
82+
if not len(datalist):
83+
raise ValueError("Input datalist is empty")
84+
85+
transforms = _default_transforms(image_key, label_key, pixdim) if transforms is None else transforms
86+
new_datalist = []
87+
for idx in range(len(datalist)):
88+
if limit and idx >= limit:
89+
break
90+
91+
image = datalist[idx][image_key]
92+
label = datalist[idx].get(label_key, None)
93+
if base_dir:
94+
image = os.path.join(base_dir, image)
95+
label = os.path.join(base_dir, label) if label else None
96+
97+
image = os.path.abspath(image)
98+
label = os.path.abspath(label) if label else None
99+
100+
logging.info("Image: {}; Label: {}".format(image, label if label else None))
101+
data = transforms({image_key: image, label_key: label})
102+
if dimension == 2:
103+
data = _save_data_2d(
104+
vol_idx=idx,
105+
vol_image=data[image_key],
106+
vol_label=data[label_key],
107+
dataset_dir=output_dir,
108+
relative_path=relative_path,
109+
)
110+
else:
111+
data = _save_data_3d(
112+
vol_idx=idx,
113+
vol_image=data[image_key],
114+
vol_label=data[label_key],
115+
dataset_dir=output_dir,
116+
relative_path=relative_path,
117+
)
118+
new_datalist.extend(data)
119+
return new_datalist
120+
121+
122+
def _default_transforms(image_key, label_key, pixdim):
123+
keys = [image_key] if label_key is None else [image_key, label_key]
124+
mode = [GridSampleMode.BILINEAR, GridSampleMode.NEAREST] if len(keys) == 2 else [GridSampleMode.BILINEAR]
125+
return Compose(
126+
[
127+
LoadImaged(keys=keys),
128+
AsChannelFirstd(keys=keys),
129+
Spacingd(keys=keys, pixdim=pixdim, mode=mode),
130+
Orientationd(keys=keys, axcodes="RAS"),
131+
]
132+
)
133+
134+
135+
def _save_data_2d(vol_idx, vol_image, vol_label, dataset_dir, relative_path):
136+
data_list = []
137+
138+
if len(vol_image.shape) == 4:
139+
logging.info(
140+
"4D-Image, pick only first series; Image: {}; Label: {}".format(
141+
vol_image.shape, vol_label.shape if vol_label else None
142+
)
143+
)
144+
vol_image = vol_image[0]
145+
vol_image = np.moveaxis(vol_image, -1, 0)
146+
147+
image_count = 0
148+
label_count = 0
149+
unique_labels_count = 0
150+
for sid in range(vol_image.shape[0]):
151+
image = vol_image[sid, ...]
152+
label = vol_label[sid, ...] if vol_label is not None else None
153+
154+
if vol_label is not None and np.sum(label) == 0:
155+
continue
156+
157+
image_file_prefix = "vol_idx_{:0>4d}_slice_{:0>3d}".format(vol_idx, sid)
158+
image_file = os.path.join(dataset_dir, "images", image_file_prefix)
159+
image_file += ".npy"
160+
161+
os.makedirs(os.path.join(dataset_dir, "images"), exist_ok=True)
162+
np.save(image_file, image)
163+
image_count += 1
164+
165+
# Test Data
166+
if vol_label is None:
167+
data_list.append(
168+
{
169+
"image": image_file.replace(dataset_dir + "/", "") if relative_path else image_file,
170+
}
171+
)
172+
continue
173+
174+
# For all Labels
175+
unique_labels = np.unique(label.flatten())
176+
unique_labels = unique_labels[unique_labels != 0]
177+
unique_labels_count = max(unique_labels_count, len(unique_labels))
178+
179+
for idx in unique_labels:
180+
label_file_prefix = "{}_region_{:0>2d}".format(image_file_prefix, int(idx))
181+
label_file = os.path.join(dataset_dir, "labels", label_file_prefix)
182+
label_file += ".npy"
183+
184+
os.makedirs(os.path.join(dataset_dir, "labels"), exist_ok=True)
185+
curr_label = (label == idx).astype(np.float32)
186+
np.save(label_file, curr_label)
187+
188+
label_count += 1
189+
data_list.append(
190+
{
191+
"image": image_file.replace(dataset_dir + "/", "") if relative_path else image_file,
192+
"label": label_file.replace(dataset_dir + "/", "") if relative_path else label_file,
193+
"region": int(idx),
194+
}
195+
)
196+
197+
logging.info(
198+
"{} => Image Shape: {} => {}; Label Shape: {} => {}; Unique Labels: {}".format(
199+
vol_idx,
200+
vol_image.shape,
201+
image_count,
202+
vol_label.shape if vol_label is not None else None,
203+
label_count,
204+
unique_labels_count,
205+
)
206+
)
207+
return data_list
208+
209+
210+
def _save_data_3d(vol_idx, vol_image, vol_label, dataset_dir, relative_path):
211+
data_list = []
212+
213+
if len(vol_image.shape) == 4:
214+
logging.info("4D-Image, pick only first series; Image: {}; Label: {}".format(vol_image.shape, vol_label.shape))
215+
vol_image = vol_image[0]
216+
vol_image = np.moveaxis(vol_image, -1, 0)
217+
218+
image_count = 0
219+
label_count = 0
220+
unique_labels_count = 0
221+
222+
image_file_prefix = "vol_idx_{:0>4d}".format(vol_idx)
223+
image_file = os.path.join(dataset_dir, "images", image_file_prefix)
224+
image_file += ".npy"
225+
226+
os.makedirs(os.path.join(dataset_dir, "images"), exist_ok=True)
227+
np.save(image_file, vol_image)
228+
image_count += 1
229+
230+
# Test Data
231+
if vol_label is None:
232+
data_list.append(
233+
{
234+
"image": image_file.replace(dataset_dir + "/", "") if relative_path else image_file,
235+
}
236+
)
237+
else:
238+
# For all Labels
239+
unique_labels = np.unique(vol_label.flatten())
240+
unique_labels = unique_labels[unique_labels != 0]
241+
unique_labels_count = max(unique_labels_count, len(unique_labels))
242+
243+
for idx in unique_labels:
244+
label_file_prefix = "{}_region_{:0>2d}".format(image_file_prefix, int(idx))
245+
label_file = os.path.join(dataset_dir, "labels", label_file_prefix)
246+
label_file += ".npy"
247+
248+
curr_label = (vol_label == idx).astype(np.float32)
249+
os.makedirs(os.path.join(dataset_dir, "labels"), exist_ok=True)
250+
np.save(label_file, curr_label)
251+
252+
label_count += 1
253+
data_list.append(
254+
{
255+
"image": image_file.replace(dataset_dir + "/", "") if relative_path else image_file,
256+
"label": label_file.replace(dataset_dir + "/", "") if relative_path else label_file,
257+
"region": int(idx),
258+
}
259+
)
260+
261+
logging.info(
262+
"{} => Image Shape: {} => {}; Label Shape: {} => {}; Unique Labels: {}".format(
263+
vol_idx,
264+
vol_image.shape,
265+
image_count,
266+
vol_label.shape if vol_label is not None else None,
267+
label_count,
268+
unique_labels_count,
269+
)
270+
)
271+
return data_list

tests/test_deepgrow_dataset.py

+55
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# Copyright 2020 - 2021 MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
import os
13+
import tempfile
14+
import unittest
15+
16+
import nibabel as nib
17+
import numpy as np
18+
19+
from monai.apps.deepgrow.dataset import create_dataset
20+
21+
22+
class TestCreateDataset(unittest.TestCase):
23+
def _create_data(self, tempdir):
24+
affine = np.eye(4)
25+
image = np.random.randint(0, 2, size=(128, 128, 40))
26+
image_file = os.path.join(tempdir, "image1.nii.gz")
27+
nib.save(nib.Nifti1Image(image, affine), image_file)
28+
29+
label = np.zeros((128, 128, 40))
30+
label[0][1][0] = 1
31+
label[0][1][1] = 1
32+
label[0][0][2] = 1
33+
label[0][1][2] = 1
34+
label_file = os.path.join(tempdir, "label1.nii.gz")
35+
nib.save(nib.Nifti1Image(label, affine), label_file)
36+
37+
return [{"image": image_file, "label": label_file}]
38+
39+
def test_create_dataset_2d(self):
40+
with tempfile.TemporaryDirectory() as tempdir:
41+
datalist = self._create_data(tempdir)
42+
output_dir = os.path.join(tempdir, "2d")
43+
deepgrow_datalist = create_dataset(datalist=datalist, output_dir=output_dir, dimension=2, pixdim=(1, 1))
44+
assert len(deepgrow_datalist) == 3 and deepgrow_datalist[0]["region"] == 1
45+
46+
def test_create_dataset_3d(self):
47+
with tempfile.TemporaryDirectory() as tempdir:
48+
datalist = self._create_data(tempdir)
49+
output_dir = os.path.join(tempdir, "3d")
50+
deepgrow_datalist = create_dataset(datalist=datalist, output_dir=output_dir, dimension=3, pixdim=(1, 1, 1))
51+
assert len(deepgrow_datalist) == 1 and deepgrow_datalist[0]["region"] == 1
52+
53+
54+
if __name__ == "__main__":
55+
unittest.main()

0 commit comments

Comments
 (0)