Skip to content

Commit

Permalink
Merge pull request #84 from rmnldwg/release-1.2.1
Browse files Browse the repository at this point in the history
Release 1.2.1
  • Loading branch information
rmnldwg authored May 28, 2024
2 parents b7f453a + dba519a commit 473b4e3
Show file tree
Hide file tree
Showing 7 changed files with 100 additions and 55 deletions.
47 changes: 46 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,50 @@

All notable changes to this project will be documented in this file.


<a name="1.2.1"></a>
## [1.2.1] - 2024-05-28

### Bug Fixes

- (**uni**) `load_patient_data` should accept `None`.
- (**mid**) Correct type hint of `marginalize`.
- (**graph**) Wrong dict when trinary.\
The `to_dict()` method returned a wrong graph dictionary when trinary
due to growth edges. This is fixed now.
- Skip `marginalize` only when safe.\
The marginalization should only be skipped (and 1 returned), when the
entire disease state of interest is `None`. In the midline case, this
disease state includes the midline extension.\
Previously, only the involvement pattern was checked. Now, the model is
more careful about when to take shortcuts.


### Features

- (**graph**) Modify mermaid graph.\
The `get_mermaid()` and `get_mermaid_url()` methods now accept arguments
that allow some modifications of the output.
- (**uni**) Add `__repr__()`.

### Refactor

- (**uni**) Use pandas `map` instead of `apply`.\
This saves us a couple of lines in the `load_patient_data` method and is
more readable.


### Merge

- Branch 'main' into 'dev'.

### Remove

- Remains of callbacks.\
Some callback functionality that was tested in a pre-release has been
forgotten in the code base and is now deleted.


<a name="1.2.0"></a>
## [1.2.0] - 2024-03-29

Expand Down Expand Up @@ -668,7 +712,8 @@ Almost the entire API has changed. I'd therefore recommend to have a look at the
- add pre-commit hook to check commit msg


[Unreleased]: https://github.com/rmnldwg/lymph/compare/1.2.0...HEAD
[Unreleased]: https://github.com/rmnldwg/lymph/compare/1.2.1...HEAD
[1.2.1]: https://github.com/rmnldwg/lymph/compare/1.1.0...1.2.1
[1.2.0]: https://github.com/rmnldwg/lymph/compare/1.1.0...1.2.0
[1.1.0]: https://github.com/rmnldwg/lymph/compare/1.0.0...1.1.0
[1.0.0]: https://github.com/rmnldwg/lymph/compare/1.0.0.rc2...1.0.0
Expand Down
50 changes: 30 additions & 20 deletions lymph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,17 @@
import base64
import warnings
from itertools import product
from typing import Literal

import numpy as np

from lymph import types
from lymph.utils import (
check_unique_names,
comp_transition_tensor,
flatten,
popfirst,
set_params_for,
trigger,
)


Expand Down Expand Up @@ -224,7 +225,6 @@ def __init__(
child: LymphNodeLevel,
spread_prob: float = 0.,
micro_mod: float = 1.,
callbacks: list[callable] | None = None,
) -> None:
"""Create a new edge between two nodes.
Expand All @@ -235,10 +235,6 @@ def __init__(
spread to the next LNL. The ``micro_mod`` parameter is a modifier for the spread
probability in case of only a microscopic node involvement.
"""
self.trigger_callbacks = []
if callbacks is not None:
self.trigger_callbacks += callbacks

self.parent: Tumor | LymphNodeLevel = parent
self.child: LymphNodeLevel = child

Expand Down Expand Up @@ -353,7 +349,6 @@ def get_micro_mod(self) -> float:
self._micro_mod = 1.
return self._micro_mod

@trigger
def set_micro_mod(self, new_micro_mod: float | None) -> None:
"""Set the spread modifier for LNLs with microscopic involvement."""
if new_micro_mod is None:
Expand All @@ -380,7 +375,6 @@ def get_spread_prob(self) -> float:
self._spread_prob = 0.
return self._spread_prob

@trigger
def set_spread_prob(self, new_spread_prob: float | None) -> None:
"""Set the spread probability of the edge."""
if new_spread_prob is None:
Expand Down Expand Up @@ -493,7 +487,6 @@ def __init__(
graph_dict: dict[tuple[str], list[str]],
tumor_state: int | None = None,
allowed_states: list[int] | None = None,
on_edge_change: list[callable] | None = None,
) -> None:
"""Create a new graph representation of nodes and edges.
Expand All @@ -512,7 +505,7 @@ def __init__(

check_unique_names(graph_dict)
self._init_nodes(graph_dict, tumor_state, allowed_states)
self._init_edges(graph_dict, on_edge_change)
self._init_edges(graph_dict)


def _init_nodes(self, graph, tumor_state, allowed_lnl_states):
Expand Down Expand Up @@ -585,7 +578,6 @@ def is_trinary(self) -> bool:
def _init_edges(
self,
graph: dict[tuple[str, str], list[str]],
on_edge_change: list[callable]
) -> None:
"""Initialize the edges of the ``graph``.
Expand All @@ -602,12 +594,12 @@ def _init_edges(
for (_, start_name), end_names in graph.items():
start = self.nodes[start_name]
if isinstance(start, LymphNodeLevel) and start.is_trinary:
growth_edge = Edge(parent=start, child=start, callbacks=on_edge_change)
growth_edge = Edge(parent=start, child=start)
self._edges[growth_edge.get_name()] = growth_edge

for end_name in end_names:
end = self.nodes[end_name]
new_edge = Edge(parent=start, child=end, callbacks=on_edge_change)
new_edge = Edge(parent=start, child=end)
self._edges[new_edge.get_name()] = new_edge


Expand Down Expand Up @@ -669,11 +661,19 @@ def to_dict(self) -> dict[tuple[str, str], set[str]]:
res = {}
for node in self.nodes.values():
node_type = "tumor" if isinstance(node, Tumor) else "lnl"
res[(node_type, node.name)] = [o.child.name for o in node.out]
res[(node_type, node.name)] = [
o.child.name
for o in node.out
if not o.is_growth
]
return res


def get_mermaid(self) -> str:
def get_mermaid(
self,
with_params: bool = True,
direction: Literal["TD", "LR"] = "TD",
) -> str:
"""Prints the graph in mermaid format.
>>> graph_dict = {
Expand All @@ -691,19 +691,29 @@ def get_mermaid(self) -> str:
T-->|20%| III
II-->|30%| III
<BLANKLINE>
>>> print(graph.get_mermaid(with_params=False)) # doctest: +NORMALIZE_WHITESPACE
flowchart TD
T--> II
T--> III
II--> III
<BLANKLINE>
"""
mermaid_graph = "flowchart TD\n"
mermaid_graph = f"flowchart {direction}\n"

for node in self.nodes.values():
for edge in node.out:
mermaid_graph += f"\t{node.name}-->|{edge.spread_prob:.0%}| {edge.child.name}\n"
param_str = f"|{edge.spread_prob:.0%}|" if with_params else ""
mermaid_graph += f"\t{node.name}-->{param_str} {edge.child.name}\n"

return mermaid_graph


def get_mermaid_url(self) -> str:
"""Returns the URL to the rendered graph."""
mermaid_graph = self.get_mermaid()
def get_mermaid_url(self, **mermaid_kwargs) -> str:
"""Returns the URL to the rendered graph.
Keyword arguments are passed to :py:meth:`~Representation.get_mermaid`.
"""
mermaid_graph = self.get_mermaid(**mermaid_kwargs)
graphbytes = mermaid_graph.encode("ascii")
base64_bytes = base64.b64encode(graphbytes)
base64_string = base64_bytes.decode("ascii")
Expand Down
2 changes: 1 addition & 1 deletion lymph/models/bilateral.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,7 +626,7 @@ def marginalize(
are ignored if ``given_state_dist`` is provided.
"""
if involvement is None:
return 1.
involvement = {}

if given_state_dist is None:
given_state_dist = self.state_dist(t_stage=t_stage, mode=mode)
Expand Down
6 changes: 3 additions & 3 deletions lymph/models/midline.py
Original file line number Diff line number Diff line change
Expand Up @@ -753,7 +753,7 @@ def posterior_state_dist(

def marginalize(
self,
involvement: types.PatternType | None = None,
involvement: dict[str, types.PatternType] | None = None,
given_state_dist: np.ndarray | None = None,
t_stage: str = "early",
mode: Literal["HMM", "BN"] = "HMM",
Expand All @@ -770,7 +770,7 @@ def marginalize(
:py:meth:`.state_dist` method.
"""
if involvement is None:
return 1.
involvement = {}

if given_state_dist is None:
given_state_dist = self.state_dist(t_stage=t_stage, mode=mode, central=central)
Expand All @@ -787,7 +787,7 @@ def marginalize(
given_state_dist = given_state_dist[int(midext)]
# I think I don't need to normalize here, since I am not computing a
# probability of something *given* midext, but only sum up all states that
# match the involvement pattern (which includes the midext status).
# match the disease state of interest (which includes the midext status).

return self.ext.marginalize(
involvement=involvement,
Expand Down
32 changes: 19 additions & 13 deletions lymph/models/unilateral.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import warnings
from itertools import product
from typing import Any, Iterable, Literal
from typing import Any, Callable, Iterable, Literal

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -117,6 +117,17 @@ def trinary(cls, graph_dict: types.GraphDictType, **kwargs) -> Unilateral:
return cls(graph_dict, allowed_states=[0, 1, 2], **kwargs)


def __repr__(self) -> str:
"""Return a string representation of the instance."""
return (
f"{type(self).__name__}("
f"graph_dict={self.graph.to_dict()}, "
f"tumor_state={list(self.graph.tumors.values())[0].state}, "
f"allowed_states={self.graph.allowed_states}, "
f"max_time={self.max_time})"
)


def __str__(self) -> str:
"""Print info about the instance."""
return f"Unilateral with {len(self.graph.tumors)} tumors and {len(self.graph.lnls)} LNLs"
Expand Down Expand Up @@ -489,7 +500,7 @@ def load_patient_data(
self,
patient_data: pd.DataFrame,
side: str = "ipsi",
mapping: callable | dict[int, Any] | None = None,
mapping: Callable[[int], Any] | dict[int, Any] | None = None,
) -> None:
"""Load patient data in `LyProX`_ format into the model.
Expand All @@ -512,7 +523,6 @@ def load_patient_data(
if mapping is None:
mapping = early_late_mapping

# pylint: disable=unnecessary-lambda-assignment
patient_data = (
patient_data
.copy()
Expand Down Expand Up @@ -545,15 +555,7 @@ def load_patient_data(

patient_data["_model", modality, lnl] = column

if len(patient_data) == 0:
patient_data[MAP_T_COL] = None
else:
mapping = dict_to_func(mapping) if isinstance(mapping, dict) else mapping
patient_data[MAP_T_COL] = patient_data.apply(
lambda row: mapping(row[RAW_T_COL]),
axis=1,
)

patient_data[MAP_T_COL] = patient_data[RAW_T_COL].map(mapping)
self._patient_data = patient_data
self._cache_version += 1

Expand Down Expand Up @@ -833,7 +835,11 @@ def marginalize(
:py:meth:`.state_dist` with the given ``t_stage`` and ``mode``. These arguments
are ignored if ``given_state_dist`` is provided.
"""
if involvement is None:
if (
involvement is None
or not involvement # empty dict is falsey
or all(value is None for value in involvement.values())
):
return 1.

if given_state_dist is None:
Expand Down
13 changes: 1 addition & 12 deletions tests/edge_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,7 @@ def setUp(self) -> None:
super().setUp()
parent = graph.LymphNodeLevel("parent")
child = graph.LymphNodeLevel("child")
self.was_called = False
self.edge = graph.Edge(parent, child, callbacks=[self.callback])

def callback(self) -> None:
"""Callback function for the edge."""
self.was_called = True
self.edge = graph.Edge(parent, child)

def test_str(self) -> None:
"""Test the string representation of the edge."""
Expand All @@ -41,17 +36,11 @@ def test_repr(self) -> None:
self.assertEqual(self.edge.spread_prob, recreated_edge.spread_prob)
self.assertEqual(self.edge.micro_mod, recreated_edge.micro_mod)

def test_callback_on_param_change(self) -> None:
"""Test if the callback function is called."""
self.edge.spread_prob = 0.5
self.assertTrue(self.was_called)

def test_graph_change(self) -> None:
"""Check if the callback also works when parent/child nodes are changed."""
old_child = self.edge.child
new_child = graph.LymphNodeLevel("new_child")
self.edge.child = new_child
self.assertTrue(self.was_called)
self.assertNotIn(self.edge, old_child.inc)

def test_transition_tensor_row_sums(self) -> None:
Expand Down
5 changes: 0 additions & 5 deletions tests/graph_representation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,10 @@ def setUp(self) -> None:
self.graph_repr = graph.Representation(
graph_dict=self.graph_dict,
allowed_states=[0, 1],
on_edge_change=[self.callback],
)
self.was_called = False
self.rng = np.random.default_rng(42)

def callback(self) -> None:
"""Callback function for the graph."""
self.was_called = True

def test_nodes(self) -> None:
"""Test the number of nodes."""
self.assertEqual(len(self.graph_repr.nodes), len(self.graph_dict))
Expand Down

0 comments on commit 473b4e3

Please sign in to comment.