Skip to content

Commit

Permalink
updated HPV modules with more fitting naming
Browse files Browse the repository at this point in the history
  • Loading branch information
YoelPH committed Sep 25, 2024
1 parent df253dc commit 87a9b39
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 62 deletions.
2 changes: 1 addition & 1 deletion lymph/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@
from lymph.models.midline import Midline
from lymph.models.unilateral import Unilateral

__all__ = ["Unilateral", "HpvWrapper" "Bilateral", "Midline"]
__all__ = ["Unilateral", "HpvUnilateral" "Bilateral", "Midline"]
64 changes: 3 additions & 61 deletions lymph/models/hpv.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import logging
import warnings
from collections.abc import Iterable
from typing import Any, Literal

import numpy as np
Expand All @@ -16,7 +15,7 @@
logger = logging.getLogger(__name__)


class HpvWrapper(
class HpvUnilateral(
diagnosis_times.Composite,
modalities.Composite,
types.Model,
Expand Down Expand Up @@ -95,7 +94,7 @@ def _init_models(
self.base_2_key = list(self.HPV.graph.tumors.keys())[0] + "toII"

@classmethod
def binary(cls, *args, **kwargs) -> HpvWrapper:
def binary(cls, *args, **kwargs) -> HpvUnilateral:
"""Initialize a binary bilateral model.
This is a convenience method that sets the ``allowed_states`` of the
Expand All @@ -107,7 +106,7 @@ def binary(cls, *args, **kwargs) -> HpvWrapper:
return cls(*args, uni_kwargs=uni_kwargs, **kwargs)

@classmethod
def trinary(cls, *args, **kwargs) -> HpvWrapper:
def trinary(cls, *args, **kwargs) -> HpvUnilateral:
"""Initialize a trinary bilateral model.
This is a convenience method that sets the ``allowed_states`` of the
Expand Down Expand Up @@ -478,60 +477,3 @@ def risk(
)

return self.marginalize(involvement, posterior_state_dist)

def draw_patients(
self,
num: int,
stage_dist: Iterable[float],
rng: np.random.Generator | None = None,
seed: int = 42,
**_kwargs,
) -> pd.DataFrame:
"""Draw ``num`` random patients from the parametrized model.
See Also
--------
:py:meth:`.diagnosis_times.Distribution.draw_diag_times`
Method to draw diagnosis times from a distribution.
:py:meth:`.Unilateral.draw_diagnosis`
Method to draw individual diagnosis from a unilateral model.
:py:meth:`.Unilateral.draw_patients`
The unilateral method to draw a synthetic dataset.
"""
if rng is None:
rng = np.random.default_rng(seed)

if sum(stage_dist) != 1.0:
warnings.warn("Sum of stage distribution is not 1. Renormalizing.")
stage_dist = np.array(stage_dist) / sum(stage_dist)

drawn_t_stages = rng.choice(
a=self.t_stages,
p=stage_dist,
size=num,
)
drawn_diag_times = [
self.get_distribution(t_stage).draw_diag_times(rng=rng)
for t_stage in drawn_t_stages
]

drawn_obs_ipsi = self.ipsi.draw_diagnosis(drawn_diag_times, rng=rng)
drawn_obs_noHPV = self.contra.draw_diagnosis(drawn_diag_times, rng=rng)
drawn_obs = np.concatenate([drawn_obs_ipsi, drawn_obs_noHPV], axis=1)

# construct MultiIndex with "ipsi" and "contra" at top level to allow
# concatenation of the two separate drawn diagnosis
sides = ["ipsi", "contra"]
modality_names = list(self.get_all_modalities().keys())
lnl_names = list(self.ipsi.graph.lnls.keys())
multi_cols = pd.MultiIndex.from_product([sides, modality_names, lnl_names])

# reorder the column levels and thus also the individual columns to match the
# LyProX format without mixing up the data
dataset = pd.DataFrame(drawn_obs, columns=multi_cols)
dataset = dataset.reorder_levels(order=[1, 0, 2], axis="columns")
dataset = dataset.sort_index(axis="columns", level=0)
dataset[("tumor", "1", "t_stage")] = drawn_t_stages

return dataset

0 comments on commit 87a9b39

Please sign in to comment.