diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ce0cab7..e0f6442 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -14,6 +14,11 @@ repos: - id: check-yaml - id: check-added-large-files + - repo: https://github.com/psf/black + rev: 23.1.0 + hooks: + - id: black + - repo: https://github.com/charliermarsh/ruff-pre-commit rev: v0.2.2 hooks: diff --git a/pyproject.toml b/pyproject.toml index ac3310c..8706b10 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,6 +35,7 @@ dev = [ 'pdoc', 'pre-commit', 'types-tqdm', + 'pytest-unordered' ] [project.urls] diff --git a/src/motile_toolbox/candidate_graph/__init__.py b/src/motile_toolbox/candidate_graph/__init__.py index d06fd4e..8b67fe2 100644 --- a/src/motile_toolbox/candidate_graph/__init__.py +++ b/src/motile_toolbox/candidate_graph/__init__.py @@ -1,3 +1,5 @@ +from .compute_graph import get_candidate_graph from .graph_attributes import EdgeAttr, NodeAttr -from .graph_from_segmentation import graph_from_segmentation from .graph_to_nx import graph_to_nx +from .iou import add_iou +from .utils import add_cand_edges, get_node_id, nodes_from_segmentation diff --git a/src/motile_toolbox/candidate_graph/compute_graph.py b/src/motile_toolbox/candidate_graph/compute_graph.py new file mode 100644 index 0000000..7649c62 --- /dev/null +++ b/src/motile_toolbox/candidate_graph/compute_graph.py @@ -0,0 +1,78 @@ +import logging +from typing import Any + +import networkx as nx +import numpy as np + +from .conflict_sets import compute_conflict_sets +from .iou import add_iou +from .utils import add_cand_edges, nodes_from_segmentation + +logger = logging.getLogger(__name__) + + +def get_candidate_graph( + segmentation: np.ndarray, + max_edge_distance: float, + iou: bool = False, + multihypo: bool = False, +) -> tuple[nx.DiGraph, list[set[Any]] | None]: + """Construct a candidate graph from a segmentation array. Nodes are placed at the + centroid of each segmentation and edges are added for all nodes in adjacent frames + within max_edge_distance. If segmentation contains multiple hypotheses, will also + return a list of conflicting node ids that cannot be selected together. + + Args: + segmentation (np.ndarray): A numpy array with integer labels and dimensions + (t, [h], [z], y, x), where h is the number of hypotheses. + max_edge_distance (float): Maximum distance that objects can travel between + frames. All nodes with centroids within this distance in adjacent frames + will by connected with a candidate edge. + iou (bool, optional): Whether to include IOU on the candidate graph. + Defaults to False. + multihypo (bool, optional): Whether the segmentation contains multiple + hypotheses. Defaults to False. + + Returns: + tuple[nx.DiGraph, list[set[Any]] | None]: A candidate graph that can be passed + to the motile solver, and a list of conflicting node ids. + """ + # add nodes + if multihypo: + cand_graph = nx.DiGraph() + num_frames = segmentation.shape[0] + node_frame_dict: dict[int, list[Any]] = {t: [] for t in range(num_frames)} + num_hypotheses = segmentation.shape[1] + for hypo_id in range(num_hypotheses): + hypothesis = segmentation[:, hypo_id] + node_graph, frame_dict = nodes_from_segmentation( + hypothesis, hypo_id=hypo_id + ) + cand_graph.update(node_graph) + for t in range(num_frames): + if t in frame_dict: + node_frame_dict[t].extend(frame_dict[t]) + else: + cand_graph, node_frame_dict = nodes_from_segmentation(segmentation) + logger.info(f"Candidate nodes: {cand_graph.number_of_nodes()}") + + # add edges + add_cand_edges( + cand_graph, + max_edge_distance=max_edge_distance, + node_frame_dict=node_frame_dict, + ) + if iou: + add_iou(cand_graph, segmentation, node_frame_dict, multihypo=multihypo) + + logger.info(f"Candidate edges: {cand_graph.number_of_edges()}") + + # Compute conflict sets between segmentations + if multihypo: + conflicts = [] + for time, segs in enumerate(segmentation): + conflicts.extend(compute_conflict_sets(segs, time)) + else: + conflicts = None + + return cand_graph, conflicts diff --git a/src/motile_toolbox/candidate_graph/conflict_sets.py b/src/motile_toolbox/candidate_graph/conflict_sets.py new file mode 100644 index 0000000..4747c29 --- /dev/null +++ b/src/motile_toolbox/candidate_graph/conflict_sets.py @@ -0,0 +1,44 @@ +from itertools import combinations + +import numpy as np + +from .utils import ( + get_node_id, +) + + +def compute_conflict_sets(segmentation_frame: np.ndarray, time: int) -> list[set]: + """Compute all sets of node ids that conflict with each other. + Note: Results might include redundant sets, for example {a, b, c} and {a, b} + might both appear in the results. + + Args: + segmentation_frame (np.ndarray): One frame of the multiple hypothesis + segmentation. Dimensions are (h, [z], y, x), where h is the number of + hypotheses. + time (int): Time frame, for computing node_ids. + + Returns: + list[set]: list of sets of node ids that overlap. Might include some sets + that are subsets of others. + """ + flattened_segs = [seg.flatten() for seg in segmentation_frame] + + # get locations where at least two hypotheses have labels + # This approach may be inefficient, but likely doesn't matter compared to np.unique + conflict_indices = np.zeros(flattened_segs[0].shape, dtype=bool) + for seg1, seg2 in combinations(flattened_segs, 2): + non_zero_indices = np.logical_and(seg1, seg2) + conflict_indices = np.logical_or(conflict_indices, non_zero_indices) + + flattened_stacked = np.array([seg[conflict_indices] for seg in flattened_segs]) + values = np.unique(flattened_stacked, axis=1) + values = np.transpose(values) + conflict_sets = [] + for conflicting_labels in values: + id_set = set() + for hypo_id, label in enumerate(conflicting_labels): + if label != 0: + id_set.add(get_node_id(time, label, hypo_id)) + conflict_sets.append(id_set) + return conflict_sets diff --git a/src/motile_toolbox/candidate_graph/graph_attributes.py b/src/motile_toolbox/candidate_graph/graph_attributes.py index e6a9d49..478c2b3 100644 --- a/src/motile_toolbox/candidate_graph/graph_attributes.py +++ b/src/motile_toolbox/candidate_graph/graph_attributes.py @@ -7,7 +7,10 @@ class NodeAttr(Enum): implementations of commonly used ones, listed here. """ - SEG_ID = "segmentation_id" + POS = "pos" + TIME = "time" + SEG_ID = "seg_id" + SEG_HYPO = "seg_hypo" class EdgeAttr(Enum): diff --git a/src/motile_toolbox/candidate_graph/graph_from_segmentation.py b/src/motile_toolbox/candidate_graph/graph_from_segmentation.py deleted file mode 100644 index 638d003..0000000 --- a/src/motile_toolbox/candidate_graph/graph_from_segmentation.py +++ /dev/null @@ -1,277 +0,0 @@ -import logging -import math -from typing import Any - -import networkx as nx -import numpy as np -from skimage.measure import regionprops -from tqdm import tqdm - -from .graph_attributes import EdgeAttr, NodeAttr - -logger = logging.getLogger(__name__) - - -def _get_location( - node_data: dict[str, Any], position_keys: tuple[str, ...] | list[str] -) -> list[Any]: - """Convenience function to get the location of a networkx node when each dimension - is stored in a different attribute. - - Args: - node_data (dict[str, Any]): Dictionary of attributes of a networkx node. - Assumes the provided position keys are in the dictionary. - position_keys (tuple[str, ...] | list[str], optional): Keys to use to get - location information from node_data (assumes they are present in node_data). - Defaults to ("z", "y", "x"). - - Returns: - list: _description_ - Raises: - KeyError if position keys not in node_data - """ - return [node_data[k] for k in position_keys] - - -def nodes_from_segmentation( - segmentation: np.ndarray, - attributes: tuple[NodeAttr, ...] | list[NodeAttr] = (NodeAttr.SEG_ID,), - position_keys: tuple[str, ...] | list[str] = ("y", "x"), - frame_key: str = "t", -) -> tuple[nx.DiGraph, dict[int, list[Any]]]: - """Extract candidate nodes from a segmentation. Also computes specified attributes. - Returns a networkx graph with only nodes, and also a dictionary from frames to - node_ids for efficient edge adding. - - Args: - segmentation (np.ndarray): A 3 or 4 dimensional numpy array with integer labels - (0 is background, all pixels with value 1 belong to one cell, etc.). The - time dimension is first, followed by two or three position dimensions. If - the position dims are not (y, x), use `position_keys` to specify the names - of the dimensions. - attributes (tuple[str, ...] | list[str] , optional): Set of attributes to - compute and add to graph nodes. Valid attributes are: "segmentation_id". - Defaults to ("segmentation_id",). - position_keys (tuple[str, ...]| list[str] , optional): What to label the - position dimensions in the candidate graph. The order of the names - corresponds to the order of the dimensions in `segmentation`. Defaults to - ("y", "x"). - frame_key (str, optional): What to label the time dimension in the candidate - graph. Defaults to 't'. - - Returns: - tuple[nx.DiGraph, dict[int, list[Any]]]: A candidate graph with only nodes, - and a mapping from time frames to node ids. - """ - cand_graph = nx.DiGraph() - # also construct a dictionary from time frame to node_id for efficiency - node_frame_dict = {} - print("Extracting nodes from segmentaiton") - for t in tqdm(range(len(segmentation))): - nodes_in_frame = [] - props = regionprops(segmentation[t]) - for regionprop in props: - node_id = f"{t}_{regionprop.label}" - attrs = { - frame_key: t, - } - if NodeAttr.SEG_ID in attributes: - attrs[NodeAttr.SEG_ID.value] = regionprop.label - centroid = regionprop.centroid # [z,] y, x - for label, value in zip(position_keys, centroid): - attrs[label] = value - cand_graph.add_node(node_id, **attrs) - nodes_in_frame.append(node_id) - if nodes_in_frame: - node_frame_dict[t] = nodes_in_frame - return cand_graph, node_frame_dict - - -def _compute_node_frame_dict( - cand_graph: nx.DiGraph, frame_key: str = "t" -) -> dict[int, list[Any]]: - """Compute dictionary from time frames to node ids for candidate graph. - - Args: - cand_graph (nx.DiGraph): A networkx graph - frame_key (str, optional): Attribute key that holds the time frame of each - node in cand_graph. Defaults to "t". - - Returns: - dict[int, list[Any]]: A mapping from time frames to lists of node ids. - """ - node_frame_dict: dict[int, list[Any]] = {} - for node, data in cand_graph.nodes(data=True): - t = data[frame_key] - if t not in node_frame_dict: - node_frame_dict[t] = [] - node_frame_dict[t].append(node) - return node_frame_dict - - -def add_cand_edges( - cand_graph: nx.DiGraph, - max_edge_distance: float, - attributes: tuple[EdgeAttr, ...] | list[EdgeAttr] = (EdgeAttr.DISTANCE,), - position_keys: tuple[str, ...] | list[str] = ("y", "x"), - frame_key: str = "t", - node_frame_dict: None | dict[int, list[Any]] = None, - segmentation: None | np.ndarray = None, -) -> None: - """Add candidate edges to a candidate graph by connecting all nodes in adjacent - frames that are closer than max_edge_distance. Also adds attributes to the edges. - - Args: - cand_graph (nx.DiGraph): Candidate graph with only nodes populated. Will - be modified in-place to add edges. - max_edge_distance (float): Maximum distance that objects can travel between - frames. All nodes within this distance in adjacent frames will by connected - with a candidate edge. - attributes (tuple[EdgeAttr, ...], optional): Set of attributes to compute and - add to graph. Defaults to (EdgeAttr.DISTANCE,). - position_keys (tuple[str, ...], optional): What the position dimensions of nodes - in the candidate graph are labeled. Defaults to ("y", "x"). - frame_key (str, optional): The label of the time dimension in the candidate - graph. Defaults to "t". - node_frame_dict (dict[int, list[Any]] | None, optional): A mapping from frames - to node ids. If not provided, it will be computed from cand_graph. Defaults - to None. - segmentation (np.ndarray, optional): The segmentation array for optionally - computing attributes such as IOU. Defaults to None. - """ - print("Extracting candidate edges") - if not node_frame_dict: - node_frame_dict = _compute_node_frame_dict(cand_graph, frame_key=frame_key) - - frames = sorted(node_frame_dict.keys()) - for frame in tqdm(frames): - if frame + 1 not in node_frame_dict: - continue - next_nodes = node_frame_dict[frame + 1] - next_locs = [ - _get_location(cand_graph.nodes[n], position_keys=position_keys) - for n in next_nodes - ] - if EdgeAttr.IOU in attributes: - if segmentation is None: - raise ValueError("Can't compute IOU without segmentation.") - ious = compute_ious(segmentation[frame], segmentation[frame + 1]) - for node in node_frame_dict[frame]: - loc = _get_location(cand_graph.nodes[node], position_keys=position_keys) - for next_id, next_loc in zip(next_nodes, next_locs): - dist = math.dist(next_loc, loc) - if dist <= max_edge_distance: - attrs = {} - if EdgeAttr.DISTANCE in attributes: - attrs[EdgeAttr.DISTANCE.value] = dist - if EdgeAttr.IOU in attributes: - node_seg_id = cand_graph.nodes[node][NodeAttr.SEG_ID.value] - next_seg_id = cand_graph.nodes[next_id][NodeAttr.SEG_ID.value] - attrs[EdgeAttr.IOU.value] = ious.get(node_seg_id, {}).get( - next_seg_id, 0 - ) - cand_graph.add_edge(node, next_id, **attrs) - - -def compute_ious(frame1: np.ndarray, frame2: np.ndarray) -> dict[int, dict[int, float]]: - """Compute label IOUs between two label arrays of the same shape. Ignores background - (label 0). - - Args: - frame1 (np.ndarray): Array with integer labels - frame2 (np.ndarray): Array with integer labels - - Returns: - dict[int, dict[int, float]]: Dictionary from labels in frame 1 to labels in - frame 2 to iou values. Nodes that have no overlap are not included. - """ - frame1 = frame1.flatten() - frame2 = frame2.flatten() - # get indices where both are not zero (ignore background) - # this speeds up computation significantly - non_zero_indices = np.logical_and(frame1, frame2) - flattened_stacked = np.array([frame1[non_zero_indices], frame2[non_zero_indices]]) - - values, counts = np.unique(flattened_stacked, axis=1, return_counts=True) - frame1_values, frame1_counts = np.unique(frame1, return_counts=True) - frame1_label_sizes = dict(zip(frame1_values, frame1_counts)) - frame2_values, frame2_counts = np.unique(frame2, return_counts=True) - frame2_label_sizes = dict(zip(frame2_values, frame2_counts)) - iou_dict: dict[int, dict[int, float]] = {} - for index in range(values.shape[1]): - pair = values[:, index] - intersection = counts[index] - id1, id2 = pair - union = frame1_label_sizes[id1] + frame2_label_sizes[id2] - intersection - if id1 not in iou_dict: - iou_dict[id1] = {} - iou_dict[id1][id2] = intersection / union - return iou_dict - - -def graph_from_segmentation( - segmentation: np.ndarray, - max_edge_distance: float, - node_attributes: tuple[NodeAttr, ...] | list[NodeAttr] = (NodeAttr.SEG_ID,), - edge_attributes: tuple[EdgeAttr, ...] | list[EdgeAttr] = (EdgeAttr.DISTANCE,), - position_keys: tuple[str, ...] | list[str] = ("y", "x"), - frame_key: str = "t", -): - """Construct a candidate graph from a segmentation array. Nodes are placed at the - centroid of each segmentation and edges are added for all nodes in adjacent frames - within max_edge_distance. The specified attributes are computed during construction. - Node ids are strings with format "{time}_{label id}". - - Args: - segmentation (np.ndarray): A 3 or 4 dimensional numpy array with integer labels - (0 is background, all pixels with value 1 belong to one cell, etc.). The - time dimension is first, followed by two or three position dimensions. If - the position dims are not (y, x), use `position_keys` to specify the names - of the dimensions. - max_edge_distance (float): Maximum distance that objects can travel between - frames. All nodes within this distance in adjacent frames will by connected - with a candidate edge. - node_attributes (tuple[str, ...] | list[str], optional): Set of attributes to - compute and add to nodes in graph. Valid attributes are: "segmentation_id". - Defaults to ("segmentation_id",). - edge_attributes (tuple[str, ...] | list[str], optional): Set of attributes to - compute and add to edges in graph. Valid attributes are: "distance". - Defaults to ("distance",). - position_keys (tuple[str, ...], optional): What to label the position dimensions - in the candidate graph. The order of the names corresponds to the order of - the dimensions in `segmentation`. Defaults to ("y", "x"). - frame_key (str, optional): What to label the time dimension in the candidate - graph. Defaults to 't'. - - Returns: - nx.DiGraph: A candidate graph that can be passed to the motile solver. - - Raises: - ValueError: if unsupported attribute strings are passed in to the attributes - arguments, or if the number of position keys provided does not match the - number of position dimensions. - """ - if len(position_keys) != segmentation.ndim - 1: - raise ValueError( - f"Position labels {position_keys} does not match number of spatial dims " - f"({segmentation.ndim - 1})" - ) - # add nodes - - cand_graph, node_frame_dict = nodes_from_segmentation( - segmentation, node_attributes, position_keys=position_keys, frame_key=frame_key - ) - logger.info(f"Candidate nodes: {cand_graph.number_of_nodes()}") - - # add edges - add_cand_edges( - cand_graph, - max_edge_distance=max_edge_distance, - attributes=edge_attributes, - position_keys=position_keys, - node_frame_dict=node_frame_dict, - segmentation=segmentation, - ) - - logger.info(f"Candidate edges: {cand_graph.number_of_edges()}") - return cand_graph diff --git a/src/motile_toolbox/candidate_graph/iou.py b/src/motile_toolbox/candidate_graph/iou.py new file mode 100644 index 0000000..0f3d4f7 --- /dev/null +++ b/src/motile_toolbox/candidate_graph/iou.py @@ -0,0 +1,115 @@ +from itertools import product +from typing import Any + +import networkx as nx +import numpy as np +from tqdm import tqdm + +from .graph_attributes import EdgeAttr +from .utils import _compute_node_frame_dict, get_node_id + + +def _compute_ious( + frame1: np.ndarray, frame2: np.ndarray +) -> list[tuple[int, int, float]]: + """Compute label IOUs between two label arrays of the same shape. Ignores background + (label 0). + + Args: + frame1 (np.ndarray): Array with integer labels + frame2 (np.ndarray): Array with integer labels + + Returns: + list[tuple[int, int, float]]: List of tuples of label in frame 1, label in + frame 2, and iou values. Labels that have no overlap are not included. + """ + frame1 = frame1.flatten() + frame2 = frame2.flatten() + # get indices where both are not zero (ignore background) + # this speeds up computation significantly + non_zero_indices = np.logical_and(frame1, frame2) + flattened_stacked = np.array([frame1[non_zero_indices], frame2[non_zero_indices]]) + + values, counts = np.unique(flattened_stacked, axis=1, return_counts=True) + frame1_values, frame1_counts = np.unique(frame1, return_counts=True) + frame1_label_sizes = dict(zip(frame1_values, frame1_counts)) + frame2_values, frame2_counts = np.unique(frame2, return_counts=True) + frame2_label_sizes = dict(zip(frame2_values, frame2_counts)) + ious: list[tuple[int, int, float]] = [] + for index in range(values.shape[1]): + pair = values[:, index] + intersection = counts[index] + id1, id2 = pair + union = frame1_label_sizes[id1] + frame2_label_sizes[id2] - intersection + ious.append((id1, id2, intersection / union)) + return ious + + +def _get_iou_dict(segmentation, multihypo=False) -> dict[str, dict[str, float]]: + """Get all ious values for the provided segmentation (all frames). + Will return as map from node_id -> dict[node_id] -> iou for easy + navigation when adding to candidate graph. + + Args: + segmentation (np.ndarray): Segmentation that was used to create cand_graph. + Has shape (t, [h], [z], y, x), where h is the number of hypotheses. + multihypo (bool, optional): Whether or not the segmentation is multi hypothesis. + Defaults to False. + + Returns: + dict[str, dict[str, float]]: A map from node id to another dictionary, which + contains node_ids to iou values. + """ + iou_dict: dict[str, dict[str, float]] = {} + hypo_pairs: list[tuple[int | None, ...]] + if multihypo: + num_hypotheses = segmentation.shape[1] + hypo_pairs = list(product(range(num_hypotheses), repeat=2)) + else: + hypo_pairs = [(None, None)] + + for frame in range(len(segmentation) - 1): + for hypo1, hypo2 in hypo_pairs: + seg1 = segmentation[frame][hypo1] + seg2 = segmentation[frame + 1][hypo2] + ious = _compute_ious(seg1, seg2) + for label1, label2, iou in ious: + node_id1 = get_node_id(frame, label1, hypo1) + if node_id1 not in iou_dict: + iou_dict[node_id1] = {} + node_id2 = get_node_id(frame + 1, label2, hypo2) + iou_dict[node_id1][node_id2] = iou + return iou_dict + + +def add_iou( + cand_graph: nx.DiGraph, + segmentation: np.ndarray, + node_frame_dict: dict[int, list[Any]] | None = None, + multihypo: bool = False, +) -> None: + """Add IOU to the candidate graph. + + Args: + cand_graph (nx.DiGraph): Candidate graph with nodes and edges already populated + segmentation (np.ndarray): segmentation that was used to create cand_graph. + Has shape (t, [h], [z], y, x), where h is the number of hypotheses. + node_frame_dict(dict[int, list[Any]] | None, optional): A mapping from + time frames to nodes in that frame. Will be computed if not provided, + but can be provided for efficiency (e.g. after running + nodes_from_segmentation). Defaults to None. + multihypo (bool, optional): Whether the segmentation contains multiple + hypotheses. Defaults to False. + """ + if node_frame_dict is None: + node_frame_dict = _compute_node_frame_dict(cand_graph) + frames = sorted(node_frame_dict.keys()) + ious = _get_iou_dict(segmentation, multihypo=multihypo) + for frame in tqdm(frames): + if frame + 1 not in node_frame_dict: + continue + next_nodes = node_frame_dict[frame + 1] + for node_id in node_frame_dict[frame]: + for next_id in next_nodes: + iou = ious.get(node_id, {}).get(next_id, 0) + cand_graph.edges[(node_id, next_id)][EdgeAttr.IOU.value] = iou diff --git a/src/motile_toolbox/candidate_graph/utils.py b/src/motile_toolbox/candidate_graph/utils.py new file mode 100644 index 0000000..f75870b --- /dev/null +++ b/src/motile_toolbox/candidate_graph/utils.py @@ -0,0 +1,132 @@ +import logging +import math +from typing import Any + +import networkx as nx +import numpy as np +from skimage.measure import regionprops +from tqdm import tqdm + +from .graph_attributes import EdgeAttr, NodeAttr + +logger = logging.getLogger(__name__) + + +def get_node_id(time: int, label_id: int, hypothesis_id: int | None = None) -> str: + """Construct a node id given the time frame, segmentation label id, and + optionally the hypothesis id. This function is not designed for candidate graphs + that do not come from segmentations, but could be used if there is a similar + "detection id" that is unique for all cells detected in a given frame. + + Args: + time (int): The time frame the node is in + label_id (int): The label the node has in the segmentation. + hypothesis_id (int | None, optional): An integer representing which hypothesis + the segmentation came from, if applicable. Defaults to None. + + Returns: + str: A string to use as the node id in the candidate graph. Assuming that label + ids are not repeated in the same time frame and hypothesis, it is unique. + """ + if hypothesis_id is not None: + return f"{time}_{hypothesis_id}_{label_id}" + else: + return f"{time}_{label_id}" + + +def nodes_from_segmentation( + segmentation: np.ndarray, hypo_id: int | None = None +) -> tuple[nx.DiGraph, dict[int, list[Any]]]: + """Extract candidate nodes from a segmentation. Also computes specified attributes. + Returns a networkx graph with only nodes, and also a dictionary from frames to + node_ids for efficient edge adding. + + Args: + segmentation (np.ndarray): A 3 or 4 dimensional numpy array with integer labels + (0 is background, all pixels with value 1 belong to one cell, etc.). The + time dimension is first, followed by two or three position dimensions. + hypo_id (int | None, optional): An id to identify which layer of the multi- + hypothesis segmentation this is. Used to create node id, and is added + to each node if not None. Defaults to None. + + Returns: + tuple[nx.DiGraph, dict[int, list[Any]]]: A candidate graph with only nodes, + and a mapping from time frames to node ids. + """ + cand_graph = nx.DiGraph() + # also construct a dictionary from time frame to node_id for efficiency + node_frame_dict = {} + print("Extracting nodes from segmentation") + for t in tqdm(range(len(segmentation))): + nodes_in_frame = [] + props = regionprops(segmentation[t]) + for regionprop in props: + node_id = get_node_id(t, regionprop.label, hypothesis_id=hypo_id) + attrs = { + NodeAttr.TIME.value: t, + } + attrs[NodeAttr.SEG_ID.value] = regionprop.label + if hypo_id is not None: + attrs[NodeAttr.SEG_HYPO.value] = hypo_id + centroid = regionprop.centroid # [z,] y, x + attrs[NodeAttr.POS.value] = centroid + cand_graph.add_node(node_id, **attrs) + nodes_in_frame.append(node_id) + if nodes_in_frame: + node_frame_dict[t] = nodes_in_frame + return cand_graph, node_frame_dict + + +def _compute_node_frame_dict(cand_graph: nx.DiGraph) -> dict[int, list[Any]]: + """Compute dictionary from time frames to node ids for candidate graph. + + Args: + cand_graph (nx.DiGraph): A networkx graph + + Returns: + dict[int, list[Any]]: A mapping from time frames to lists of node ids. + """ + node_frame_dict: dict[int, list[Any]] = {} + for node, data in cand_graph.nodes(data=True): + t = data[NodeAttr.TIME.value] + if t not in node_frame_dict: + node_frame_dict[t] = [] + node_frame_dict[t].append(node) + return node_frame_dict + + +def add_cand_edges( + cand_graph: nx.DiGraph, + max_edge_distance: float, + node_frame_dict: None | dict[int, list[Any]] = None, +) -> None: + """Add candidate edges to a candidate graph by connecting all nodes in adjacent + frames that are closer than max_edge_distance. Also adds attributes to the edges. + + Args: + cand_graph (nx.DiGraph): Candidate graph with only nodes populated. Will + be modified in-place to add edges. + max_edge_distance (float): Maximum distance that objects can travel between + frames. All nodes within this distance in adjacent frames will by connected + with a candidate edge. + node_frame_dict (dict[int, list[Any]] | None, optional): A mapping from frames + to node ids. If not provided, it will be computed from cand_graph. Defaults + to None. + """ + print("Extracting candidate edges") + if not node_frame_dict: + node_frame_dict = _compute_node_frame_dict(cand_graph) + + frames = sorted(node_frame_dict.keys()) + for frame in tqdm(frames): + if frame + 1 not in node_frame_dict: + continue + next_nodes = node_frame_dict[frame + 1] + next_locs = [cand_graph.nodes[n][NodeAttr.POS.value] for n in next_nodes] + for node in node_frame_dict[frame]: + loc = cand_graph.nodes[node][NodeAttr.POS.value] + for next_id, next_loc in zip(next_nodes, next_locs): + dist = math.dist(next_loc, loc) + if dist <= max_edge_distance: + attrs = {EdgeAttr.DISTANCE.value: dist} + cand_graph.add_edge(node, next_id, **attrs) diff --git a/src/motile_toolbox/utils/saving_utils.py b/src/motile_toolbox/utils/saving_utils.py index 6bd98f2..868c755 100644 --- a/src/motile_toolbox/utils/saving_utils.py +++ b/src/motile_toolbox/utils/saving_utils.py @@ -7,7 +7,6 @@ def relabel_segmentation( solution_nx_graph: nx.DiGraph, segmentation: np.ndarray, - frame_key="t", ) -> np.ndarray: """Relabel a segmentation based on tracking results so that nodes in same track share the same id. IDs do change at division. @@ -33,7 +32,7 @@ def relabel_segmentation( soln_copy.remove_edges_from(out_edges) for node_set in nx.weakly_connected_components(soln_copy): for node in node_set: - time_frame = solution_nx_graph.nodes[node][frame_key] + time_frame = solution_nx_graph.nodes[node][NodeAttr.TIME.value] previous_seg_id = solution_nx_graph.nodes[node][NodeAttr.SEG_ID.value] previous_seg_mask = segmentation[time_frame] == previous_seg_id tracked_masks[time_frame][previous_seg_mask] = id_counter diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..a4643a2 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,348 @@ +import networkx as nx +import numpy as np +import pytest +from motile_toolbox.candidate_graph.graph_attributes import EdgeAttr, NodeAttr +from skimage.draw import disk + + +@pytest.fixture +def segmentation_2d(): + frame_shape = (100, 100) + total_shape = (2, *frame_shape) + segmentation = np.zeros(total_shape, dtype="int32") + # make frame with one cell in center with label 1 + rr, cc = disk(center=(50, 50), radius=20, shape=(100, 100)) + segmentation[0][rr, cc] = 1 + + # make frame with two cells + # first cell centered at (20, 80) with label 1 + # second cell centered at (60, 45) with label 2 + rr, cc = disk(center=(20, 80), radius=10, shape=frame_shape) + segmentation[1][rr, cc] = 1 + rr, cc = disk(center=(60, 45), radius=15, shape=frame_shape) + segmentation[1][rr, cc] = 2 + + return segmentation + + +@pytest.fixture +def multi_hypothesis_segmentation_2d(): + """ + Creates a multi-hypothesis version of the `segmentation_2d` fixture defined above. + + """ + frame_shape = (100, 100) + total_shape = (2, 2, *frame_shape) # 2 time points, 2 hypotheses layers, H, W + segmentation = np.zeros(total_shape, dtype="int32") + # make frame with one cell in center with label 1 (hypo 1) + rr0, cc0 = disk(center=(50, 50), radius=20, shape=frame_shape) + # make frame with one cell at (45, 45) with label 1 (hypo 2) + rr1, cc1 = disk(center=(45, 45), radius=15, shape=frame_shape) + + segmentation[0, 0][rr0, cc0] = 1 + segmentation[0, 1][rr1, cc1] = 1 + + # make frame with two cells + # first cell centered at (20, 80) with label 1 + rr0, cc0 = disk(center=(20, 80), radius=10, shape=frame_shape) + rr1, cc1 = disk(center=(15, 75), radius=15, shape=frame_shape) + + segmentation[1, 0][rr0, cc0] = 1 + segmentation[1, 1][rr1, cc1] = 1 + + # second cell centered at (60, 45) with label 2 + rr0, cc0 = disk(center=(60, 45), radius=15, shape=frame_shape) + rr1, cc1 = disk(center=(55, 40), radius=20, shape=frame_shape) + + segmentation[1, 0][rr0, cc0] = 2 + segmentation[1, 1][rr1, cc1] = 2 + + return segmentation + + +@pytest.fixture +def graph_2d(): + graph = nx.DiGraph() + nodes = [ + ( + "0_1", + { + NodeAttr.POS.value: (50, 50), + NodeAttr.TIME.value: 0, + NodeAttr.SEG_ID.value: 1, + }, + ), + ( + "1_1", + { + NodeAttr.POS.value: (20, 80), + NodeAttr.TIME.value: 1, + NodeAttr.SEG_ID.value: 1, + }, + ), + ( + "1_2", + { + NodeAttr.POS.value: (60, 45), + NodeAttr.TIME.value: 1, + NodeAttr.SEG_ID.value: 2, + }, + ), + ] + edges = [ + ("0_1", "1_1", {EdgeAttr.DISTANCE.value: 42.43, EdgeAttr.IOU.value: 0.0}), + ("0_1", "1_2", {EdgeAttr.DISTANCE.value: 11.18, EdgeAttr.IOU.value: 0.395}), + ] + graph.add_nodes_from(nodes) + graph.add_edges_from(edges) + return graph + + +@pytest.fixture +def multi_hypothesis_graph_2d(): + graph = nx.DiGraph() + nodes = [ + ( + "0_0_1", + { + NodeAttr.POS.value: (50, 50), + NodeAttr.TIME.value: 0, + NodeAttr.SEG_HYPO.value: 0, + NodeAttr.SEG_ID.value: 1, + }, + ), + ( + "0_1_1", + { + NodeAttr.POS.value: (45, 45), + NodeAttr.TIME.value: 0, + NodeAttr.SEG_HYPO.value: 1, + NodeAttr.SEG_ID.value: 1, + }, + ), + ( + "1_0_1", + { + NodeAttr.POS.value: (20, 80), + NodeAttr.TIME.value: 1, + NodeAttr.SEG_HYPO.value: 0, + NodeAttr.SEG_ID.value: 1, + }, + ), + ( + "1_1_1", + { + NodeAttr.POS.value: (15, 75), + NodeAttr.TIME.value: 1, + NodeAttr.SEG_HYPO.value: 1, + NodeAttr.SEG_ID.value: 1, + }, + ), + ( + "1_0_2", + { + NodeAttr.POS.value: (60, 45), + NodeAttr.TIME.value: 1, + NodeAttr.SEG_HYPO.value: 0, + NodeAttr.SEG_ID.value: 2, + }, + ), + ( + "1_1_2", + { + NodeAttr.POS.value: (55, 40), + NodeAttr.TIME.value: 1, + NodeAttr.SEG_HYPO.value: 1, + NodeAttr.SEG_ID.value: 2, + }, + ), + ] + + edges = [ + ("0_0_1", "1_0_1", {EdgeAttr.DISTANCE.value: 42.426, EdgeAttr.IOU.value: 0.0}), + ("0_0_1", "1_1_1", {EdgeAttr.DISTANCE.value: 43.011, EdgeAttr.IOU.value: 0.0}), + ( + "0_0_1", + "1_0_2", + {EdgeAttr.DISTANCE.value: 11.180, EdgeAttr.IOU.value: 0.3931}, + ), + ( + "0_0_1", + "1_1_2", + {EdgeAttr.DISTANCE.value: 11.180, EdgeAttr.IOU.value: 0.4768}, + ), + ("0_1_1", "1_0_1", {EdgeAttr.DISTANCE.value: 43.011, EdgeAttr.IOU.value: 0.0}), + ("0_1_1", "1_1_1", {EdgeAttr.DISTANCE.value: 42.426, EdgeAttr.IOU.value: 0.0}), + ("0_1_1", "1_0_2", {EdgeAttr.DISTANCE.value: 15.0, EdgeAttr.IOU.value: 0.2402}), + ( + "0_1_1", + "1_1_2", + {EdgeAttr.DISTANCE.value: 11.180, EdgeAttr.IOU.value: 0.3931}, + ), + ] + graph.add_nodes_from(nodes) + graph.add_edges_from(edges) + return graph + + +def sphere(center, radius, shape): + assert len(center) == len(shape) + indices = np.moveaxis(np.indices(shape), 0, -1) # last dim is the index + distance = np.linalg.norm(np.subtract(indices, np.asarray(center)), axis=-1) + mask = distance <= radius + return mask + + +@pytest.fixture +def segmentation_3d(): + frame_shape = (100, 100, 100) + total_shape = (2, *frame_shape) + segmentation = np.zeros(total_shape, dtype="int32") + # make frame with one cell in center with label 1 + mask = sphere(center=(50, 50, 50), radius=20, shape=frame_shape) + segmentation[0][mask] = 1 + + # make frame with two cells + # first cell centered at (20, 50, 80) with label 1 + # second cell centered at (60, 50, 45) with label 2 + mask = sphere(center=(20, 50, 80), radius=10, shape=frame_shape) + segmentation[1][mask] = 1 + mask = sphere(center=(60, 50, 45), radius=15, shape=frame_shape) + segmentation[1][mask] = 2 + + return segmentation + + +@pytest.fixture +def multi_hypothesis_segmentation_3d(): + """ + Creates a multi-hypothesis version of the `segmentation_3d` fixture defined above. + + """ + frame_shape = (100, 100, 100) + total_shape = (2, 2, *frame_shape) # 2 time points, 2 hypotheses + segmentation = np.zeros(total_shape, dtype="int32") + # make first frame with one cell in center with label 1 + mask = sphere(center=(50, 50, 50), radius=20, shape=frame_shape) + segmentation[0, 0][mask] = 1 + mask = sphere(center=(45, 50, 55), radius=20, shape=frame_shape) + segmentation[0, 1][mask] = 1 + + # make second frame, first hypothesis with two cells + # first cell centered at (20, 50, 80) with label 1 + # second cell centered at (60, 50, 45) with label 2 + mask = sphere(center=(20, 50, 80), radius=10, shape=frame_shape) + segmentation[1, 0][mask] = 1 + mask = sphere(center=(60, 50, 45), radius=15, shape=frame_shape) + segmentation[1, 0][mask] = 2 + + # make second frame, second hypothesis with one cell + # first cell centered at (15, 50, 70) with label 1 + # second cell centered at (55, 55, 45) with label 2 + mask = sphere(center=(15, 50, 70), radius=10, shape=frame_shape) + segmentation[1, 1][mask] = 1 + + return segmentation + + +@pytest.fixture +def graph_3d(): + graph = nx.DiGraph() + nodes = [ + ( + "0_1", + { + NodeAttr.POS.value: (50, 50, 50), + NodeAttr.TIME.value: 0, + NodeAttr.SEG_ID.value: 1, + }, + ), + ( + "1_1", + { + NodeAttr.POS.value: (20, 50, 80), + NodeAttr.TIME.value: 1, + NodeAttr.SEG_ID.value: 1, + }, + ), + ( + "1_2", + { + NodeAttr.POS.value: (60, 50, 45), + NodeAttr.TIME.value: 1, + NodeAttr.SEG_ID.value: 2, + }, + ), + ] + edges = [ + # math.dist([50, 50], [20, 80]) + ("0_1", "1_1", {EdgeAttr.DISTANCE.value: 42.43}), + # math.dist([50, 50], [60, 45]) + ("0_1", "1_2", {EdgeAttr.DISTANCE.value: 11.18}), + ] + graph.add_nodes_from(nodes) + graph.add_edges_from(edges) + return graph + + +@pytest.fixture +def multi_hypothesis_graph_3d(): + graph = nx.DiGraph() + nodes = [ + ( + "0_0_1", + { + NodeAttr.POS.value: (50, 50, 50), + NodeAttr.TIME.value: 0, + NodeAttr.SEG_HYPO.value: 0, + NodeAttr.SEG_ID.value: 1, + }, + ), + ( + "0_1_1", + { + NodeAttr.POS.value: (45, 50, 55), + NodeAttr.TIME.value: 1, + NodeAttr.SEG_HYPO.value: 1, + NodeAttr.SEG_ID.value: 1, + }, + ), + ( + "1_0_1", + { + NodeAttr.POS.value: (20, 50, 80), + NodeAttr.TIME.value: 1, + NodeAttr.SEG_HYPO.value: 0, + NodeAttr.SEG_ID.value: 1, + }, + ), + ( + "1_0_2", + { + NodeAttr.POS.value: (60, 50, 45), + NodeAttr.TIME.value: 1, + NodeAttr.SEG_HYPO.value: 0, + NodeAttr.SEG_ID.value: 2, + }, + ), + ( + "1_1_1", + { + NodeAttr.POS.value: (15, 50, 70), + NodeAttr.TIME.value: 1, + NodeAttr.SEG_HYPO.value: 1, + NodeAttr.SEG_ID.value: 1, + }, + ), + ] + edges = [ + ("0_0_1", "1_0_1", {EdgeAttr.DISTANCE.value: 42.4264}), + ("0_0_1", "1_0_2", {EdgeAttr.DISTANCE.value: 11.1803}), + ("0_1_1", "1_0_1", {EdgeAttr.DISTANCE.value: 35.3553}), + ("0_1_1", "1_0_2", {EdgeAttr.DISTANCE.value: 18.0277}), + ("0_0_1", "1_1_1", {EdgeAttr.DISTANCE.value: 40.3112}), + ("0_1_1", "1_1_1", {EdgeAttr.DISTANCE.value: 33.5410}), + ] + graph.add_nodes_from(nodes) + graph.add_edges_from(edges) + return graph diff --git a/tests/test_candidate_graph/test_compute_graph.py b/tests/test_candidate_graph/test_compute_graph.py new file mode 100644 index 0000000..2d4a8a1 --- /dev/null +++ b/tests/test_candidate_graph/test_compute_graph.py @@ -0,0 +1,97 @@ +from collections import Counter + +import pytest +from motile_toolbox.candidate_graph import EdgeAttr, get_candidate_graph + + +def test_graph_from_segmentation_2d(segmentation_2d, graph_2d): + # test with 2D segmentation + cand_graph, _ = get_candidate_graph( + segmentation=segmentation_2d, + max_edge_distance=100, + iou=True, + ) + assert Counter(list(cand_graph.nodes)) == Counter(list(graph_2d.nodes)) + assert Counter(list(cand_graph.edges)) == Counter(list(graph_2d.edges)) + for node in cand_graph.nodes: + assert Counter(cand_graph.nodes[node]) == Counter(graph_2d.nodes[node]) + for edge in cand_graph.edges: + print(cand_graph.edges[edge]) + assert ( + pytest.approx(cand_graph.edges[edge][EdgeAttr.DISTANCE.value], abs=0.01) + == graph_2d.edges[edge][EdgeAttr.DISTANCE.value] + ) + assert ( + pytest.approx(cand_graph.edges[edge][EdgeAttr.IOU.value], abs=0.01) + == graph_2d.edges[edge][EdgeAttr.IOU.value] + ) + + # lower edge distance + cand_graph, _ = get_candidate_graph( + segmentation=segmentation_2d, + max_edge_distance=15, + ) + assert Counter(list(cand_graph.nodes)) == Counter(["0_1", "1_1", "1_2"]) + assert Counter(list(cand_graph.edges)) == Counter([("0_1", "1_2")]) + assert cand_graph.edges[("0_1", "1_2")][EdgeAttr.DISTANCE.value] == pytest.approx( + 11.18, abs=0.01 + ) + + +def test_graph_from_segmentation_3d(segmentation_3d, graph_3d): + # test with 3D segmentation + cand_graph, _ = get_candidate_graph( + segmentation=segmentation_3d, + max_edge_distance=100, + ) + assert Counter(list(cand_graph.nodes)) == Counter(list(graph_3d.nodes)) + assert Counter(list(cand_graph.edges)) == Counter(list(graph_3d.edges)) + for node in cand_graph.nodes: + assert Counter(cand_graph.nodes[node]) == Counter(graph_3d.nodes[node]) + for edge in cand_graph.edges: + assert pytest.approx(cand_graph.edges[edge], abs=0.01) == graph_3d.edges[edge] + + +def test_graph_from_multi_segmentation_2d( + multi_hypothesis_segmentation_2d, multi_hypothesis_graph_2d +): + # test with 2D segmentation + cand_graph, conflict_set = get_candidate_graph( + segmentation=multi_hypothesis_segmentation_2d, + max_edge_distance=100, + iou=True, + multihypo=True, + ) + assert Counter(list(cand_graph.nodes)) == Counter( + list(multi_hypothesis_graph_2d.nodes) + ) + assert Counter(list(cand_graph.edges)) == Counter( + list(multi_hypothesis_graph_2d.edges) + ) + for node in cand_graph.nodes: + assert Counter(cand_graph.nodes[node]) == Counter( + multi_hypothesis_graph_2d.nodes[node] + ) + for edge in cand_graph.edges: + assert ( + pytest.approx(cand_graph.edges[edge][EdgeAttr.DISTANCE.value], abs=0.01) + == multi_hypothesis_graph_2d.edges[edge][EdgeAttr.DISTANCE.value] + ) + assert ( + pytest.approx(cand_graph.edges[edge][EdgeAttr.IOU.value], abs=0.01) + == multi_hypothesis_graph_2d.edges[edge][EdgeAttr.IOU.value] + ) + # TODO: Test conflict set + + # lower edge distance + cand_graph, _ = get_candidate_graph( + segmentation=multi_hypothesis_segmentation_2d, + max_edge_distance=14, + multihypo=True, + ) + assert Counter(list(cand_graph.nodes)) == Counter( + list(multi_hypothesis_graph_2d.nodes) + ) + assert Counter(list(cand_graph.edges)) == Counter( + [("0_0_1", "1_0_2"), ("0_0_1", "1_1_2"), ("0_1_1", "1_1_2")] + ) diff --git a/tests/test_candidate_graph/test_conflict_sets.py b/tests/test_candidate_graph/test_conflict_sets.py new file mode 100644 index 0000000..c07e4e8 --- /dev/null +++ b/tests/test_candidate_graph/test_conflict_sets.py @@ -0,0 +1,39 @@ +import numpy as np +from motile_toolbox.candidate_graph.conflict_sets import compute_conflict_sets +from pytest_unordered import unordered + + +def test_conflict_sets_2d(multi_hypothesis_segmentation_2d): + for t in range(multi_hypothesis_segmentation_2d.shape[0]): + conflict_set = compute_conflict_sets(multi_hypothesis_segmentation_2d[t], t) + if t == 0: + expected = [{"0_1_1", "0_0_1"}] + assert len(conflict_set) == 1 + assert conflict_set == unordered(expected) + elif t == 1: + expected = [{"1_0_2", "1_1_2"}, {"1_0_1", "1_1_1"}] + assert len(conflict_set) == 2 + assert conflict_set == unordered(expected) + + +def test_conflict_sets_2d_reshaped(multi_hypothesis_segmentation_2d): + """Reshape segmentation array just to provide a slightly difficult example.""" + + reshaped = np.asarray( + [ + multi_hypothesis_segmentation_2d[0, 0], # hypothesis 0 + multi_hypothesis_segmentation_2d[1, 0], # hypothesis 1 + multi_hypothesis_segmentation_2d[1, 1], + ] + ) # hypothesis 2 + conflict_set = compute_conflict_sets(reshaped, 0) + # note the expected ids are not really there since the + # reshaped array is artifically constructed + expected = [ + {"0_0_1", "0_1_2", "0_2_2"}, + {"0_1_1", "0_2_1"}, + {"0_0_1", "0_1_2"}, + {"0_1_2", "0_2_2"}, + {"0_0_1", "0_2_2"}, + ] + assert conflict_set == unordered(expected) diff --git a/tests/test_candidate_graph/test_graph_from_segmentation.py b/tests/test_candidate_graph/test_graph_from_segmentation.py deleted file mode 100644 index 2b5b7db..0000000 --- a/tests/test_candidate_graph/test_graph_from_segmentation.py +++ /dev/null @@ -1,271 +0,0 @@ -from collections import Counter - -import networkx as nx -import numpy as np -import pytest -from motile_toolbox.candidate_graph import EdgeAttr, NodeAttr -from motile_toolbox.candidate_graph.graph_from_segmentation import ( - add_cand_edges, - compute_ious, - graph_from_segmentation, - nodes_from_segmentation, -) -from skimage.draw import disk - - -@pytest.fixture -def segmentation_2d(): - frame_shape = (100, 100) - total_shape = (2, *frame_shape) - segmentation = np.zeros(total_shape, dtype="int32") - # make frame with one cell in center with label 1 - rr, cc = disk(center=(50, 50), radius=20, shape=(100, 100)) - segmentation[0][rr, cc] = 1 - - # make frame with two cells - # first cell centered at (20, 80) with label 1 - # second cell centered at (60, 45) with label 2 - rr, cc = disk(center=(20, 80), radius=10, shape=frame_shape) - segmentation[1][rr, cc] = 1 - rr, cc = disk(center=(60, 45), radius=15, shape=frame_shape) - segmentation[1][rr, cc] = 2 - - return segmentation - - -@pytest.fixture -def graph_2d(): - graph = nx.DiGraph() - nodes = [ - ("0_1", {"y": 50, "x": 50, "t": 0, "segmentation_id": 1}), - ("1_1", {"y": 20, "x": 80, "t": 1, "segmentation_id": 1}), - ("1_2", {"y": 60, "x": 45, "t": 1, "segmentation_id": 2}), - ] - edges = [ - ("0_1", "1_1", {"distance": 42.43, "iou": 0.0}), - ("0_1", "1_2", {"distance": 11.18, "iou": 0.395}), - ] - graph.add_nodes_from(nodes) - graph.add_edges_from(edges) - return graph - - -def sphere(center, radius, shape): - assert len(center) == len(shape) - indices = np.moveaxis(np.indices(shape), 0, -1) # last dim is the index - distance = np.linalg.norm(np.subtract(indices, np.asarray(center)), axis=-1) - mask = distance <= radius - return mask - - -@pytest.fixture -def segmentation_3d(): - frame_shape = (100, 100, 100) - total_shape = (2, *frame_shape) - segmentation = np.zeros(total_shape, dtype="int32") - # make frame with one cell in center with label 1 - mask = sphere(center=(50, 50, 50), radius=20, shape=frame_shape) - segmentation[0][mask] = 1 - - # make frame with two cells - # first cell centered at (20, 50, 80) with label 1 - # second cell centered at (60, 50, 45) with label 2 - mask = sphere(center=(20, 50, 80), radius=10, shape=frame_shape) - segmentation[1][mask] = 1 - mask = sphere(center=(60, 50, 45), radius=15, shape=frame_shape) - segmentation[1][mask] = 2 - - return segmentation - - -@pytest.fixture -def graph_3d(): - graph = nx.DiGraph() - nodes = [ - ("0_1", {"z": 50, "y": 50, "x": 50, "t": 0, "segmentation_id": 1}), - ("1_1", {"z": 20, "y": 50, "x": 80, "t": 1, "segmentation_id": 1}), - ("1_2", {"z": 60, "y": 50, "x": 45, "t": 1, "segmentation_id": 2}), - ] - edges = [ - # math.dist([50, 50], [20, 80]) - ("0_1", "1_1", {"distance": 42.43}), - # math.dist([50, 50], [60, 45]) - ("0_1", "1_2", {"distance": 11.18}), - ] - graph.add_nodes_from(nodes) - graph.add_edges_from(edges) - return graph - - -# nodes_from_segmentation -def test_nodes_from_segmentation_empty(): - # test with empty segmentation - empty_graph, node_frame_dict = nodes_from_segmentation( - np.zeros((3, 10, 10), dtype="int32") - ) - assert Counter(empty_graph.nodes) == Counter([]) - assert node_frame_dict == {} - - -def test_nodes_from_segmentation_2d(segmentation_2d): - # test with 2D segmentation - node_graph, node_frame_dict = nodes_from_segmentation( - segmentation=segmentation_2d, - ) - assert Counter(list(node_graph.nodes)) == Counter(["0_1", "1_1", "1_2"]) - assert node_graph.nodes["1_1"]["segmentation_id"] == 1 - assert node_graph.nodes["1_1"]["t"] == 1 - assert node_graph.nodes["1_1"]["y"] == 20 - assert node_graph.nodes["1_1"]["x"] == 80 - - assert node_frame_dict[0] == ["0_1"] - assert Counter(node_frame_dict[1]) == Counter(["1_1", "1_2"]) - - # remove attrs - node_graph, _ = nodes_from_segmentation( - segmentation=segmentation_2d, - attributes=[], - ) - assert Counter(list(node_graph.nodes)) == Counter(["0_1", "1_1", "1_2"]) - assert "segmentation_id" not in node_graph.nodes["0_1"] - - -def test_nodes_from_segmentation_3d(segmentation_3d): - # test with 3D segmentation - node_graph, node_frame_dict = nodes_from_segmentation( - segmentation=segmentation_3d, - attributes=[NodeAttr.SEG_ID], - position_keys=("pos_z", "pos_y", "pos_x"), - ) - assert Counter(list(node_graph.nodes)) == Counter(["0_1", "1_1", "1_2"]) - assert node_graph.nodes["1_1"]["segmentation_id"] == 1 - assert node_graph.nodes["1_1"]["t"] == 1 - assert node_graph.nodes["1_1"]["pos_z"] == 20 - assert node_graph.nodes["1_1"]["pos_y"] == 50 - assert node_graph.nodes["1_1"]["pos_x"] == 80 - - assert node_frame_dict[0] == ["0_1"] - assert Counter(node_frame_dict[1]) == Counter(["1_1", "1_2"]) - - -# add_cand_edges -def test_add_cand_edges_2d(graph_2d): - cand_graph = nx.create_empty_copy(graph_2d) - add_cand_edges(cand_graph, max_edge_distance=50) - assert Counter(list(cand_graph.edges)) == Counter(list(graph_2d.edges)) - for edge in cand_graph.edges: - assert ( - pytest.approx(cand_graph.edges[edge][EdgeAttr.DISTANCE.value], abs=0.01) - == graph_2d.edges[edge][EdgeAttr.DISTANCE.value] - ) - - -def test_add_cand_edges_3d(graph_3d): - cand_graph = nx.create_empty_copy(graph_3d) - add_cand_edges(cand_graph, max_edge_distance=15, position_keys=("z", "y", "x")) - graph_3d.remove_edge("0_1", "1_1") - assert Counter(list(cand_graph.edges)) == Counter(list(graph_3d.edges)) - for edge in cand_graph.edges: - assert pytest.approx(cand_graph.edges[edge], abs=0.01) == graph_3d.edges[edge] - - -# graph_from_segmentation -def test_graph_from_segmentation_invalid(): - # test invalid attributes - with pytest.raises(ValueError): - graph_from_segmentation( - np.zeros((3, 10, 10, 10), dtype="int32"), - 10, - edge_attributes=["invalid"], - ) - with pytest.raises(ValueError): - graph_from_segmentation( - np.zeros((3, 10, 10, 10), dtype="int32"), - 10, - node_attributes=["invalid"], - ) - - with pytest.raises(ValueError): - graph_from_segmentation( - np.zeros((3, 10, 10), dtype="int32"), 100, position_keys=["z", "y", "x"] - ) - - -def test_graph_from_segmentation_2d(segmentation_2d, graph_2d): - # test with 2D segmentation - cand_graph = graph_from_segmentation( - segmentation=segmentation_2d, - max_edge_distance=100, - edge_attributes=[EdgeAttr.DISTANCE, EdgeAttr.IOU], - ) - assert Counter(list(cand_graph.nodes)) == Counter(list(graph_2d.nodes)) - assert Counter(list(cand_graph.edges)) == Counter(list(graph_2d.edges)) - for node in cand_graph.nodes: - assert Counter(cand_graph.nodes[node]) == Counter(graph_2d.nodes[node]) - for edge in cand_graph.edges: - assert ( - pytest.approx(cand_graph.edges[edge][EdgeAttr.DISTANCE.value], abs=0.01) - == graph_2d.edges[edge][EdgeAttr.DISTANCE.value] - ) - assert ( - pytest.approx(cand_graph.edges[edge][EdgeAttr.IOU.value], abs=0.01) - == graph_2d.edges[edge][EdgeAttr.IOU.value] - ) - - # lower edge distance - cand_graph = graph_from_segmentation( - segmentation=segmentation_2d, - max_edge_distance=15, - ) - assert Counter(list(cand_graph.nodes)) == Counter(["0_1", "1_1", "1_2"]) - assert Counter(list(cand_graph.edges)) == Counter([("0_1", "1_2")]) - assert cand_graph.edges[("0_1", "1_2")]["distance"] == pytest.approx( - 11.18, abs=0.01 - ) - - -def test_graph_from_segmentation_3d(segmentation_3d, graph_3d): - # test with 3D segmentation - cand_graph = graph_from_segmentation( - segmentation=segmentation_3d, - max_edge_distance=100, - position_keys=("z", "y", "x"), - ) - assert Counter(list(cand_graph.nodes)) == Counter(list(graph_3d.nodes)) - assert Counter(list(cand_graph.edges)) == Counter(list(graph_3d.edges)) - for node in cand_graph.nodes: - assert Counter(cand_graph.nodes[node]) == Counter(graph_3d.nodes[node]) - for edge in cand_graph.edges: - assert pytest.approx(cand_graph.edges[edge], abs=0.01) == graph_3d.edges[edge] - - -def test_compute_ious_2d(segmentation_2d): - ious = compute_ious(segmentation_2d[0], segmentation_2d[1]) - expected = {1: {2: 555.46 / 1408.0}} - assert ious.keys() == expected.keys() - assert ious[1].keys() == expected[1].keys() - assert ious[1][2] == pytest.approx(expected[1][2], abs=0.1) - - ious = compute_ious(segmentation_2d[1], segmentation_2d[1]) - expected = {1: {1: 1.0}, 2: {2: 1.0}} - assert ious.keys() == expected.keys() - assert ious[1].keys() == expected[1].keys() - assert ious[1][1] == pytest.approx(expected[1][1], abs=0.1) - assert ious[2].keys() == expected[2].keys() - assert ious[2][2] == pytest.approx(expected[2][2], abs=0.1) - - -def test_compute_ious_3d(segmentation_3d): - ious = compute_ious(segmentation_3d[0], segmentation_3d[1]) - expected = {1: {2: 0.30}} - assert ious.keys() == expected.keys() - assert ious[1].keys() == expected[1].keys() - assert ious[1][2] == pytest.approx(expected[1][2], abs=0.1) - - ious = compute_ious(segmentation_3d[1], segmentation_3d[1]) - expected = {1: {1: 1.0}, 2: {2: 1.0}} - assert ious.keys() == expected.keys() - assert ious[1].keys() == expected[1].keys() - assert ious[1][1] == pytest.approx(expected[1][1], abs=0.1) - assert ious[2].keys() == expected[2].keys() - assert ious[2][2] == pytest.approx(expected[2][2], abs=0.1) diff --git a/tests/test_candidate_graph/test_graph_to_nx.py b/tests/test_candidate_graph/test_graph_to_nx.py index b52d864..3d08e89 100644 --- a/tests/test_candidate_graph/test_graph_to_nx.py +++ b/tests/test_candidate_graph/test_graph_to_nx.py @@ -1,30 +1,10 @@ import networkx as nx -import pytest from motile import TrackGraph from motile_toolbox.candidate_graph import graph_to_nx from networkx.utils import graphs_equal -@pytest.fixture -def graph_3d(): - graph = nx.DiGraph() - nodes = [ - ("0_1", {"z": 50, "y": 50, "x": 50, "t": 0, "segmentation_id": 1}), - ("1_1", {"z": 20, "y": 50, "x": 80, "t": 1, "segmentation_id": 1}), - ("1_2", {"z": 60, "y": 50, "x": 45, "t": 1, "segmentation_id": 2}), - ] - edges = [ - # math.dist([50, 50], [20, 80]) - ("0_1", "1_1", {"distance": 42.43}), - # math.dist([50, 50], [60, 45]) - ("0_1", "1_2", {"distance": 11.18}), - ] - graph.add_nodes_from(nodes) - graph.add_edges_from(edges) - return graph - - def test_graph_to_nx(graph_3d: nx.DiGraph): - track_graph = TrackGraph(nx_graph=graph_3d, frame_attribute="t") + track_graph = TrackGraph(nx_graph=graph_3d, frame_attribute="time") nx_graph = graph_to_nx(track_graph) assert graphs_equal(graph_3d, nx_graph) diff --git a/tests/test_candidate_graph/test_iou.py b/tests/test_candidate_graph/test_iou.py new file mode 100644 index 0000000..d5d88d4 --- /dev/null +++ b/tests/test_candidate_graph/test_iou.py @@ -0,0 +1,55 @@ +import networkx as nx +import pytest +from motile_toolbox.candidate_graph import EdgeAttr, add_iou +from motile_toolbox.candidate_graph.iou import _compute_ious + + +def test_compute_ious_2d(segmentation_2d): + ious = _compute_ious(segmentation_2d[0], segmentation_2d[1]) + expected = [ + (1, 2, 555.46 / 1408.0), + ] + for iou, expected_iou in zip(ious, expected): + assert iou == pytest.approx(expected_iou, abs=0.01) + + ious = _compute_ious(segmentation_2d[1], segmentation_2d[1]) + expected = [(1, 1, 1.0), (2, 2, 1.0)] + for iou, expected_iou in zip(ious, expected): + assert iou == pytest.approx(expected_iou, abs=0.01) + + +def test_compute_ious_3d(segmentation_3d): + ious = _compute_ious(segmentation_3d[0], segmentation_3d[1]) + expected = [(1, 2, 0.30)] + for iou, expected_iou in zip(ious, expected): + assert iou == pytest.approx(expected_iou, abs=0.01) + + ious = _compute_ious(segmentation_3d[1], segmentation_3d[1]) + expected = [(1, 1, 1.0), (2, 2, 1.0)] + for iou, expected_iou in zip(ious, expected): + assert iou == pytest.approx(expected_iou, abs=0.01) + + +def test_add_iou_2d(segmentation_2d, graph_2d): + expected = graph_2d + input_graph = graph_2d.copy() + nx.set_edge_attributes(input_graph, -1, name=EdgeAttr.IOU.value) + add_iou(input_graph, segmentation_2d) + for s, t, attrs in expected.edges(data=True): + assert ( + pytest.approx(attrs[EdgeAttr.IOU.value], abs=0.01) + == input_graph.edges[(s, t)][EdgeAttr.IOU.value] + ) + + +def test_multi_hypo_iou_2d(multi_hypothesis_segmentation_2d, multi_hypothesis_graph_2d): + expected = multi_hypothesis_graph_2d + input_graph = multi_hypothesis_graph_2d.copy() + nx.set_edge_attributes(input_graph, -1, name=EdgeAttr.IOU.value) + add_iou(input_graph, multi_hypothesis_segmentation_2d, multihypo=True) + for s, t, attrs in expected.edges(data=True): + print(s, t) + assert ( + pytest.approx(attrs[EdgeAttr.IOU.value], abs=0.01) + == input_graph.edges[(s, t)][EdgeAttr.IOU.value] + ) diff --git a/tests/test_candidate_graph/test_utils.py b/tests/test_candidate_graph/test_utils.py new file mode 100644 index 0000000..46f1832 --- /dev/null +++ b/tests/test_candidate_graph/test_utils.py @@ -0,0 +1,103 @@ +from collections import Counter + +import networkx as nx +import numpy as np +import pytest +from motile_toolbox.candidate_graph import ( + EdgeAttr, + NodeAttr, + add_cand_edges, + get_node_id, + nodes_from_segmentation, +) +from motile_toolbox.candidate_graph.utils import _compute_node_frame_dict + + +# nodes_from_segmentation +def test_nodes_from_segmentation_empty(): + # test with empty segmentation + empty_graph, node_frame_dict = nodes_from_segmentation( + np.zeros((3, 10, 10), dtype="int32") + ) + assert Counter(empty_graph.nodes) == Counter([]) + assert node_frame_dict == {} + + +def test_nodes_from_segmentation_2d(segmentation_2d): + # test with 2D segmentation + node_graph, node_frame_dict = nodes_from_segmentation( + segmentation=segmentation_2d, + ) + assert Counter(list(node_graph.nodes)) == Counter(["0_1", "1_1", "1_2"]) + assert node_graph.nodes["1_1"][NodeAttr.SEG_ID.value] == 1 + assert node_graph.nodes["1_1"][NodeAttr.TIME.value] == 1 + assert node_graph.nodes["1_1"][NodeAttr.POS.value] == (20, 80) + + assert node_frame_dict[0] == ["0_1"] + assert Counter(node_frame_dict[1]) == Counter(["1_1", "1_2"]) + + +def test_nodes_from_segmentation_2d_hypo(segmentation_2d): + # test with 2D segmentation + node_graph, node_frame_dict = nodes_from_segmentation( + segmentation=segmentation_2d, hypo_id=0 + ) + assert Counter(list(node_graph.nodes)) == Counter(["0_0_1", "1_0_1", "1_0_2"]) + assert node_graph.nodes["1_0_1"][NodeAttr.SEG_ID.value] == 1 + assert node_graph.nodes["1_0_1"][NodeAttr.SEG_HYPO.value] == 0 + assert node_graph.nodes["1_0_1"][NodeAttr.TIME.value] == 1 + assert node_graph.nodes["1_0_1"][NodeAttr.POS.value] == (20, 80) + + assert node_frame_dict[0] == ["0_0_1"] + assert Counter(node_frame_dict[1]) == Counter(["1_0_1", "1_0_2"]) + + +def test_nodes_from_segmentation_3d(segmentation_3d): + # test with 3D segmentation + node_graph, node_frame_dict = nodes_from_segmentation( + segmentation=segmentation_3d, + ) + assert Counter(list(node_graph.nodes)) == Counter(["0_1", "1_1", "1_2"]) + assert node_graph.nodes["1_1"][NodeAttr.SEG_ID.value] == 1 + assert node_graph.nodes["1_1"][NodeAttr.TIME.value] == 1 + assert node_graph.nodes["1_1"][NodeAttr.POS.value] == (20, 50, 80) + + assert node_frame_dict[0] == ["0_1"] + assert Counter(node_frame_dict[1]) == Counter(["1_1", "1_2"]) + + +# add_cand_edges +def test_add_cand_edges_2d(graph_2d): + cand_graph = nx.create_empty_copy(graph_2d) + add_cand_edges(cand_graph, max_edge_distance=50) + assert Counter(list(cand_graph.edges)) == Counter(list(graph_2d.edges)) + for edge in cand_graph.edges: + assert ( + pytest.approx(cand_graph.edges[edge][EdgeAttr.DISTANCE.value], abs=0.01) + == graph_2d.edges[edge][EdgeAttr.DISTANCE.value] + ) + + +def test_add_cand_edges_3d(graph_3d): + cand_graph = nx.create_empty_copy(graph_3d) + add_cand_edges(cand_graph, max_edge_distance=15) + graph_3d.remove_edge("0_1", "1_1") + assert Counter(list(cand_graph.edges)) == Counter(list(graph_3d.edges)) + for edge in cand_graph.edges: + assert pytest.approx(cand_graph.edges[edge], abs=0.01) == graph_3d.edges[edge] + + +def test_get_node_id(): + assert get_node_id(0, 2) == "0_2" + assert get_node_id(2, 10, 3) == "2_3_10" + + +def test_compute_node_frame_dict(graph_2d): + node_frame_dict = _compute_node_frame_dict(graph_2d) + expected = { + 0: [ + "0_1", + ], + 1: ["1_1", "1_2"], + } + assert node_frame_dict == expected diff --git a/tests/test_utils/test_saving_utils.py b/tests/test_utils/test_saving_utils.py index c4ff2ac..57d796f 100644 --- a/tests/test_utils/test_saving_utils.py +++ b/tests/test_utils/test_saving_utils.py @@ -1,46 +1,9 @@ -import networkx as nx import numpy as np -import pytest from motile_toolbox.utils import relabel_segmentation from numpy.testing import assert_array_equal from skimage.draw import disk -@pytest.fixture -def segmentation_2d(): - frame_shape = (100, 100) - total_shape = (2, *frame_shape) - segmentation = np.zeros(total_shape, dtype="int32") - # make frame with one cell in center with label 1 - rr, cc = disk(center=(50, 50), radius=20, shape=(100, 100)) - segmentation[0][rr, cc] = 1 - - # make frame with two cells - # first cell centered at (20, 80) with label 2 - # second cell centered at (60, 45) with label 3 - rr, cc = disk(center=(20, 80), radius=10, shape=frame_shape) - segmentation[1][rr, cc] = 2 - rr, cc = disk(center=(60, 45), radius=15, shape=frame_shape) - segmentation[1][rr, cc] = 3 - - return segmentation - - -@pytest.fixture -def graph_2d(): - graph = nx.DiGraph() - nodes = [ - ("0_1", {"y": 50, "x": 50, "t": 0, "segmentation_id": 1}), - ("1_1", {"y": 20, "x": 80, "t": 1, "segmentation_id": 2}), - ] - edges = [ - ("0_1", "1_1", {"distance": 42.43}), - ] - graph.add_nodes_from(nodes) - graph.add_edges_from(edges) - return graph - - def test_relabel_segmentation(segmentation_2d, graph_2d): frame_shape = (100, 100) expected = np.zeros(segmentation_2d.shape, dtype="int32") @@ -52,6 +15,7 @@ def test_relabel_segmentation(segmentation_2d, graph_2d): rr, cc = disk(center=(20, 80), radius=10, shape=frame_shape) expected[1][rr, cc] = 1 + graph_2d.remove_node("1_2") relabeled_seg = relabel_segmentation(graph_2d, segmentation_2d) print(f"Nonzero relabeled: {np.count_nonzero(relabeled_seg)}") print(f"Nonzero expected: {np.count_nonzero(expected)}")