Skip to content

Commit

Permalink
ECG Autoencoder PheWAS in the Model Zoo (#514)
Browse files Browse the repository at this point in the history
* categorical and continuous composite TensorMaps, bump version, add ecg phewas to model zoo
  • Loading branch information
lucidtronix committed Apr 10, 2023
1 parent a0bbf53 commit dfd6a23
Show file tree
Hide file tree
Showing 31 changed files with 17,253 additions and 8,657 deletions.
1 change: 1 addition & 0 deletions .gitattributes
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ notebooks/**.ipynb filter=nbstripout
*.linux filter=lfs diff=lfs merge=lfs -text
*.osx filter=lfs diff=lfs merge=lfs -text
*.genes filter=lfs diff=lfs merge=lfs -text
model_zoo/ECG_PheWAS/*.h5 filter=lfs diff=lfs merge=lfs -text
2 changes: 1 addition & 1 deletion .github/workflows/docker-publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ on:
# - v*

# Run tests for any PRs.
pull_request:
# pull_request:

env:
IMAGE_NAME: ml4h_terra
Expand Down
4 changes: 2 additions & 2 deletions docker/vm_boot_images/config/tensorflow-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ plotnine
vega
ipycanvas>=0.7.0
ipyannotations>=0.2.1
torch
torch==1.12.1
opencv-python
blosc
boto3
ml4ht==0.0.9
ml4ht==0.0.10
8 changes: 3 additions & 5 deletions ml4h/defines.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ def __str__(self):
'anterior_papillary': 9, 'LV_cavity': 10, 'LA_cavity': 11, 'body': 12,
}
LAX_2CH_HEART_LABELS = {
'aortic_arch': 1, 'left_pulmonary_artery_wall': 2, 'left_pulmonary_artery': 3,
'LA_appendage': 4, 'LA_free_wall': 5, 'LV_posterior_wall': 6, 'LV_anterior_wall': 7, 'posterior_papillary': 8,
'anterior_papillary': 9, 'LV_cavity': 10, 'LA_cavity': 11,
}
Expand Down Expand Up @@ -88,10 +87,9 @@ def __str__(self):
'RV_cavity': 5, 'thoracic_cavity': 6, 'liver': 7, 'stomach': 8, 'spleen': 9, 'kidney': 11, 'body': 10,
'left_atrium': 12, 'right_atrium': 13, 'aorta': 14, 'pulmonary_artery': 15,
}
MRI_SAX_SEGMENTED_CHANNEL_MAP = {
'background': 0, 'RV_free_wall': 1, 'interventricular_septum': 2, 'LV_free_wall': 3, 'LV_cavity': 4,
'RV_cavity': 5, 'thoracic_cavity': 6, 'liver': 7, 'stomach': 8, 'spleen': 9, 'kidney': 11, 'body': 10,
'left_atrium': 12, 'right_atrium': 13, 'aorta': 14, 'pulmonary_artery': 15,
SAX_HEART_LABELS = {
'RV_free_wall': 1, 'interventricular_septum': 2, 'LV_free_wall': 3, 'LV_cavity': 4,
'RV_cavity': 5, 'left_atrium': 12, 'right_atrium': 13,
}
MRI_AO_SEGMENTED_CHANNEL_MAP = {
'ao_background': 0, 'ao_superior_vena_cava': 1, 'ao_pulmonary_artery': 2, 'ao_ascending_aortic_wall': 3, 'ao_ascending_aorta': 4,
Expand Down
8 changes: 7 additions & 1 deletion ml4h/explorations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1045,7 +1045,13 @@ def explore(args):

def latent_space_dataframe(infer_hidden_tsv, explore_csv):
df = pd.read_csv(explore_csv)
df['sample_id'] = pd.to_numeric(df['sample_id'], errors='coerce')
if 'sample_id' in df.columns:
id_col = 'sample_id'
elif 'fpath' in df.columns:
id_col = 'fpath'
else:
raise ValueError(f'Could not find a sample ID column in explore CSV:{explore_csv}')
df['sample_id'] = pd.to_numeric(df[id_col], errors='coerce')
df2 = pd.read_csv(infer_hidden_tsv, sep='\t', engine='python')
df2['sample_id'] = pd.to_numeric(df2['sample_id'], errors='coerce')
latent_df = pd.merge(df, df2, on='sample_id', how='inner')
Expand Down
4 changes: 2 additions & 2 deletions ml4h/models/pretrained_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,8 @@ def __init__(
*,
tensor_map: TensorMap,
pretrain_trainable: bool,
base_model = "https://tfhub.dev/jeongukjae/roberta_en_cased_L-24_H-1024_A-16/1",
preprocess_model="https://tfhub.dev/jeongukjae/roberta_en_cased_preprocess/1",
base_model="https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/3",
preprocess_model="https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3",
**kwargs,
):
self.tensor_map = tensor_map
Expand Down
8 changes: 7 additions & 1 deletion ml4h/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -1032,6 +1032,9 @@ def plot_survivorship(
:param title: Title for the plot
:param prefix: Path prefix where plot will be saved
:param days_window: Maximum days of follow up
:param dpi: Dots per inch of the figure
:param width: Width in inches of the figure
:param height: Height in inches of the figure
"""
plt.figure(figsize=(width, height), dpi=dpi)
days_sorted_index = np.argsort(days_follow_up)
Expand Down Expand Up @@ -1100,6 +1103,9 @@ def plot_survival(
:param title: Title for the plot
:param days_window: Maximum days of follow up
:param prefix: Path prefix where plot will be saved
:param dpi: Dots per inch of the figure
:param width: Width in inches of the figure
:param height: Height in inches of the figure
:return: Dictionary mapping metric names to their floating point values
"""
Expand All @@ -1109,7 +1115,7 @@ def plot_survival(
plt.figure(figsize=(width, height), dpi=dpi)

cumulative_sick = np.cumsum(np.sum(truth[:, intervals:], axis=0))
cumulative_censored = (truth.shape[0]-np.sum(truth[:, :intervals], axis=0))-cumulative_sick
cumulative_censored = (truth.shape[0]-np.sum(truth[:, :intervals], axis=0)) - cumulative_sick
alive_per_step = np.sum(truth[:, :intervals], axis=0)
sick_per_step = np.sum(truth[:, intervals:], axis=0)
survivorship = np.cumprod(1 - (sick_per_step / alive_per_step))
Expand Down
7 changes: 5 additions & 2 deletions ml4h/tensorize/dataflow/bigquery_ukb_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,14 @@ def write_tensor_from_sql(sampleid_to_rows, output_path, tensor_type):
hd5.create_dataset('icd', (1,), data=JOIN_CHAR.join(icds), dtype=h5py.special_dtype(vlen=str))
elif tensor_type == 'categorical':
for row in rows:
hd5_dataset_name = dataset_name_from_meaning('categorical', [row['field'], row['meaning'], str(row['instance']), str(row['array_idx'])])
fields = [str(row['fieldid']), row['field'], row['meaning'],
str(row['instance']), str(row['array_idx'])]
hd5_dataset_name = dataset_name_from_meaning('categorical', fields)
_write_float_or_warn(sample_id, row, hd5_dataset_name, hd5)
elif tensor_type == 'continuous':
for row in rows:
hd5_dataset_name = dataset_name_from_meaning('continuous', [str(row['fieldid']), row['field'], str(row['instance']), str(row['array_idx'])])
fields = [str(row['fieldid']), row['field'], str(row['instance']), str(row['array_idx'])]
hd5_dataset_name = dataset_name_from_meaning('continuous', fields)
_write_float_or_warn(sample_id, row, hd5_dataset_name, hd5)
elif tensor_type in ['disease', 'phecode_disease']:
for row in rows:
Expand Down
32 changes: 18 additions & 14 deletions ml4h/tensorize/tensor_writer_ukbb.py
Original file line number Diff line number Diff line change
Expand Up @@ -786,11 +786,9 @@ def _write_ecg_bike_tensors(ecgs, xml_field, hd5, sample_id, stats):
for ecg in ecgs:
root = et.parse(ecg).getroot()
date = datetime.datetime.strptime(_date_str_from_ecg(root), '%Y-%m-%d')
write_to_hd5 = partial(create_tensor_in_hd5, hd5=hd5, path_prefix='ukb_ecg_bike', stats=stats, date=date)
logging.info('Got ECG for sample:{} XML field:{}'.format(sample_id, xml_field))

instance = ecg.split(JOIN_CHAR)[-2]
write_to_hd5(storage_type=StorageType.STRING, name='instance', value=instance)
write_to_hd5 = partial(create_tensor_in_hd5, hd5=hd5, path_prefix='ukb_ecg_bike', instance=instance, stats=stats, date=date)
logging.info(f'Got ECG for sample:{sample_id} XML field:{xml_field}')

protocol = root.findall('./Protocol/Phase')[0].find('ProtocolName').text
write_to_hd5(storage_type=StorageType.STRING, name='protocol', value=protocol)
Expand Down Expand Up @@ -881,10 +879,13 @@ def _write_ecg_bike_tensors(ecgs, xml_field, hd5, sample_id, stats):
if field_val is False:
continue
trends[lead_field][i, lead_to_int[lead_num]] = field_val
trends['time'][i] = SECONDS_PER_MINUTE * int(trend_entry.find("EntryTime/Minute").text) + int(trend_entry.find("EntryTime/Second").text)
trends['PhaseTime'][i] = SECONDS_PER_MINUTE * int(trend_entry.find("PhaseTime/Minute").text) + int(trend_entry.find("PhaseTime/Second").text)
trends['PhaseName'][i] = phase_to_int[trend_entry.find('PhaseName').text]
trends['Artifact'][i] = float(trend_entry.find('Artifact').text.strip('%')) / 100 # Artifact is reported as a percentage
try:
trends['time'][i] = SECONDS_PER_MINUTE * int(trend_entry.find("EntryTime/Minute").text) + int(trend_entry.find("EntryTime/Second").text)
trends['PhaseTime'][i] = SECONDS_PER_MINUTE * int(trend_entry.find("PhaseTime/Minute").text) + int(trend_entry.find("PhaseTime/Second").text)
trends['PhaseName'][i] = phase_to_int[trend_entry.find('PhaseName').text]
trends['Artifact'][i] = float(trend_entry.find('Artifact').text.strip('%')) / 100 # Artifact is reported as a percentage
except AttributeError as e:
stats['AttributeError on Trend Data'] += 1

for field, trend_list in trends.items():
write_to_hd5(name=f'trend_{str.lower(field)}', value=trend_list)
Expand All @@ -900,12 +901,15 @@ def _write_ecg_bike_tensors(ecgs, xml_field, hd5, sample_id, stats):
write_to_hd5(name=f'{str.lower(phase_name)}_duration', value=[phase_duration])

# HR stats
max_hr = _xml_path_to_float(root, './ExerciseMeasurements/MaxHeartRate')
resting_hr = _xml_path_to_float(root, './ExerciseMeasurements/RestingStats/RestHR')
max_pred_hr = _xml_path_to_float(root, './ExerciseMeasurements/MaxPredictedHR')
write_to_hd5(name='max_hr', value=[max_hr])
write_to_hd5(name='resting_hr', value=[resting_hr])
write_to_hd5(name='max_pred_hr', value=[max_pred_hr])
try:
max_hr = _xml_path_to_float(root, './ExerciseMeasurements/MaxHeartRate')
write_to_hd5(name='max_hr', value=[max_hr])
resting_hr = _xml_path_to_float(root, './ExerciseMeasurements/RestingStats/RestHR')
write_to_hd5(name='resting_hr', value=[resting_hr])
max_pred_hr = _xml_path_to_float(root, './ExerciseMeasurements/MaxPredictedHR')
write_to_hd5(name='max_pred_hr', value=[max_pred_hr])
except AttributeError as e:
stats['AttributeError on HR Stats'] += 1


def _write_tensors_from_niftis(folder: str, hd5: h5py.File, field_id: str, stats: Counter):
Expand Down
Loading

0 comments on commit dfd6a23

Please sign in to comment.