-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #5 from funkelab/multi_hypothesis
Multi hypothesis
- Loading branch information
Showing
18 changed files
with
1,027 additions
and
610 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -35,6 +35,7 @@ dev = [ | |
'pdoc', | ||
'pre-commit', | ||
'types-tqdm', | ||
'pytest-unordered' | ||
] | ||
|
||
[project.urls] | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,5 @@ | ||
from .compute_graph import get_candidate_graph | ||
from .graph_attributes import EdgeAttr, NodeAttr | ||
from .graph_from_segmentation import graph_from_segmentation | ||
from .graph_to_nx import graph_to_nx | ||
from .iou import add_iou | ||
from .utils import add_cand_edges, get_node_id, nodes_from_segmentation |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
import logging | ||
from typing import Any | ||
|
||
import networkx as nx | ||
import numpy as np | ||
|
||
from .conflict_sets import compute_conflict_sets | ||
from .iou import add_iou | ||
from .utils import add_cand_edges, nodes_from_segmentation | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
def get_candidate_graph( | ||
segmentation: np.ndarray, | ||
max_edge_distance: float, | ||
iou: bool = False, | ||
multihypo: bool = False, | ||
) -> tuple[nx.DiGraph, list[set[Any]] | None]: | ||
"""Construct a candidate graph from a segmentation array. Nodes are placed at the | ||
centroid of each segmentation and edges are added for all nodes in adjacent frames | ||
within max_edge_distance. If segmentation contains multiple hypotheses, will also | ||
return a list of conflicting node ids that cannot be selected together. | ||
Args: | ||
segmentation (np.ndarray): A numpy array with integer labels and dimensions | ||
(t, [h], [z], y, x), where h is the number of hypotheses. | ||
max_edge_distance (float): Maximum distance that objects can travel between | ||
frames. All nodes with centroids within this distance in adjacent frames | ||
will by connected with a candidate edge. | ||
iou (bool, optional): Whether to include IOU on the candidate graph. | ||
Defaults to False. | ||
multihypo (bool, optional): Whether the segmentation contains multiple | ||
hypotheses. Defaults to False. | ||
Returns: | ||
tuple[nx.DiGraph, list[set[Any]] | None]: A candidate graph that can be passed | ||
to the motile solver, and a list of conflicting node ids. | ||
""" | ||
# add nodes | ||
if multihypo: | ||
cand_graph = nx.DiGraph() | ||
num_frames = segmentation.shape[0] | ||
node_frame_dict: dict[int, list[Any]] = {t: [] for t in range(num_frames)} | ||
num_hypotheses = segmentation.shape[1] | ||
for hypo_id in range(num_hypotheses): | ||
hypothesis = segmentation[:, hypo_id] | ||
node_graph, frame_dict = nodes_from_segmentation( | ||
hypothesis, hypo_id=hypo_id | ||
) | ||
cand_graph.update(node_graph) | ||
for t in range(num_frames): | ||
if t in frame_dict: | ||
node_frame_dict[t].extend(frame_dict[t]) | ||
else: | ||
cand_graph, node_frame_dict = nodes_from_segmentation(segmentation) | ||
logger.info(f"Candidate nodes: {cand_graph.number_of_nodes()}") | ||
|
||
# add edges | ||
add_cand_edges( | ||
cand_graph, | ||
max_edge_distance=max_edge_distance, | ||
node_frame_dict=node_frame_dict, | ||
) | ||
if iou: | ||
add_iou(cand_graph, segmentation, node_frame_dict, multihypo=multihypo) | ||
|
||
logger.info(f"Candidate edges: {cand_graph.number_of_edges()}") | ||
|
||
# Compute conflict sets between segmentations | ||
if multihypo: | ||
conflicts = [] | ||
for time, segs in enumerate(segmentation): | ||
conflicts.extend(compute_conflict_sets(segs, time)) | ||
else: | ||
conflicts = None | ||
|
||
return cand_graph, conflicts |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
from itertools import combinations | ||
|
||
import numpy as np | ||
|
||
from .utils import ( | ||
get_node_id, | ||
) | ||
|
||
|
||
def compute_conflict_sets(segmentation_frame: np.ndarray, time: int) -> list[set]: | ||
"""Compute all sets of node ids that conflict with each other. | ||
Note: Results might include redundant sets, for example {a, b, c} and {a, b} | ||
might both appear in the results. | ||
Args: | ||
segmentation_frame (np.ndarray): One frame of the multiple hypothesis | ||
segmentation. Dimensions are (h, [z], y, x), where h is the number of | ||
hypotheses. | ||
time (int): Time frame, for computing node_ids. | ||
Returns: | ||
list[set]: list of sets of node ids that overlap. Might include some sets | ||
that are subsets of others. | ||
""" | ||
flattened_segs = [seg.flatten() for seg in segmentation_frame] | ||
|
||
# get locations where at least two hypotheses have labels | ||
# This approach may be inefficient, but likely doesn't matter compared to np.unique | ||
conflict_indices = np.zeros(flattened_segs[0].shape, dtype=bool) | ||
for seg1, seg2 in combinations(flattened_segs, 2): | ||
non_zero_indices = np.logical_and(seg1, seg2) | ||
conflict_indices = np.logical_or(conflict_indices, non_zero_indices) | ||
|
||
flattened_stacked = np.array([seg[conflict_indices] for seg in flattened_segs]) | ||
values = np.unique(flattened_stacked, axis=1) | ||
values = np.transpose(values) | ||
conflict_sets = [] | ||
for conflicting_labels in values: | ||
id_set = set() | ||
for hypo_id, label in enumerate(conflicting_labels): | ||
if label != 0: | ||
id_set.add(get_node_id(time, label, hypo_id)) | ||
conflict_sets.append(id_set) | ||
return conflict_sets |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.