diff --git a/lymph/models/bilateral.py b/lymph/models/bilateral.py index e710f68..080f8f0 100644 --- a/lymph/models/bilateral.py +++ b/lymph/models/bilateral.py @@ -696,36 +696,57 @@ def risk( ) - def generate_dataset( + def draw_patients( self, - num_patients: int, - stage_dist: dict[str, float], + num: int, + stage_dist: Iterable[float], + rng: np.random.Generator | None = None, + seed: int = 42, + **_kwargs, ) -> pd.DataFrame: - """Generate/sample a pandas :class:`DataFrame` from the defined network. + """Draw ``num`` random patients from the parametrized model. - Args: - num_patients: Number of patients to generate. - stage_dist: Probability to find a patient in a certain T-stage. + See Also: + :py:meth:`lymph.diagnose_times.Distribution.draw_diag_times` + Method to draw diagnose times from a distribution. + :py:meth:`lymph.models.Unilateral.draw_diagnoses` + Method to draw individual diagnoses from a unilateral model. + :py:meth:`lymph.models.Unilateral.draw_patients` + The unilateral method to draw a synthetic dataset. """ - # TODO: check if this still works - drawn_t_stages, drawn_diag_times = self.diag_time_dists.draw( - dist=stage_dist, size=num_patients + if rng is None: + rng = np.random.default_rng(seed) + + if sum(stage_dist) != 1.: + warnings.warn("Sum of stage distribution is not 1. Renormalizing.") + stage_dist = np.array(stage_dist) / sum(stage_dist) + + drawn_t_stages = rng.choice( + a=list(self.diag_time_dists.keys()), + p=stage_dist, + size=num, ) + drawn_diag_times = [ + self.diag_time_dists[t_stage].draw_diag_times(rng=rng) + for t_stage in drawn_t_stages + ] - drawn_obs_ipsi = self.ipsi.draw_diagnoses(drawn_diag_times) - drawn_obs_contra = self.contra.draw_diagnoses(drawn_diag_times) + drawn_obs_ipsi = self.ipsi.draw_diagnoses(drawn_diag_times, rng=rng) + drawn_obs_contra = self.contra.draw_diagnoses(drawn_diag_times, rng=rng) drawn_obs = np.concatenate([drawn_obs_ipsi, drawn_obs_contra], axis=1) - # construct MultiIndex for dataset from stored modalities + # construct MultiIndex with "ipsi" and "contra" at top level to allow + # concatenation of the two separate drawn diagnoses sides = ["ipsi", "contra"] - modalities = list(self.modalities.keys()) - lnl_names = [lnl.name for lnl in self.ipsi.graph._lnls] - multi_cols = pd.MultiIndex.from_product([sides, modalities, lnl_names]) + modality_names = list(self.modalities.keys()) + lnl_names = [lnl for lnl in self.ipsi.graph.lnls.keys()] + multi_cols = pd.MultiIndex.from_product([sides, modality_names, lnl_names]) - # create DataFrame + # reorder the column levels and thus also the individual columns to match the + # LyProX format without mixing up the data dataset = pd.DataFrame(drawn_obs, columns=multi_cols) dataset = dataset.reorder_levels(order=[1, 0, 2], axis="columns") dataset = dataset.sort_index(axis="columns", level=0) - dataset[('info', 'tumor', 't_stage')] = drawn_t_stages + dataset[('tumor', '1', 't_stage')] = drawn_t_stages return dataset diff --git a/lymph/models/unilateral.py b/lymph/models/unilateral.py index 2323e5f..de21f9b 100644 --- a/lymph/models/unilateral.py +++ b/lymph/models/unilateral.py @@ -975,7 +975,7 @@ def draw_patients( seed: int = 42, **_kwargs, ) -> pd.DataFrame: - """Draw a ``num`` random patients from the model. + """Draw ``num`` random patients from the model. For this, a ``stage_dist``, i.e., a distribution over the T-stages, needs to be defined. This must be an iterable of probabilities with as many elements as @@ -983,6 +983,14 @@ def draw_patients( A random number generator can be provided as ``rng``. If ``None``, a new one is initialized with the given ``seed`` (or ``42``, by default). + + See Also: + :py:meth:`lymph.diagnose_times.Distribution.draw_diag_times` + Method to draw diagnose times from a distribution. + :py:meth:`lymph.models.Unilateral.draw_diagnoses` + Method to draw individual diagnoses. + :py:meth:`lymph.models.Bilateral.draw_patients` + The corresponding bilateral method. """ if rng is None: rng = np.random.default_rng(seed) @@ -1001,14 +1009,12 @@ def draw_patients( for t_stage in drawn_t_stages ] - drawn_obs = self.draw_diagnoses(drawn_diag_times) + drawn_obs = self.draw_diagnoses(drawn_diag_times, rng=rng) - # construct MultiIndex for dataset from stored modalities modality_names = list(self.modalities.keys()) lnl_names = list(self.graph.lnls.keys()) multi_cols = pd.MultiIndex.from_product([modality_names, ["ipsi"], lnl_names]) - # create DataFrame dataset = pd.DataFrame(drawn_obs, columns=multi_cols) dataset[("tumor", "1", "t_stage")] = drawn_t_stages