Skip to content

Commit

Permalink
fix(bi): update bilateral data generation method
Browse files Browse the repository at this point in the history
Fixes: #65
  • Loading branch information
rmnldwg committed Dec 29, 2023
1 parent 770bc57 commit e80eba4
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 22 deletions.
57 changes: 39 additions & 18 deletions lymph/models/bilateral.py
Original file line number Diff line number Diff line change
Expand Up @@ -696,36 +696,57 @@ def risk(
)


def generate_dataset(
def draw_patients(
self,
num_patients: int,
stage_dist: dict[str, float],
num: int,
stage_dist: Iterable[float],
rng: np.random.Generator | None = None,
seed: int = 42,
**_kwargs,
) -> pd.DataFrame:
"""Generate/sample a pandas :class:`DataFrame` from the defined network.
"""Draw ``num`` random patients from the parametrized model.
Args:
num_patients: Number of patients to generate.
stage_dist: Probability to find a patient in a certain T-stage.
See Also:
:py:meth:`lymph.diagnose_times.Distribution.draw_diag_times`
Method to draw diagnose times from a distribution.
:py:meth:`lymph.models.Unilateral.draw_diagnoses`
Method to draw individual diagnoses from a unilateral model.
:py:meth:`lymph.models.Unilateral.draw_patients`
The unilateral method to draw a synthetic dataset.
"""
# TODO: check if this still works
drawn_t_stages, drawn_diag_times = self.diag_time_dists.draw(
dist=stage_dist, size=num_patients
if rng is None:
rng = np.random.default_rng(seed)

if sum(stage_dist) != 1.:
warnings.warn("Sum of stage distribution is not 1. Renormalizing.")
stage_dist = np.array(stage_dist) / sum(stage_dist)

drawn_t_stages = rng.choice(
a=list(self.diag_time_dists.keys()),
p=stage_dist,
size=num,
)
drawn_diag_times = [
self.diag_time_dists[t_stage].draw_diag_times(rng=rng)
for t_stage in drawn_t_stages
]

drawn_obs_ipsi = self.ipsi.draw_diagnoses(drawn_diag_times)
drawn_obs_contra = self.contra.draw_diagnoses(drawn_diag_times)
drawn_obs_ipsi = self.ipsi.draw_diagnoses(drawn_diag_times, rng=rng)
drawn_obs_contra = self.contra.draw_diagnoses(drawn_diag_times, rng=rng)
drawn_obs = np.concatenate([drawn_obs_ipsi, drawn_obs_contra], axis=1)

# construct MultiIndex for dataset from stored modalities
# construct MultiIndex with "ipsi" and "contra" at top level to allow
# concatenation of the two separate drawn diagnoses
sides = ["ipsi", "contra"]
modalities = list(self.modalities.keys())
lnl_names = [lnl.name for lnl in self.ipsi.graph._lnls]
multi_cols = pd.MultiIndex.from_product([sides, modalities, lnl_names])
modality_names = list(self.modalities.keys())
lnl_names = [lnl for lnl in self.ipsi.graph.lnls.keys()]
multi_cols = pd.MultiIndex.from_product([sides, modality_names, lnl_names])

# create DataFrame
# reorder the column levels and thus also the individual columns to match the
# LyProX format without mixing up the data
dataset = pd.DataFrame(drawn_obs, columns=multi_cols)
dataset = dataset.reorder_levels(order=[1, 0, 2], axis="columns")
dataset = dataset.sort_index(axis="columns", level=0)
dataset[('info', 'tumor', 't_stage')] = drawn_t_stages
dataset[('tumor', '1', 't_stage')] = drawn_t_stages

return dataset
14 changes: 10 additions & 4 deletions lymph/models/unilateral.py
Original file line number Diff line number Diff line change
Expand Up @@ -975,14 +975,22 @@ def draw_patients(
seed: int = 42,
**_kwargs,
) -> pd.DataFrame:
"""Draw a ``num`` random patients from the model.
"""Draw ``num`` random patients from the model.
For this, a ``stage_dist``, i.e., a distribution over the T-stages, needs to
be defined. This must be an iterable of probabilities with as many elements as
there are defined T-stages in the model's :py:attr:`diag_time_dists` attribute.
A random number generator can be provided as ``rng``. If ``None``, a new one
is initialized with the given ``seed`` (or ``42``, by default).
See Also:
:py:meth:`lymph.diagnose_times.Distribution.draw_diag_times`
Method to draw diagnose times from a distribution.
:py:meth:`lymph.models.Unilateral.draw_diagnoses`
Method to draw individual diagnoses.
:py:meth:`lymph.models.Bilateral.draw_patients`
The corresponding bilateral method.
"""
if rng is None:
rng = np.random.default_rng(seed)
Expand All @@ -1001,14 +1009,12 @@ def draw_patients(
for t_stage in drawn_t_stages
]

drawn_obs = self.draw_diagnoses(drawn_diag_times)
drawn_obs = self.draw_diagnoses(drawn_diag_times, rng=rng)

# construct MultiIndex for dataset from stored modalities
modality_names = list(self.modalities.keys())
lnl_names = list(self.graph.lnls.keys())
multi_cols = pd.MultiIndex.from_product([modality_names, ["ipsi"], lnl_names])

# create DataFrame
dataset = pd.DataFrame(drawn_obs, columns=multi_cols)
dataset[("tumor", "1", "t_stage")] = drawn_t_stages

Expand Down

0 comments on commit e80eba4

Please sign in to comment.