Skip to content

Commit

Permalink
fix(uni): use cached matrices & simplify stuff
Browse files Browse the repository at this point in the history
The new way of globally caching the matrices required some refactoring
and simplifying.

Fixes: #68
  • Loading branch information
rmnldwg committed Jan 17, 2024
1 parent dc2b43a commit 70b16a8
Showing 1 changed file with 30 additions and 39 deletions.
69 changes: 30 additions & 39 deletions lymph/models/unilateral.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,16 +382,6 @@ def comp_diagnose_prob(
return prob


def _gen_obs_list(self):
"""Generates the list of possible observations."""
possible_obs_list = []
for modality in self.modalities.values():
possible_obs = np.arange(modality.confusion_matrix.shape[1])
for _ in self.graph.lnls:
possible_obs_list.append(possible_obs.copy())

self._obs_list = np.array(list(product(*possible_obs_list)))

@property
def obs_list(self):
"""Return the list of all possible observations.
Expand Down Expand Up @@ -423,17 +413,13 @@ def obs_list(self):
modality CT, the second two columns correspond to the same LNLs under the
pathology modality.
"""
try:
return self._obs_list
except AttributeError:
self._gen_obs_list()
return self._obs_list
possible_obs_list = []
for modality in self.modalities.values():
possible_obs = np.arange(modality.confusion_matrix.shape[1])
for _ in self.graph.lnls:
possible_obs_list.append(possible_obs.copy())

@obs_list.deleter
def obs_list(self):
"""Delete the observation list. Necessary to pass as callback."""
if hasattr(self, "_obs_list"):
del self._obs_list
return np.array(list(product(*possible_obs_list)))


@property
Expand Down Expand Up @@ -487,10 +473,7 @@ def modalities(self) -> modalities.ModalitiesUserDict:
:py:class:`~lymph.descriptors.modalities.ModalitiesUserDict`
:py:class:`~lymph.descriptors.modalities.Modality`
"""
return modalities.ModalitiesUserDict(
is_trinary=self.is_trinary,
trigger_callbacks=[self.delete_obs_list_and_matrix],
)
return modalities.ModalitiesUserDict(is_trinary=self.is_trinary)


@cached_property
Expand All @@ -507,16 +490,9 @@ def observation_matrix(self) -> np.ndarray:
:py:func:`~lymph.descriptors.matrix.generate_observation`
The function actually computing the observation matrix.
"""
return matrix.generate_observation(self)

def delete_obs_list_and_matrix(self):
"""Delete the observation matrix. Necessary to pass as callback."""
try:
del self.observation_matrix
except AttributeError:
pass

del self.obs_list
return matrix.cached_generate_observation(
self.modalities.confusion_matrices_hash(), self
)


@smart_updating_dict_cached_property
Expand Down Expand Up @@ -611,7 +587,7 @@ def load_patient_data(
if side not in patient_data[modality_name]:
raise ValueError(f"{side}lateral involvement data not found.")

for name in self.graph.lnls:
for name in self.graph.lnls.keys():
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=pd.errors.PerformanceWarning)
modality_side_data = patient_data[modality_name, side]
Expand All @@ -629,12 +605,27 @@ def load_patient_data(
if t_stage not in patient_data["_model", "#", "t_stage"].values:
warnings.warn(f"No data for T-stage {t_stage} found.")

self._patient_data = patient_data
# Changes to the patient data require a recomputation of the data and
# diagnose matrices. Clearing them will trigger this when they are next
# accessed.
# diagnose matrices. For the data matrix, it is enough to clear the respective
# ``UserDict``. For the diagnose matrices, we need to delete the hash value of
# the patient data, so that the next time it is requested, a cache miss occurs
# and they are recomputed.
self.data_matrices.clear()
self.diagnose_matrices.clear()
self._patient_data = patient_data
try:
del self.patient_data_hash
except AttributeError:
pass


@cached_property
def patient_data_hash(self) -> int:
"""Hash of the patient data.
This is used to check if the patient data has changed since the last time
the data and diagnose matrices were computed. If so, they are recomputed.
"""
return hash(self.patient_data.to_numpy().tobytes())


@property
Expand Down

0 comments on commit 70b16a8

Please sign in to comment.