diff --git a/scilpy/io/streamlines.py b/scilpy/io/streamlines.py index ba40bf57d..6aa1b9ed2 100644 --- a/scilpy/io/streamlines.py +++ b/scilpy/io/streamlines.py @@ -82,11 +82,28 @@ def ichunk(sequence, n): def is_argument_set(args, arg_name): # Check that attribute is not None - return not getattr(args, 'reference', None) is None + return not getattr(args, arg_name, None) is None -def load_tractogram_with_reference(parser, args, filepath, - bbox_check=True, arg_name=None): +def load_tractogram_with_reference(parser, args, filepath, arg_name=None): + """ + Parameters + ---------- + parser: Argument Parser + Used to print errors, if any. + args: Namespace + Parsed arguments. Used to get the 'ref' and 'bbox_check' args. + See scilpy.io.utils to add the arguments to your parser. + filepath: str + Path of the tractogram file. + arg_name: str, optional + Name of the reference argument. By default the args.ref is used. If + arg_name is given, then args.arg_name_ref will be used instead. + """ + if is_argument_set(args, 'bbox_check'): + bbox_check = args.bbox_check + else: + bbox_check = True _, ext = os.path.splitext(filepath) if ext == '.trk': diff --git a/scilpy/io/utils.py b/scilpy/io/utils.py index 4e167ec08..2e082ceea 100644 --- a/scilpy/io/utils.py +++ b/scilpy/io/utils.py @@ -198,6 +198,15 @@ def add_verbose_arg(parser): help='If set, produces verbose output.') +def add_bbox_arg(parser): + parser.add_argument('--no_bbox_check', dest='bbox_check', + action='store_false', + help='Activate to ignore validity of the bounding ' + 'box during loading / saving of \n' + 'tractograms (ignores the presence of invalid ' + 'streamlines).') + + def add_sh_basis_args(parser, mandatory=False): """Add spherical harmonics (SH) bases argument. @@ -440,7 +449,8 @@ def verify_compatibility_with_reference_sft(ref_sft, files_to_verify, parser: argument parser Will raise an error if a file is not compatible. args: Namespace - Should contain a args.reference if any file is a .tck. + Should contain a args.reference if any file is a .tck, and possibly a + args.bbox_check (set to True by default). """ save_ref = args.reference @@ -455,8 +465,7 @@ def verify_compatibility_with_reference_sft(ref_sft, files_to_verify, args.reference = None else: args.reference = save_ref - mask = load_tractogram_with_reference(parser, args, file, - bbox_check=False) + mask = load_tractogram_with_reference(parser, args, file) else: # should be a nifti file. mask = file compatible = is_header_compatible(ref_sft, mask) diff --git a/scilpy/segment/tractogram_from_roi.py b/scilpy/segment/tractogram_from_roi.py index 6c1440476..3bd7028e4 100644 --- a/scilpy/segment/tractogram_from_roi.py +++ b/scilpy/segment/tractogram_from_roi.py @@ -41,9 +41,11 @@ def compute_masks_from_bundles(gt_files, parser, args, inverse_mask=False): gt_files: list List of either StatefulTractograms or niftis. parser: ArgumentParser - Argument parser which handles the script's arguments. + Argument parser which handles the script's arguments. Used to print + parser errors, if any. args: Namespace - List of arguments passed to the script. + List of arguments passed to the script. Used for its 'ref' and + 'bbox_check' arguments. inverse_mask: bool If true, returns the list of inversed masks instead. @@ -74,7 +76,7 @@ def compute_masks_from_bundles(gt_files, parser, args, inverse_mask=False): else: args.reference = save_ref gt_sft = load_tractogram_with_reference( - parser, args, gt_bundle, bbox_check=False) + parser, args, gt_bundle) gt_sft.to_vox() gt_sft.to_corner() _, dimensions, _, _ = gt_sft.space_attributes @@ -653,7 +655,7 @@ def segment_tractogram_from_roi( nc_sft = sft[all_nc_ids] if len(nc_sft) > 0 or not args.no_empty: save_tractogram(nc_sft, os.path.join( - args.out_dir, filename), bbox_valid_check=False) + args.out_dir, filename), bbox_valid_check=args.bbox_check) return (vb_sft_list, wpc_sft_list, ib_sft_list, nc_sft, ib_names, bundle_stats) diff --git a/scripts/scil_apply_transform_to_tractogram.py b/scripts/scil_apply_transform_to_tractogram.py index f27d9f706..2cf3ff38f 100755 --- a/scripts/scil_apply_transform_to_tractogram.py +++ b/scripts/scil_apply_transform_to_tractogram.py @@ -41,7 +41,8 @@ import numpy as np from scilpy.io.streamlines import load_tractogram_with_reference -from scilpy.io.utils import (add_overwrite_arg, +from scilpy.io.utils import (add_bbox_arg, + add_overwrite_arg, add_reference_arg, add_verbose_arg, assert_inputs_exist, @@ -55,7 +56,9 @@ def _build_arg_parser(): description=__doc__) p.add_argument('in_moving_tractogram', - help='Path of the tractogram to be transformed.') + help='Path of the tractogram to be transformed.\n' + 'Bounding box validity will not be checked (could ' + 'contain invalid streamlines).') p.add_argument('in_target_file', help='Path of the reference target file (trk or nii).') p.add_argument('in_transfo', @@ -105,9 +108,9 @@ def main(): args.in_transfo], args.in_deformation) assert_outputs_exist(parser, args, args.out_tractogram) + args.bbox_check = False # Adding manually bbox_check argument. moving_sft = load_tractogram_with_reference(parser, args, - args.in_moving_tractogram, - bbox_check=False) + args.in_moving_tractogram) transfo = load_matrix_in_any_format(args.in_transfo) deformation_data = None diff --git a/scripts/scil_clean_qbx_clusters.py b/scripts/scil_clean_qbx_clusters.py index b9ec6c1ff..40c410f80 100755 --- a/scripts/scil_clean_qbx_clusters.py +++ b/scripts/scil_clean_qbx_clusters.py @@ -27,7 +27,8 @@ from dipy.io.utils import is_header_compatible from scilpy.io.streamlines import load_tractogram_with_reference -from scilpy.io.utils import (add_overwrite_arg, +from scilpy.io.utils import (add_bbox_arg, + add_overwrite_arg, add_reference_arg, add_verbose_arg, assert_inputs_exist, @@ -65,6 +66,7 @@ def _build_arg_parser(): add_reference_arg(p) add_overwrite_arg(p) add_verbose_arg(p) + add_bbox_arg(p) return p @@ -175,12 +177,11 @@ def keypress_callback(obj, _): concat_streamlines = [] ref_bundle = load_tractogram_with_reference( - parser, args, args.in_bundles[0], bbox_check=False) + parser, args, args.in_bundles[0]) for filename in args.in_bundles: basename = os.path.basename(filename) - sft = load_tractogram_with_reference(parser, args, filename, - bbox_check=False) + sft = load_tractogram_with_reference(parser, args, filename) if not is_header_compatible(ref_bundle, sft): return if len(sft) >= args.min_cluster_size: @@ -240,32 +241,38 @@ def keypress_callback(obj, _): accepted_streamlines = save_clusters(sft_accepted_on_size, accepted_streamlines, args.out_accepted_dir, - filename_accepted_on_size) + filename_accepted_on_size, + args.bbox_check) accepted_sft = StatefulTractogram(accepted_streamlines, sft_accepted_on_size[0], Space.RASMM) - save_tractogram(accepted_sft, args.out_accepted, bbox_valid_check=False) + save_tractogram(accepted_sft, args.out_accepted, + bbox_valid_check=args.bbox_check) # Save rejected clusters (by GUI) rejected_streamlines = save_clusters(sft_accepted_on_size, rejected_streamlines, args.out_rejected_dir, - filename_accepted_on_size) + filename_accepted_on_size, + args.bbox_check) # Save rejected clusters (by size) rejected_streamlines.extend(save_clusters(sft_rejected_on_size, range(len(sft_rejected_on_size)), args.out_rejected_dir, - filename_rejected_on_size)) + filename_rejected_on_size, + args.bbox_check)) rejected_sft = StatefulTractogram(rejected_streamlines, sft_accepted_on_size[0], Space.RASMM) - save_tractogram(rejected_sft, args.out_rejected, bbox_valid_check=False) + save_tractogram(rejected_sft, args.out_rejected, + bbox_valid_check=args.bbox_check) -def save_clusters(cluster_lists, indexes_list, directory, basenames_list): +def save_clusters(cluster_lists, indexes_list, directory, basenames_list, + bbox_check): output_streamlines = [] for idx in indexes_list: streamlines = cluster_lists[idx].streamlines @@ -277,7 +284,8 @@ def save_clusters(cluster_lists, indexes_list, directory, basenames_list): Space.RASMM) tmp_filename = os.path.join(directory, basenames_list[idx]) - save_tractogram(tmp_sft, tmp_filename, bbox_valid_check=False) + save_tractogram(tmp_sft, tmp_filename, + bbox_valid_check=bbox_check) return output_streamlines diff --git a/scripts/scil_compute_seed_density_map.py b/scripts/scil_compute_seed_density_map.py index 7fa6dfc00..23f7b47cf 100755 --- a/scripts/scil_compute_seed_density_map.py +++ b/scripts/scil_compute_seed_density_map.py @@ -11,10 +11,11 @@ from nibabel import Nifti1Image from nibabel.streamlines import detect_format, TrkFile import numpy as np -from scilpy.io.utils import ( - add_overwrite_arg, - assert_inputs_exist, - assert_outputs_exist) + +from scilpy.io.utils import (add_bbox_arg, + add_overwrite_arg, + assert_inputs_exist, + assert_outputs_exist) def _build_arg_parser(): @@ -34,6 +35,7 @@ def _build_arg_parser(): 'without a value, 1 is used.\n If a value is given, ' 'will be used as the stored value.') add_overwrite_arg(p) + add_bbox_arg(p) return p @@ -57,9 +59,9 @@ def main(): .format(args.binary, max_)) # Load files and data. TRKs can have 'same' as reference - # Can handle streamlines outside of bbox + # Can handle streamlines outside of bbox, if asked by user. sft = load_tractogram(args.tractogram_filename, 'same', - bbox_valid_check=False) + bbox_valid_check=args.bbox_check) # IMPORTANT # Origin should be center when creating the seeds (see below, we diff --git a/scripts/scil_convert_tractogram.py b/scripts/scil_convert_tractogram.py index eada06b01..eec511b59 100755 --- a/scripts/scil_convert_tractogram.py +++ b/scripts/scil_convert_tractogram.py @@ -12,8 +12,9 @@ from dipy.io.streamline import save_tractogram from scilpy.io.streamlines import load_tractogram_with_reference -from scilpy.io.utils import (add_overwrite_arg, add_reference_arg, - assert_inputs_exist, assert_outputs_exist) +from scilpy.io.utils import (add_bbox_arg, add_overwrite_arg, + add_reference_arg, assert_inputs_exist, + assert_outputs_exist) def _build_arg_parser(): @@ -30,6 +31,7 @@ def _build_arg_parser(): add_reference_arg(p) add_overwrite_arg(p) + add_bbox_arg(p) return p @@ -48,9 +50,8 @@ def main(): assert_outputs_exist(parser, args, args.output_name) - sft = load_tractogram_with_reference(parser, args, args.in_tractogram, - bbox_check=False) - save_tractogram(sft, args.output_name, bbox_valid_check=False) + sft = load_tractogram_with_reference(parser, args, args.in_tractogram) + save_tractogram(sft, args.output_name, bbox_valid_check=args.bbox_check) if __name__ == "__main__": diff --git a/scripts/scil_decompose_connectivity.py b/scripts/scil_decompose_connectivity.py index 18a181eaa..e3e87017e 100755 --- a/scripts/scil_decompose_connectivity.py +++ b/scripts/scil_decompose_connectivity.py @@ -46,7 +46,8 @@ from scilpy.image.labels import get_data_as_labels from scilpy.io.streamlines import load_tractogram_with_reference -from scilpy.io.utils import (add_overwrite_arg, +from scilpy.io.utils import (add_bbox_arg, + add_overwrite_arg, add_processes_arg, add_verbose_arg, add_reference_arg, @@ -228,6 +229,7 @@ def _build_arg_parser(): add_processes_arg(p) add_verbose_arg(p) add_overwrite_arg(p) + add_bbox_arg(p) return p @@ -274,9 +276,13 @@ def main(): logging.info('*** Loading streamlines ***') time1 = time.time() - sft = load_tractogram_with_reference(parser, args, args.in_tractogram, - bbox_check=False) - sft.remove_invalid_streamlines() + sft = load_tractogram_with_reference(parser, args, args.in_tractogram) + + # If loaded with invalid (bbox_check False), remove invalid streamlines + # before continuing. + if not args.bbox_check: + sft.remove_invalid_streamlines() + time2 = time.time() logging.info(' Loading {} streamlines took {} sec.'.format( len(sft), round(time2 - time1, 2))) diff --git a/scripts/scil_extract_ushape.py b/scripts/scil_extract_ushape.py index 553d6b0a0..e9d93e2b6 100755 --- a/scripts/scil_extract_ushape.py +++ b/scripts/scil_extract_ushape.py @@ -70,8 +70,7 @@ def main(): parser.error('Min-Max ufactor "{},{}" '.format(args.minU, args.maxU) + 'must be between -1 and 1.') - sft = load_tractogram_with_reference( - parser, args, args.in_tractogram) + sft = load_tractogram_with_reference(parser, args, args.in_tractogram) ids_c = detect_ushape(sft, args.minU, args.maxU) ids_l = np.setdiff1d(np.arange(len(sft.streamlines)), ids_c) diff --git a/scripts/scil_fix_dsi_studio_trk.py b/scripts/scil_fix_dsi_studio_trk.py index 413b098e0..73f902981 100755 --- a/scripts/scil_fix_dsi_studio_trk.py +++ b/scripts/scil_fix_dsi_studio_trk.py @@ -41,7 +41,8 @@ import nibabel as nib import numpy as np -from scilpy.io.utils import (add_overwrite_arg, +from scilpy.io.utils import (add_bbox_arg, + add_overwrite_arg, assert_inputs_exist, assert_outputs_exist) from scilpy.utils.streamlines import (transform_warp_sft, @@ -82,10 +83,9 @@ def _build_arg_parser(): invalid.add_argument('--remove_invalid', action='store_true', help='Remove the streamlines landing out of the ' 'bounding box.') - invalid.add_argument('--keep_invalid', action='store_true', - help='Keep the streamlines landing out of the ' - 'bounding box.') + add_overwrite_arg(p) + add_bbox_arg(p) return p @@ -122,7 +122,7 @@ def main(): assert_outputs_exist(parser, args, args.out_tractogram) sft = load_tractogram(args.in_dsi_tractogram, 'same', - bbox_valid_check=False) + bbox_valid_check=args.bbox_check) # LPS -> RAS convention in voxel space sft.to_vox() @@ -143,7 +143,7 @@ def main(): elif args.remove_invalid: sft_flip.remove_invalid_streamlines() save_tractogram(sft_flip, args.out_tractogram, - bbox_valid_check=not args.keep_invalid) + bbox_valid_check=args.bbox_check) else: static_img = nib.load(args.in_native_fa) static_data = static_img.get_fdata() @@ -206,7 +206,7 @@ def main(): elif args.remove_invalid: new_sft.remove_invalid_streamlines() save_tractogram(new_sft, args.out_tractogram, - bbox_valid_check=not args.keep_invalid) + bbox_valid_check=args.bbox_check) if __name__ == "__main__": diff --git a/scripts/scil_remove_invalid_streamlines.py b/scripts/scil_remove_invalid_streamlines.py index 7233c7885..359b19bbe 100755 --- a/scripts/scil_remove_invalid_streamlines.py +++ b/scripts/scil_remove_invalid_streamlines.py @@ -18,8 +18,9 @@ import numpy as np from scilpy.io.streamlines import load_tractogram_with_reference -from scilpy.io.utils import (add_overwrite_arg, add_reference_arg, - assert_inputs_exist, assert_outputs_exist) +from scilpy.io.utils import (add_bbox_arg, add_overwrite_arg, + add_reference_arg, assert_inputs_exist, + assert_outputs_exist) from scilpy.utils.streamlines import cut_invalid_streamlines @@ -57,14 +58,17 @@ def main(): parser = _build_arg_parser() args = parser.parse_args() + # Equivalent of add_bbox_arg(p): always ignoring invalid streamlines for + # this script. + args.bbox_check = False + assert_inputs_exist(parser, args.in_tractogram, args.reference) assert_outputs_exist(parser, args, args.out_tractogram) if args.threshold < 0: parser.error("Threshold must be positive.") - sft = load_tractogram_with_reference(parser, args, args.in_tractogram, - bbox_check=False) + sft = load_tractogram_with_reference(parser, args, args.in_tractogram) ori_len = len(sft) if args.cut_invalid: sft, cutting_counter = cut_invalid_streamlines(sft) diff --git a/scripts/scil_score_bundles.py b/scripts/scil_score_bundles.py index 0163b46eb..d8bc72232 100755 --- a/scripts/scil_score_bundles.py +++ b/scripts/scil_score_bundles.py @@ -37,10 +37,12 @@ from dipy.io.streamline import load_tractogram -from scilpy.io.utils import (add_overwrite_arg, +from scilpy.io.utils import (add_bbox_arg, + add_overwrite_arg, add_json_args, add_reference_arg, - add_verbose_arg, assert_inputs_exist, + add_verbose_arg, + assert_inputs_exist, assert_outputs_exist) from scilpy.segment.tractogram_from_roi import compute_masks_from_bundles from scilpy.tractanalysis.scoring import compute_tractometry @@ -71,20 +73,16 @@ def _build_arg_parser(): "gt_config.\nIf not set, filenames in the config " "file are considered \nas absolute paths.") - g = p.add_argument_group("Preprocessing") - g.add_argument("--ignore_invalid", action="store_true", - help="Ignore invalid streamlines in loaded tractograms.") - add_json_args(p) add_overwrite_arg(p) add_reference_arg(p) add_verbose_arg(p) + add_bbox_arg(p) return p def load_and_verify_everything(parser, args): - bbox_check = False if args.ignore_invalid else True assert_inputs_exist(parser, [args.gt_config]) if not os.path.isdir(args.bundles_dir): @@ -138,7 +136,8 @@ def load_and_verify_everything(parser, args): for bundle in bundle_names: vb_name = os.path.join(vb_path, bundle + '_VS.trk') if os.path.isfile(vb_name): - sft = load_tractogram(vb_name, 'same', bbox_valid_check=bbox_check) + sft = load_tractogram(vb_name, 'same', + bbox_valid_check=args.bbox_check) vb_sft_list.append(sft) if ref_sft is None: ref_sft = sft @@ -151,7 +150,8 @@ def load_and_verify_everything(parser, args): if wpc_path is not None: logging.info("Loading WPC bundles") for bundle in glob.glob(wpc_path + '/*'): - sft = load_tractogram(bundle, 'same', bbox_valid_check=bbox_check) + sft = load_tractogram(bundle, 'same', + bbox_valid_check=args.bbox_check) wpc_sft_list.append(sft) if ref_sft is None: ref_sft = sft @@ -165,7 +165,8 @@ def load_and_verify_everything(parser, args): logging.info("Loading invalid bundles") for bundle in glob.glob(ib_path + '/*'): ib_names.append(os.path.basename(bundle)) - sft = load_tractogram(bundle, 'same', bbox_valid_check=bbox_check) + sft = load_tractogram(bundle, 'same', + bbox_valid_check=args.bbox_check) ib_sft_list.append(ref_sft) if ref_sft is None: ref_sft = sft @@ -175,7 +176,7 @@ def load_and_verify_everything(parser, args): # Load either NC or IS if nc_filename is not None: nc_sft = load_tractogram(nc_filename, 'same', - bbox_valid_check=bbox_check) + bbox_valid_check=args.bbox_check) ref_sft = nc_sft else: nc_sft = None diff --git a/scripts/scil_score_tractogram.py b/scripts/scil_score_tractogram.py index 003346659..9ad281f07 100755 --- a/scripts/scil_score_tractogram.py +++ b/scripts/scil_score_tractogram.py @@ -79,10 +79,12 @@ from dipy.io.utils import is_header_compatible from scilpy.io.streamlines import load_tractogram_with_reference -from scilpy.io.utils import (add_overwrite_arg, +from scilpy.io.utils import (add_bbox_arg, + add_overwrite_arg, add_json_args, add_reference_arg, - add_verbose_arg, assert_inputs_exist, + add_verbose_arg, + assert_inputs_exist, assert_output_dirs_exist_and_empty, verify_compatibility_with_reference_sft, assert_outputs_exist) @@ -155,6 +157,7 @@ def _build_arg_parser(): add_overwrite_arg(p) add_reference_arg(p) add_verbose_arg(p) + add_bbox_arg(p) return p @@ -198,8 +201,7 @@ def load_and_verify_everything(parser, args): list_masks_files_o) logging.info("Loading tractogram.") - sft = load_tractogram_with_reference( - parser, args, args.in_tractogram, bbox_check=False) + sft = load_tractogram_with_reference(parser, args, args.in_tractogram) _, dimensions, _, _ = sft.space_attributes if args.remove_invalid: @@ -428,18 +430,18 @@ def main(): filename = "segmented_VB/{}_VS.trk".format(bundle_names[i]) save_tractogram(vb_sft_list[i], os.path.join(args.out_dir, filename), - bbox_valid_check=False) + bbox_valid_check=args.bbox_check) if (args.save_wpc_separately and wpc_sft_list[i] is not None and (len(wpc_sft_list[i]) > 0 or not args.no_empty)): filename = "segmented_WPC/{}_wpc.trk".format(bundle_names[i]) save_tractogram(wpc_sft_list[i], os.path.join(args.out_dir, filename), - bbox_valid_check=False) + bbox_valid_check=args.bbox_check) for i in range(len(ib_sft_list)): if len(ib_sft_list[i]) > 0 or not args.no_empty: file = "segmented_IB/{}_IC.trk".format(ib_names[i]) save_tractogram(ib_sft_list[i], os.path.join(args.out_dir, file), - bbox_valid_check=False) + bbox_valid_check=args.bbox_check) # Tractometry on bundles final_results = compute_tractometry( diff --git a/scripts/scil_streamlines_math.py b/scripts/scil_streamlines_math.py index f42a53d83..26f99d766 100755 --- a/scripts/scil_streamlines_math.py +++ b/scripts/scil_streamlines_math.py @@ -49,7 +49,8 @@ import numpy as np from scilpy.io.streamlines import load_tractogram_with_reference -from scilpy.io.utils import (add_json_args, +from scilpy.io.utils import (add_bbox_arg, + add_json_args, add_overwrite_arg, add_reference_arg, add_verbose_arg, @@ -105,14 +106,11 @@ def _build_arg_parser(): help='Save the streamline indices to the supplied ' 'json file.') - p.add_argument('--ignore_invalid', action='store_true', - help='If set, does not crash because of invalid ' - 'streamlines.') - add_json_args(p) add_reference_arg(p) add_verbose_arg(p) add_overwrite_arg(p) + add_bbox_arg(p) return p @@ -176,8 +174,7 @@ def list_generator_from_nib(filenames): sft_list = [] for f in args.in_tractograms: logging.info("Loading file {}".format(f)) - sft_list.append(load_tractogram_with_reference( - parser, args, f, bbox_check=not args.ignore_invalid)) + sft_list.append(load_tractogram_with_reference(parser, args, f)) # Apply the requested operation to each input file. logging.info('Performing operation \'{}\'.'.format(args.operation)) @@ -217,7 +214,7 @@ def list_generator_from_nib(filenames): logging.info('Saving {} streamlines to {}.'.format(len(indices), args.out_tractogram)) save_tractogram(new_sft[indices], args.out_tractogram, - bbox_valid_check=not args.ignore_invalid) + bbox_valid_check=args.bbox_check) if __name__ == "__main__": diff --git a/scripts/tests/test_score_bundles.py b/scripts/tests/test_score_bundles.py index 5990ce89f..2396ddd97 100644 --- a/scripts/tests/test_score_bundles.py +++ b/scripts/tests/test_score_bundles.py @@ -43,6 +43,6 @@ def test_score_bundles(script_runner): json.dump(json_contents, f) ret = script_runner.run('scil_score_bundles.py', - "config_file.json", "./", '--ignore_invalid') + "config_file.json", "./", '--no_bbox_check') assert ret.success