Skip to content

Commit

Permalink
Merge pull request #728 from frheault/multi_centro_labels_map
Browse files Browse the repository at this point in the history
Multi centro labels map
  • Loading branch information
arnaudbore authored Jun 29, 2023
2 parents eb5d3fb + 3579053 commit 4d37896
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 177 deletions.
77 changes: 16 additions & 61 deletions scilpy/tractanalysis/distance_to_centroid.py
Original file line number Diff line number Diff line change
@@ -1,70 +1,25 @@
# -*- coding: utf-8 -*-

import numpy as np
import tempfile
import os
from scipy.spatial import KDTree


def min_dist_to_centroid(bundle_pts, centroid_pts):
nb_bundle_points = len(bundle_pts)
nb_centroid_points = len(centroid_pts)
total_len = nb_bundle_points*nb_centroid_points
def min_dist_to_centroid(bundle_pts, centroid_pts, nb_pts):
tree = KDTree(centroid_pts, copy_data=True)
dists, labels = tree.query(bundle_pts, k=1)
dists, labels = np.expand_dims(
dists, axis=1), np.expand_dims(labels, axis=1)

# bundle_points will be shaped like
# [[bundle_pt1], ⸣
# [bundle_pt1], ⸠ → Repeated # of centroid points time
# [bundle_pt1], ⸥
# ...
# [bundle_ptN],
# [bundle_ptN],
# [bundle_ptN]]
with tempfile.TemporaryDirectory() as tmp_path:
bundle_points = np.memmap(os.path.join(tmp_path, 'bundle_points'),
dtype='float16', mode='w+',
shape=(total_len, 3))
bundle_points[:] = np.repeat(bundle_pts,
nb_centroid_points,
axis=0)
labels = np.mod(labels, nb_pts)

# centroid_points will be shaped like
# [[centroid_pt1], ⸣
# [centroid_pt2], |
# ... ⸠ → Repeated # of points in bundle times
# [centroid_pt20], ⸥
# [centroid_pt1],
# [centroid_pt2],
# ...
# [centroid_pt20]]
centroid_points = np.memmap(os.path.join(tmp_path, 'centroid_points'),
dtype='float16', mode='w+',
shape=(total_len, 3))
centroid_points[:] = np.tile(centroid_pts, (nb_bundle_points, 1))
sum_dist = np.expand_dims(np.sum(dists, axis=1), axis=1)
weights = np.exp(-dists / sum_dist)

# norm will be shaped like
# [[bundle_pt1 - centroid_pt1],
# [bundle_pt1 - centroid_pt2],
# [bundle_pt1 - centroid_pt3],
# ...
# [bundle_ptN - centroid_pt1]]
# [bundle_ptN - centroid_pt2]]
# ...
# [bundle_ptN - centroid_pt20]]
norm = np.memmap(os.path.join(tmp_path, 'norm'),
dtype='float16', mode='w+',
shape=(total_len,))
norm[:] = np.linalg.norm(bundle_points - centroid_points, axis=1)
votes = []
for i in range(len(bundle_pts)):
vote = np.bincount(labels[i], weights=weights[i])
total = np.arange(np.amax(labels[i])+1)
winner = total[np.argmax(vote)]
votes.append(winner)

# Reshape so we have the distance to each centroid for each
# bundle point
dist_to_centroid = np.memmap(os.path.join(tmp_path, 'dist_to_centroid'),
dtype='float16', mode='w+',
shape=(nb_bundle_points, nb_centroid_points))
dist_to_centroid[:] = norm.reshape(nb_bundle_points,
nb_centroid_points)

# Find the closest centroid (label and distance) for each point of the
# bundle
min_dist_label = np.argmin(dist_to_centroid, axis=1)
min_dist = np.amin(dist_to_centroid, axis=1)

return min_dist_label, min_dist
return np.array(votes, dtype=np.uint16), np.average(dists, axis=1)
202 changes: 90 additions & 112 deletions scripts/scil_compute_bundle_voxel_label_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
import logging
import os

from dipy.align.streamlinear import StreamlineLinearRegistration
from dipy.io.streamline import save_tractogram
from dipy.io.stateful_tractogram import StatefulTractogram, set_sft_logger_level
from dipy.io.utils import is_header_compatible
from dipy.segment.clustering import qbx_and_merge
import matplotlib.pyplot as plt
import nibabel as nib
from nibabel.streamlines.array_sequence import ArraySequence
Expand Down Expand Up @@ -52,17 +52,11 @@ def _build_arg_parser():
p.add_argument('--nb_pts', type=int,
help='Number of divisions for the bundles.\n'
'Default is the number of points of the centroid.')
p.add_argument('--new_labeling', action='store_true',
help='Activate the new labeling method based on clusters.')
p.add_argument('--min_streamline_count', type=int, default=100000,
help='Minimum number of streamlines for filtering/cutting'
'operation [%(default)s].')
p.add_argument('--min_voxel_count', type=int, default=1000000,
help='Minimum number of voxels for filtering/cutting'
'operation [%(default)s].')
p.add_argument('--colormap', default='jet',
help='Select the colormap for colored trk (data_per_point) '
'[%(default)s].')
p.add_argument('--new_labelling', action='store_true',
help='Use the new labelling method (multi-centroids).')

add_reference_arg(p)
add_overwrite_arg(p)
Expand All @@ -81,28 +75,23 @@ def main():
sft_centroid = load_tractogram_with_reference(parser, args,
args.in_centroid)

if len(sft_centroid.streamlines) < 1 \
or len(sft_centroid.streamlines) > 1:
logging.error('Centroid file {} should contain one streamline. '
'Skipping'.format(args.in_centroid))
raise ValueError
sft_centroid.to_vox()
sft_centroid.to_corner()

sft_list = []
for filename in args.in_bundles:
sft = load_tractogram_with_reference(parser, args, filename)
if not len(sft.streamlines):
logging.error('Empty bundle file {}. '
raise IOError('Empty bundle file {}. '
'Skipping'.format(args.in_bundle))
raise ValueError
sft.to_vox()
sft.to_corner()
sft_list.append(sft)

if len(sft_list):
if not is_header_compatible(sft_list[0], sft_list[-1]):
parser.error('ERROR HEADER')
parser.error('Header of {} and {} are not compatible'.format(
args.in_bundles[0], filename))

density_list = []
binary_list = []
Expand All @@ -112,10 +101,6 @@ def main():
binary = np.zeros(sft.dimensions)
binary[density > 0] = 1
binary_list.append(binary)

# density = ndi.gaussian_filter(density, 1) * binary
# density[binary < 1] += np.random.normal(0.0, 1.0,
# binary[binary < 1].shape)
density_list.append(density)

if not is_header_compatible(sft_centroid, sft_list[0]):
Expand All @@ -132,18 +117,6 @@ def main():
# with no neighbor. Remove isolated voxels to keep a single 'blob'
binary_bundle = np.zeros(corr_map.shape, dtype=bool)
binary_bundle[corr_map > 0.5] = 1
min_streamlines_count = 1e16
for sft in sft_list:
min_streamlines_count = min(len(sft), min_streamlines_count)

structure_cross = ndi.generate_binary_structure(3, 1)
if np.count_nonzero(binary_bundle) > args.min_voxel_count \
and min_streamlines_count > args.min_streamline_count:
binary_bundle = ndi.binary_dilation(binary_bundle,
structure=structure_cross)
binary_bundle = ndi.binary_erosion(binary_bundle,
structure=structure_cross,
iterations=2)

bundle_disjoint, _ = ndi.label(binary_bundle)
unique, count = np.unique(bundle_disjoint, return_counts=True)
Expand All @@ -162,97 +135,102 @@ def main():
if len(sft_list[i]):
concat_sft += sft_list[i]

if args.nb_pts is not None:
sft_centroid = resample_streamlines_num_points(sft_centroid,
args.nb_pts)
else:
args.nb_pts = len(sft_centroid.streamlines[0])
args.nb_pts = len(sft_centroid.streamlines[0]) if args.nb_pts is None \
else args.nb_pts

sft_centroid = resample_streamlines_num_points(sft_centroid, args.nb_pts)
tmp_sft = resample_streamlines_num_points(concat_sft, args.nb_pts)


thresholds = [24, 18, 12, 6] if args.new_labeling else [200]
clusters_map = qbx_and_merge(concat_sft.streamlines, thresholds,
nb_pts=args.nb_pts, verbose=False,
rng=np.random.RandomState(1))
if not args.new_labelling:
new_streamlines = sft_centroid.streamlines.copy()
sft_centroid = StatefulTractogram.from_sft([new_streamlines[0]],
sft_centroid)
else:
srr = StreamlineLinearRegistration()
srm = srr.optimize(static=tmp_sft.streamlines,
moving=sft_centroid.streamlines)
sft_centroid.streamlines = srm.transform(sft_centroid.streamlines)

uniformize_bundle_sft(concat_sft, ref_bundle=sft_centroid[0])
labels, dists = min_dist_to_centroid(concat_sft.streamlines._data,
sft_centroid.streamlines._data,
args.nb_pts)
labels += 1 # 0 means no labels

# It is not allowed that labels jumps labels for consistency
# Streamlines should have continous labels
final_streamlines = []
final_label = []
final_dist = []
for _, cluster in enumerate(clusters_map):
tmp_sft = StatefulTractogram.from_sft([cluster.centroid], concat_sft)
uniformize_bundle_sft(tmp_sft, ref_bundle=sft_centroid)
cluster_centroid = tmp_sft.streamlines[0] if args.new_labeling \
else sft_centroid.streamlines[0]
cluster_streamlines = ArraySequence(cluster[:])
min_dist_label, min_dist = min_dist_to_centroid(cluster_streamlines._data,
cluster_centroid)
min_dist_label += 1 # 0 means no labels

# It is not allowed that labels jumps labels for consistency
# Streamlines should have continous labels
curr_ind = 0
for i, streamline in enumerate(cluster_streamlines):
next_ind = curr_ind + len(streamline)
curr_labels = min_dist_label[curr_ind:next_ind]
curr_dist = min_dist[curr_ind:next_ind]
curr_ind = next_ind

# Flip streamlines so the labels increase (facilitate if/else)
# Should always be ordered in nextflow pipeline
gradient = np.gradient(curr_labels)
if len(np.argwhere(gradient < 0)) > len(np.argwhere(gradient > 0)):
streamline = streamline[::-1]
curr_labels = curr_labels[::-1]
curr_dist = curr_dist[::-1]

# Find jumps, cut them and find the longest
gradient = np.ediff1d(curr_labels)
max_jump = max(args.nb_pts // 5, 1)
if len(np.argwhere(np.abs(gradient) > max_jump)) > 0:
pos_jump = np.where(np.abs(gradient) > max_jump)[0] + 1
split_chunk = np.split(curr_labels,
pos_jump)
max_len = 0
max_pos = 0
for j, chunk in enumerate(split_chunk):
if len(chunk) > max_len:
max_len = len(chunk)
max_pos = j

curr_labels = split_chunk[max_pos]
gradient_chunk = np.ediff1d(chunk)
if len(np.unique(np.sign(gradient_chunk))) > 1:
continue
streamline = np.split(streamline,
pos_jump)[max_pos]
curr_dist = np.split(curr_dist,
pos_jump)[max_pos]

final_streamlines.append(streamline)
final_label.append(curr_labels)
final_dist.append(curr_dist)
final_dists = []
curr_ind = 0
for i, streamline in enumerate(concat_sft.streamlines):
next_ind = curr_ind + len(streamline)
curr_labels = labels[curr_ind:next_ind]
curr_dists = dists[curr_ind:next_ind]
curr_ind = next_ind

# Flip streamlines so the labels increase (facilitate if/else)
# Should always be ordered in nextflow pipeline
gradient = np.gradient(curr_labels)
if len(np.argwhere(gradient < 0)) > len(np.argwhere(gradient > 0)):
streamline = streamline[::-1]
curr_labels = curr_labels[::-1]
curr_dists = curr_dists[::-1]

# # Find jumps, cut them and find the longest
gradient = np.ediff1d(curr_labels)
max_jump = max(args.nb_pts // 5, 1)
if len(np.argwhere(np.abs(gradient) > max_jump)) > 0:
pos_jump = np.where(np.abs(gradient) > max_jump)[0] + 1
split_chunk = np.split(curr_labels,
pos_jump)

max_len = 0
max_pos = 0
for j, chunk in enumerate(split_chunk):
if len(chunk) > max_len:
max_len = len(chunk)
max_pos = j

curr_labels = split_chunk[max_pos]
gradient_chunk = np.ediff1d(chunk)
if len(np.unique(np.sign(gradient_chunk))) > 1:
continue
streamline = np.split(streamline,
pos_jump)[max_pos]
curr_dists = np.split(curr_dists,
pos_jump)[max_pos]

final_streamlines.append(streamline)
final_label.append(curr_labels)
final_dists.append(curr_dists)

final_streamlines = ArraySequence(final_streamlines)
labels_array = ArraySequence(final_label)
dist_array = ArraySequence(final_dist)
final_labels = ArraySequence(final_label)
final_dists = ArraySequence(final_dists)

kd_tree = cKDTree(final_streamlines._data)
labels_map = np.zeros(binary_bundle.shape, dtype=np.int16)
distance_map = np.zeros(binary_bundle.shape, dtype=float)
indices = np.array(np.nonzero(binary_bundle), dtype=int).T

for ind in indices:
neighbor_ids = kd_tree.query_ball_point(ind, 2.0)
if not neighbor_ids:
_, neighbor_ids = kd_tree.query(ind, k=5)

if not len(neighbor_ids):
continue
labels_val = labels_array._data[neighbor_ids]
dist_centro = dist_array._data[neighbor_ids]
dist_vox = np.linalg.norm(final_streamlines._data[neighbor_ids] - ind,
axis=1)
if np.sum(dist_centro, dtype=np.int64) > 0:
labels_map[tuple(ind)] = np.round(
np.average(labels_val, weights=dist_centro*dist_vox))
distance_map[tuple(ind)] = np.average(dist_centro*dist_vox)
else:
labels_map[tuple(ind)] = np.round(
np.average(labels_val, weights=dist_vox))
distance_map[tuple(ind)] = np.average(dist_vox)

labels_val = final_labels._data[neighbor_ids]
dists_val = final_dists._data[neighbor_ids]
sum_dists_vox = np.sum(dists_val)
weights_vox = np.exp(-dists_val / sum_dists_vox)

vote = np.bincount(labels_val, weights=weights_vox)
total = np.arange(np.amax(labels_val+1))
winner = total[np.argmax(vote)]
labels_map[ind[0], ind[1], ind[2]] = winner
distance_map[ind[0], ind[1], ind[2]] = np.average(dists_val)

cmap = get_colormap(args.colormap)

Expand Down
Loading

0 comments on commit 4d37896

Please sign in to comment.