Skip to content

Commit

Permalink
Merge pull request #5 from funkelab/multi_hypothesis
Browse files Browse the repository at this point in the history
Multi hypothesis
  • Loading branch information
cmalinmayor authored Apr 3, 2024
2 parents 968cb2f + ab75d9a commit be7b3e9
Show file tree
Hide file tree
Showing 18 changed files with 1,027 additions and 610 deletions.
5 changes: 5 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@ repos:
- id: check-yaml
- id: check-added-large-files

- repo: https://github.com/psf/black
rev: 23.1.0
hooks:
- id: black

- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: v0.2.2
hooks:
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ dev = [
'pdoc',
'pre-commit',
'types-tqdm',
'pytest-unordered'
]

[project.urls]
Expand Down
4 changes: 3 additions & 1 deletion src/motile_toolbox/candidate_graph/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from .compute_graph import get_candidate_graph
from .graph_attributes import EdgeAttr, NodeAttr
from .graph_from_segmentation import graph_from_segmentation
from .graph_to_nx import graph_to_nx
from .iou import add_iou
from .utils import add_cand_edges, get_node_id, nodes_from_segmentation
78 changes: 78 additions & 0 deletions src/motile_toolbox/candidate_graph/compute_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import logging
from typing import Any

import networkx as nx
import numpy as np

from .conflict_sets import compute_conflict_sets
from .iou import add_iou
from .utils import add_cand_edges, nodes_from_segmentation

logger = logging.getLogger(__name__)


def get_candidate_graph(
segmentation: np.ndarray,
max_edge_distance: float,
iou: bool = False,
multihypo: bool = False,
) -> tuple[nx.DiGraph, list[set[Any]] | None]:
"""Construct a candidate graph from a segmentation array. Nodes are placed at the
centroid of each segmentation and edges are added for all nodes in adjacent frames
within max_edge_distance. If segmentation contains multiple hypotheses, will also
return a list of conflicting node ids that cannot be selected together.
Args:
segmentation (np.ndarray): A numpy array with integer labels and dimensions
(t, [h], [z], y, x), where h is the number of hypotheses.
max_edge_distance (float): Maximum distance that objects can travel between
frames. All nodes with centroids within this distance in adjacent frames
will by connected with a candidate edge.
iou (bool, optional): Whether to include IOU on the candidate graph.
Defaults to False.
multihypo (bool, optional): Whether the segmentation contains multiple
hypotheses. Defaults to False.
Returns:
tuple[nx.DiGraph, list[set[Any]] | None]: A candidate graph that can be passed
to the motile solver, and a list of conflicting node ids.
"""
# add nodes
if multihypo:
cand_graph = nx.DiGraph()
num_frames = segmentation.shape[0]
node_frame_dict: dict[int, list[Any]] = {t: [] for t in range(num_frames)}
num_hypotheses = segmentation.shape[1]
for hypo_id in range(num_hypotheses):
hypothesis = segmentation[:, hypo_id]
node_graph, frame_dict = nodes_from_segmentation(
hypothesis, hypo_id=hypo_id
)
cand_graph.update(node_graph)
for t in range(num_frames):
if t in frame_dict:
node_frame_dict[t].extend(frame_dict[t])
else:
cand_graph, node_frame_dict = nodes_from_segmentation(segmentation)
logger.info(f"Candidate nodes: {cand_graph.number_of_nodes()}")

# add edges
add_cand_edges(
cand_graph,
max_edge_distance=max_edge_distance,
node_frame_dict=node_frame_dict,
)
if iou:
add_iou(cand_graph, segmentation, node_frame_dict, multihypo=multihypo)

logger.info(f"Candidate edges: {cand_graph.number_of_edges()}")

# Compute conflict sets between segmentations
if multihypo:
conflicts = []
for time, segs in enumerate(segmentation):
conflicts.extend(compute_conflict_sets(segs, time))
else:
conflicts = None

return cand_graph, conflicts
44 changes: 44 additions & 0 deletions src/motile_toolbox/candidate_graph/conflict_sets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from itertools import combinations

import numpy as np

from .utils import (
get_node_id,
)


def compute_conflict_sets(segmentation_frame: np.ndarray, time: int) -> list[set]:
"""Compute all sets of node ids that conflict with each other.
Note: Results might include redundant sets, for example {a, b, c} and {a, b}
might both appear in the results.
Args:
segmentation_frame (np.ndarray): One frame of the multiple hypothesis
segmentation. Dimensions are (h, [z], y, x), where h is the number of
hypotheses.
time (int): Time frame, for computing node_ids.
Returns:
list[set]: list of sets of node ids that overlap. Might include some sets
that are subsets of others.
"""
flattened_segs = [seg.flatten() for seg in segmentation_frame]

# get locations where at least two hypotheses have labels
# This approach may be inefficient, but likely doesn't matter compared to np.unique
conflict_indices = np.zeros(flattened_segs[0].shape, dtype=bool)
for seg1, seg2 in combinations(flattened_segs, 2):
non_zero_indices = np.logical_and(seg1, seg2)
conflict_indices = np.logical_or(conflict_indices, non_zero_indices)

flattened_stacked = np.array([seg[conflict_indices] for seg in flattened_segs])
values = np.unique(flattened_stacked, axis=1)
values = np.transpose(values)
conflict_sets = []
for conflicting_labels in values:
id_set = set()
for hypo_id, label in enumerate(conflicting_labels):
if label != 0:
id_set.add(get_node_id(time, label, hypo_id))
conflict_sets.append(id_set)
return conflict_sets
5 changes: 4 additions & 1 deletion src/motile_toolbox/candidate_graph/graph_attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@ class NodeAttr(Enum):
implementations of commonly used ones, listed here.
"""

SEG_ID = "segmentation_id"
POS = "pos"
TIME = "time"
SEG_ID = "seg_id"
SEG_HYPO = "seg_hypo"


class EdgeAttr(Enum):
Expand Down
Loading

0 comments on commit be7b3e9

Please sign in to comment.