diff --git a/swcgeom/transforms/__init__.py b/swcgeom/transforms/__init__.py index 51f3fc3..6f07983 100644 --- a/swcgeom/transforms/__init__.py +++ b/swcgeom/transforms/__init__.py @@ -2,6 +2,7 @@ from swcgeom.transforms.base import * from swcgeom.transforms.branch import * +from swcgeom.transforms.branch_tree import * from swcgeom.transforms.geometry import * from swcgeom.transforms.image_preprocess import * from swcgeom.transforms.image_stack import * diff --git a/swcgeom/transforms/branch.py b/swcgeom/transforms/branch.py index 1d6c1a3..74a8cd0 100644 --- a/swcgeom/transforms/branch.py +++ b/swcgeom/transforms/branch.py @@ -76,6 +76,57 @@ def extra_repr(self) -> str: return f"n_nodes={self.n_nodes}" +class BranchIsometricResampler(_BranchResampler): + def __init__(self, distance: float, *, adjust_last_gap: bool = True) -> None: + super().__init__() + self.distance = distance + self.adjust_last_gap = adjust_last_gap + + def resample(self, xyzr: npt.NDArray[np.float32]) -> npt.NDArray[np.float32]: + """Resampling by isometric interpolation, DO NOT keep original node. + + Parameters + ---------- + xyzr : np.ndarray[np.float32] + The array of shape (N, 4). + + Returns + ------- + new_xyzr : ~numpy.NDArray[float32] + An array of shape (n_nodes, 4). + """ + + # Compute the cumulative distances between consecutive points + diffs = np.diff(xyzr[:, :3], axis=0) + distances = np.sqrt((diffs**2).sum(axis=1)) + cumulative_distances = np.concatenate([[0], np.cumsum(distances)]) + + total_length = cumulative_distances[-1] + n_nodes = int(np.ceil(total_length / self.distance)) + 1 + + # Determine the new distances + if self.adjust_last_gap and n_nodes > 1: + new_distances = np.linspace(0, total_length, n_nodes) + else: + new_distances = np.arange(0, total_length, self.distance) + # keep endpoint + new_distances = np.concatenate([new_distances, total_length]) + + # Interpolate the new points + new_xyzr = np.zeros((n_nodes, 4), dtype=np.float32) + new_xyzr[:, :3] = np.array( + [ + np.interp(new_distances, cumulative_distances, xyzr[:, i]) + for i in range(3) + ] + ).T + new_xyzr[:, 3] = np.interp(new_distances, cumulative_distances, xyzr[:, 3]) + return new_xyzr + + def extra_repr(self) -> str: + return f"distance={self.distance},adjust_last_gap={self.adjust_last_gap}" + + class BranchConvSmoother(Transform[Branch, Branch[DictSWC]]): r"""Smooth the branch by sliding window.""" @@ -88,14 +139,14 @@ def __init__(self, n_nodes: int = 5) -> None: """ super().__init__() self.n_nodes = n_nodes - self.kernal = np.ones(n_nodes) + self.kernel = np.ones(n_nodes) def __call__(self, x: Branch) -> Branch[DictSWC]: x = x.detach() - c = signal.convolve(np.ones(x.number_of_nodes()), self.kernal, mode="same") + c = signal.convolve(np.ones(x.number_of_nodes()), self.kernel, mode="same") for k in ["x", "y", "z"]: v = x.get_ndata(k) - s = signal.convolve(v, self.kernal, mode="same") + s = signal.convolve(v, self.kernel, mode="same") x.attach.ndata[k][1:-1] = (s / c)[1:-1] return x @@ -105,7 +156,7 @@ def extra_repr(self) -> str: class BranchStandardizer(Transform[Branch, Branch[DictSWC]]): - r"""Standarize branch. + r"""Standardize branch. Standardized branch starts at (0, 0, 0), ends at (1, 0, 0), up at y, and scale max radius to 1. @@ -123,7 +174,7 @@ def __call__(self, x: Branch) -> Branch: @staticmethod def get_matrix(xyz: npt.NDArray[np.float32]) -> npt.NDArray[np.float32]: - r"""Get standarize transformation matrix. + r"""Get standardize transformation matrix. Standardized branch starts at (0, 0, 0), ends at (1, 0, 0), up at y. @@ -136,7 +187,7 @@ def get_matrix(xyz: npt.NDArray[np.float32]) -> npt.NDArray[np.float32]: Returns ------- T : np.ndarray[np.float32] - An homogeneous transfomation matrix of shape (4, 4). + An homogeneous transformation matrix of shape (4, 4). """ assert ( diff --git a/swcgeom/transforms/branch_tree.py b/swcgeom/transforms/branch_tree.py new file mode 100644 index 0000000..4c18117 --- /dev/null +++ b/swcgeom/transforms/branch_tree.py @@ -0,0 +1,67 @@ +from typing import Iterable + +import numpy as np + +from swcgeom.core import Branch, BranchTree, Node, Tree +from swcgeom.transforms.base import Transform + +__all__ = ["BranchTreeAssembler"] + + +class BranchTreeAssembler(Transform[BranchTree, Tree]): + EPS = 1e-6 + + def __call__(self, x: BranchTree) -> Tree: + nodes = [x.soma().detach()] + stack = [(x.soma(), 0)] # n_orig, id_new + while len(stack): + n_orig, pid_new = stack.pop() + children = n_orig.children() + + for br, c in self.pair(x.branches.get(n_orig.id, []), children): + s = 1 if np.linalg.norm(br[0].xyz() - n_orig.xyz()) < self.EPS else 0 + e = -2 if np.linalg.norm(br[-1].xyz() - c.xyz()) < self.EPS else -1 + + br_nodes = [n.detach() for n in br[s:e]] + [c.detach()] + for i, n in enumerate(br_nodes): + # reindex + n.id = len(nodes) + i + n.pid = len(nodes) + i - 1 + + br_nodes[0].pid = pid_new + nodes.extend(br_nodes) + stack.append((c, br_nodes[-1].id)) + + return Tree( + len(nodes), + source=x.source, + comments=x.comments, + names=x.names, + **{ + k: np.array([n.__getattribute__(k) for n in nodes]) + for k in x.names.cols() + }, + ) + + def pair( + self, branches: list[Branch], endpoints: list[Node] + ) -> Iterable[tuple[Branch, Node]]: + assert len(branches) == len(endpoints) + xyz1 = [br[-1].xyz() for br in branches] + xyz2 = [n.xyz() for n in endpoints] + v = np.reshape(xyz1, (-1, 1, 3)) - np.reshape(xyz2, (1, -1, 3)) + dis = np.linalg.norm(v, axis=-1) + + # greedy algorithm + pairs = [] + for _ in range(len(branches)): + # find minimal + min_idx = np.argmin(dis) + min_branch_idx, min_endpoint_idx = np.unravel_index(min_idx, dis.shape) + pairs.append((branches[min_branch_idx], endpoints[min_endpoint_idx])) + + # remove current node + dis[min_branch_idx, :] = np.inf + dis[:, min_endpoint_idx] = np.inf + + return pairs diff --git a/swcgeom/transforms/tree.py b/swcgeom/transforms/tree.py index e04e6ed..b98eda7 100644 --- a/swcgeom/transforms/tree.py +++ b/swcgeom/transforms/tree.py @@ -6,10 +6,11 @@ import numpy as np from typing_extensions import deprecated -from swcgeom.core import BranchTree, DictSWC, Path, Tree, cut_tree, to_subtree +from swcgeom.core import Branch, BranchTree, DictSWC, Path, Tree, cut_tree, to_subtree from swcgeom.core.swc_utils import SWCTypes, get_types from swcgeom.transforms.base import Transform -from swcgeom.transforms.branch import BranchConvSmoother +from swcgeom.transforms.branch import BranchConvSmoother, BranchIsometricResampler +from swcgeom.transforms.branch_tree import BranchTreeAssembler from swcgeom.transforms.geometry import Normalizer __all__ = [ @@ -22,6 +23,7 @@ "CutDendriteTree", "CutByFurcationOrder", "CutShortTipBranch", + "IsometricResampler", ] @@ -227,3 +229,27 @@ def _leave( cb(br) return None + + +class Resampler(Transform[Tree, Tree]): + def __init__(self, branch_resampler: Transform[Branch, Branch]) -> None: + super().__init__() + self.resampler = branch_resampler + self.assembler = BranchTreeAssembler() + + def __call__(self, x: Tree) -> Tree: + t = BranchTree.from_tree(x) + t.branches = { + k: [self.resampler(br) for br in brs] for k, brs in t.branches.items() + } + return self.assembler(t) + + +class IsometricResampler(Resampler): + def __init__( + self, distance: float, *, adjust_last_gap: bool = True, **kwargs + ) -> None: + branch_resampler = BranchIsometricResampler( + distance, adjust_last_gap=adjust_last_gap, **kwargs + ) + super().__init__(branch_resampler)