Skip to content

Commit 457d7dc

Browse files
committed
Bug fix: Check mask threshold on de-dilated mask array
1 parent f618a99 commit 457d7dc

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

data_io/dataset_reading.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -165,10 +165,11 @@ def get_numpy_dataset(original_dataset, input_slice, output_slice, transform):
165165
if output_slice is not None:
166166
component_erosion_steps = original_dataset.get('component_erosion_steps', 0)
167167
dilation_amount = 1 + component_erosion_steps
168-
dilated_output_slices = tuple([slice(s.start - dilation_amount, s.stop + dilation_amount, s.step) for s in output_slice])
168+
dilated_output_slices = tuple(slice(s.start - dilation_amount, s.stop + dilation_amount, s.step) for s in output_slice)
169+
de_dilation_slices = (Ellipsis,) + tuple(slice(dilation_amount, -dilation_amount) for _ in output_slice)
169170
components, affinities, mask = get_outputs(original_dataset, dilated_output_slices)
170171
mask_threshold = float(original_dataset.get('mask_threshold', 0))
171-
mask_fraction_of_this_batch = np.mean(mask)
172+
mask_fraction_of_this_batch = np.mean(mask[de_dilation_slices])
172173
good_enough = mask_fraction_of_this_batch > mask_threshold
173174
if not good_enough:
174175
return None
@@ -186,7 +187,6 @@ def get_numpy_dataset(original_dataset, input_slice, output_slice, transform):
186187
affinities = augmented_dilated_dataset["label"]
187188
mask = augmented_dilated_dataset["mask"]
188189
image = augmented_dilated_dataset["data"]
189-
de_dilation_slices = (Ellipsis,) + tuple([slice(dilation_amount, -dilation_amount) for _ in output_slice])
190190
dataset_numpy['components'] = components[de_dilation_slices]
191191
dataset_numpy['label'] = affinities[de_dilation_slices]
192192
dataset_numpy['mask'] = mask[de_dilation_slices]

0 commit comments

Comments
 (0)