Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dfp pm seg instance 3 #563

Merged
merged 3 commits into from
May 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ml4h/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from tensorflow.keras.losses import binary_crossentropy, categorical_crossentropy, sparse_categorical_crossentropy
from tensorflow.keras.losses import logcosh, cosine_similarity, mean_squared_error, mean_absolute_error, mean_absolute_percentage_error

#from neurite.tf.losses import Dice
from neurite.tf.losses import Dice

STRING_METRICS = [
'categorical_crossentropy','binary_crossentropy','mean_absolute_error','mae',
Expand Down
60 changes: 37 additions & 23 deletions ml4h/recipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,7 @@
from ml4h.tensor_generators import TensorGenerator, test_train_valid_tensor_generators, big_batch_from_minibatch_generator
from ml4h.data_descriptions import dataframe_data_description_from_tensor_map, ECGDataDescription, DataFrameDataDescription
from ml4h.metrics import get_roc_aucs, get_precision_recall_aucs, get_pearson_coefficients, log_aucs, log_pearson_coefficients, concordance_index_censored
from ml4h.plots import plot_dice, plot_reconstruction, plot_saliency_maps, plot_partners_ecgs, plot_ecg_rest_mp
from ml4h.plots import evaluate_predictions, plot_scatters, plot_rocs, plot_precision_recalls, subplot_roc_per_class, plot_tsne, plot_survival, plot_dice
from ml4h.plots import plot_reconstruction, plot_hit_to_miss_transforms, plot_saliency_maps, plot_partners_ecgs, plot_ecg_rest_mp
from ml4h.plots import plot_dice, plot_reconstruction, plot_hit_to_miss_transforms, plot_saliency_maps, plot_partners_ecgs, plot_ecg_rest_mp
from ml4h.plots import subplot_rocs, subplot_comparison_rocs, subplot_scatters, subplot_comparison_scatters, plot_prediction_calibrations
from ml4h.models.legacy_models import make_character_model_plus, embed_model_predict, make_siamese_model, legacy_multimodal_multitask_model
from ml4h.plots import evaluate_predictions, plot_scatters, plot_rocs, plot_precision_recalls, subplot_roc_per_class, plot_tsne, plot_survival
Expand Down Expand Up @@ -140,7 +138,7 @@ def run(args):

except Exception as e:
logging.exception(e)

if args.gcs_cloud_bucket is not None:
save_to_google_cloud(args)

Expand Down Expand Up @@ -348,10 +346,14 @@ def option_picker(sample_id, data_descriptions):
valid_ids = list(mrn_df[mrn_df.split == 'valid'].index)
test_ids = list(mrn_df[mrn_df.split == 'test'].index)

train_dataset = SampleGetterIterableDataset(sample_ids=list(train_ids), sample_getter=sg,
get_epoch=shuffle_get_epoch)
valid_dataset = SampleGetterIterableDataset(sample_ids=list(valid_ids), sample_getter=sg,
get_epoch=shuffle_get_epoch)
train_dataset = SampleGetterIterableDataset(
sample_ids=list(train_ids), sample_getter=sg,
get_epoch=shuffle_get_epoch,
)
valid_dataset = SampleGetterIterableDataset(
sample_ids=list(valid_ids), sample_getter=sg,
get_epoch=shuffle_get_epoch,
)

num_train_workers = int(args.training_steps / (args.training_steps + args.validation_steps) * args.num_workers) or (1 if args.num_workers else 0)
num_valid_workers = int(args.validation_steps / (args.training_steps + args.validation_steps) * args.num_workers) or (1 if args.num_workers else 0)
Expand Down Expand Up @@ -447,10 +449,14 @@ def option_picker(sample_id, data_descriptions):
valid_ids = list(mrn_df[mrn_df.split == 'valid'].index)
test_ids = list(mrn_df[mrn_df.split == 'test'].index)

train_dataset = SampleGetterIterableDataset(sample_ids=list(train_ids), sample_getter=sg,
get_epoch=shuffle_get_epoch)
valid_dataset = SampleGetterIterableDataset(sample_ids=list(valid_ids), sample_getter=sg,
get_epoch=shuffle_get_epoch)
train_dataset = SampleGetterIterableDataset(
sample_ids=list(train_ids), sample_getter=sg,
get_epoch=shuffle_get_epoch,
)
valid_dataset = SampleGetterIterableDataset(
sample_ids=list(valid_ids), sample_getter=sg,
get_epoch=shuffle_get_epoch,
)

num_train_workers = int(args.training_steps / (args.training_steps + args.validation_steps) * args.num_workers) or (1 if args.num_workers else 0)
num_valid_workers = int(args.validation_steps / (args.training_steps + args.validation_steps) * args.num_workers) or (1 if args.num_workers else 0)
Expand Down Expand Up @@ -483,14 +489,16 @@ def option_picker(sample_id, data_descriptions):
output_data_descriptions=output_dds, # what we want a model to predict from the input data
option_picker=option_picker,
)
test_dataset = SampleGetterIterableDataset(sample_ids=list(test_ids), sample_getter=test_sg,
get_epoch=shuffle_get_epoch)
test_dataset = SampleGetterIterableDataset(
sample_ids=list(test_ids), sample_getter=test_sg,
get_epoch=shuffle_get_epoch,
)

generate_test = TensorMapDataLoader2(
batch_size=args.batch_size, input_maps=args.tensor_maps_in, output_maps=args.tensor_maps_out,
dataset=test_dataset,
num_workers=num_train_workers,
)
)

y_trues = defaultdict(list)
y_preds = defaultdict(list)
Expand Down Expand Up @@ -518,8 +526,10 @@ def option_picker(sample_id, data_descriptions):
plot_survival(y_preds[otm.name], y_trues[otm.name], f'{otm.name.upper()} Model:{args.id}', otm.days_window)
elif otm.is_categorical():
plot_roc(y_preds[otm.name], y_trues[otm.name], otm.channel_map, f'{otm.name} ROC')
plot_precision_recall_per_class(y_preds[otm.name], y_trues[otm.name], otm.channel_map,
f'{otm.name} Precision Recall')
plot_precision_recall_per_class(
y_preds[otm.name], y_trues[otm.name], otm.channel_map,
f'{otm.name} Precision Recall',
)
elif otm.is_continuous():
plot_scatter(y_preds[otm.name], y_trues[otm.name], f'{otm.name} Scatter')

Expand Down Expand Up @@ -561,7 +571,7 @@ def infer_from_dataloader(dataloader, model, tensor_maps_out, max_batches=125000
space_dict[f'{otm.name}_event'].append(str(sick[0]))
space_dict[f'{otm.name}_follow_up'].append(str(follow_up[0]))
for k in target:
if k in ['MRN', 'linker_id', 'is_c3po', 'output_age_in_days_continuous' ]:
if k in ['MRN', 'linker_id', 'is_c3po', 'output_age_in_days_continuous']:
space_dict[f'{k}'].append(target[k][b].numpy())
elif k in ['datetime']:
space_dict[f'{k}'].append(float_to_datetime(int(target[k][b].numpy())))
Expand Down Expand Up @@ -749,13 +759,17 @@ def infer_multimodal_multitask(args):
hd5_path = os.path.join(args.output_folder, args.id, 'inferred_hd5s', f'{sample_id}{TENSOR_EXT}')
os.makedirs(os.path.dirname(hd5_path), exist_ok=True)
with h5py.File(hd5_path, 'a') as hd5:
hd5.create_dataset(f'{otm.name}_truth', data=otm.rescale(output_data[otm.output_name()][0]),
compression='gzip')
hd5.create_dataset(
f'{otm.name}_truth', data=otm.rescale(output_data[otm.output_name()][0]),
compression='gzip',
)
if otm.path_prefix == 'ukb_ecg_rest':
for lead in otm.channel_map:
hd5.create_dataset(f'/ukb_ecg_rest/{lead}/instance_0',
data=otm.rescale(y[0, otm.channel_map[lead]]),
compression='gzip')
hd5.create_dataset(
f'/ukb_ecg_rest/{lead}/instance_0',
data=otm.rescale(y[0, otm.channel_map[lead]]),
compression='gzip',
)
inference_writer.writerow(csv_row)
tensor_paths_inferred.add(tensor_paths[0])
stats['count'] += 1
Expand Down
26 changes: 21 additions & 5 deletions ml4h/tensormap/ukb/mri.py
Original file line number Diff line number Diff line change
Expand Up @@ -2748,13 +2748,21 @@ def _pad_crop_single_channel(tm, hd5, dependents={}, key_prefix=None):
img,
)

def _pad_crop_single_channel_t1map_b2(tm, hd5, dependents={}):
if f'/{tm.path_prefix}/shmolli_192i_sax_b2s_sax_b2s_sax_b2s_t1map/instance_2' in hd5:
key_prefix = f'/{tm.path_prefix}/shmolli_192i_sax_b2s_sax_b2s_sax_b2s_t1map/instance_2'
elif f'/{tm.path_prefix}/shmolli_192i_b2_sax_b2s_sax_b2s_sax_b2s_t1map/instance_2' in hd5:
key_prefix = f'/{tm.path_prefix}/shmolli_192i_b2_sax_b2s_sax_b2s_sax_b2s_t1map/instance_2'
def _pad_crop_single_channel_t1map_b2_instance(tm, hd5, instance):
if f'/{tm.path_prefix}/shmolli_192i_sax_b2s_sax_b2s_sax_b2s_t1map/instance_{instance}' in hd5:
key_prefix = f'/{tm.path_prefix}/shmolli_192i_sax_b2s_sax_b2s_sax_b2s_t1map/instance_{instance}'
elif f'/{tm.path_prefix}/shmolli_192i_b2_sax_b2s_sax_b2s_sax_b2s_t1map/instance_{instance}' in hd5:
key_prefix = f'/{tm.path_prefix}/shmolli_192i_b2_sax_b2s_sax_b2s_sax_b2s_t1map/instance_{instance}'
else:
raise ValueError(f'Could not find T1 Map image for tensormap: {tm.name}')
return key_prefix

def _pad_crop_single_channel_t1map_b2(tm, hd5, dependents={}):
key_prefix = _pad_crop_single_channel_t1map_b2_instance(tm, hd5, 2)
return _pad_crop_single_channel(tm, hd5, dependents, key_prefix)

def _pad_crop_single_channel_t1map_b2_instance_3(tm, hd5, dependents={}):
key_prefix = _pad_crop_single_channel_t1map_b2_instance(tm, hd5, 3)
return _pad_crop_single_channel(tm, hd5, dependents, key_prefix)

t1map_b2 = TensorMap(
Expand All @@ -2765,6 +2773,14 @@ def _pad_crop_single_channel_t1map_b2(tm, hd5, dependents={}):
tensor_from_file=_pad_crop_single_channel_t1map_b2,
)

t1map_b2_instance3 = TensorMap(
'shmolli_192i_sax_b2s_sax_b2s_sax_b2s_t1map',
shape=(384, 384, 1),
path_prefix='ukb_cardiac_mri',
normalization=Standardize(mean=455.81, std=609.50),
tensor_from_file=_pad_crop_single_channel_t1map_b2_instance_3,
)

t1map_pancreas = TensorMap(
'shmolli_192i_pancreas_t1map',
shape=(288, 384, 1),
Expand Down
Loading