Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dp kidney #549

Merged
merged 21 commits into from
Jan 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion ml4h/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ def parse_args():
)
parser.add_argument('--balance_csvs', default=[], nargs='*', help='Balances batches with representation from sample IDs in this list of CSVs')
parser.add_argument('--optimizer', default='radam', type=str, help='Optimizer for model training')
parser.add_argument('--learning_rate_schedule', default=None, type=str, choices=['triangular', 'triangular2'], help='Adjusts learning rate during training.')
parser.add_argument('--learning_rate_schedule', default=None, type=str, choices=['triangular', 'triangular2', 'cosine_decay'], help='Adjusts learning rate during training.')
parser.add_argument('--anneal_rate', default=0., type=float, help='Annealing rate in epochs of loss terms during training')
parser.add_argument('--anneal_shift', default=0., type=float, help='Annealing offset in epochs of loss terms during training')
parser.add_argument('--anneal_max', default=2.0, type=float, help='Annealing maximum value')
Expand Down Expand Up @@ -389,6 +389,8 @@ def parse_args():
parser.add_argument('--structures_to_analyze', nargs='*', default=[], help='Structure names to include in the .tsv files and scatter plots')
parser.add_argument('--erosion_radius', default=1, type=int, help='Radius of the unit disk structuring element for erosion preprocessing')
parser.add_argument('--intensity_thresh', type=float, help='Threshold value for preprocessing')
parser.add_argument('--intensity_thresh_percentile', type=float, help='Threshold percentile for preprocessing, between 0 and 100 inclusive')
parser.add_argument('--intensity_thresh_k_means', nargs='*', default=[], type=int, help='Preprocessing using k-means specified as two numbers, the first is the number of clusters and the second is the cluster index to keep')
parser.add_argument('--intensity_thresh_in_structures', nargs='*', default=[], help='Structure names whose pixels should be replaced if the images has intensity above the threshold')
parser.add_argument('--intensity_thresh_out_structure', help='Replacement structure name')

Expand Down
5 changes: 4 additions & 1 deletion ml4h/defines.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,10 @@ def __str__(self):
'aortic_root': 7, 'ascending_aorta': 8, 'pulmonary_artery': 9, 'ascending_aortic_wall': 10, 'LVOT': 11,
}
MRI_LIVER_SEGMENTED_CHANNEL_MAP = {'background': 0, 'liver': 1, 'inferior_vena_cava': 2, 'abdominal_aorta': 3, 'body': 4}

MRI_PANCREAS_SEGMENTED_CHANNEL_MAP = {
'background': 0, 'body': 1, 'pancreas': 2, 'liver': 3, 'stomach': 4, 'spleen': 5,
'kidney': 6, 'bowel': 7, 'spine': 8, 'aorta':9, 'ivc': 10,
}

# TODO: These values should ultimately come from the coding table
CODING_VALUES_LESS_THAN_ONE = [-10, -1001]
Expand Down
39 changes: 27 additions & 12 deletions ml4h/explorations.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import pandas as pd
import multiprocessing as mp
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
from tensorflow.keras.models import Model


Expand Down Expand Up @@ -719,13 +720,26 @@ def _get_csv_row(sample_id, means, medians, stds, date):
csv_row = [sample_id] + res[0].astype('str').tolist() + [date]
return csv_row

def _thresh_labels_above(y, img, intensity_thresh, in_labels, out_label, nb_orig_channels):
def _thresh_labels_above(y, img, intensity_thresh, intensity_thresh_percentile, in_labels, out_label, nb_orig_channels):
y = np.argmax(y, axis=-1)[..., np.newaxis]
y[np.logical_and(img >= intensity_thresh, np.isin(y, in_labels))] = out_label
if intensity_thresh:
img_intensity_thresh = intensity_thresh
elif intensity_thresh_percentile:
img_intensity_thresh = np.percentile(img, intensity_thresh_percentile)
y[np.logical_and(img >= img_intensity_thresh, np.isin(y, in_labels))] = out_label
y = y[..., 0]
y = _to_categorical(y, nb_orig_channels)
return y

def _intensity_thresh_k_means(y, img, intensity_thresh_k_means):
X = img[y==1][...,np.newaxis]
if X.size > 1:
kmeans = KMeans(n_clusters=intensity_thresh_k_means[0], random_state=0, n_init="auto").fit(X)
labels = kmeans.predict(img.flatten()[...,np.newaxis])
labels = np.reshape(labels, img.shape)
y[np.logical_and(labels==intensity_thresh_k_means[1], y==1)] = 0
return y

def _scatter_plots_from_segmented_region_stats(
inference_tsv_true, inference_tsv_pred, structures_to_analyze,
output_folder, id, input_name, output_name,
Expand Down Expand Up @@ -759,13 +773,9 @@ def _scatter_plots_from_segmented_region_stats(
title = col.replace('_', ' ')
ax.set_xlabel(f'{title} T1 Time (ms) - Manual Segmentation')
ax.set_ylabel(f'{title} T1 Time (ms) - Model Segmentation')
if i == 'all':
min_value = -50
max_value = 1300
elif i == 'filter_outliers':
min_value, max_value = plot_data.min(), plot_data.max()
min_value = min([min_value['true'], min_value['pred']]) - 100
max_value = min([max_value['true'], max_value['pred']]) + 100
min_value, max_value = plot_data.min(), plot_data.max()
min_value = min([min_value['true'], min_value['pred']]) - 100
max_value = min([max_value['true'], max_value['pred']]) + 100
ax.set_xlim([min_value, max_value])
ax.set_ylim([min_value, max_value])
res = stats.pearsonr(plot_data['true'], plot_data['pred'])
Expand Down Expand Up @@ -798,7 +808,6 @@ def infer_stats_from_segmented_regions(args):
assert(tm_in.shape[-1] == 1, 'no support here for stats on multiple input channels')

# don't filter datasets for ground truth segmentations if we want to run inference on everything
# TODO HELP - this isn't giving me all 56K anymore
if not args.analyze_ground_truth:
args.output_tensors = []
args.tensor_maps_out = []
Expand All @@ -820,6 +829,8 @@ def infer_stats_from_segmented_regions(args):
# Setup for intensity thresholding
do_intensity_thresh = args.intensity_thresh_in_structures and args.intensity_thresh_out_structure
if do_intensity_thresh:
assert (not (args.intensity_thresh and args.intensity_thresh_percentile))
assert (not (args.intensity_thresh_k_means and len(args.intensity_thresh_in_structures) > 1))
intensity_thresh_in_channels = [tm_out.channel_map[k] for k in args.intensity_thresh_in_structures]
intensity_thresh_out_channel = tm_out.channel_map[args.intensity_thresh_out_structure]

Expand Down Expand Up @@ -870,19 +881,23 @@ def infer_stats_from_segmented_regions(args):

if args.analyze_ground_truth:
if do_intensity_thresh:
y_true = _thresh_labels_above(y_true, img, args.intensity_thresh, intensity_thresh_in_channels, intensity_thresh_out_channel, nb_orig_channels)
y_true = _thresh_labels_above(y_true, img, args.intensity_thresh, args.intensity_thresh_percentile, intensity_thresh_in_channels, intensity_thresh_out_channel, nb_orig_channels)
y_true = np.delete(y_true, bad_channels, axis=-1)
if args.erosion_radius > 0:
y_true = binary_erosion(y_true, structure).astype(y_true.dtype)
if args.intensity_thresh_k_means:
y_true = _intensity_thresh_k_means(y_true, img, args.intensity_thresh_k_means)
means_true, medians_true, stds_true = _compute_masked_stats(rescaled_img, y_true, nb_good_channels)
csv_row_true = _get_csv_row(sample_id, means_true, medians_true, stds_true, date)
inference_writer_true.writerow(csv_row_true)

if do_intensity_thresh:
y_pred = _thresh_labels_above(y_pred, img, args.intensity_thresh, intensity_thresh_in_channels, intensity_thresh_out_channel, nb_orig_channels)
y_pred = _thresh_labels_above(y_pred, img, args.intensity_thresh, args.intensity_thresh_percentile, intensity_thresh_in_channels, intensity_thresh_out_channel, nb_orig_channels)
y_pred = np.delete(y_pred, bad_channels, axis=-1)
if args.erosion_radius > 0:
y_pred = binary_erosion(y_pred, structure).astype(y_pred.dtype)
if args.intensity_thresh_k_means:
y_pred = _intensity_thresh_k_means(y_pred, img, args.intensity_thresh_k_means)
means_pred, medians_pred, stds_pred = _compute_masked_stats(rescaled_img, y_pred, nb_good_channels)
csv_row_pred = _get_csv_row(sample_id, means_pred, medians_pred, stds_pred, date)
inference_writer_pred.writerow(csv_row_pred)
Expand Down
8 changes: 7 additions & 1 deletion ml4h/models/layer_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from tensorflow.keras.layers import MaxPooling2D, MaxPooling3D, Average, AveragePooling1D, AveragePooling2D, AveragePooling3D, Layer
from tensorflow.keras.layers import SeparableConv1D, SeparableConv2D, DepthwiseConv2D, Concatenate, Add
from tensorflow.keras.layers import GlobalAveragePooling1D, GlobalAveragePooling2D, GlobalAveragePooling3D
from tensorflow.keras.regularizers import L1, L2


Tensor = tf.Tensor
Expand Down Expand Up @@ -52,9 +53,14 @@
# class name -> (dimension -> class)
'spatial_dropout': {2: SpatialDropout1D, 3: SpatialDropout2D, 4: SpatialDropout3D},
'dropout': defaultdict(lambda _: Dropout),
'l1': L1,
daniellepace marked this conversation as resolved.
Show resolved Hide resolved
'l2': L2,
}
DENSE_REGULARIZATION_CLASSES = {
'dropout': Dropout, # TODO: add l1, l2
'dropout': Dropout,
'dropout': Dropout,
'l1': L1,
'l2': L2,
}


Expand Down
7 changes: 6 additions & 1 deletion ml4h/models/legacy_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from tensorflow.keras.layers import SeparableConv1D, SeparableConv2D, DepthwiseConv2D, Concatenate, Add
from tensorflow.keras.layers import GlobalAveragePooling1D, GlobalAveragePooling2D, GlobalAveragePooling3D
from tensorflow.keras.layers.experimental.preprocessing import RandomRotation, RandomZoom, RandomContrast
from tensorflow.keras.regularizers import L1, L2
import tensorflow_probability as tfp

from ml4h.metrics import get_metric_dict
Expand Down Expand Up @@ -79,9 +80,13 @@ class BottleneckType(Enum):
# class name -> (dimension -> class)
'spatial_dropout': {2: SpatialDropout1D, 3: SpatialDropout2D, 4: SpatialDropout3D},
'dropout': defaultdict(lambda _: Dropout),
'l1': L1,
'l2': L2,
}
DENSE_REGULARIZATION_CLASSES = {
'dropout': Dropout, # TODO: add l1, l2
'dropout': Dropout,
'l1': L1,
'l2': L2,
}


Expand Down
3 changes: 3 additions & 0 deletions ml4h/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from tensorflow.keras import backend as K
from tensorflow.keras.models import Model
from tensorflow_addons.optimizers import RectifiedAdam, TriangularCyclicalLearningRate, Triangular2CyclicalLearningRate
from tensorflow.keras.optimizers.schedules import CosineDecay

from ml4h.plots import plot_find_learning_rate
from ml4h.tensor_generators import TensorGenerator
Expand Down Expand Up @@ -40,6 +41,8 @@ def _get_learning_rate_schedule(learning_rate: float, learning_rate_schedule: st
initial_learning_rate=learning_rate / 5, maximal_learning_rate=learning_rate,
step_size=steps_per_epoch * 5,
)
if learning_rate_schedule == 'cosine_decay':
return CosineDecay(initial_learning_rate=learning_rate, decay_steps=steps_per_epoch)
else:
raise ValueError(f'Learning rate schedule "{learning_rate_schedule}" unknown.')

Expand Down
62 changes: 32 additions & 30 deletions ml4h/recipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,37 +220,39 @@ def train_multimodal_multitask(args):
if merger:
merger.save(f'{args.output_folder}{args.id}/merger.h5')

test_data, test_labels, test_paths = big_batch_from_minibatch_generator(generate_test, args.test_steps)
performance_metrics = _predict_and_evaluate(
model, test_data, test_labels, args.tensor_maps_in, args.tensor_maps_out, args.tensor_maps_protected,
args.batch_size, args.hidden_layer, os.path.join(args.output_folder, args.id + '/'), test_paths,
args.embed_visualization, args.alpha, args.dpi, args.plot_width, args.plot_height,
)
performance_metrics = {}
if args.test_steps > 0:
test_data, test_labels, test_paths = big_batch_from_minibatch_generator(generate_test, args.test_steps)
performance_metrics = _predict_and_evaluate(
model, test_data, test_labels, args.tensor_maps_in, args.tensor_maps_out, args.tensor_maps_protected,
args.batch_size, args.hidden_layer, os.path.join(args.output_folder, args.id + '/'), test_paths,
args.embed_visualization, args.alpha, args.dpi, args.plot_width, args.plot_height,
)

predictions_list = model.predict(test_data)
samples = min(args.test_steps * args.batch_size, 12)
out_path = os.path.join(args.output_folder, args.id, 'reconstructions/')
if len(args.tensor_maps_out) == 1:
predictions_list = [predictions_list]
predictions_dict = {name: pred for name, pred in zip(model.output_names, predictions_list)}
logging.info(f'Predictions and shapes are: {[(p, predictions_dict[p].shape) for p in predictions_dict]}')

for i, etm in enumerate(encoders):
embed = encoders[etm].predict(test_data[etm.input_name()])
if etm.output_name() in predictions_dict:
plot_reconstruction(etm, test_data[etm.input_name()], predictions_dict[etm.output_name()], out_path, test_paths, samples)
for dtm in decoders:
reconstruction = decoders[dtm].predict(embed)
logging.info(f'{dtm.name} has prediction shape: {reconstruction.shape} from embed shape: {embed.shape}')
my_out_path = os.path.join(out_path, f'decoding_{dtm.name}_from_{etm.name}/')
os.makedirs(os.path.dirname(my_out_path), exist_ok=True)
if dtm.axes() > 1:
plot_reconstruction(dtm, test_labels[dtm.output_name()], reconstruction, my_out_path, test_paths, samples)
else:
evaluate_predictions(
dtm, reconstruction, test_labels[dtm.output_name()], {}, dtm.name, my_out_path,
test_paths, dpi=args.dpi, width=args.plot_width, height=args.plot_height,
)
predictions_list = model.predict(test_data)
samples = min(args.test_steps * args.batch_size, 12)
out_path = os.path.join(args.output_folder, args.id, 'reconstructions/')
if len(args.tensor_maps_out) == 1:
predictions_list = [predictions_list]
predictions_dict = {name: pred for name, pred in zip(model.output_names, predictions_list)}
logging.info(f'Predictions and shapes are: {[(p, predictions_dict[p].shape) for p in predictions_dict]}')

for i, etm in enumerate(encoders):
embed = encoders[etm].predict(test_data[etm.input_name()])
if etm.output_name() in predictions_dict:
plot_reconstruction(etm, test_data[etm.input_name()], predictions_dict[etm.output_name()], out_path, test_paths, samples)
for dtm in decoders:
reconstruction = decoders[dtm].predict(embed)
logging.info(f'{dtm.name} has prediction shape: {reconstruction.shape} from embed shape: {embed.shape}')
my_out_path = os.path.join(out_path, f'decoding_{dtm.name}_from_{etm.name}/')
os.makedirs(os.path.dirname(my_out_path), exist_ok=True)
if dtm.axes() > 1:
plot_reconstruction(dtm, test_labels[dtm.output_name()], reconstruction, my_out_path, test_paths, samples)
else:
evaluate_predictions(
dtm, reconstruction, test_labels[dtm.output_name()], {}, dtm.name, my_out_path,
test_paths, dpi=args.dpi, width=args.plot_width, height=args.plot_height,
)
return performance_metrics


Expand Down
2 changes: 1 addition & 1 deletion ml4h/tensor_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def __init__(
:param paths: If weights is provided, paths should be a list of path lists the same length as weights
"""
self.augment = augment
self.paths = sum(paths) if isinstance(paths[0], list) else paths
self.paths = sum(paths) if (len(paths) > 0 and isinstance(paths[0], list)) else paths
self.run_on_main_thread = num_workers == 0
self.q = None
self.stats_q = None
Expand Down
Loading
Loading