Skip to content

Commit

Permalink
ENH: Combine 2 paps after postprocessing
Browse files Browse the repository at this point in the history
  • Loading branch information
daniellepace committed Jul 29, 2024
1 parent b1f84f0 commit 785ecf4
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 17 deletions.
2 changes: 1 addition & 1 deletion ml4h/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ def parse_args():

# Arguments for explorations/infer_stats_from_segmented_regions
parser.add_argument('--analyze_ground_truth', default=False, action='store_true', help='Whether or not to filter by images with ground truth segmentations, for comparison')
parser.add_argument('--structures_to_analyze', nargs='*', default=[], help='Structure names to include in the .tsv files and scatter plots. Must be in the same order as the output channel map.')
parser.add_argument('--structures_to_analyze', nargs='*', default=[], help='Structure names to include in the .tsv files and scatter plots. Must be in the same order as the output channel map. Use + to merge structures before postprocessing, and ++ to merge structures after postprocessing.')
parser.add_argument('--erosion_radius', nargs='*', default=[], type=int, help='Radius of the unit disk structuring element for erosion preprocessing, optionally as a list per structure to analyze')
parser.add_argument('--intensity_thresh', type=float, help='Threshold value for preprocessing')
parser.add_argument('--intensity_thresh_in_structures', nargs='*', default=[], help='Structure names whose pixels should be replaced if the images has intensity above the threshold')
Expand Down
56 changes: 40 additions & 16 deletions ml4h/explorations.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,7 +701,7 @@ def infer_with_pixels(args):
logging.info(f"Wrote:{stats['count']} rows of inference. Last tensor:{tensor_paths[0]}")


def _compute_masked_stats(img, y):
def _compute_masked_stats(img, y, merged_after_channels):
nb_classes = y.shape[-1]
img = np.tile(img, nb_classes)
melt_shape = (img.shape[0], img.shape[1] * img.shape[2], img.shape[3])
Expand Down Expand Up @@ -790,13 +790,12 @@ def _intensity_thresh_auto(
return (bins[np.where(pred == 1)[0][-1]][0] + bins[np.where(pred == 0)[0][0]][0]) / 2

def _scatter_plots_from_segmented_region_stats(
inference_tsv_true, inference_tsv_pred, structures_to_analyze,
output_folder, id, input_name, output_name,
inference_tsv_true, inference_tsv_pred, output_folder, id, input_name, output_name,
):
df_true = pd.read_csv(inference_tsv_true, delimiter='\t', quotechar='"', quoting=csv.QUOTE_MINIMAL)
df_pred = pd.read_csv(inference_tsv_pred, delimiter='\t', quotechar='"', quoting=csv.QUOTE_MINIMAL)

results_to_plot = [f'{s}_median' for s in structures_to_analyze]
results_to_plot = [c for c in df_pred.columns if 'median' in c]
for col in results_to_plot:
for i in ['all', 'filter_outliers']: # Two types of plots
plot_data = pd.concat(
Expand Down Expand Up @@ -864,28 +863,50 @@ def infer_stats_from_segmented_regions(args):
_, _, generate_test = test_train_valid_tensor_generators(**args.__dict__)
model, _, _, _ = make_multimodal_multitask_model(**args.__dict__)

# the user can use '+' to create a new channel that merges other channels
# those channels must be included alone in structures_to_analyze as well, and must be at the end of the list
merged_locs = [k for k in range(len(args.structures_to_analyze)) if '+' in args.structures_to_analyze[k]]
# the user can use '+' to create a new channel that merges other channels before postprocessing
# the user can use '++' to create a new channel that merges other channels after postprocessing
# those channels must be included alone in structures_to_analyze as well, then '+', then '++'
merged_locs = [
k for k in range(len(args.structures_to_analyze))
if '+' in args.structures_to_analyze[k] and '++' not in args.structures_to_analyze[k]
]
merged_after_locs = [
k for k in range(len(args.structures_to_analyze))
if '++' in args.structures_to_analyze[k]
]
uni_locs = [k for k in range(len(args.structures_to_analyze)) if '+' not in args.structures_to_analyze[k]]
assert((len(merged_locs) == 0) or (merged_locs[0] > uni_locs[-1]))
assert((len(merged_after_locs) == 0) or (merged_after_locs[0] > uni_locs[-1]))
merged_structures = [args.structures_to_analyze[k] for k in merged_locs]
merged_after_structures = [args.structures_to_analyze[k] for k in merged_after_locs]
uni_structures = [args.structures_to_analyze[k] for k in uni_locs]
merged_channels = [k.split('+') for k in merged_structures]
merged_after_channels = [k.split('++') for k in merged_after_structures]
for i in range(len(merged_channels)):
for j in range(len(merged_channels[i])):
merged_channels[i][j] = tm_out.channel_map[merged_channels[i][j]]
for i in range(len(merged_after_channels)):
for j in range(len(merged_after_channels[i])):
merged_after_channels[i][j] = tm_out.channel_map[merged_after_channels[i][j]]

# structures have to be in the same order as the channel map
good_channels = [tm_out.channel_map[k] for k in uni_structures]
assert (good_channels == sorted(good_channels))
good_structures = [[k for k in tm_out.channel_map.keys() if tm_out.channel_map[k] == v][0] for v in good_channels] + merged_structures
title_structures = [[k for k in tm_out.channel_map.keys() if tm_out.channel_map[k] == v][0] for v in good_channels] \
+ merged_structures + merged_after_structures
nb_orig_channels = len(tm_out.channel_map)
nb_out_channels = len(good_channels) + len(merged_structures)
bad_channels = [k for k in range(nb_orig_channels) if k not in good_channels]
for m in merged_channels:
for c in m:
assert(c in good_channels)
for m in merged_after_channels:
for c in m:
assert(c in good_channels)
# Get the channels after postprocessing
for i in range(len(merged_after_channels)):
for j in range(len(merged_after_channels[i])):
merged_after_channels[i][j] = good_channels.index(merged_after_channels[i][j])

# Structuring element used for the erosion
if len(args.erosion_radius) > 0:
Expand Down Expand Up @@ -933,11 +954,11 @@ def infer_stats_from_segmented_regions(args):
inference_writer_pred = csv.writer(inference_file_pred, delimiter='\t', quotechar='"', quoting=csv.QUOTE_MINIMAL)

header = ['sample_id']
header += [f'{k}_mean' for k in good_structures]
header += [f'{k}_median' for k in good_structures]
header += [f'{k}_std' for k in good_structures]
header += [f'{k}_iqr' for k in good_structures]
header += [f'{k}_count' for k in good_structures]
header += [f'{k}_mean' for k in title_structures]
header += [f'{k}_median' for k in title_structures]
header += [f'{k}_std' for k in title_structures]
header += [f'{k}_iqr' for k in title_structures]
header += [f'{k}_count' for k in title_structures]
header += ['mri_date']
inference_writer_true.writerow(header)
inference_writer_pred.writerow(header)
Expand Down Expand Up @@ -1000,7 +1021,11 @@ def postprocess_seg_and_write_stats(y, inference_writer):
y[...,i] = binary_erosion(y[...,i], structures[i]).astype(y.dtype)
assert(y.shape[-1] == nb_out_channels)

means, medians, stds, iqrs, counts = _compute_masked_stats(rescaled_img, y)
# TODO take me out
print(y.shape)
assert(False)

means, medians, stds, iqrs, counts = _compute_masked_stats(rescaled_img, y, merged_after_channels)
csv_row = _get_csv_row(sample_id, means, medians, stds, iqrs, counts, date)
inference_writer.writerow(csv_row)
return y
Expand All @@ -1025,8 +1050,7 @@ def postprocess_seg_and_write_stats(y, inference_writer):
# Scatter plots
if args.analyze_ground_truth:
_scatter_plots_from_segmented_region_stats(
inference_tsv_true, inference_tsv_pred, args.structures_to_analyze,
args.output_folder, args.id, tm_in.input_name(), tm_out.output_name(),
inference_tsv_true, inference_tsv_pred, args.output_folder, args.id, tm_in.input_name(), tm_out.output_name(),
)

# pngs
Expand Down

0 comments on commit 785ecf4

Please sign in to comment.