Skip to content

Commit

Permalink
Add official bbox arg
Browse files Browse the repository at this point in the history
  • Loading branch information
EmmaRenauld committed Nov 8, 2022
1 parent f4a215d commit bf59191
Show file tree
Hide file tree
Showing 14 changed files with 120 additions and 73 deletions.
23 changes: 20 additions & 3 deletions scilpy/io/streamlines.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,28 @@ def ichunk(sequence, n):

def is_argument_set(args, arg_name):
# Check that attribute is not None
return not getattr(args, 'reference', None) is None
return not getattr(args, arg_name, None) is None


def load_tractogram_with_reference(parser, args, filepath,
bbox_check=True, arg_name=None):
def load_tractogram_with_reference(parser, args, filepath, arg_name=None):
"""
Parameters
----------
parser: Argument Parser
Used to print errors, if any.
args: Namespace
Parsed arguments. Used to get the 'ref' and 'bbox_check' args.
See scilpy.io.utils to add the arguments to your parser.
filepath: str
Path of the tractogram file.
arg_name: str, optional
Name of the reference argument. By default the args.ref is used. If
arg_name is given, then args.arg_name_ref will be used instead.
"""
if is_argument_set(args, 'bbox_check'):
bbox_check = args.bbox_check
else:
bbox_check = True

_, ext = os.path.splitext(filepath)
if ext == '.trk':
Expand Down
14 changes: 11 additions & 3 deletions scilpy/io/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,14 @@ def add_verbose_arg(parser):
help='If set, produces verbose output.')


def add_bbox_arg(parser):
parser.add_argument('--bbox_check', type=bool, default=True,
help='Set to false to ignore validity of the bounding '
'box during loading / saving of \n'
'tractograms (ignores the presence of invalid '
'streamlines). Default: True.')


def add_sh_basis_args(parser, mandatory=False):
"""Add spherical harmonics (SH) bases argument.
Expand Down Expand Up @@ -440,7 +448,8 @@ def verify_compatibility_with_reference_sft(ref_sft, files_to_verify,
parser: argument parser
Will raise an error if a file is not compatible.
args: Namespace
Should contain a args.reference if any file is a .tck.
Should contain a args.reference if any file is a .tck, and possibly a
args.bbox_check (set to True by default).
"""
save_ref = args.reference

Expand All @@ -455,8 +464,7 @@ def verify_compatibility_with_reference_sft(ref_sft, files_to_verify,
args.reference = None
else:
args.reference = save_ref
mask = load_tractogram_with_reference(parser, args, file,
bbox_check=False)
mask = load_tractogram_with_reference(parser, args, file)
else: # should be a nifti file.
mask = file
compatible = is_header_compatible(ref_sft, mask)
Expand Down
10 changes: 6 additions & 4 deletions scilpy/segment/tractogram_from_roi.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,11 @@ def compute_masks_from_bundles(gt_files, parser, args, inverse_mask=False):
gt_files: list
List of either StatefulTractograms or niftis.
parser: ArgumentParser
Argument parser which handles the script's arguments.
Argument parser which handles the script's arguments. Used to print
parser errors, if any.
args: Namespace
List of arguments passed to the script.
List of arguments passed to the script. Used for its 'ref' and
'bbox_check' arguments.
inverse_mask: bool
If true, returns the list of inversed masks instead.
Expand Down Expand Up @@ -74,7 +76,7 @@ def compute_masks_from_bundles(gt_files, parser, args, inverse_mask=False):
else:
args.reference = save_ref
gt_sft = load_tractogram_with_reference(
parser, args, gt_bundle, bbox_check=False)
parser, args, gt_bundle)
gt_sft.to_vox()
gt_sft.to_corner()
_, dimensions, _, _ = gt_sft.space_attributes
Expand Down Expand Up @@ -653,7 +655,7 @@ def segment_tractogram_from_roi(
nc_sft = sft[all_nc_ids]
if len(nc_sft) > 0 or not args.no_empty:
save_tractogram(nc_sft, os.path.join(
args.out_dir, filename), bbox_valid_check=False)
args.out_dir, filename), bbox_valid_check=args.bbox_check)

return (vb_sft_list, wpc_sft_list, ib_sft_list, nc_sft, ib_names,
bundle_stats)
11 changes: 7 additions & 4 deletions scripts/scil_apply_transform_to_tractogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@
import numpy as np

from scilpy.io.streamlines import load_tractogram_with_reference
from scilpy.io.utils import (add_overwrite_arg,
from scilpy.io.utils import (add_bbox_arg,
add_overwrite_arg,
add_reference_arg,
add_verbose_arg,
assert_inputs_exist,
Expand All @@ -55,7 +56,9 @@ def _build_arg_parser():
description=__doc__)

p.add_argument('in_moving_tractogram',
help='Path of the tractogram to be transformed.')
help='Path of the tractogram to be transformed.\n'
'Bounding box validity will not be checked (could '
'contain invalid streamlines).')
p.add_argument('in_target_file',
help='Path of the reference target file (trk or nii).')
p.add_argument('in_transfo',
Expand Down Expand Up @@ -105,9 +108,9 @@ def main():
args.in_transfo], args.in_deformation)
assert_outputs_exist(parser, args, args.out_tractogram)

args.bbox_check = False # Adding manually bbox_check argument.
moving_sft = load_tractogram_with_reference(parser, args,
args.in_moving_tractogram,
bbox_check=False)
args.in_moving_tractogram)

transfo = load_matrix_in_any_format(args.in_transfo)
deformation_data = None
Expand Down
30 changes: 19 additions & 11 deletions scripts/scil_clean_qbx_clusters.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@
from dipy.io.utils import is_header_compatible

from scilpy.io.streamlines import load_tractogram_with_reference
from scilpy.io.utils import (add_overwrite_arg,
from scilpy.io.utils import (add_bbox_arg,
add_overwrite_arg,
add_reference_arg,
add_verbose_arg,
assert_inputs_exist,
Expand Down Expand Up @@ -65,6 +66,7 @@ def _build_arg_parser():
add_reference_arg(p)
add_overwrite_arg(p)
add_verbose_arg(p)
add_bbox_arg(p)

return p

Expand Down Expand Up @@ -175,12 +177,11 @@ def keypress_callback(obj, _):
concat_streamlines = []

ref_bundle = load_tractogram_with_reference(
parser, args, args.in_bundles[0], bbox_check=False)
parser, args, args.in_bundles[0])

for filename in args.in_bundles:
basename = os.path.basename(filename)
sft = load_tractogram_with_reference(parser, args, filename,
bbox_check=False)
sft = load_tractogram_with_reference(parser, args, filename)
if not is_header_compatible(ref_bundle, sft):
return
if len(sft) >= args.min_cluster_size:
Expand Down Expand Up @@ -240,32 +241,38 @@ def keypress_callback(obj, _):
accepted_streamlines = save_clusters(sft_accepted_on_size,
accepted_streamlines,
args.out_accepted_dir,
filename_accepted_on_size)
filename_accepted_on_size,
args.bbox_check)

accepted_sft = StatefulTractogram(accepted_streamlines,
sft_accepted_on_size[0],
Space.RASMM)
save_tractogram(accepted_sft, args.out_accepted, bbox_valid_check=False)
save_tractogram(accepted_sft, args.out_accepted,
bbox_valid_check=args.bbox_check)

# Save rejected clusters (by GUI)
rejected_streamlines = save_clusters(sft_accepted_on_size,
rejected_streamlines,
args.out_rejected_dir,
filename_accepted_on_size)
filename_accepted_on_size,
args.bbox_check)

# Save rejected clusters (by size)
rejected_streamlines.extend(save_clusters(sft_rejected_on_size,
range(len(sft_rejected_on_size)),
args.out_rejected_dir,
filename_rejected_on_size))
filename_rejected_on_size,
args.bbox_check))

rejected_sft = StatefulTractogram(rejected_streamlines,
sft_accepted_on_size[0],
Space.RASMM)
save_tractogram(rejected_sft, args.out_rejected, bbox_valid_check=False)
save_tractogram(rejected_sft, args.out_rejected,
bbox_valid_check=args.bbox_check)


def save_clusters(cluster_lists, indexes_list, directory, basenames_list):
def save_clusters(cluster_lists, indexes_list, directory, basenames_list,
bbox_check):
output_streamlines = []
for idx in indexes_list:
streamlines = cluster_lists[idx].streamlines
Expand All @@ -277,7 +284,8 @@ def save_clusters(cluster_lists, indexes_list, directory, basenames_list):
Space.RASMM)
tmp_filename = os.path.join(directory,
basenames_list[idx])
save_tractogram(tmp_sft, tmp_filename, bbox_valid_check=False)
save_tractogram(tmp_sft, tmp_filename,
bbox_valid_check=bbox_check)

return output_streamlines

Expand Down
13 changes: 7 additions & 6 deletions scripts/scil_compute_seed_density_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
from nibabel import Nifti1Image
from nibabel.streamlines import detect_format, TrkFile
import numpy as np
from scilpy.io.utils import (
add_overwrite_arg,
assert_inputs_exist,
assert_outputs_exist)
from scilpy.io.utils import (add_bbox_arg,
add_overwrite_arg,
assert_inputs_exist,
assert_outputs_exist)


def _build_arg_parser():
Expand All @@ -34,6 +34,7 @@ def _build_arg_parser():
'without a value, 1 is used.\n If a value is given, '
'will be used as the stored value.')
add_overwrite_arg(p)
add_bbox_arg(p)

return p

Expand All @@ -57,9 +58,9 @@ def main():
.format(args.binary, max_))

# Load files and data. TRKs can have 'same' as reference
# Can handle streamlines outside of bbox
# Can handle streamlines outside of bbox, if asked by user.
sft = load_tractogram(args.tractogram_filename, 'same',
bbox_valid_check=False)
bbox_valid_check=args.bbox_check)

# IMPORTANT
# Origin should be center when creating the seeds (see below, we
Expand Down
11 changes: 6 additions & 5 deletions scripts/scil_convert_tractogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
from dipy.io.streamline import save_tractogram

from scilpy.io.streamlines import load_tractogram_with_reference
from scilpy.io.utils import (add_overwrite_arg, add_reference_arg,
assert_inputs_exist, assert_outputs_exist)
from scilpy.io.utils import (add_bbox_arg, add_overwrite_arg,
add_reference_arg, assert_inputs_exist,
assert_outputs_exist)


def _build_arg_parser():
Expand All @@ -30,6 +31,7 @@ def _build_arg_parser():

add_reference_arg(p)
add_overwrite_arg(p)
add_bbox_arg(p)

return p

Expand All @@ -48,9 +50,8 @@ def main():

assert_outputs_exist(parser, args, args.output_name)

sft = load_tractogram_with_reference(parser, args, args.in_tractogram,
bbox_check=False)
save_tractogram(sft, args.output_name, bbox_valid_check=False)
sft = load_tractogram_with_reference(parser, args, args.in_tractogram)
save_tractogram(sft, args.output_name, bbox_valid_check=args.bbox_check)


if __name__ == "__main__":
Expand Down
7 changes: 4 additions & 3 deletions scripts/scil_decompose_connectivity.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@

from scilpy.io.image import get_data_as_label
from scilpy.io.streamlines import load_tractogram_with_reference
from scilpy.io.utils import (add_overwrite_arg,
from scilpy.io.utils import (add_bbox_arg,
add_overwrite_arg,
add_verbose_arg,
add_reference_arg,
assert_inputs_exist,
Expand Down Expand Up @@ -225,6 +226,7 @@ def _build_arg_parser():
add_reference_arg(p)
add_verbose_arg(p)
add_overwrite_arg(p)
add_bbox_arg(p)

return p

Expand Down Expand Up @@ -272,8 +274,7 @@ def main():

logging.info('*** Loading streamlines ***')
time1 = time.time()
sft = load_tractogram_with_reference(parser, args, args.in_tractogram,
bbox_check=False)
sft = load_tractogram_with_reference(parser, args, args.in_tractogram)
sft.remove_invalid_streamlines()
time2 = time.time()
logging.info(' Loading {} streamlines took {} sec.'.format(
Expand Down
3 changes: 1 addition & 2 deletions scripts/scil_extract_ushape.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,7 @@ def main():
parser.error('Min-Max ufactor "{},{}" '.format(args.minU, args.maxU) +
'must be between -1 and 1.')

sft = load_tractogram_with_reference(
parser, args, args.in_tractogram)
sft = load_tractogram_with_reference(parser, args, args.in_tractogram)

ids_c = detect_ushape(sft, args.minU, args.maxU)
ids_l = np.setdiff1d(np.arange(len(sft.streamlines)), ids_c)
Expand Down
14 changes: 7 additions & 7 deletions scripts/scil_fix_dsi_studio_trk.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@
import nibabel as nib
import numpy as np

from scilpy.io.utils import (add_overwrite_arg,
from scilpy.io.utils import (add_bbox_arg,
add_overwrite_arg,
assert_inputs_exist,
assert_outputs_exist)
from scilpy.utils.streamlines import (transform_warp_sft,
Expand Down Expand Up @@ -82,10 +83,9 @@ def _build_arg_parser():
invalid.add_argument('--remove_invalid', action='store_true',
help='Remove the streamlines landing out of the '
'bounding box.')
invalid.add_argument('--keep_invalid', action='store_true',
help='Keep the streamlines landing out of the '
'bounding box.')

add_overwrite_arg(p)
add_bbox_arg(p)

return p

Expand Down Expand Up @@ -122,7 +122,7 @@ def main():
assert_outputs_exist(parser, args, args.out_tractogram)

sft = load_tractogram(args.in_dsi_tractogram, 'same',
bbox_valid_check=False)
bbox_valid_check=args.bbox_check)

# LPS -> RAS convention in voxel space
sft.to_vox()
Expand All @@ -143,7 +143,7 @@ def main():
elif args.remove_invalid:
sft_flip.remove_invalid_streamlines()
save_tractogram(sft_flip, args.out_tractogram,
bbox_valid_check=not args.keep_invalid)
bbox_valid_check=args.bbox_check)
else:
static_img = nib.load(args.in_native_fa)
static_data = static_img.get_fdata()
Expand Down Expand Up @@ -206,7 +206,7 @@ def main():
elif args.remove_invalid:
new_sft.remove_invalid_streamlines()
save_tractogram(new_sft, args.out_tractogram,
bbox_valid_check=not args.keep_invalid)
bbox_valid_check=args.bbox_check)


if __name__ == "__main__":
Expand Down
Loading

0 comments on commit bf59191

Please sign in to comment.