From be66ab2b45797df28c1f598d48b85f199a06c768 Mon Sep 17 00:00:00 2001 From: larstwi <101893533+larstwi@users.noreply.github.com> Date: Fri, 21 Apr 2023 14:37:45 +0200 Subject: [PATCH 1/7] added sampler for midline extension probability --- lyscripts/evaluate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lyscripts/evaluate.py b/lyscripts/evaluate.py index a5af641..b8cde93 100644 --- a/lyscripts/evaluate.py +++ b/lyscripts/evaluate.py @@ -156,7 +156,7 @@ def main(args: argparse.Namespace): params = load_yaml_params(args.params) model = create_model_from_config(params) - ndim = len(model.spread_probs) + model.diag_time_dists.num_parametric + ndim = len(model.spread_probs) + model.diag_time_dists.num_parametric + 1 is_uni = isinstance(model, lymph.Unilateral) data = load_data_for_model(args.data, header_rows=[0,1] if is_uni else [0,1,2]) From 245ba6babcc2466fe502cc03105f42136134e38a Mon Sep 17 00:00:00 2001 From: larstwi <101893533+larstwi@users.noreply.github.com> Date: Fri, 21 Apr 2023 14:39:40 +0200 Subject: [PATCH 2/7] added sampler for midline extension probability and label to corner plot --- lyscripts/data/generate.py | 2 +- lyscripts/plot/corner.py | 2 ++ lyscripts/sample.py | 2 +- tests/_sample.py | 2 +- 4 files changed, 5 insertions(+), 3 deletions(-) diff --git a/lyscripts/data/generate.py b/lyscripts/data/generate.py index dac613e..b384574 100644 --- a/lyscripts/data/generate.py +++ b/lyscripts/data/generate.py @@ -96,7 +96,7 @@ def main(args: argparse.Namespace): """ params = load_yaml_params(args.params) model = create_model_from_config(params) - ndim = len(model.spread_probs) + model.diag_time_dists.num_parametric + ndim = len(model.spread_probs) + model.diag_time_dists.num_parametric + 1 if args.set_theta is not None: with report.status("Assign given parameters to model..."): diff --git a/lyscripts/plot/corner.py b/lyscripts/plot/corner.py index d54140c..6b02f28 100644 --- a/lyscripts/plot/corner.py +++ b/lyscripts/plot/corner.py @@ -102,12 +102,14 @@ def get_param_labels( "mixing $\\alpha$", *trans_labels, *binom_labels, + "midext_prob", ] if model.use_mixing else [ *base_ipsi_labels, *base_contra_ext_labels, *base_contra_noext_labels, *trans_labels, *binom_labels, + "midext_prob", ] diff --git a/lyscripts/sample.py b/lyscripts/sample.py index fea65e7..b4618c6 100644 --- a/lyscripts/sample.py +++ b/lyscripts/sample.py @@ -418,7 +418,7 @@ def main(args: argparse.Namespace): ), ) MODEL.patient_data = inference_data - ndim = len(MODEL.spread_probs) + MODEL.diag_time_dists.num_parametric + ndim = len(MODEL.spread_probs) + MODEL.diag_time_dists.num_parametric + 1 nwalkers = ndim * params["sampling"]["walkers_per_dim"] thin_by = params["sampling"]["thin_by"] report.success( diff --git a/tests/_sample.py b/tests/_sample.py index 7f6b116..8883509 100644 --- a/tests/_sample.py +++ b/tests/_sample.py @@ -66,7 +66,7 @@ def test_sampling( ): """Test the basic sampling function.""" model.patient_data = data - ndim = len(model.spread_probs) + model.diag_time_dists.num_parametric + ndim = len(model.spread_probs) + model.diag_time_dists.num_parametric + 1 nwalker = ndim * params["sampling"]["walkers_per_dim"] info = run_mcmc_with_burnin( From cdba6a974c5589a8a25bc1d4d4b8d3815b020c43 Mon Sep 17 00:00:00 2001 From: larstwi <101893533+larstwi@users.noreply.github.com> Date: Tue, 25 Apr 2023 01:20:33 +0200 Subject: [PATCH 3/7] added dimension for midext parameter and adjusted corner plot --- .DS_Store | Bin 0 -> 6148 bytes 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 .DS_Store diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..cf0519a8bcd9b29c02c09b3ce4a4931282865d40 GIT binary patch literal 6148 zcmeHK%TB{U3>-rwin#R1alZhGKUh`e3-|%lRtiXwf=GMdmRoeheB7Zfuo~+I=I*fK-AxOGhW*)K`a^|*1*w`5t=xa=v0XjL!8cd ziMkp%IyxN^!-vG3B_ Date: Tue, 25 Apr 2023 15:43:44 +0200 Subject: [PATCH 4/7] new class MidlineBilateraltime for time evolution of Midline extension --- lyscripts/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lyscripts/utils.py b/lyscripts/utils.py index 44b6e57..5359429 100644 --- a/lyscripts/utils.py +++ b/lyscripts/utils.py @@ -370,7 +370,7 @@ def model_from_config( graph_params: Dict[str, Any], model_params: Dict[str, Any], modalities_params: Optional[Dict[str, Any]] = None, -) -> Union[lymph.Unilateral, lymph.Bilateral, lymph.MidlineBilateral]: +) -> Union[lymph.Unilateral, lymph.Bilateral, lymph.MidlineBilateral, lymph.MidlineBilateraltime]: """Create a model instance as defined by some YAML params.""" graph = graph_from_config(graph_params) @@ -448,7 +448,7 @@ def get_lnls(model: LymphModel) -> List[str]: return [lnl.name for lnl in model.lnls] if isinstance(model, lymph.Bilateral): return [lnl.name for lnl in model.ipsi.lnls] - if isinstance(model, lymph.MidlineBilateral): + if isinstance(model, (lymph.MidlineBilateral, lymph.MidlineBilateraltime)): return [lnl.name for lnl in model.ext.ipsi.lnls] raise TypeError(f"Model cannot be of type {type(model)}") From 649e548e0bcfad8fb4f4476c9388bf3625005644 Mon Sep 17 00:00:00 2001 From: larstwi <101893533+larstwi@users.noreply.github.com> Date: Tue, 25 Apr 2023 15:44:39 +0200 Subject: [PATCH 5/7] new class MidlineBilateraltime for time evolution of Midline extension --- lyscripts/data/clean.py | 2 ++ lyscripts/plot/corner.py | 27 ++++++++++++++++++++++++++- lyscripts/predict/prevalences.py | 3 ++- lyscripts/predict/risks.py | 2 +- tests/predict/prevalences_test.py | 4 +++- 5 files changed, 34 insertions(+), 4 deletions(-) diff --git a/lyscripts/data/clean.py b/lyscripts/data/clean.py index a2020fa..32f2122 100644 --- a/lyscripts/data/clean.py +++ b/lyscripts/data/clean.py @@ -102,6 +102,8 @@ class that is later supposed to load the data. if class_name == "MidlineBilateral": diagnostic_data[("info", "tumor", "midline_extension")] = midline_extension_data + if class_name == "MidlineBilateraltime": + diagnostic_data[("info", "tumor", "midline_extension")] = midline_extension_data elif class_name == "Unilateral": diagnostic_data = diagnostic_data.drop(columns=["contra"], level=1) diagnostic_data.columns = diagnostic_data.columns.droplevel(1) diff --git a/lyscripts/plot/corner.py b/lyscripts/plot/corner.py index 6b02f28..a5df07c 100644 --- a/lyscripts/plot/corner.py +++ b/lyscripts/plot/corner.py @@ -55,7 +55,7 @@ def _add_arguments(parser: argparse.ArgumentParser): def get_param_labels( - model: Union[lymph.Unilateral, lymph.Bilateral, lymph.MidlineBilateral], + model: Union[lymph.Unilateral, lymph.Bilateral, lymph.MidlineBilateral, lymph.MidlineBilateraltime], ) -> List[str]: """Create labels from a `model`. @@ -111,6 +111,31 @@ def get_param_labels( *binom_labels, "midext_prob", ] + + if isinstance(model, lymph.MidlineBilateraltime): + base_ipsi_labels = [f"i {e.start}->{e.end}" for e in model.ext.ipsi.base_edges] + base_contra_ext_labels = [ + f"ce {e.start}->{e.end}" for e in model.ext.contra.base_edges + ] + base_contra_noext_labels = [ + f"cn {e.start}->{e.end}" for e in model.noext.contra.base_edges + ] + trans_labels = [f"{e.start}->{e.end}" for e in model.ext.ipsi.trans_edges] + return [ + *base_ipsi_labels, + *base_contra_noext_labels, + "mixing $\\alpha$", + *trans_labels, + *binom_labels, + "midext_trans", + ] if model.use_mixing else [ + *base_ipsi_labels, + *base_contra_ext_labels, + *base_contra_noext_labels, + *trans_labels, + *binom_labels, + "midext_trans", + ] def main(args: argparse.Namespace): diff --git a/lyscripts/predict/prevalences.py b/lyscripts/predict/prevalences.py index 5df58f0..907fa8e 100644 --- a/lyscripts/predict/prevalences.py +++ b/lyscripts/predict/prevalences.py @@ -259,7 +259,7 @@ def compute_predicted_prevalence( If `midline_ext` is set to `None`, the prevalence is marginalized over both cases, assuming the provided `midline_ext_prob`. """ - if isinstance(loaded_model, lymph.MidlineBilateral): + if isinstance(loaded_model, (lymph.MidlineBilateral, lymph.MidlineBilateraltime): loaded_model.check_and_assign(given_params) if midline_ext is None: # marginalize over patients with and without midline extension @@ -271,6 +271,7 @@ def compute_predicted_prevalence( prevalence = loaded_model.ext.likelihood(log=False) else: prevalence = loaded_model.noext.likelihood(log=False) + else: prevalence = loaded_model.likelihood( given_params=given_params, diff --git a/lyscripts/predict/risks.py b/lyscripts/predict/risks.py index ec354ed..ea9bae9 100644 --- a/lyscripts/predict/risks.py +++ b/lyscripts/predict/risks.py @@ -118,7 +118,7 @@ def predicted_risk( ) yield 1. - risk if invert else risk - elif isinstance(model, (lymph.Bilateral, lymph.MidlineBilateral)): + elif isinstance(model, (lymph.Bilateral, lymph.MidlineBilateral, lymph.MidlineBilateraltime)): given_diagnosis = {"risk": given_diagnosis} for sample in samples: diff --git a/tests/predict/prevalences_test.py b/tests/predict/prevalences_test.py index 8278d5e..423eac4 100644 --- a/tests/predict/prevalences_test.py +++ b/tests/predict/prevalences_test.py @@ -24,15 +24,17 @@ def test_get_lnls(): uni_model = lymph.Unilateral(graph) bi_model = lymph.Bilateral(graph) mid_model = lymph.MidlineBilateral(graph) + midtime_model = lymph.MidlineBilateraltime(graph) uni_lnls = get_lnls(uni_model) bi_lnls = get_lnls(bi_model) mid_lnls = get_lnls(mid_model) + midtime_lnls = get_lnls(midtime_model) assert uni_lnls == lnls, "Did not extract LNLs correctly from unilateral model" assert bi_lnls == lnls, "Did not extract LNLs correctly from bilateral model" assert mid_lnls == lnls, "Did not extract LNLs correctly from midline model" - + assert midtime_lnls == lnls, "Did not extract LNLs correctly from midline time evolution model" def test_get_match_idx(): """Test if the pattern dictionaries & pandas data are compared correctly.""" From aa348ef4796ca573f0e3540817e12ad6b66ae8cc Mon Sep 17 00:00:00 2001 From: larstwi <101893533+larstwi@users.noreply.github.com> Date: Tue, 25 Apr 2023 16:07:07 +0200 Subject: [PATCH 6/7] bug fixes --- lyscripts/.DS_Store | Bin 0 -> 6148 bytes lyscripts/predict/prevalences.py | 2 +- 2 files changed, 1 insertion(+), 1 deletion(-) create mode 100644 lyscripts/.DS_Store diff --git a/lyscripts/.DS_Store b/lyscripts/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..7fdec4623181a1dd222bf6941e176628c56528c2 GIT binary patch literal 6148 zcmeH~JqiLr422WjLa^D=avBfd4F=H@cmaR56fD$!j_%73f~&QNyg>3zG82}4#m+`V zbbTLIBE5*r;6_lA49C} z?O@4sHQ9pEE}FxK=9AT?7??)8Xh8zg>R_M(RA8jQH1gi=|1JE}{6A`8N(HFEpDCc- zX17`6rSfildp)b~vuf)G2mLt0%TE9jyNVZZH|!T%fHm2IsKEFm;4&~!fv+m?01d Date: Thu, 26 Oct 2023 04:45:50 +0200 Subject: [PATCH 7/7] implemented prevalence prediction for midext time evolution model --- lyscripts/predict/prevalences.py | 219 +++++++++++++++++++++++++------ 1 file changed, 182 insertions(+), 37 deletions(-) diff --git a/lyscripts/predict/prevalences.py b/lyscripts/predict/prevalences.py index 04bc1f4..26591c9 100644 --- a/lyscripts/predict/prevalences.py +++ b/lyscripts/predict/prevalences.py @@ -10,6 +10,7 @@ [`lynference`](https://github.com/rmnldwg/lynference) repository. """ import argparse +import logging from pathlib import Path from typing import Dict, Generator, List, Optional @@ -19,6 +20,7 @@ import pandas as pd from rich.progress import track +from lyscripts.decorators import log_state from lyscripts.predict.utils import complete_pattern from lyscripts.utils import ( LymphModel, @@ -31,6 +33,8 @@ report, ) +logger = logging.getLogger(__name__) + def _add_parser( subparsers: argparse._SubParsersAction, @@ -110,10 +114,17 @@ def get_match_idx( def does_t_stage_match(data: pd.DataFrame, t_stage: str) -> pd.Index: """Return the indices of the `data` where the `t_stage` of the patients matches.""" - if data.columns.nlevels == 2: - return data["info", "t_stage"] == t_stage - elif data.columns.nlevels == 3: - return data["info", "tumor", "t_stage"] == t_stage + if data.columns.nlevels == 3: + if t_stage=="early/late": + return data[("info","tumor", "t_stage")].isin(["early", "late"]) + else: + return data["info", "tumor", "t_stage"] == t_stage + + elif data.columns.nlevels == 2: + if t_stage=="early/late": + return data[("info", "t_stage")].isin(["early", "late"]) + else: + return data["info", "t_stage"] == t_stage else: raise ValueError("Data has neither 2 nor 3 header rows") @@ -147,6 +158,28 @@ def get_midline_ext_prob(data: pd.DataFrame, t_stage: str) -> float: matching_data = eligible_data[has_matching_midline_ext] return len(matching_data) / len(eligible_data) +def calculate_midline_ext_prob(diag_prob, midline_ext_prob_rates): + num_timesteps = len(diag_prob) + cumulative_probability = 0.0 + + for diagnosis_timestep in range(num_timesteps): + cumulative_probability_at_diagnosis = 1.0 + + for t in range(diagnosis_timestep): + cumulative_probability_at_diagnosis *= (1 - midline_ext_prob_rates[t]) + + cumulative_probability_at_diagnosis *= diag_prob[diagnosis_timestep] + cumulative_probability += cumulative_probability_at_diagnosis + + return 1 - cumulative_probability + +def get_early_prob(data: pd.DataFrame) -> float: + """Get the prevalence of midline extension from `data` for `t_stage`.""" + + has_matching_t_stage = does_t_stage_match(data, "early") + matching_data = data[has_matching_t_stage] + return len(matching_data) / len(data) + def create_patient_row( pattern: Dict[str, Dict[str, bool]], @@ -165,24 +198,56 @@ def create_patient_row( if make_unilateral: flat_pattern = flatten({"prev": pattern["ipsi"]}) patient_row = pd.DataFrame(flat_pattern, index=[0]) - patient_row["info", "t_stage"] = t_stage - return patient_row - - flat_pattern = flatten({"prev": pattern}) - patient_row = pd.DataFrame(flat_pattern, index=[0]) - patient_row["info", "tumor", "t_stage"] = t_stage - if midline_ext is not None: - patient_row["info", "tumor", "midline_extension"] = midline_ext - return patient_row - - with_midline_ext = patient_row.copy() - with_midline_ext["info", "tumor", "midline_extension"] = True - without_midline_ext = patient_row.copy() - without_midline_ext["info", "tumor", "midline_extension"] = False - - return with_midline_ext.append(without_midline_ext).reset_index() + if t_stage != "early/late": + patient_row["info", "t_stage"] = t_stage + return patient_row + else: + early_tstage = patient_row.copy() + early_tstage["info", "t_stage"] = "early" + late_tstage = patient_row.copy() + late_tstage["info", "t_stage"] = "late" + return pd.concat([early_tstage, late_tstage], ignore_index=True) + elif t_stage != "early/late": + flat_pattern = flatten({"prev": pattern}) + patient_row = pd.DataFrame(flat_pattern, index=[0]) + patient_row["info", "tumor", "t_stage"] = t_stage + if midline_ext is not None: + patient_row["info", "tumor", "midline_extension"] = midline_ext + return patient_row + + with_midline_ext = patient_row.copy() + with_midline_ext["info", "tumor", "midline_extension"] = True + without_midline_ext = patient_row.copy() + without_midline_ext["info", "tumor", "midline_extension"] = False + + return pd.concat([with_midline_ext, without_midline_ext], ignore_index=True) + + else: + flat_pattern = flatten({"prev": pattern}) + patient_row = pd.DataFrame(flat_pattern, index=[0]) + early_tstage = patient_row.copy() + early_tstage["info", "tumor", "t_stage"] = "early" + late_tstage = patient_row.copy() + late_tstage["info", "tumor", "t_stage"] = "late" + if midline_ext is not None: + early_tstage["info", "tumor", "midline_extension"] = midline_ext + late_tstage["info", "tumor", "midline_extension"] = midline_ext + return pd.concat([early_tstage, late_tstage], ignore_index=True) + + early_with_midline_ext = early_tstage.copy() + early_with_midline_ext["info", "tumor", "midline_extension"] = True + early_without_midline_ext = early_tstage.copy() + early_without_midline_ext["info", "tumor", "midline_extension"] = False + late_with_midline_ext = late_tstage.copy() + late_with_midline_ext["info", "tumor", "midline_extension"] = True + late_without_midline_ext = late_tstage.copy() + late_without_midline_ext["info", "tumor", "midline_extension"] = False + + return pd.concat([early_with_midline_ext, late_with_midline_ext, early_without_midline_ext, late_without_midline_ext], ignore_index=True) + +@log_state(logger=logger) def compute_observed_prevalence( pattern: Dict[str, Dict[str, bool]], data: pd.DataFrame, @@ -246,7 +311,9 @@ def compute_predicted_prevalence( loaded_model: LymphModel, given_params: np.ndarray, midline_ext: bool, + t_stage: str, midline_ext_prob: float = 0.3, + early_prob: float = 0.5 ) -> float: """ Given a `loaded_model` with loaded patient data and modalities, compute the @@ -259,28 +326,102 @@ def compute_predicted_prevalence( If `midline_ext` is set to `None`, the prevalence is marginalized over both cases, assuming the provided `midline_ext_prob`. """ - if isinstance(loaded_model, (lymph.MidlineBilateral, lymph.MidlineBilateraltime)): + if isinstance(loaded_model, lymph.MidlineBilateral): loaded_model.check_and_assign(given_params) if midline_ext is None: # marginalize over patients with and without midline extension - prevalence = ( - midline_ext_prob * loaded_model.ext.likelihood(log=False) + - (1. - midline_ext_prob) * loaded_model.noext.likelihood(log=False) + #only correct with new code of time evolution over midline extension + if t_stage=="early/late": + early_llhs = loaded_model.likelihood(log=False, t_stages=["early"], given_params=given_params, prevalence_calc=True) + late_llhs = loaded_model.likelihood(log=False, t_stages=["late"], given_params=given_params, prevalence_calc=True) + prevalence = ( + early_prob * early_llhs[0] + + early_prob * early_llhs[1] + + (1 - early_prob) * late_llhs[0] + + (1 - early_prob) * late_llhs[1] ) + else: + llhs = loaded_model.likelihood(log=False, given_params=given_params, prevalence_calc=True) + prevalence = llhs[0] + llhs[1] + elif midline_ext: - prevalence = loaded_model.ext.likelihood(log=False) + if t_stage=="early/late": + midline_ext_prob_early = calculate_midline_ext_prob( + loaded_model.ext.ipsi.diag_time_dists['early'].pmf, + ([given_params[-2]] * len(loaded_model.ext.ipsi.diag_time_dists['early'].pmf)) + ) + midline_ext_prob_late = calculate_midline_ext_prob( + loaded_model.ext.ipsi.diag_time_dists['late'].pmf, + ([given_params[-2]] * len(loaded_model.ext.ipsi.diag_time_dists['late'].pmf)) + ) + prevalence = ( + early_prob * loaded_model.likelihood( + log=False, + given_params=given_params, + t_stages=["early"], + prevalence_calc=True + )/midline_ext_prob_early + + (1 - early_prob) * loaded_model.likelihood( + log=False, + given_params=given_params, + t_stages=["late"], + prevalence_calc=True + )/midline_ext_prob_late + ) + else: + midline_ext_prob = calculate_midline_ext_prob( + loaded_model.ext.ipsi.diag_time_dists[t_stage].pmf, + ([given_params[-2]] * len(loaded_model.ext.ipsi.diag_time_dists[t_stage].pmf)) + ) + prevalence = loaded_model.likelihood(log=False, given_params=given_params, prevalence_calc=True)/midline_ext_prob else: - prevalence = loaded_model.noext.likelihood(log=False) - + if t_stage=="early/late": + midline_ext_prob_early = calculate_midline_ext_prob( + loaded_model.ext.ipsi.diag_time_dists['early'].pmf, + ([given_params[-2]] * len(loaded_model.ext.ipsi.diag_time_dists['early'].pmf)) + ) + midline_ext_prob_late = calculate_midline_ext_prob( + loaded_model.ext.ipsi.diag_time_dists['late'].pmf, + ([given_params[-2]] * len(loaded_model.ext.ipsi.diag_time_dists['late'].pmf)) + ) + prevalence = ( + early_prob * loaded_model.likelihood( + log=False, + given_params=given_params, + t_stages=["early"], + prevalence_calc=True + )/(1-midline_ext_prob_early) + + (1 - early_prob) * loaded_model.likelihood( + log=False, + given_params=given_params, + t_stages=["late"], + prevalence_calc=True + )/(1-midline_ext_prob_late) + ) + else: + midline_ext_prob = calculate_midline_ext_prob( + loaded_model.ext.ipsi.diag_time_dists[t_stage].pmf, + ([given_params[-2]] * len(loaded_model.ext.ipsi.diag_time_dists[t_stage].pmf)) + ) + prevalence = loaded_model.likelihood(log=False, given_params=given_params, prevalence_calc=True)/(1-midline_ext_prob) else: - prevalence = loaded_model.likelihood( - given_params=given_params, - log=False, - ) - + if t_stage=="early/late": + prevalence = early_prob * loaded_model.likelihood( + given_params=given_params, + log=False, t_stages=["early"] + ) + (1-early_prob) * loaded_model.likelihood( + given_params=given_params, + log=False, t_stages=["late"] + ) + else: + prevalence = loaded_model.likelihood( + given_params=given_params, + log=False + ) return prevalence +@log_state(logger=logger) def generate_predicted_prevalences( pattern: Dict[str, Dict[str, bool]], model: LymphModel, @@ -290,6 +431,7 @@ def generate_predicted_prevalences( midline_ext_prob: float = 0.3, modality_spsn: Optional[List[float]] = None, invert: bool = False, + early_prob: float = 0.5, **_kwargs, ) -> Generator[float, None, None]: """Compute the prevalence of a given `pattern` of lymphatic progression using a @@ -323,6 +465,8 @@ def generate_predicted_prevalences( given_params=sample, midline_ext=midline_ext, midline_ext_prob=midline_ext_prob, + t_stage=t_stage, + early_prob = early_prob, ) yield (1. - prevalence) if invert else prevalence @@ -354,12 +498,12 @@ def main(args: argparse.Namespace): --params PARAMS Path to parameter file (default: ./params.yaml) ``` """ - params = load_yaml_params(args.params) - model = create_model_from_config(params) - samples = load_hdf5_samples(args.model) + params = load_yaml_params(args.params, logger=logger) + model = create_model_from_config(params, logger=logger) + samples = load_hdf5_samples(args.model, logger=logger) header_rows = [0,1] if isinstance(model, lymph.Unilateral) else [0,1,2] - data = load_data_for_model(args.data, header_rows) + data = load_data_for_model(args.data, header_rows, logger=logger) args.output.parent.mkdir(exist_ok=True) num_prevalences = len(params["prevalences"]) @@ -369,6 +513,7 @@ def main(args: argparse.Namespace): model=model, samples=samples[::args.thin], midline_ext_prob=get_midline_ext_prob(data, scenario["t_stage"]), + early_prob=get_early_prob(data), **scenario ) prevs_progress = track( @@ -397,7 +542,7 @@ def main(args: argparse.Namespace): prevs_h5dset.attrs["num_match"] = num_match prevs_h5dset.attrs["num_total"] = num_total - report.success( + logger.info( f"Computed prevalences of {num_prevalences} scenarios stored at " f"{args.output}" )