Skip to content

Commit

Permalink
fix(uni): update draw_patients() method
Browse files Browse the repository at this point in the history
Update the method to generate synthetic data from the parametrized model
in the same LyProX format that is also required to load the data in the
first place.

Fixes: #65
  • Loading branch information
rmnldwg committed Dec 29, 2023
1 parent 2818fb5 commit 1e125ed
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 36 deletions.
18 changes: 15 additions & 3 deletions lymph/diagnose_times.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,9 +231,21 @@ def set_params(self, **kwargs) -> None:
warnings.warn("Distribution is not updateable, skipping...")


def draw(self) -> np.ndarray:
"""Draw sample of diagnose times from the PMF."""
return np.random.choice(a=self.support, p=self.distribution)
def draw_diag_times(
self,
num: int | None = None,
rng: np.random.Generator | None = None,
seed: int = 42,
) -> np.ndarray:
"""Draw ``num`` samples of diagnose times from the stored PMF.
A random number generator can be provided as ``rng``. If ``None``, a new one
is initialized with the given ``seed`` (or ``42``, by default).
"""
if rng is None:
rng = np.random.default_rng(seed)

return rng.choice(a=self.support, p=self.distribution, size=num)


class DistributionsUserDict(AbstractLookupDict):
Expand Down
4 changes: 2 additions & 2 deletions lymph/models/bilateral.py
Original file line number Diff line number Diff line change
Expand Up @@ -712,8 +712,8 @@ def generate_dataset(
dist=stage_dist, size=num_patients
)

drawn_obs_ipsi = self.ipsi._draw_patient_diagnoses(drawn_diag_times)
drawn_obs_contra = self.contra._draw_patient_diagnoses(drawn_diag_times)
drawn_obs_ipsi = self.ipsi.draw_diagnoses(drawn_diag_times)
drawn_obs_contra = self.contra.draw_diagnoses(drawn_diag_times)
drawn_obs = np.concatenate([drawn_obs_ipsi, drawn_obs_contra], axis=1)

# construct MultiIndex for dataset from stored modalities
Expand Down
75 changes: 44 additions & 31 deletions lymph/models/unilateral.py
Original file line number Diff line number Diff line change
Expand Up @@ -945,58 +945,71 @@ def risk(
return marginalize_over_states @ posterior_state_dist


def _draw_patient_diagnoses(
def draw_diagnoses(
self,
diag_times: list[int],
rng: np.random.Generator | None = None,
seed: int = 42,
) -> np.ndarray:
"""Draw random possible observations for a list of T-stages and
diagnose times.
"""Given some ``diag_times``, draw diagnoses for each LNL."""
if rng is None:
rng = np.random.default_rng(seed)

Args:
diag_times: List of diagnose times for each patient who's diagnose
is supposed to be drawn.
"""
# use the drawn diagnose times to compute probabilities over states and
# diagnoses
per_time_state_probs = self.comp_dist_evolution()
per_patient_state_probs = per_time_state_probs[diag_times]
per_patient_obs_probs = per_patient_state_probs @ self.observation_matrix

# then, draw a diagnose from the possible ones
obs_idx = np.arange(len(self.obs_list))
state_probs_given_time = self.comp_dist_evolution()[diag_times]
obs_probs_given_time = state_probs_given_time @ self.observation_matrix

obs_indices = np.arange(len(self.obs_list))
drawn_obs_idx = [
np.random.choice(obs_idx, p=obs_prob)
for obs_prob in per_patient_obs_probs
np.random.choice(obs_indices, p=obs_prob)
for obs_prob in obs_probs_given_time
]

return self.obs_list[drawn_obs_idx].astype(bool)


def generate_dataset(
def draw_patients(
self,
num_patients: int,
stage_dist: dict[str, float],
num: int,
stage_dist: Iterable[float],
rng: np.random.Generator | None = None,
seed: int = 42,
**_kwargs,
) -> pd.DataFrame:
"""Generate/sample a pandas :class:`DataFrame` from the defined network
using the samples and diagnostic modalities that have been set.
"""Draw a ``num`` random patients from the model.
For this, a ``stage_dist``, i.e., a distribution over the T-stages, needs to
be defined. This must be an iterable of probabilities with as many elements as
there are defined T-stages in the model's :py:attr:`diag_time_dists` attribute.
Args:
num_patients: Number of patients to generate.
stage_dist: Probability to find a patient in a certain T-stage.
A random number generator can be provided as ``rng``. If ``None``, a new one
is initialized with the given ``seed`` (or ``42``, by default).
"""
drawn_t_stages, drawn_diag_times = self.diag_time_dists.draw(
prob_of_t_stage=stage_dist, size=num_patients
if rng is None:
rng = np.random.default_rng(seed)

if sum(stage_dist) != 1.:
warnings.warn("Sum of stage distribution is not 1. Renormalizing.")
stage_dist = np.array(stage_dist) / sum(stage_dist)

drawn_t_stages = rng.choice(
a=list(self.diag_time_dists.keys()),
p=stage_dist,
size=num,
)
drawn_diag_times = [
self.diag_time_dists[t_stage].draw_diag_times(rng=rng)
for t_stage in drawn_t_stages
]

drawn_obs = self._draw_patient_diagnoses(drawn_diag_times)
drawn_obs = self.draw_diagnoses(drawn_diag_times)

# construct MultiIndex for dataset from stored modalities
modality_names = list(self.modalities.keys())
lnl_names = self.graph.lnls.keys()
multi_cols = pd.MultiIndex.from_product([modality_names, lnl_names])
lnl_names = list(self.graph.lnls.keys())
multi_cols = pd.MultiIndex.from_product([modality_names, ["ipsi"], lnl_names])

# create DataFrame
dataset = pd.DataFrame(drawn_obs, columns=multi_cols)
dataset[('info', 't_stage')] = drawn_t_stages
dataset[("tumor", "1", "t_stage")] = drawn_t_stages

return dataset

0 comments on commit 1e125ed

Please sign in to comment.