From 1e125ed9207a4e4dd7d47bf6f10d98b1e9ae475e Mon Sep 17 00:00:00 2001 From: rmnldwg <48687784+rmnldwg@users.noreply.github.com> Date: Fri, 29 Dec 2023 10:30:11 +0100 Subject: [PATCH] fix(uni): update `draw_patients()` method Update the method to generate synthetic data from the parametrized model in the same LyProX format that is also required to load the data in the first place. Fixes: #65 --- lymph/diagnose_times.py | 18 +++++++-- lymph/models/bilateral.py | 4 +- lymph/models/unilateral.py | 75 ++++++++++++++++++++++---------------- 3 files changed, 61 insertions(+), 36 deletions(-) diff --git a/lymph/diagnose_times.py b/lymph/diagnose_times.py index 892c9fa..7a4b572 100644 --- a/lymph/diagnose_times.py +++ b/lymph/diagnose_times.py @@ -231,9 +231,21 @@ def set_params(self, **kwargs) -> None: warnings.warn("Distribution is not updateable, skipping...") - def draw(self) -> np.ndarray: - """Draw sample of diagnose times from the PMF.""" - return np.random.choice(a=self.support, p=self.distribution) + def draw_diag_times( + self, + num: int | None = None, + rng: np.random.Generator | None = None, + seed: int = 42, + ) -> np.ndarray: + """Draw ``num`` samples of diagnose times from the stored PMF. + + A random number generator can be provided as ``rng``. If ``None``, a new one + is initialized with the given ``seed`` (or ``42``, by default). + """ + if rng is None: + rng = np.random.default_rng(seed) + + return rng.choice(a=self.support, p=self.distribution, size=num) class DistributionsUserDict(AbstractLookupDict): diff --git a/lymph/models/bilateral.py b/lymph/models/bilateral.py index bb49cf6..e710f68 100644 --- a/lymph/models/bilateral.py +++ b/lymph/models/bilateral.py @@ -712,8 +712,8 @@ def generate_dataset( dist=stage_dist, size=num_patients ) - drawn_obs_ipsi = self.ipsi._draw_patient_diagnoses(drawn_diag_times) - drawn_obs_contra = self.contra._draw_patient_diagnoses(drawn_diag_times) + drawn_obs_ipsi = self.ipsi.draw_diagnoses(drawn_diag_times) + drawn_obs_contra = self.contra.draw_diagnoses(drawn_diag_times) drawn_obs = np.concatenate([drawn_obs_ipsi, drawn_obs_contra], axis=1) # construct MultiIndex for dataset from stored modalities diff --git a/lymph/models/unilateral.py b/lymph/models/unilateral.py index e5ea940..2323e5f 100644 --- a/lymph/models/unilateral.py +++ b/lymph/models/unilateral.py @@ -945,58 +945,71 @@ def risk( return marginalize_over_states @ posterior_state_dist - def _draw_patient_diagnoses( + def draw_diagnoses( self, diag_times: list[int], + rng: np.random.Generator | None = None, + seed: int = 42, ) -> np.ndarray: - """Draw random possible observations for a list of T-stages and - diagnose times. + """Given some ``diag_times``, draw diagnoses for each LNL.""" + if rng is None: + rng = np.random.default_rng(seed) - Args: - diag_times: List of diagnose times for each patient who's diagnose - is supposed to be drawn. - """ - # use the drawn diagnose times to compute probabilities over states and - # diagnoses - per_time_state_probs = self.comp_dist_evolution() - per_patient_state_probs = per_time_state_probs[diag_times] - per_patient_obs_probs = per_patient_state_probs @ self.observation_matrix - - # then, draw a diagnose from the possible ones - obs_idx = np.arange(len(self.obs_list)) + state_probs_given_time = self.comp_dist_evolution()[diag_times] + obs_probs_given_time = state_probs_given_time @ self.observation_matrix + + obs_indices = np.arange(len(self.obs_list)) drawn_obs_idx = [ - np.random.choice(obs_idx, p=obs_prob) - for obs_prob in per_patient_obs_probs + np.random.choice(obs_indices, p=obs_prob) + for obs_prob in obs_probs_given_time ] + return self.obs_list[drawn_obs_idx].astype(bool) - def generate_dataset( + def draw_patients( self, - num_patients: int, - stage_dist: dict[str, float], + num: int, + stage_dist: Iterable[float], + rng: np.random.Generator | None = None, + seed: int = 42, **_kwargs, ) -> pd.DataFrame: - """Generate/sample a pandas :class:`DataFrame` from the defined network - using the samples and diagnostic modalities that have been set. + """Draw a ``num`` random patients from the model. + + For this, a ``stage_dist``, i.e., a distribution over the T-stages, needs to + be defined. This must be an iterable of probabilities with as many elements as + there are defined T-stages in the model's :py:attr:`diag_time_dists` attribute. - Args: - num_patients: Number of patients to generate. - stage_dist: Probability to find a patient in a certain T-stage. + A random number generator can be provided as ``rng``. If ``None``, a new one + is initialized with the given ``seed`` (or ``42``, by default). """ - drawn_t_stages, drawn_diag_times = self.diag_time_dists.draw( - prob_of_t_stage=stage_dist, size=num_patients + if rng is None: + rng = np.random.default_rng(seed) + + if sum(stage_dist) != 1.: + warnings.warn("Sum of stage distribution is not 1. Renormalizing.") + stage_dist = np.array(stage_dist) / sum(stage_dist) + + drawn_t_stages = rng.choice( + a=list(self.diag_time_dists.keys()), + p=stage_dist, + size=num, ) + drawn_diag_times = [ + self.diag_time_dists[t_stage].draw_diag_times(rng=rng) + for t_stage in drawn_t_stages + ] - drawn_obs = self._draw_patient_diagnoses(drawn_diag_times) + drawn_obs = self.draw_diagnoses(drawn_diag_times) # construct MultiIndex for dataset from stored modalities modality_names = list(self.modalities.keys()) - lnl_names = self.graph.lnls.keys() - multi_cols = pd.MultiIndex.from_product([modality_names, lnl_names]) + lnl_names = list(self.graph.lnls.keys()) + multi_cols = pd.MultiIndex.from_product([modality_names, ["ipsi"], lnl_names]) # create DataFrame dataset = pd.DataFrame(drawn_obs, columns=multi_cols) - dataset[('info', 't_stage')] = drawn_t_stages + dataset[("tumor", "1", "t_stage")] = drawn_t_stages return dataset