Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Allow for single basal #89

Merged
merged 4 commits into from
Sep 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions neurots/extract_input/from_neurom.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ def trunk_neurite(pop, neurite_type=nm.BASAL_DENDRITE, bins=30):
return trunk_data


def number_neurites(pop, neurite_type=nm.BASAL_DENDRITE):
def number_neurites(pop, neurite_type=nm.BASAL_DENDRITE, min_n_basals=1):
"""Extract the number of trees for a specific tree type from a given population.

Args:
Expand Down Expand Up @@ -283,11 +283,11 @@ def number_neurites(pop, neurite_type=nm.BASAL_DENDRITE):
nm.get("number_of_neurites", pop, neurite_type=neurite_type), dtype=np.int32
)
# Clean the data from single basal trees cells
if neurite_type == nm.BASAL_DENDRITE and len(np.where(nneurites == 1)[0]) > 0:
nneurites[np.where(nneurites == 1)[0]] = 2
if neurite_type == nm.BASAL_DENDRITE and len(np.where(nneurites == min_n_basals - 1)[0]) > 0:
nneurites[np.where(nneurites == min_n_basals - 1)[0]] = min_n_basals
print(
"Warning, input population includes cells with single basal trees! "
+ "The distribution has been altered to include 2 basals minimum."
"Warning, input population includes cells with too few basal trees! "
+ f"The distribution has been altered to include {min_n_basals} basal(s) minimum."
)

heights, bins = np.histogram(
Expand Down
4 changes: 3 additions & 1 deletion neurots/extract_input/input_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def distributions(
diameter_input_morph=None,
feature="path_distances",
diameter_model=None,
min_n_basals=1,
):
"""Extracts the input distributions from an input population.

Expand All @@ -66,6 +67,7 @@ def distributions(
``{<neurite type 1>: <feature 1>, ...}``.
diameter_model (str): model for diameters, internal models are `M1`, `M2`, `M3`, `M4` and
`M5`. Can be set to `external` for external model.
min_n_basals (int): minimum number of basals, if less we enforce this value (default=1)

Returns:
dict: The input distributions.
Expand Down Expand Up @@ -112,7 +114,7 @@ def distributions(
nm_type = getattr(NeuriteType, neurite_type)

input_distributions[neurite_type] = _append_dicts(
trunk_neurite(pop_nm, nm_type), number_neurites(pop_nm, nm_type)
trunk_neurite(pop_nm, nm_type), number_neurites(pop_nm, nm_type, min_n_basals)
)
if type_feature in ["path_distances", "radial_distances"]:
_append_dicts(
Expand Down
8 changes: 4 additions & 4 deletions tests/test_extract_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,16 +421,16 @@ def test_number_neurites_cut_pop(POPUL):
smallest = 0
biggest = 1

for i in list(range(2, len(neurons[biggest].root_sections) - 1))[::-1]:
for i in list(range(0, len(neurons[biggest].root_sections) - 1))[::-1]:
neurons[biggest].delete_section(neurons[biggest].root_sections[i], recursive=True)

POPUL = neurom.core.population.Population(neurons)
assert_equal(len(neurons), 2)
assert_equal(len(neurons[biggest].neurites), 3)
assert_equal(len(neurons[biggest].neurites), 1)
assert_equal(len(neurons[smallest].neurites), 6)
assert_equal(len(list(POPUL.neurites)), 9)
assert_equal(len(list(POPUL.neurites)), 7)
res_cut = extract_input.from_neurom.number_neurites(POPUL)
assert_equal(res_cut, {"num_trees": {"data": {"bins": [2, 3, 4], "weights": [1, 0, 1]}}})
assert_equal(res_cut, {"num_trees": {"data": {"bins": [1, 2, 3, 4], "weights": [1, 0, 0, 1]}}})


def test_parameters():
Expand Down