@@ -165,10 +165,11 @@ def get_numpy_dataset(original_dataset, input_slice, output_slice, transform):
165
165
if output_slice is not None :
166
166
component_erosion_steps = original_dataset .get ('component_erosion_steps' , 0 )
167
167
dilation_amount = 1 + component_erosion_steps
168
- dilated_output_slices = tuple ([slice (s .start - dilation_amount , s .stop + dilation_amount , s .step ) for s in output_slice ])
168
+ dilated_output_slices = tuple (slice (s .start - dilation_amount , s .stop + dilation_amount , s .step ) for s in output_slice )
169
+ de_dilation_slices = (Ellipsis ,) + tuple (slice (dilation_amount , - dilation_amount ) for _ in output_slice )
169
170
components , affinities , mask = get_outputs (original_dataset , dilated_output_slices )
170
171
mask_threshold = float (original_dataset .get ('mask_threshold' , 0 ))
171
- mask_fraction_of_this_batch = np .mean (mask )
172
+ mask_fraction_of_this_batch = np .mean (mask [ de_dilation_slices ] )
172
173
good_enough = mask_fraction_of_this_batch > mask_threshold
173
174
if not good_enough :
174
175
return None
@@ -186,7 +187,6 @@ def get_numpy_dataset(original_dataset, input_slice, output_slice, transform):
186
187
affinities = augmented_dilated_dataset ["label" ]
187
188
mask = augmented_dilated_dataset ["mask" ]
188
189
image = augmented_dilated_dataset ["data" ]
189
- de_dilation_slices = (Ellipsis ,) + tuple ([slice (dilation_amount , - dilation_amount ) for _ in output_slice ])
190
190
dataset_numpy ['components' ] = components [de_dilation_slices ]
191
191
dataset_numpy ['label' ] = affinities [de_dilation_slices ]
192
192
dataset_numpy ['mask' ] = mask [de_dilation_slices ]
0 commit comments