Skip to content

Commit

Permalink
Take into account existing points when adding new trunks with no give…
Browse files Browse the repository at this point in the history
…n orientation
  • Loading branch information
adrien-berchet committed Nov 23, 2021
1 parent a384c1c commit 14e11b2
Show file tree
Hide file tree
Showing 10 changed files with 277 additions and 32 deletions.
21 changes: 17 additions & 4 deletions neurots/generate/grower.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ def _convert_orientation2points(self, orientation, n_trees, distr, params):
* None: creates a list of orientations according to the biological distributions.
* 'from_space': generates orientations depending on spatial input (not implemented yet).
"""
# pylint: disable=too-many-locals
if isinstance(orientation, list): # Gets major orientations externally
assert np.all(
np.linalg.norm(orientation, axis=1) > 0
Expand Down Expand Up @@ -240,10 +241,22 @@ def _convert_orientation2points(self, orientation, n_trees, distr, params):
else:
raise ValueError("Not enough orientation points!")
elif orientation is None: # Samples from trunk_angles
trunk_angles = sample.trunk_angles(distr, n_trees, self._rng)
trunk_z = sample.azimuth_angles(distr, n_trees, self._rng)
phis, thetas = _oris.trunk_to_spherical_angles(trunk_angles, trunk_z)
orientations = _oris.spherical_angles_to_orientations(phis, thetas)
phi_intervals, interval_n_trees = _oris.compute_interval_n_tree(
self.soma_grower.soma,
n_trees,
self._rng,
)

# Create trunks in each interval
orientations_i = []
for phi_interval, i_n_trees in zip(phi_intervals, interval_n_trees):
phis, thetas = _oris.trunk_to_spherical_angles(
sample.trunk_angles(distr, i_n_trees, self._rng),
sample.azimuth_angles(distr, i_n_trees, self._rng),
phi_interval,
)
orientations_i.append(_oris.spherical_angles_to_orientations(phis, thetas))
orientations = np.vstack(orientations_i)

elif orientation == "from_space":
raise ValueError("Not implemented yet!")
Expand Down
91 changes: 80 additions & 11 deletions neurots/generate/orientations.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from neurots.morphmath.utils import normalize_vectors
from neurots.utils import NeuroTSError

_TWOPI = 2.0 * np.pi


class OrientationManagerBase:
"""Base class that automatically registers orientation modes.
Expand Down Expand Up @@ -157,12 +159,22 @@ def _mode_sample_pairwise_angles(self, _, tree_type):

n_orientations = sample.n_neurites(tree_type_distrs["num_trees"], self._rng)

phis, thetas = trunk_to_spherical_angles(
trunk_angles=sample.trunk_angles(tree_type_distrs, n_orientations, self._rng),
z_angles=sample.azimuth_angles(tree_type_distrs, n_orientations, self._rng),
phi_intervals, interval_n_trees = compute_interval_n_tree(
self._soma,
n_orientations,
self._rng,
)

return spherical_angles_to_orientations(phis, thetas)
# Create trunks in each interval
orientations_i = []
for phi_interval, i_n_trees in zip(phi_intervals, interval_n_trees):
phis, thetas = trunk_to_spherical_angles(
sample.trunk_angles(tree_type_distrs, i_n_trees, self._rng),
sample.azimuth_angles(tree_type_distrs, i_n_trees, self._rng),
phi_interval,
)
orientations_i.append(spherical_angles_to_orientations(phis, thetas))
return np.vstack(orientations_i)


def spherical_angles_to_orientations(phis, thetas):
Expand Down Expand Up @@ -208,15 +220,24 @@ def orientations_to_sphere_points(oris, sphere_center, sphere_radius):
return sphere_center + oris * sphere_radius


def trunk_to_spherical_angles(trunk_angles, z_angles):
def trunk_to_spherical_angles(trunk_angles, z_angles, phi_interval=None):
"""Generate spherical angles from a list of NeuroM angles.
trunk_angles correspond to the angles on the x-y plane,
while z_angles correspond to the equivalent z-direction.
trunk angles correspond to polar angles, phi
trunk_angles correspond to polar angles, phi
z_angles correspond to azimuthal angles, theta
"""
if phi_interval is None:
phi_interval = (0.0, _TWOPI)

nb_intervals_min = 0
else:
assert len(phi_interval) == 2, "'phi_interval' must be a sequence of 2 elements."
assert phi_interval[0] < phi_interval[1], "'phi_interval' must be sorted ascending."

# Add 1 so the equiangle is computed such that angles are not equal to any boundary of the
# given interval
nb_intervals_min = 1

trunk_angles = np.asarray(trunk_angles)
z_angles = np.asarray(z_angles)

Expand All @@ -227,8 +248,8 @@ def trunk_to_spherical_angles(trunk_angles, z_angles):

thetas = z_angles[sorted_ids]

equiangle = 2.0 * np.pi / n_angles
phis = np.arange(1, n_angles + 1) * equiangle + sorted_phi_devs
equiangle = (phi_interval[1] - phi_interval[0]) / (n_angles + nb_intervals_min)
phis = np.arange(1, n_angles + 1) * equiangle + sorted_phi_devs + phi_interval[0]

return phis, thetas

Expand All @@ -247,3 +268,51 @@ def trunk_absolute_orientation_to_spherical_angles(orientation, trunk_absolute_a
thetas = theta + sorted_thetas - 0.5 * np.pi

return phis, thetas


def compute_interval_n_tree(soma, n_trees, _rng=np.random):
"""Compute the number of trunks to add between each pair of consecutive existing trunks.
If points already exist in the soma, the algorithm is the following:
- build the intervals between each pair of consecutive points.
- compute the size of each interval.
- randomly select the interval in which each new point will be added (the intervals are
weighted by their sizes to ensure the new trunks are created isotropically).
- count the number of new points in each interval.
- return the intervals in which at least one point must be added.
If no point exists in the soma, the interval [0, 2pi] contains all the new trunks.
"""
if soma and len(soma.points) > 0:
# Get angles of existing trunk origins
phis = []
for pt in soma.points:
pt_orientation = soma.orientation_from_point(pt)
phi, _ = rotation.spherical_from_vector(pt_orientation)
phis.append(phi)

phis = sorted(phis)

# The last interval goes beyond 2 * pi but the function
# self.soma.add_points_from_trunk_angles can deal with it.
phis += [phis[0] + _TWOPI]
phi_intervals = np.column_stack((phis[:-1], phis[1:]))

# Compute the number of trunks to create in each interval: each interval is weighted by
# its size to ensure the new trunks are created isotropically.
sizes = phi_intervals[:, 1] - phi_intervals[:, 0]
interval_i, interval_i_n_trees = np.unique(
_rng.choice(range(len(sizes)), size=n_trees, p=sizes / sizes.sum()), return_counts=True
)

# Keep only intervals with n_trees > 0
phi_intervals = phi_intervals[interval_i].tolist()
interval_n_trees = interval_i_n_trees

else:
# If there is no existing trunk, we create a None interval
phi_intervals = [None]
interval_n_trees = np.array([n_trees])

return phi_intervals, interval_n_trees
9 changes: 4 additions & 5 deletions neurots/generate/soma.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,16 +94,15 @@ def __init__(self, soma, context=None, rng=np.random):
self.context = context # for future, hypothetical use
self._rng = rng

def add_points_from_trunk_angles(self, trunk_angles, z_angles):
def add_points_from_trunk_angles(self, trunk_angles, z_angles, phi_interval=None):
"""Generate points on the soma surface from a list of angles.
trunk_angles correspond to the angles on the x-y plane,
while z_angles correspond to the equivalent z-direction.
trunk angles correspond to polar angles, phi
z_angles correspond to azimuthal angles, theta
phi_interval correspond to the interval in which the trunk angles must fit (the values can
be in [-inf, inf] as a modulo '2 pi' is applied internaly)
"""
phis, thetas = orientations.trunk_to_spherical_angles(trunk_angles, z_angles)
phis, thetas = orientations.trunk_to_spherical_angles(trunk_angles, z_angles, phi_interval)
new_directions = orientations.spherical_angles_to_orientations(phis, thetas)
return self.add_points_from_orientations(new_directions)

Expand Down
5 changes: 3 additions & 2 deletions neurots/morphmath/rotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@

import numpy as np

from neurots.morphmath.utils import norm


def spherical_from_vector(vect):
"""Return the spherical coordinates of a vector: phi, theta."""
x, y, z = vect

phi = np.arctan2(y, x)
# pylint: disable=assignment-from-no-return
theta = np.arccos(z / np.linalg.norm(vect))
theta = np.arccos(z / norm(vect))

return phi, theta

Expand Down
4 changes: 1 addition & 3 deletions neurots/morphmath/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
"""Util functions useful for general purposes."""

from math import sqrt

import numpy as np

# TODO: use KDTree when python3.6 is dropped and scipy>=1.6 is available
Expand Down Expand Up @@ -32,7 +30,7 @@ def get_random_point(D=1.0, random_generator=np.random):

def norm(vector):
"""Return the norm of the numpy array."""
return sqrt(vector.dot(vector))
return np.sqrt(vector.dot(vector))


def normalize_inplace(vector):
Expand Down
105 changes: 105 additions & 0 deletions tests/test_extract_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import neurom
import numpy as np
import pytest
import tmd
from neurom import load_morphologies
from neurom import stats
from numpy.testing import assert_array_almost_equal
Expand Down Expand Up @@ -515,3 +516,107 @@ def test_parameters():
method="trunk",
diameter_parameters=object(),
)


def test_from_TMD():
files = sorted([os.path.join(POP_PATH, neuron_dir) for neuron_dir in os.listdir(POP_PATH)])
pop = tmd.io.load_population(files)
angles = extract_input.from_TMD.persistent_homology_angles(pop, neurite_type="basal")
expected = [
[
[39.8034782, 22.0587348, 0.2204977, -0.0031118, -0.7250693, 0.4353605],
[47.0833969, 16.8049907, 0.4378265, 0.5057776, -1.4803076, 0.4328786],
[142.8110504, 0, np.nan, np.nan, np.nan, np.nan],
],
[
[31.5143489, 13.1089029, -0.0196220, 0.0082083, -0.8850177, 0.0621305],
[31.6991310, 4.0347867, 0.5645827, -1.2206746, -1.958114, 0.3290293],
[37.1221504, 0, np.nan, np.nan, np.nan, np.nan],
],
[
[80.1259918, 45.2574195, 0.5985462, -0.7937724, -1.4556563, 0.2089598],
[138.8324737, 137.2655334, 0.7664448, -0.5276853, -1.5705068, -0.1490920],
[129.4999694, 58.3956718, -0.1864954, 0.2471917, -0.1825209, -0.8895693],
[116.7485580, 29.0435256, 0.1938211, 0.8495636, -0.6798108, -0.3523395],
[116.3598709, 4.6591954, -0.0381157, -0.5630145, -0.9937219, -0.0060555],
[140.7718811, 0, np.nan, np.nan, np.nan, np.nan],
],
[
[108.1560363, 24.4956016, 0.0795795, -0.2466572, -0.818391, 0.4079079],
[84.4512786, 73.8114471, -0.0682486, -0.1316394, 0.7043639, 0.3319016],
[41.1267166, 22.6543693, 0.3379619, -0.3496508, 0.2551283, 0.1817587],
[58.6278800, 55.9588279, 0.1952034, -0.6733589, 0.0355827, 0.2554649],
[67.1058807, 67.8727340, -0.0696913, -0.1445076, 4.92025, -0.8872188],
[78.9355468, 16.0958881, -0.3179875, 0.5255199, -1.5740335, -0.2523424],
[92.3746414, 50.6163444, 0.0482639, 0.1450342, 0.8074381, -0.2487332],
[88.7393646, 14.1915130, -0.1820048, 0.6027378, 0.7695344, 0.4431591],
[107.2000274, 13.7320985, 0.4757962, -1.473934, -1.7596798, 1.2384288],
[157.3328552, 0, np.nan, np.nan, np.nan, np.nan],
],
[
[124.3847961, 78.5777359, -0.2851194, 0.1851800, 0.7059138, -0.4451136],
[67.1271133, 46.9302597, -0.0895297, 1.0650232, 6.1442575, -1.1252722],
[50.9799156, 27.6712512, -5.344535, 0.164577, -0.0902247, 0.6043591],
[147.5868072, 45.4994354, 0.0477559, -0.3307448, 0.5073497, -0.0292972],
[125.9251632, 32.7475662, -0.3614119, 0.3678546, -0.4462253, -0.5532181],
[119.1635360, 18.5245800, -0.0425105, 0.3792363, 0.3548733, -0.3755915],
[115.0881805, 26.0076904, -0.1438069, -0.4749578, 5.3513756, -1.0207773],
[44.8667869, 20.1308231, -0.5238513, 0.4554575, -0.3537357, 0.6082043],
[124.8795852, 9.7746734, 0.1573835, -0.6259746, 0.5335315, 0.3545703],
[159.7980194, 0, np.nan, np.nan, np.nan, np.nan],
],
[
[73.5007705, 44.1673202, -0.2861464, 0.6478886, 0.7622488, -0.6312206],
[187.1019897, 185.1600341, 0.5756854, -0.9131198, 4.854946, -0.0758894],
[111.3322906, 23.4609394, -1.1881468, 1.4903252, -0.7528698, 0.3027869],
[110.6484298, 24.8346958, 0.8869759, -0.5435520, 0.1794759, -1.2358185],
[82.1550140, 12.0827064, 0.5742538, -0.2249454, 4.78082, 0.0218874],
[187.2543487, 0, np.nan, np.nan, np.nan, np.nan],
],
[
[128.7404022, 44.2018013, 0.1672787, 0.3934768, 0.4804887, -1.2018781],
[202.8781890, 18.3915138, -0.4089744, 0.1219186, 5.5452833, 0.4808956],
[219.8619384, 0, np.nan, np.nan, np.nan, np.nan],
],
[
[224.0865325, 10.8910417, 0.0639219, -0.2720557, -0.536548, 0.8312207],
[196.6136322, 2.6748218, -0.5220707, -0.7471831, 3.6071417, 0.5896363],
[265.9921875, 0, np.nan, np.nan, np.nan, np.nan],
],
]
for a, b in zip(angles["persistence_diagram"], expected):
for ai, bi in zip(a, b):
assert_array_almost_equal(ai, bi)

angles = extract_input.from_TMD.persistent_homology_angles(
pop, neurite_type="basal", threshold=9
)
expected = [
[
[108.1560363, 24.4956016, 0.0795795, -0.2466572, -0.818391, 0.4079079],
[84.4512786, 73.8114471, -0.0682486, -0.1316394, 0.7043639, 0.3319016],
[41.1267166, 22.6543693, 0.3379619, -0.3496508, 0.2551283, 0.1817587],
[58.6278800, 55.9588279, 0.1952034, -0.6733589, 0.0355827, 0.2554649],
[67.1058807, 67.8727340, -0.0696913, -0.1445076, 4.92025, -0.8872188],
[78.9355468, 16.0958881, -0.3179875, 0.5255199, -1.5740335, -0.2523424],
[92.3746414, 50.6163444, 0.0482639, 0.1450342, 0.8074381, -0.2487332],
[88.7393646, 14.1915130, -0.1820048, 0.6027378, 0.7695344, 0.4431591],
[107.2000274, 13.7320985, 0.4757962, -1.473934, -1.7596798, 1.2384288],
[157.3328552, 0, np.nan, np.nan, np.nan, np.nan],
],
[
[124.3847961, 78.5777359, -0.2851194, 0.1851800, 0.7059138, -0.4451136],
[67.1271133, 46.9302597, -0.0895297, 1.0650232, 6.1442575, -1.1252722],
[50.9799156, 27.6712512, -5.344535, 0.164577, -0.0902247, 0.6043591],
[147.5868072, 45.4994354, 0.0477559, -0.3307448, 0.5073497, -0.0292972],
[125.9251632, 32.7475662, -0.3614119, 0.3678546, -0.4462253, -0.5532181],
[119.1635360, 18.5245800, -0.0425105, 0.3792363, 0.3548733, -0.3755915],
[115.0881805, 26.0076904, -0.1438069, -0.4749578, 5.3513756, -1.0207773],
[44.8667869, 20.1308231, -0.5238513, 0.4554575, -0.3537357, 0.6082043],
[124.8795852, 9.7746734, 0.1573835, -0.6259746, 0.5335315, 0.3545703],
[159.7980194, 0, np.nan, np.nan, np.nan, np.nan],
],
]
for a, b in zip(angles["persistence_diagram"], expected):
for ai, bi in zip(a, b):
assert_array_almost_equal(ai, bi)
15 changes: 11 additions & 4 deletions tests/test_neuron_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ def test_convert_orientation2points():
pts = ng._convert_orientation2points([[0, 1, 0]], 1, distributions["apical"], {})
assert_array_almost_equal(pts, [[0, 15.27995, 0]])

# Test with no existing trunk
ng = NeuronGrower(parameters, distributions)
pts = ng._convert_orientation2points(None, 2, distributions["apical"], {})
assert_array_almost_equal(
Expand All @@ -208,6 +209,12 @@ def test_convert_orientation2points():
with pytest.raises(ValueError):
ng._convert_orientation2points("from_space", 1, distributions["apical"], {})

# Test with existing trunks
ng.grow()
pts = ng._convert_orientation2points(None, 2, distributions["apical"], {})

assert_array_almost_equal(pts, [[2.770599, 4.868847, 8.813554], [-6.314678, 6.2103, 5.533321]])

with pytest.raises(ValueError):
ng._convert_orientation2points(object(), 1, distributions["apical"], {})

Expand Down Expand Up @@ -247,11 +254,11 @@ def test_breaker_of_tmd_algo():
assert_array_equal(N.apical_sections, [33])
assert_array_almost_equal(
n.sections[169].points[-1],
np.array([117.20551, -41.12157, 189.57013]),
np.array([-220.93813, -21.49141, -55.93323]),
decimal=5,
)
assert_array_almost_equal(
n.sections[122].points[-1], np.array([77.08879, 115.79825, -0.99393]), decimal=5
n.sections[122].points[-1], np.array([-17.31787, 151.4876, -6.67741]), decimal=5
)

# Test with a specific random generator
Expand All @@ -263,11 +270,11 @@ def test_breaker_of_tmd_algo():
assert_array_equal(N.apical_sections, [33])
assert_array_almost_equal(
n.sections[169].points[-1],
np.array([117.20551, -41.12157, 189.57013]),
np.array([-220.93813, -21.49141, -55.93323]),
decimal=5,
)
assert_array_almost_equal(
n.sections[122].points[-1], np.array([77.08879, 115.79825, -0.99393]), decimal=5
n.sections[122].points[-1], np.array([-17.31787, 151.4876, -6.67741]), decimal=5
)


Expand Down
Loading

0 comments on commit 14e11b2

Please sign in to comment.