diff --git a/run_dipy_cpu.py b/run_dipy_cpu.py index 9b32d98..63d5dd0 100644 --- a/run_dipy_cpu.py +++ b/run_dipy_cpu.py @@ -49,6 +49,9 @@ import nibabel as nib from nibabel.orientations import aff2axcodes +import numpy.linalg as npl +from dipy.tracking.streamline import transform_streamlines + # Import custom module import cuslines.cuslines as cuslines @@ -88,6 +91,15 @@ def get_img(ep2_seq): roi_data = roi.get_data() mask = mask.get_data() +img_affine = img.affine + +fa_threshold = 0.1 +min_relative_peak = 0.25 +min_peak_spacing = 0.7 +min_peak_deg = 45 +sm_lambda = 0 +seed_density = 5 + tenmodel = dti.TensorModel(gtab, fit_method='WLS') print('Fitting Tensor') tenfit = tenmodel.fit(data, mask) @@ -96,26 +108,26 @@ def get_img(ep2_seq): FA[np.isnan(FA)] = 0 # Setup tissue_classifier args -tissue_classifier = ThresholdStoppingCriterion(FA, 0.1) +tissue_classifier = ThresholdStoppingCriterion(FA, fa_threshold) metric_map = np.asarray(FA, 'float64') # Create seeds for ROI -seed_mask = utils.seeds_from_mask(roi_data, density=3, affine=np.eye(4)) +seed_mask = utils.seeds_from_mask(roi_data, density=seed_density, affine=img_affine) # Setup model print('slowadcodf') -sh_order = 6 -model = OpdtModel(gtab, sh_order=sh_order, min_signal=1) +sh_order = 4 +model = OpdtModel(gtab, sh_order=sh_order, smooth=sm_lambda, min_signal=1) # Setup direction getter args print('Bootstrap direction getter') -boot_dg = BootDirectionGetter.from_data(data, model, max_angle=60., sphere=small_sphere) +boot_dg = BootDirectionGetter.from_data(data, model, max_angle=60., sphere=small_sphere, sh_order=sh_order, relative_peak_threshold=min_relative_peak, min_separation_angle=min_peak_deg) print('streamline gen') global_chunk_size = args.chunk_size nchunks = (seed_mask.shape[0] + global_chunk_size - 1) // global_chunk_size -streamline_generator = LocalTracking(boot_dg, tissue_classifier, seed_mask, affine=np.eye(4), step_size=.5) +streamline_generator = LocalTracking(boot_dg, tissue_classifier, seed_mask, affine=img_affine, step_size=.5) t1 = time.time() streamline_time = 0 @@ -130,12 +142,18 @@ def get_img(ep2_seq): seed_mask[idx*global_chunk_size:(idx+1)*global_chunk_size].shape[0], te-ts)) + # Invert streamline affine + inv_affine = npl.inv(img_affine) + streamlines = transform_streamlines(streamlines,inv_affine) + + # Save tracklines file if args.output_prefix: fname = "{}.{}_{}.trk".format(args.output_prefix, idx+1, nchunks) ts = time.time() #save_tractogram(fname, streamlines, img.affine, vox_size=roi.header.get_zooms(), shape=roi_data.shape) sft = StatefulTractogram(streamlines, args.nifti_file, Space.VOX) + sft.to_rasmm() save_tractogram(sft, fname) te = time.time() print("Saved streamlines to {}, time {} s".format(fname, te-ts))