Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 24 additions & 6 deletions run_dipy_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@
import nibabel as nib
from nibabel.orientations import aff2axcodes

import numpy.linalg as npl
from dipy.tracking.streamline import transform_streamlines

# Import custom module
import cuslines.cuslines as cuslines

Expand Down Expand Up @@ -88,6 +91,15 @@ def get_img(ep2_seq):
roi_data = roi.get_data()
mask = mask.get_data()

img_affine = img.affine

fa_threshold = 0.1
min_relative_peak = 0.25
min_peak_spacing = 0.7
min_peak_deg = 45
sm_lambda = 0
seed_density = 5

tenmodel = dti.TensorModel(gtab, fit_method='WLS')
print('Fitting Tensor')
tenfit = tenmodel.fit(data, mask)
Expand All @@ -96,26 +108,26 @@ def get_img(ep2_seq):
FA[np.isnan(FA)] = 0

# Setup tissue_classifier args
tissue_classifier = ThresholdStoppingCriterion(FA, 0.1)
tissue_classifier = ThresholdStoppingCriterion(FA, fa_threshold)
metric_map = np.asarray(FA, 'float64')

# Create seeds for ROI
seed_mask = utils.seeds_from_mask(roi_data, density=3, affine=np.eye(4))
seed_mask = utils.seeds_from_mask(roi_data, density=seed_density, affine=img_affine)

# Setup model
print('slowadcodf')
sh_order = 6
model = OpdtModel(gtab, sh_order=sh_order, min_signal=1)
sh_order = 4
model = OpdtModel(gtab, sh_order=sh_order, smooth=sm_lambda, min_signal=1)

# Setup direction getter args
print('Bootstrap direction getter')
boot_dg = BootDirectionGetter.from_data(data, model, max_angle=60., sphere=small_sphere)
boot_dg = BootDirectionGetter.from_data(data, model, max_angle=60., sphere=small_sphere, sh_order=sh_order, relative_peak_threshold=min_relative_peak, min_separation_angle=min_peak_deg)

print('streamline gen')
global_chunk_size = args.chunk_size
nchunks = (seed_mask.shape[0] + global_chunk_size - 1) // global_chunk_size

streamline_generator = LocalTracking(boot_dg, tissue_classifier, seed_mask, affine=np.eye(4), step_size=.5)
streamline_generator = LocalTracking(boot_dg, tissue_classifier, seed_mask, affine=img_affine, step_size=.5)

t1 = time.time()
streamline_time = 0
Expand All @@ -130,12 +142,18 @@ def get_img(ep2_seq):
seed_mask[idx*global_chunk_size:(idx+1)*global_chunk_size].shape[0],
te-ts))

# Invert streamline affine
inv_affine = npl.inv(img_affine)
streamlines = transform_streamlines(streamlines,inv_affine)


# Save tracklines file
if args.output_prefix:
fname = "{}.{}_{}.trk".format(args.output_prefix, idx+1, nchunks)
ts = time.time()
#save_tractogram(fname, streamlines, img.affine, vox_size=roi.header.get_zooms(), shape=roi_data.shape)
sft = StatefulTractogram(streamlines, args.nifti_file, Space.VOX)
sft.to_rasmm()
save_tractogram(sft, fname)
te = time.time()
print("Saved streamlines to {}, time {} s".format(fname, te-ts))
Expand Down