diff --git a/lymph/models/__init__.py b/lymph/models/__init__.py index 84989cc..c76d968 100644 --- a/lymph/models/__init__.py +++ b/lymph/models/__init__.py @@ -5,4 +5,4 @@ from lymph.models.midline import Midline from lymph.models.unilateral import Unilateral -__all__ = ["Unilateral", "HpvWrapper" "Bilateral", "Midline"] +__all__ = ["Unilateral", "HpvUnilateral" "Bilateral", "Midline"] diff --git a/lymph/models/hpv.py b/lymph/models/hpv.py index 5e2490d..96de360 100644 --- a/lymph/models/hpv.py +++ b/lymph/models/hpv.py @@ -4,7 +4,6 @@ import logging import warnings -from collections.abc import Iterable from typing import Any, Literal import numpy as np @@ -16,7 +15,7 @@ logger = logging.getLogger(__name__) -class HpvWrapper( +class HpvUnilateral( diagnosis_times.Composite, modalities.Composite, types.Model, @@ -95,7 +94,7 @@ def _init_models( self.base_2_key = list(self.HPV.graph.tumors.keys())[0] + "toII" @classmethod - def binary(cls, *args, **kwargs) -> HpvWrapper: + def binary(cls, *args, **kwargs) -> HpvUnilateral: """Initialize a binary bilateral model. This is a convenience method that sets the ``allowed_states`` of the @@ -107,7 +106,7 @@ def binary(cls, *args, **kwargs) -> HpvWrapper: return cls(*args, uni_kwargs=uni_kwargs, **kwargs) @classmethod - def trinary(cls, *args, **kwargs) -> HpvWrapper: + def trinary(cls, *args, **kwargs) -> HpvUnilateral: """Initialize a trinary bilateral model. This is a convenience method that sets the ``allowed_states`` of the @@ -478,60 +477,3 @@ def risk( ) return self.marginalize(involvement, posterior_state_dist) - - def draw_patients( - self, - num: int, - stage_dist: Iterable[float], - rng: np.random.Generator | None = None, - seed: int = 42, - **_kwargs, - ) -> pd.DataFrame: - """Draw ``num`` random patients from the parametrized model. - - See Also - -------- - :py:meth:`.diagnosis_times.Distribution.draw_diag_times` - Method to draw diagnosis times from a distribution. - :py:meth:`.Unilateral.draw_diagnosis` - Method to draw individual diagnosis from a unilateral model. - :py:meth:`.Unilateral.draw_patients` - The unilateral method to draw a synthetic dataset. - - """ - if rng is None: - rng = np.random.default_rng(seed) - - if sum(stage_dist) != 1.0: - warnings.warn("Sum of stage distribution is not 1. Renormalizing.") - stage_dist = np.array(stage_dist) / sum(stage_dist) - - drawn_t_stages = rng.choice( - a=self.t_stages, - p=stage_dist, - size=num, - ) - drawn_diag_times = [ - self.get_distribution(t_stage).draw_diag_times(rng=rng) - for t_stage in drawn_t_stages - ] - - drawn_obs_ipsi = self.ipsi.draw_diagnosis(drawn_diag_times, rng=rng) - drawn_obs_noHPV = self.contra.draw_diagnosis(drawn_diag_times, rng=rng) - drawn_obs = np.concatenate([drawn_obs_ipsi, drawn_obs_noHPV], axis=1) - - # construct MultiIndex with "ipsi" and "contra" at top level to allow - # concatenation of the two separate drawn diagnosis - sides = ["ipsi", "contra"] - modality_names = list(self.get_all_modalities().keys()) - lnl_names = list(self.ipsi.graph.lnls.keys()) - multi_cols = pd.MultiIndex.from_product([sides, modality_names, lnl_names]) - - # reorder the column levels and thus also the individual columns to match the - # LyProX format without mixing up the data - dataset = pd.DataFrame(drawn_obs, columns=multi_cols) - dataset = dataset.reorder_levels(order=[1, 0, 2], axis="columns") - dataset = dataset.sort_index(axis="columns", level=0) - dataset[("tumor", "1", "t_stage")] = drawn_t_stages - - return dataset