Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add check_bbox argument management #641

Merged
merged 6 commits into from
Mar 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 20 additions & 3 deletions scilpy/io/streamlines.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,28 @@ def ichunk(sequence, n):

def is_argument_set(args, arg_name):
# Check that attribute is not None
return not getattr(args, 'reference', None) is None
return not getattr(args, arg_name, None) is None


def load_tractogram_with_reference(parser, args, filepath,
bbox_check=True, arg_name=None):
def load_tractogram_with_reference(parser, args, filepath, arg_name=None):
"""
Parameters
----------
parser: Argument Parser
Used to print errors, if any.
args: Namespace
Parsed arguments. Used to get the 'ref' and 'bbox_check' args.
See scilpy.io.utils to add the arguments to your parser.
filepath: str
Path of the tractogram file.
arg_name: str, optional
Name of the reference argument. By default the args.ref is used. If
arg_name is given, then args.arg_name_ref will be used instead.
"""
if is_argument_set(args, 'bbox_check'):
bbox_check = args.bbox_check
else:
bbox_check = True

_, ext = os.path.splitext(filepath)
if ext == '.trk':
Expand Down
15 changes: 12 additions & 3 deletions scilpy/io/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,15 @@ def add_verbose_arg(parser):
help='If set, produces verbose output.')


def add_bbox_arg(parser):
parser.add_argument('--no_bbox_check', dest='bbox_check',
action='store_false',
help='Activate to ignore validity of the bounding '
'box during loading / saving of \n'
'tractograms (ignores the presence of invalid '
'streamlines).')


def add_sh_basis_args(parser, mandatory=False):
"""Add spherical harmonics (SH) bases argument.

Expand Down Expand Up @@ -440,7 +449,8 @@ def verify_compatibility_with_reference_sft(ref_sft, files_to_verify,
parser: argument parser
Will raise an error if a file is not compatible.
args: Namespace
Should contain a args.reference if any file is a .tck.
Should contain a args.reference if any file is a .tck, and possibly a
args.bbox_check (set to True by default).
"""
save_ref = args.reference

Expand All @@ -455,8 +465,7 @@ def verify_compatibility_with_reference_sft(ref_sft, files_to_verify,
args.reference = None
else:
args.reference = save_ref
mask = load_tractogram_with_reference(parser, args, file,
bbox_check=False)
mask = load_tractogram_with_reference(parser, args, file)
else: # should be a nifti file.
mask = file
compatible = is_header_compatible(ref_sft, mask)
Expand Down
10 changes: 6 additions & 4 deletions scilpy/segment/tractogram_from_roi.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,11 @@ def compute_masks_from_bundles(gt_files, parser, args, inverse_mask=False):
gt_files: list
List of either StatefulTractograms or niftis.
parser: ArgumentParser
Argument parser which handles the script's arguments.
Argument parser which handles the script's arguments. Used to print
parser errors, if any.
args: Namespace
List of arguments passed to the script.
List of arguments passed to the script. Used for its 'ref' and
'bbox_check' arguments.
inverse_mask: bool
If true, returns the list of inversed masks instead.

Expand Down Expand Up @@ -74,7 +76,7 @@ def compute_masks_from_bundles(gt_files, parser, args, inverse_mask=False):
else:
args.reference = save_ref
gt_sft = load_tractogram_with_reference(
parser, args, gt_bundle, bbox_check=False)
parser, args, gt_bundle)
gt_sft.to_vox()
gt_sft.to_corner()
_, dimensions, _, _ = gt_sft.space_attributes
Expand Down Expand Up @@ -653,7 +655,7 @@ def segment_tractogram_from_roi(
nc_sft = sft[all_nc_ids]
if len(nc_sft) > 0 or not args.no_empty:
save_tractogram(nc_sft, os.path.join(
args.out_dir, filename), bbox_valid_check=False)
args.out_dir, filename), bbox_valid_check=args.bbox_check)

return (vb_sft_list, wpc_sft_list, ib_sft_list, nc_sft, ib_names,
bundle_stats)
11 changes: 7 additions & 4 deletions scripts/scil_apply_transform_to_tractogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@
import numpy as np

from scilpy.io.streamlines import load_tractogram_with_reference
from scilpy.io.utils import (add_overwrite_arg,
from scilpy.io.utils import (add_bbox_arg,
add_overwrite_arg,
add_reference_arg,
add_verbose_arg,
assert_inputs_exist,
Expand All @@ -55,7 +56,9 @@ def _build_arg_parser():
description=__doc__)

p.add_argument('in_moving_tractogram',
help='Path of the tractogram to be transformed.')
help='Path of the tractogram to be transformed.\n'
'Bounding box validity will not be checked (could '
'contain invalid streamlines).')
p.add_argument('in_target_file',
help='Path of the reference target file (trk or nii).')
p.add_argument('in_transfo',
Expand Down Expand Up @@ -105,9 +108,9 @@ def main():
args.in_transfo], args.in_deformation)
assert_outputs_exist(parser, args, args.out_tractogram)

args.bbox_check = False # Adding manually bbox_check argument.
moving_sft = load_tractogram_with_reference(parser, args,
args.in_moving_tractogram,
bbox_check=False)
args.in_moving_tractogram)

transfo = load_matrix_in_any_format(args.in_transfo)
deformation_data = None
Expand Down
30 changes: 19 additions & 11 deletions scripts/scil_clean_qbx_clusters.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@
from dipy.io.utils import is_header_compatible

from scilpy.io.streamlines import load_tractogram_with_reference
from scilpy.io.utils import (add_overwrite_arg,
from scilpy.io.utils import (add_bbox_arg,
add_overwrite_arg,
add_reference_arg,
add_verbose_arg,
assert_inputs_exist,
Expand Down Expand Up @@ -65,6 +66,7 @@ def _build_arg_parser():
add_reference_arg(p)
add_overwrite_arg(p)
add_verbose_arg(p)
add_bbox_arg(p)

return p

Expand Down Expand Up @@ -175,12 +177,11 @@ def keypress_callback(obj, _):
concat_streamlines = []

ref_bundle = load_tractogram_with_reference(
parser, args, args.in_bundles[0], bbox_check=False)
parser, args, args.in_bundles[0])

for filename in args.in_bundles:
basename = os.path.basename(filename)
sft = load_tractogram_with_reference(parser, args, filename,
bbox_check=False)
sft = load_tractogram_with_reference(parser, args, filename)
if not is_header_compatible(ref_bundle, sft):
return
if len(sft) >= args.min_cluster_size:
Expand Down Expand Up @@ -240,32 +241,38 @@ def keypress_callback(obj, _):
accepted_streamlines = save_clusters(sft_accepted_on_size,
accepted_streamlines,
args.out_accepted_dir,
filename_accepted_on_size)
filename_accepted_on_size,
args.bbox_check)

accepted_sft = StatefulTractogram(accepted_streamlines,
sft_accepted_on_size[0],
Space.RASMM)
save_tractogram(accepted_sft, args.out_accepted, bbox_valid_check=False)
save_tractogram(accepted_sft, args.out_accepted,
bbox_valid_check=args.bbox_check)

# Save rejected clusters (by GUI)
rejected_streamlines = save_clusters(sft_accepted_on_size,
rejected_streamlines,
args.out_rejected_dir,
filename_accepted_on_size)
filename_accepted_on_size,
args.bbox_check)

# Save rejected clusters (by size)
rejected_streamlines.extend(save_clusters(sft_rejected_on_size,
range(len(sft_rejected_on_size)),
args.out_rejected_dir,
filename_rejected_on_size))
filename_rejected_on_size,
args.bbox_check))

rejected_sft = StatefulTractogram(rejected_streamlines,
sft_accepted_on_size[0],
Space.RASMM)
save_tractogram(rejected_sft, args.out_rejected, bbox_valid_check=False)
save_tractogram(rejected_sft, args.out_rejected,
bbox_valid_check=args.bbox_check)


def save_clusters(cluster_lists, indexes_list, directory, basenames_list):
def save_clusters(cluster_lists, indexes_list, directory, basenames_list,
bbox_check):
output_streamlines = []
for idx in indexes_list:
streamlines = cluster_lists[idx].streamlines
Expand All @@ -277,7 +284,8 @@ def save_clusters(cluster_lists, indexes_list, directory, basenames_list):
Space.RASMM)
tmp_filename = os.path.join(directory,
basenames_list[idx])
save_tractogram(tmp_sft, tmp_filename, bbox_valid_check=False)
save_tractogram(tmp_sft, tmp_filename,
bbox_valid_check=bbox_check)

return output_streamlines

Expand Down
14 changes: 8 additions & 6 deletions scripts/scil_compute_seed_density_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@
from nibabel import Nifti1Image
from nibabel.streamlines import detect_format, TrkFile
import numpy as np
from scilpy.io.utils import (
add_overwrite_arg,
assert_inputs_exist,
assert_outputs_exist)

from scilpy.io.utils import (add_bbox_arg,
add_overwrite_arg,
assert_inputs_exist,
assert_outputs_exist)


def _build_arg_parser():
Expand All @@ -34,6 +35,7 @@ def _build_arg_parser():
'without a value, 1 is used.\n If a value is given, '
'will be used as the stored value.')
add_overwrite_arg(p)
add_bbox_arg(p)

return p

Expand All @@ -57,9 +59,9 @@ def main():
.format(args.binary, max_))

# Load files and data. TRKs can have 'same' as reference
# Can handle streamlines outside of bbox
# Can handle streamlines outside of bbox, if asked by user.
sft = load_tractogram(args.tractogram_filename, 'same',
bbox_valid_check=False)
bbox_valid_check=args.bbox_check)

# IMPORTANT
# Origin should be center when creating the seeds (see below, we
Expand Down
11 changes: 6 additions & 5 deletions scripts/scil_convert_tractogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
from dipy.io.streamline import save_tractogram

from scilpy.io.streamlines import load_tractogram_with_reference
from scilpy.io.utils import (add_overwrite_arg, add_reference_arg,
assert_inputs_exist, assert_outputs_exist)
from scilpy.io.utils import (add_bbox_arg, add_overwrite_arg,
add_reference_arg, assert_inputs_exist,
assert_outputs_exist)


def _build_arg_parser():
Expand All @@ -30,6 +31,7 @@ def _build_arg_parser():

add_reference_arg(p)
add_overwrite_arg(p)
add_bbox_arg(p)

return p

Expand All @@ -48,9 +50,8 @@ def main():

assert_outputs_exist(parser, args, args.output_name)

sft = load_tractogram_with_reference(parser, args, args.in_tractogram,
bbox_check=False)
save_tractogram(sft, args.output_name, bbox_valid_check=False)
sft = load_tractogram_with_reference(parser, args, args.in_tractogram)
save_tractogram(sft, args.output_name, bbox_valid_check=args.bbox_check)


if __name__ == "__main__":
Expand Down
14 changes: 10 additions & 4 deletions scripts/scil_decompose_connectivity.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@

from scilpy.image.labels import get_data_as_labels
from scilpy.io.streamlines import load_tractogram_with_reference
from scilpy.io.utils import (add_overwrite_arg,
from scilpy.io.utils import (add_bbox_arg,
add_overwrite_arg,
add_processes_arg,
add_verbose_arg,
add_reference_arg,
Expand Down Expand Up @@ -228,6 +229,7 @@ def _build_arg_parser():
add_processes_arg(p)
add_verbose_arg(p)
add_overwrite_arg(p)
add_bbox_arg(p)

return p

Expand Down Expand Up @@ -274,9 +276,13 @@ def main():

logging.info('*** Loading streamlines ***')
time1 = time.time()
sft = load_tractogram_with_reference(parser, args, args.in_tractogram,
bbox_check=False)
sft.remove_invalid_streamlines()
sft = load_tractogram_with_reference(parser, args, args.in_tractogram)

# If loaded with invalid (bbox_check False), remove invalid streamlines
# before continuing.
if not args.bbox_check:
sft.remove_invalid_streamlines()

time2 = time.time()
logging.info(' Loading {} streamlines took {} sec.'.format(
len(sft), round(time2 - time1, 2)))
Expand Down
3 changes: 1 addition & 2 deletions scripts/scil_extract_ushape.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,7 @@ def main():
parser.error('Min-Max ufactor "{},{}" '.format(args.minU, args.maxU) +
'must be between -1 and 1.')

sft = load_tractogram_with_reference(
parser, args, args.in_tractogram)
sft = load_tractogram_with_reference(parser, args, args.in_tractogram)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing add_bbox_check

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bbox_check=False was not there to start with. I just removed the unnecessary newline.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure what you mean with "no distinction between load and save" @arnaudbore


ids_c = detect_ushape(sft, args.minU, args.maxU)
ids_l = np.setdiff1d(np.arange(len(sft.streamlines)), ids_c)
Expand Down
14 changes: 7 additions & 7 deletions scripts/scil_fix_dsi_studio_trk.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@
import nibabel as nib
import numpy as np

from scilpy.io.utils import (add_overwrite_arg,
from scilpy.io.utils import (add_bbox_arg,
add_overwrite_arg,
assert_inputs_exist,
assert_outputs_exist)
from scilpy.utils.streamlines import (transform_warp_sft,
Expand Down Expand Up @@ -82,10 +83,9 @@ def _build_arg_parser():
invalid.add_argument('--remove_invalid', action='store_true',
help='Remove the streamlines landing out of the '
'bounding box.')
invalid.add_argument('--keep_invalid', action='store_true',
help='Keep the streamlines landing out of the '
'bounding box.')

add_overwrite_arg(p)
add_bbox_arg(p)

return p

Expand Down Expand Up @@ -122,7 +122,7 @@ def main():
assert_outputs_exist(parser, args, args.out_tractogram)

sft = load_tractogram(args.in_dsi_tractogram, 'same',
bbox_valid_check=False)
bbox_valid_check=args.bbox_check)

# LPS -> RAS convention in voxel space
sft.to_vox()
Expand All @@ -143,7 +143,7 @@ def main():
elif args.remove_invalid:
sft_flip.remove_invalid_streamlines()
save_tractogram(sft_flip, args.out_tractogram,
bbox_valid_check=not args.keep_invalid)
bbox_valid_check=args.bbox_check)
else:
static_img = nib.load(args.in_native_fa)
static_data = static_img.get_fdata()
Expand Down Expand Up @@ -206,7 +206,7 @@ def main():
elif args.remove_invalid:
new_sft.remove_invalid_streamlines()
save_tractogram(new_sft, args.out_tractogram,
bbox_valid_check=not args.keep_invalid)
bbox_valid_check=args.bbox_check)


if __name__ == "__main__":
Expand Down
Loading