Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added .DS_Store
Binary file not shown.
Binary file added lyscripts/.DS_Store
Binary file not shown.
2 changes: 2 additions & 0 deletions lyscripts/data/clean.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,8 @@ class that is later supposed to load the data.

if class_name == "MidlineBilateral":
diagnostic_data[("info", "tumor", "midline_extension")] = midline_extension_data
if class_name == "MidlineBilateraltime":
diagnostic_data[("info", "tumor", "midline_extension")] = midline_extension_data
elif class_name == "Unilateral":
diagnostic_data = diagnostic_data.drop(columns=["contra"], level=1)
diagnostic_data.columns = diagnostic_data.columns.droplevel(1)
Expand Down
2 changes: 1 addition & 1 deletion lyscripts/data/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def main(args: argparse.Namespace):
"""
params = load_yaml_params(args.params)
model = create_model_from_config(params)
ndim = len(model.spread_probs) + model.diag_time_dists.num_parametric
ndim = len(model.spread_probs) + model.diag_time_dists.num_parametric + 1

if args.set_theta is not None:
with report.status("Assign given parameters to model..."):
Expand Down
2 changes: 1 addition & 1 deletion lyscripts/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def main(args: argparse.Namespace):

params = load_yaml_params(args.params)
model = create_model_from_config(params)
ndim = len(model.spread_probs) + model.diag_time_dists.num_parametric
ndim = len(model.spread_probs) + model.diag_time_dists.num_parametric + 1
is_uni = isinstance(model, lymph.Unilateral)

data = load_data_for_model(args.data, header_rows=[0,1] if is_uni else [0,1,2])
Expand Down
29 changes: 28 additions & 1 deletion lyscripts/plot/corner.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def _add_arguments(parser: argparse.ArgumentParser):


def get_param_labels(
model: Union[lymph.Unilateral, lymph.Bilateral, lymph.MidlineBilateral],
model: Union[lymph.Unilateral, lymph.Bilateral, lymph.MidlineBilateral, lymph.MidlineBilateraltime],
) -> List[str]:
"""Create labels from a `model`.

Expand Down Expand Up @@ -102,12 +102,39 @@ def get_param_labels(
"mixing $\\alpha$",
*trans_labels,
*binom_labels,
"midext_prob",
] if model.use_mixing else [
*base_ipsi_labels,
*base_contra_ext_labels,
*base_contra_noext_labels,
*trans_labels,
*binom_labels,
"midext_prob",
]

if isinstance(model, lymph.MidlineBilateraltime):
base_ipsi_labels = [f"i {e.start}->{e.end}" for e in model.ext.ipsi.base_edges]
base_contra_ext_labels = [
f"ce {e.start}->{e.end}" for e in model.ext.contra.base_edges
]
base_contra_noext_labels = [
f"cn {e.start}->{e.end}" for e in model.noext.contra.base_edges
]
trans_labels = [f"{e.start}->{e.end}" for e in model.ext.ipsi.trans_edges]
return [
*base_ipsi_labels,
*base_contra_noext_labels,
"mixing $\\alpha$",
*trans_labels,
*binom_labels,
"midext_trans",
] if model.use_mixing else [
*base_ipsi_labels,
*base_contra_ext_labels,
*base_contra_noext_labels,
*trans_labels,
*binom_labels,
"midext_trans",
]


Expand Down
216 changes: 181 additions & 35 deletions lyscripts/predict/prevalences.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
[`lynference`](https://github.com/rmnldwg/lynference) repository.
"""
import argparse
import logging
from pathlib import Path
from typing import Dict, Generator, List, Optional

Expand All @@ -19,6 +20,7 @@
import pandas as pd
from rich.progress import track

from lyscripts.decorators import log_state
from lyscripts.predict.utils import complete_pattern
from lyscripts.utils import (
LymphModel,
Expand All @@ -31,6 +33,8 @@
report,
)

logger = logging.getLogger(__name__)


def _add_parser(
subparsers: argparse._SubParsersAction,
Expand Down Expand Up @@ -110,10 +114,17 @@ def get_match_idx(

def does_t_stage_match(data: pd.DataFrame, t_stage: str) -> pd.Index:
"""Return the indices of the `data` where the `t_stage` of the patients matches."""
if data.columns.nlevels == 2:
return data["info", "t_stage"] == t_stage
elif data.columns.nlevels == 3:
return data["info", "tumor", "t_stage"] == t_stage
if data.columns.nlevels == 3:
if t_stage=="early/late":
return data[("info","tumor", "t_stage")].isin(["early", "late"])
else:
return data["info", "tumor", "t_stage"] == t_stage

elif data.columns.nlevels == 2:
if t_stage=="early/late":
return data[("info", "t_stage")].isin(["early", "late"])
else:
return data["info", "t_stage"] == t_stage
else:
raise ValueError("Data has neither 2 nor 3 header rows")

Expand Down Expand Up @@ -147,6 +158,28 @@ def get_midline_ext_prob(data: pd.DataFrame, t_stage: str) -> float:
matching_data = eligible_data[has_matching_midline_ext]
return len(matching_data) / len(eligible_data)

def calculate_midline_ext_prob(diag_prob, midline_ext_prob_rates):
num_timesteps = len(diag_prob)
cumulative_probability = 0.0

for diagnosis_timestep in range(num_timesteps):
cumulative_probability_at_diagnosis = 1.0

for t in range(diagnosis_timestep):
cumulative_probability_at_diagnosis *= (1 - midline_ext_prob_rates[t])

cumulative_probability_at_diagnosis *= diag_prob[diagnosis_timestep]
cumulative_probability += cumulative_probability_at_diagnosis

return 1 - cumulative_probability

def get_early_prob(data: pd.DataFrame) -> float:
"""Get the prevalence of midline extension from `data` for `t_stage`."""

has_matching_t_stage = does_t_stage_match(data, "early")
matching_data = data[has_matching_t_stage]
return len(matching_data) / len(data)


def create_patient_row(
pattern: Dict[str, Dict[str, bool]],
Expand All @@ -165,24 +198,56 @@ def create_patient_row(
if make_unilateral:
flat_pattern = flatten({"prev": pattern["ipsi"]})
patient_row = pd.DataFrame(flat_pattern, index=[0])
patient_row["info", "t_stage"] = t_stage
return patient_row

flat_pattern = flatten({"prev": pattern})
patient_row = pd.DataFrame(flat_pattern, index=[0])
patient_row["info", "tumor", "t_stage"] = t_stage
if midline_ext is not None:
patient_row["info", "tumor", "midline_extension"] = midline_ext
return patient_row

with_midline_ext = patient_row.copy()
with_midline_ext["info", "tumor", "midline_extension"] = True
without_midline_ext = patient_row.copy()
without_midline_ext["info", "tumor", "midline_extension"] = False

return with_midline_ext.append(without_midline_ext).reset_index()
if t_stage != "early/late":
patient_row["info", "t_stage"] = t_stage
return patient_row
else:
early_tstage = patient_row.copy()
early_tstage["info", "t_stage"] = "early"
late_tstage = patient_row.copy()
late_tstage["info", "t_stage"] = "late"

return pd.concat([early_tstage, late_tstage], ignore_index=True)

elif t_stage != "early/late":
flat_pattern = flatten({"prev": pattern})
patient_row = pd.DataFrame(flat_pattern, index=[0])
patient_row["info", "tumor", "t_stage"] = t_stage
if midline_ext is not None:
patient_row["info", "tumor", "midline_extension"] = midline_ext
return patient_row

with_midline_ext = patient_row.copy()
with_midline_ext["info", "tumor", "midline_extension"] = True
without_midline_ext = patient_row.copy()
without_midline_ext["info", "tumor", "midline_extension"] = False

return pd.concat([with_midline_ext, without_midline_ext], ignore_index=True)

else:
flat_pattern = flatten({"prev": pattern})
patient_row = pd.DataFrame(flat_pattern, index=[0])
early_tstage = patient_row.copy()
early_tstage["info", "tumor", "t_stage"] = "early"
late_tstage = patient_row.copy()
late_tstage["info", "tumor", "t_stage"] = "late"
if midline_ext is not None:
early_tstage["info", "tumor", "midline_extension"] = midline_ext
late_tstage["info", "tumor", "midline_extension"] = midline_ext
return pd.concat([early_tstage, late_tstage], ignore_index=True)

early_with_midline_ext = early_tstage.copy()
early_with_midline_ext["info", "tumor", "midline_extension"] = True
early_without_midline_ext = early_tstage.copy()
early_without_midline_ext["info", "tumor", "midline_extension"] = False
late_with_midline_ext = late_tstage.copy()
late_with_midline_ext["info", "tumor", "midline_extension"] = True
late_without_midline_ext = late_tstage.copy()
late_without_midline_ext["info", "tumor", "midline_extension"] = False

return pd.concat([early_with_midline_ext, late_with_midline_ext, early_without_midline_ext, late_without_midline_ext], ignore_index=True)

@log_state(logger=logger)
def compute_observed_prevalence(
pattern: Dict[str, Dict[str, bool]],
data: pd.DataFrame,
Expand Down Expand Up @@ -246,7 +311,9 @@ def compute_predicted_prevalence(
loaded_model: LymphModel,
given_params: np.ndarray,
midline_ext: bool,
t_stage: str,
midline_ext_prob: float = 0.3,
early_prob: float = 0.5
) -> float:
"""
Given a `loaded_model` with loaded patient data and modalities, compute the
Expand All @@ -263,23 +330,98 @@ def compute_predicted_prevalence(
loaded_model.check_and_assign(given_params)
if midline_ext is None:
# marginalize over patients with and without midline extension
prevalence = (
midline_ext_prob * loaded_model.ext.likelihood(log=False) +
(1. - midline_ext_prob) * loaded_model.noext.likelihood(log=False)
#only correct with new code of time evolution over midline extension
if t_stage=="early/late":
early_llhs = loaded_model.likelihood(log=False, t_stages=["early"], given_params=given_params, prevalence_calc=True)
late_llhs = loaded_model.likelihood(log=False, t_stages=["late"], given_params=given_params, prevalence_calc=True)
prevalence = (
early_prob * early_llhs[0] +
early_prob * early_llhs[1] +
(1 - early_prob) * late_llhs[0] +
(1 - early_prob) * late_llhs[1]
)
else:
llhs = loaded_model.likelihood(log=False, given_params=given_params, prevalence_calc=True)
prevalence = llhs[0] + llhs[1]

elif midline_ext:
prevalence = loaded_model.ext.likelihood(log=False)
if t_stage=="early/late":
midline_ext_prob_early = calculate_midline_ext_prob(
loaded_model.ext.ipsi.diag_time_dists['early'].pmf,
([given_params[-2]] * len(loaded_model.ext.ipsi.diag_time_dists['early'].pmf))
)
midline_ext_prob_late = calculate_midline_ext_prob(
loaded_model.ext.ipsi.diag_time_dists['late'].pmf,
([given_params[-2]] * len(loaded_model.ext.ipsi.diag_time_dists['late'].pmf))
)
prevalence = (
early_prob * loaded_model.likelihood(
log=False,
given_params=given_params,
t_stages=["early"],
prevalence_calc=True
)/midline_ext_prob_early +
(1 - early_prob) * loaded_model.likelihood(
log=False,
given_params=given_params,
t_stages=["late"],
prevalence_calc=True
)/midline_ext_prob_late
)
else:
midline_ext_prob = calculate_midline_ext_prob(
loaded_model.ext.ipsi.diag_time_dists[t_stage].pmf,
([given_params[-2]] * len(loaded_model.ext.ipsi.diag_time_dists[t_stage].pmf))
)
prevalence = loaded_model.likelihood(log=False, given_params=given_params, prevalence_calc=True)/midline_ext_prob
else:
prevalence = loaded_model.noext.likelihood(log=False)
if t_stage=="early/late":
midline_ext_prob_early = calculate_midline_ext_prob(
loaded_model.ext.ipsi.diag_time_dists['early'].pmf,
([given_params[-2]] * len(loaded_model.ext.ipsi.diag_time_dists['early'].pmf))
)
midline_ext_prob_late = calculate_midline_ext_prob(
loaded_model.ext.ipsi.diag_time_dists['late'].pmf,
([given_params[-2]] * len(loaded_model.ext.ipsi.diag_time_dists['late'].pmf))
)
prevalence = (
early_prob * loaded_model.likelihood(
log=False,
given_params=given_params,
t_stages=["early"],
prevalence_calc=True
)/(1-midline_ext_prob_early) +
(1 - early_prob) * loaded_model.likelihood(
log=False,
given_params=given_params,
t_stages=["late"],
prevalence_calc=True
)/(1-midline_ext_prob_late)
)
else:
midline_ext_prob = calculate_midline_ext_prob(
loaded_model.ext.ipsi.diag_time_dists[t_stage].pmf,
([given_params[-2]] * len(loaded_model.ext.ipsi.diag_time_dists[t_stage].pmf))
)
prevalence = loaded_model.likelihood(log=False, given_params=given_params, prevalence_calc=True)/(1-midline_ext_prob)
else:
prevalence = loaded_model.likelihood(
given_params=given_params,
log=False,
)

if t_stage=="early/late":
prevalence = early_prob * loaded_model.likelihood(
given_params=given_params,
log=False, t_stages=["early"]
) + (1-early_prob) * loaded_model.likelihood(
given_params=given_params,
log=False, t_stages=["late"]
)
else:
prevalence = loaded_model.likelihood(
given_params=given_params,
log=False
)
return prevalence


@log_state(logger=logger)
def generate_predicted_prevalences(
pattern: Dict[str, Dict[str, bool]],
model: LymphModel,
Expand All @@ -289,6 +431,7 @@ def generate_predicted_prevalences(
midline_ext_prob: float = 0.3,
modality_spsn: Optional[List[float]] = None,
invert: bool = False,
early_prob: float = 0.5,
**_kwargs,
) -> Generator[float, None, None]:
"""Compute the prevalence of a given `pattern` of lymphatic progression using a
Expand Down Expand Up @@ -322,6 +465,8 @@ def generate_predicted_prevalences(
given_params=sample,
midline_ext=midline_ext,
midline_ext_prob=midline_ext_prob,
t_stage=t_stage,
early_prob = early_prob,
)
yield (1. - prevalence) if invert else prevalence

Expand Down Expand Up @@ -353,12 +498,12 @@ def main(args: argparse.Namespace):
--params PARAMS Path to parameter file (default: ./params.yaml)
```
"""
params = load_yaml_params(args.params)
model = create_model_from_config(params)
samples = load_hdf5_samples(args.model)
params = load_yaml_params(args.params, logger=logger)
model = create_model_from_config(params, logger=logger)
samples = load_hdf5_samples(args.model, logger=logger)

header_rows = [0,1] if isinstance(model, lymph.Unilateral) else [0,1,2]
data = load_data_for_model(args.data, header_rows)
data = load_data_for_model(args.data, header_rows, logger=logger)

args.output.parent.mkdir(exist_ok=True)
num_prevalences = len(params["prevalences"])
Expand All @@ -368,6 +513,7 @@ def main(args: argparse.Namespace):
model=model,
samples=samples[::args.thin],
midline_ext_prob=get_midline_ext_prob(data, scenario["t_stage"]),
early_prob=get_early_prob(data),
**scenario
)
prevs_progress = track(
Expand Down Expand Up @@ -396,7 +542,7 @@ def main(args: argparse.Namespace):
prevs_h5dset.attrs["num_match"] = num_match
prevs_h5dset.attrs["num_total"] = num_total

report.success(
logger.info(
f"Computed prevalences of {num_prevalences} scenarios stored at "
f"{args.output}"
)
Expand Down
Loading