From 5a1c9071215e237d587f34a5c9ca7bb95a194bb9 Mon Sep 17 00:00:00 2001 From: frheault Date: Tue, 10 Oct 2023 21:19:22 -0400 Subject: [PATCH 1/5] New test to match Dipy potential issue with sft copy --- scripts/tests/test_workflows.py | 4 +- trx/io.py | 2 +- trx/tests/test_io.py | 69 ++++++++++++++++++++------------- trx/tests/test_memmap.py | 8 ++-- trx/trx_file_memmap.py | 12 +++--- trx/workflows.py | 4 +- 6 files changed, 57 insertions(+), 42 deletions(-) diff --git a/scripts/tests/test_workflows.py b/scripts/tests/test_workflows.py index 740fdd2..7640886 100644 --- a/scripts/tests/test_workflows.py +++ b/scripts/tests/test_workflows.py @@ -15,7 +15,7 @@ from trx.fetcher import (get_testing_files_dict, fetch_data, get_home) -from trx.io import get_trx_tmpdir +from trx.io import get_trx_tmp_dir import trx.trx_file_memmap as tmm from trx.workflows import (convert_dsi_studio, convert_tractogram, @@ -26,7 +26,7 @@ # If they already exist, this only takes 5 seconds (check md5sum) fetch_data(get_testing_files_dict(), keys=['DSI.zip', 'trx_from_scratch.zip']) -tmp_dir = get_trx_tmpdir() +tmp_dir = get_trx_tmp_dir() def test_help_option_convert_dsi(script_runner): diff --git a/trx/io.py b/trx/io.py index da2e8eb..86ecc27 100644 --- a/trx/io.py +++ b/trx/io.py @@ -15,7 +15,7 @@ from trx.utils import split_name_with_gz -def get_trx_tmpdir(): +def get_trx_tmp_dir(): if os.getenv('TRX_TMPDIR') is not None: if os.getenv('TRX_TMPDIR') == 'use_working_dir': trx_tmp_dir = os.getcwd() diff --git a/trx/tests/test_io.py b/trx/tests/test_io.py index f9628cf..889bba2 100644 --- a/trx/tests/test_io.py +++ b/trx/tests/test_io.py @@ -11,30 +11,45 @@ try: import dipy + from dipy.io.streamline import save_tractogram dipy_available = True except ImportError: dipy_available = False import trx.trx_file_memmap as tmm from trx.trx_file_memmap import TrxFile -from trx.io import load, save, get_trx_tmpdir +from trx.io import load, save, get_trx_tmp_dir from trx.fetcher import (get_testing_files_dict, fetch_data, get_home) fetch_data(get_testing_files_dict(), keys=['gold_standard.zip']) -tmp_dir = get_trx_tmpdir() +tmp_gs_dir = get_trx_tmp_dir() + + +@pytest.mark.parametrize("path", [("gs.trk"), ("gs.tck"), + ("gs.vtk")]) +@pytest.mark.skipif(not dipy_available, reason='Dipy is not installed.') +def test_load_vox(path): + gs_dir = os.path.join(get_home(), 'gold_standard') + path = os.path.join(gs_dir, path) + + obj = load(path, os.path.join(gs_dir, 'gs.nii')) + sft = obj.to_sft() + save_tractogram(sft, path) + obj.close() + save_tractogram(sft, path) @pytest.mark.parametrize("path", [("gs.trx"), ("gs.trk"), ("gs.tck"), ("gs.vtk")]) @pytest.mark.skipif(not dipy_available, reason='Dipy is not installed.') def test_load_vox(path): - dir = os.path.join(get_home(), 'gold_standard') - path = os.path.join(dir, path) + gs_dir = os.path.join(get_home(), 'gold_standard') + path = os.path.join(gs_dir, path) coord = np.loadtxt(os.path.join(get_home(), 'gold_standard', 'gs_vox_space.txt')) - obj = load(path, os.path.join(dir, 'gs.nii')) + obj = load(path, os.path.join(gs_dir, 'gs.nii')) sft = obj.to_sft() if isinstance(obj, TrxFile) else obj sft.to_vox() @@ -48,11 +63,11 @@ def test_load_vox(path): ("gs.vtk")]) @pytest.mark.skipif(not dipy_available, reason='Dipy is not installed.') def test_load_voxmm(path): - dir = os.path.join(get_home(), 'gold_standard') - path = os.path.join(dir, path) + gs_dir = os.path.join(get_home(), 'gold_standard') + path = os.path.join(gs_dir, path) coord = np.loadtxt(os.path.join(get_home(), 'gold_standard', 'gs_voxmm_space.txt')) - obj = load(path, os.path.join(dir, 'gs.nii')) + obj = load(path, os.path.join(gs_dir, 'gs.nii')) sft = obj.to_sft() if isinstance(obj, TrxFile) else obj sft.to_voxmm() @@ -65,33 +80,33 @@ def test_load_voxmm(path): @pytest.mark.parametrize("path", [("gs.trk"), ("gs.trx"), ("gs_fldr.trx")]) @pytest.mark.skipif(not dipy_available, reason='Dipy is not installed.') def test_multi_load_save_rasmm(path): - dir = os.path.join(get_home(), 'gold_standard') + gs_dir = os.path.join(get_home(), 'gold_standard') basename, ext = os.path.splitext(path) - out_path = os.path.join(tmp_dir.name, '{}_tmp{}'.format(basename, ext)) - path = os.path.join(dir, path) + out_path = os.path.join(tmp_gs_dir.name, '{}_tmp{}'.format(basename, ext)) + path = os.path.join(gs_dir, path) coord = np.loadtxt(os.path.join(get_home(), 'gold_standard', 'gs_rasmm_space.txt')) - obj = load(path, os.path.join(dir, 'gs.nii')) + obj = load(path, os.path.join(gs_dir, 'gs.nii')) for _ in range(100): save(obj, out_path) if isinstance(obj, TrxFile): obj.close() - obj = load(out_path, os.path.join(dir, 'gs.nii')) + obj = load(out_path, os.path.join(gs_dir, 'gs.nii')) assert_allclose(obj.streamlines._data, coord, rtol=1e-04, atol=1e-06) @pytest.mark.parametrize("path", [("gs.trx"), ("gs_fldr.trx")]) @pytest.mark.skipif(not dipy_available, reason='Dipy is not installed.') -def test_delete_tmp_dir(path): - dir = os.path.join(get_home(), 'gold_standard') - path = os.path.join(dir, path) +def test_delete_tmp_gs_dir(path): + gs_dir = os.path.join(get_home(), 'gold_standard') + path = os.path.join(gs_dir, path) trx1 = tmm.load(path) if os.path.isfile(path): - tmp_dir = deepcopy(trx1._uncompressed_folder_handle.name) - assert os.path.isdir(tmp_dir) + tmp_gs_dir = deepcopy(trx1._uncompressed_folder_handle.name) + assert os.path.isdir(tmp_gs_dir) sft = trx1.to_sft() trx1.close() @@ -102,7 +117,7 @@ def test_delete_tmp_dir(path): # The folder trx representation does not need tmp files if os.path.isfile(path): - assert not os.path.isdir(tmp_dir) + assert not os.path.isdir(tmp_gs_dir) assert_allclose(sft.streamlines._data, coord_rasmm, rtol=1e-04, atol=1e-06) @@ -124,8 +139,8 @@ def test_delete_tmp_dir(path): @pytest.mark.parametrize("path", [("gs.trx")]) @pytest.mark.skipif(not dipy_available, reason='Dipy is not installed.') def test_close_tmp_files(path): - dir = os.path.join(get_home(), 'gold_standard') - path = os.path.join(dir, path) + gs_dir = os.path.join(get_home(), 'gold_standard') + path = os.path.join(gs_dir, path) trx = tmm.load(path) process = psutil.Process(os.getpid()) @@ -149,8 +164,8 @@ def test_close_tmp_files(path): @pytest.mark.parametrize("tmp_path", [("~"), ("use_working_dir")]) def test_change_tmp_dir(tmp_path): - dir = os.path.join(get_home(), 'gold_standard') - path = os.path.join(dir, 'gs.trx') + gs_dir = os.path.join(get_home(), 'gold_standard') + path = os.path.join(gs_dir, 'gs.trx') if tmp_path == 'use_working_dir': os.environ['TRX_TMPDIR'] = 'use_working_dir' @@ -158,12 +173,12 @@ def test_change_tmp_dir(tmp_path): os.environ['TRX_TMPDIR'] = os.path.expanduser(tmp_path) trx = tmm.load(path) - tmp_dir = deepcopy(trx._uncompressed_folder_handle.name) + tmp_gs_dir = deepcopy(trx._uncompressed_folder_handle.name) if tmp_path == 'use_working_dir': - assert os.path.dirname(tmp_dir) == os.getcwd() + assert os.path.dirname(tmp_gs_dir) == os.getcwd() else: - assert os.path.dirname(tmp_dir) == os.path.expanduser(tmp_path) + assert os.path.dirname(tmp_gs_dir) == os.path.expanduser(tmp_path) trx.close() - assert not os.path.isdir(tmp_dir) + assert not os.path.isdir(tmp_gs_dir) diff --git a/trx/tests/test_memmap.py b/trx/tests/test_memmap.py index a0f2703..7ae189d 100644 --- a/trx/tests/test_memmap.py +++ b/trx/tests/test_memmap.py @@ -14,14 +14,14 @@ except ImportError: dipy_available = False -from trx.io import get_trx_tmpdir +from trx.io import get_trx_tmp_dir import trx.trx_file_memmap as tmm from trx.fetcher import (get_testing_files_dict, fetch_data, get_home) fetch_data(get_testing_files_dict(), keys=['memmap_test_data.zip']) -tmp_dir = get_trx_tmpdir() +tmp_dir = get_trx_tmp_dir() @pytest.mark.parametrize( @@ -129,7 +129,7 @@ def test__dichotomic_search(arr, l_bound, r_bound, expected): def test__create_memmap(basename, create, expected): if create: # Need to create array before evaluating - with get_trx_tmpdir() as dirname: + with get_trx_tmp_dir() as dirname: filename = os.path.join(dirname, basename) fp = np.memmap(filename, dtype=np.int16, mode="w+", shape=(3, 4)) fp[:] = expected[:] @@ -138,7 +138,7 @@ def test__create_memmap(basename, create, expected): assert np.array_equal(mmarr, expected) else: - with get_trx_tmpdir() as dirname: + with get_trx_tmp_dir() as dirname: filename = os.path.join(dirname, basename) mmarr = tmm._create_memmap(filename=filename, shape=(0,), dtype=np.int16) diff --git a/trx/trx_file_memmap.py b/trx/trx_file_memmap.py index b0775aa..df9e9b7 100644 --- a/trx/trx_file_memmap.py +++ b/trx/trx_file_memmap.py @@ -18,7 +18,7 @@ from nibabel.streamlines.tractogram import Tractogram, LazyTractogram import numpy as np -from trx.io import get_trx_tmpdir +from trx.io import get_trx_tmp_dir from trx.utils import (append_generator_to_dict, close_or_delete_mmap, convert_data_dict_to_tractogram, @@ -229,7 +229,7 @@ def load(input_obj: str, check_dpg: bool = True) -> Type["TrxFile"]: break if was_compressed: with zipfile.ZipFile(input_obj, "r") as zf: - tmpdir = get_trx_tmpdir() + tmpdir = get_trx_tmp_dir() zf.extractall(tmpdir.name) trx = load_from_directory(tmpdir.name) trx._uncompressed_folder_handle = tmpdir @@ -740,7 +740,7 @@ def deepcopy(self) -> Type["TrxFile"]: Returns A deepcopied TrxFile of the current TrxFile """ - tmp_dir = get_trx_tmpdir() + tmp_dir = get_trx_tmp_dir() out_json = open(os.path.join(tmp_dir.name, "header.json"), "w") tmp_header = deepcopy(self.header) @@ -917,7 +917,7 @@ def _initialize_empty_trx( An empty TrxFile preallocated with a certain size """ trx = TrxFile() - tmp_dir = get_trx_tmpdir() + tmp_dir = get_trx_tmp_dir() logging.info("Temporary folder for memmaps: {}".format(tmp_dir.name)) trx.header["NB_VERTICES"] = nb_vertices @@ -1582,7 +1582,7 @@ def from_sft(sft, dtype_dict={}): dtype_to_use) # For safety and for RAM, convert the whole object to memmaps - tmpdir = get_trx_tmpdir() + tmpdir = get_trx_tmp_dir() save(trx, tmpdir.name) trx = load_from_directory(tmpdir.name) trx._uncompressed_folder_handle = tmpdir @@ -1651,7 +1651,7 @@ def from_tractogram(tractogram, reference, tractogram.data_per_streamline[key].astype(dtype_to_use) # For safety and for RAM, convert the whole object to memmaps - tmpdir = get_trx_tmpdir() + tmpdir = get_trx_tmp_dir() save(trx, tmpdir.name) trx = load_from_directory(tmpdir.name) trx._uncompressed_folder_handle = tmpdir diff --git a/trx/workflows.py b/trx/workflows.py index 60f06ad..eb3534f 100644 --- a/trx/workflows.py +++ b/trx/workflows.py @@ -18,7 +18,7 @@ except ImportError: dipy_available = False -from trx.io import get_trx_tmpdir, load, load_sft_with_reference, save +from trx.io import get_trx_tmp_dir, load, load_sft_with_reference, save from trx.streamlines_ops import perform_streamlines_operation, intersection import trx.trx_file_memmap as tmm from trx.viz import display @@ -311,7 +311,7 @@ def generate_trx_from_scratch(reference, out_tractogram, positions_csv=False, space_str='rasmm', origin_str='nifti', verify_invalid=True, dpv=[], dps=[], groups=[], dpg=[]): - with get_trx_tmpdir() as tmpdirname: + with get_trx_tmp_dir() as tmpdirname: if positions_csv: with open(positions_csv, newline='') as f: reader = csv.reader(f) From bf73b91d88bc0b729ec1a4c47b2202b1c0688ebb Mon Sep 17 00:00:00 2001 From: frheault Date: Wed, 11 Oct 2023 12:06:07 -0400 Subject: [PATCH 2/5] Investiguate missing header --- trx/trx_file_memmap.py | 36 ++++++++++++++++-------------------- trx/utils.py | 2 +- 2 files changed, 17 insertions(+), 21 deletions(-) diff --git a/trx/trx_file_memmap.py b/trx/trx_file_memmap.py index df9e9b7..a238ca8 100644 --- a/trx/trx_file_memmap.py +++ b/trx/trx_file_memmap.py @@ -229,10 +229,10 @@ def load(input_obj: str, check_dpg: bool = True) -> Type["TrxFile"]: break if was_compressed: with zipfile.ZipFile(input_obj, "r") as zf: - tmpdir = get_trx_tmp_dir() - zf.extractall(tmpdir.name) - trx = load_from_directory(tmpdir.name) - trx._uncompressed_folder_handle = tmpdir + tmp_dir = get_trx_tmp_dir() + zf.extractall(tmp_dir.name) + trx = load_from_directory(tmp_dir.name) + trx._uncompressed_folder_handle = tmp_dir logging.info( "File was compressed, call the close() function before" "exiting." @@ -542,20 +542,15 @@ def save( compression_standard -- The compression standard to use, as defined by the ZipFile library """ - if os.path.splitext(filename)[1] and not os.path.splitext(filename)[1] in [ - ".zip", - ".trx", - ]: + _, ext = os.path.splitext(filename) + if ext not in [".zip", ".trx", ""]: raise ValueError("Unsupported extension.") copy_trx = trx.deepcopy() copy_trx.resize() tmp_dir_name = copy_trx._uncompressed_folder_handle.name - if os.path.splitext(filename)[1] and os.path.splitext(filename)[1] in [ - ".zip", - ".trx", - ]: + if ext in [".zip", ".trx"]: zip_from_folder(tmp_dir_name, filename, compression_standard) else: if os.path.isdir(filename): @@ -1582,10 +1577,10 @@ def from_sft(sft, dtype_dict={}): dtype_to_use) # For safety and for RAM, convert the whole object to memmaps - tmpdir = get_trx_tmp_dir() - save(trx, tmpdir.name) - trx = load_from_directory(tmpdir.name) - trx._uncompressed_folder_handle = tmpdir + tmp_dir = get_trx_tmp_dir() + save(trx, tmp_dir.name) + trx.close() + trx = load_from_directory(tmp_dir.name) sft.to_space(old_space) sft.to_origin(old_origin) @@ -1651,10 +1646,11 @@ def from_tractogram(tractogram, reference, tractogram.data_per_streamline[key].astype(dtype_to_use) # For safety and for RAM, convert the whole object to memmaps - tmpdir = get_trx_tmp_dir() - save(trx, tmpdir.name) - trx = load_from_directory(tmpdir.name) - trx._uncompressed_folder_handle = tmpdir + tmp_dir = get_trx_tmp_dir() + save(trx, tmp_dir.name) + trx.close() + + trx = load_from_directory(tmp_dir.name) del tmp_streamlines return trx diff --git a/trx/utils.py b/trx/utils.py index 9d1e378..e07db4b 100644 --- a/trx/utils.py +++ b/trx/utils.py @@ -36,7 +36,7 @@ def close_or_delete_mmap(obj): elif isinstance(obj, np.memmap): del obj else: - logging.warning('Object to be close or deleted must be np.memmap') + logging.debug('Object to be close or deleted must be np.memmap') def split_name_with_gz(filename): From af4e6c1aad9668e3416b219efff55ecc515b83d9 Mon Sep 17 00:00:00 2001 From: frheault Date: Wed, 11 Oct 2023 12:08:59 -0400 Subject: [PATCH 3/5] tmp_dir --- trx/workflows.py | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/trx/workflows.py b/trx/workflows.py index eb3534f..11cfcd6 100644 --- a/trx/workflows.py +++ b/trx/workflows.py @@ -311,7 +311,7 @@ def generate_trx_from_scratch(reference, out_tractogram, positions_csv=False, space_str='rasmm', origin_str='nifti', verify_invalid=True, dpv=[], dps=[], groups=[], dpg=[]): - with get_trx_tmp_dir() as tmpdirname: + with get_trx_tmp_dir() as tmp_dir_name: if positions_csv: with open(positions_csv, newline='') as f: reader = csv.reader(f) @@ -360,20 +360,20 @@ def generate_trx_from_scratch(reference, out_tractogram, positions_csv=False, raise IOError('To use this script, you need at least 2' 'streamlines.') - with open(os.path.join(tmpdirname, "header.json"), "w") as out_json: + with open(os.path.join(tmp_dir_name, "header.json"), "w") as out_json: json.dump(header, out_json) - curr_filename = os.path.join(tmpdirname, 'positions.3.{}'.format( + curr_filename = os.path.join(tmp_dir_name, 'positions.3.{}'.format( positions_dtype)) streamlines._data.astype(positions_dtype).tofile( curr_filename) - curr_filename = os.path.join(tmpdirname, 'offsets.{}'.format( + curr_filename = os.path.join(tmp_dir_name, 'offsets.{}'.format( offsets_dtype)) streamlines._offsets.astype(offsets_dtype).tofile( curr_filename) if dpv: - os.mkdir(os.path.join(tmpdirname, 'dpv')) + os.mkdir(os.path.join(tmp_dir_name, 'dpv')) for arg in dpv: curr_arr = np.squeeze(load_matrix_in_any_format(arg[0]).astype( arg[1])) @@ -383,12 +383,12 @@ def generate_trx_from_scratch(reference, out_tractogram, positions_csv=False, raise IOError('Maximum of 2 dimensions for dpv/dps/dpg.') dim = '' if curr_arr.ndim == 1 else '{}.'.format( curr_arr.shape[-1]) - curr_filename = os.path.join(tmpdirname, 'dpv', '{}.{}{}'.format( + curr_filename = os.path.join(tmp_dir_name, 'dpv', '{}.{}{}'.format( os.path.basename(os.path.splitext(arg[0])[0]), dim, arg[1])) curr_arr.tofile(curr_filename) if dps: - os.mkdir(os.path.join(tmpdirname, 'dps')) + os.mkdir(os.path.join(tmp_dir_name, 'dps')) for arg in dps: curr_arr = np.squeeze(load_matrix_in_any_format(arg[0]).astype( arg[1])) @@ -398,12 +398,12 @@ def generate_trx_from_scratch(reference, out_tractogram, positions_csv=False, raise IOError('Maximum of 2 dimensions for dpv/dps/dpg.') dim = '' if curr_arr.ndim == 1 else '{}.'.format( curr_arr.shape[-1]) - curr_filename = os.path.join(tmpdirname, 'dps', '{}.{}{}'.format( + curr_filename = os.path.join(tmp_dir_name, 'dps', '{}.{}{}'.format( os.path.basename(os.path.splitext(arg[0])[0]), dim, arg[1])) curr_arr.tofile(curr_filename) if groups: - os.mkdir(os.path.join(tmpdirname, 'groups')) + os.mkdir(os.path.join(tmp_dir_name, 'groups')) for arg in groups: curr_arr = load_matrix_in_any_format(arg[0]).astype(arg[1]) if arg[1] == 'bool': @@ -412,15 +412,15 @@ def generate_trx_from_scratch(reference, out_tractogram, positions_csv=False, raise IOError('Maximum of 2 dimensions for dpv/dps/dpg.') dim = '' if curr_arr.ndim == 1 else '{}.'.format( curr_arr.shape[-1]) - curr_filename = os.path.join(tmpdirname, 'groups', '{}.{}{}'.format( + curr_filename = os.path.join(tmp_dir_name, 'groups', '{}.{}{}'.format( os.path.basename(os.path.splitext(arg[0])[0]), dim, arg[1])) curr_arr.tofile(curr_filename) if dpg: - os.mkdir(os.path.join(tmpdirname, 'dpg')) + os.mkdir(os.path.join(tmp_dir_name, 'dpg')) for arg in dpg: - if not os.path.isdir(os.path.join(tmpdirname, 'dpg', arg[0])): - os.mkdir(os.path.join(tmpdirname, 'dpg', arg[0])) + if not os.path.isdir(os.path.join(tmp_dir_name, 'dpg', arg[0])): + os.mkdir(os.path.join(tmp_dir_name, 'dpg', arg[0])) curr_arr = load_matrix_in_any_format(arg[1]).astype(arg[2]) if arg[1] == 'bool': arg[1] = 'bit' @@ -430,11 +430,11 @@ def generate_trx_from_scratch(reference, out_tractogram, positions_csv=False, curr_arr = curr_arr.reshape((1,)) dim = '' if curr_arr.ndim == 1 else '{}.'.format( curr_arr.shape[-1]) - curr_filename = os.path.join(tmpdirname, 'dpg', arg[0], '{}.{}{}'.format( + curr_filename = os.path.join(tmp_dir_name, 'dpg', arg[0], '{}.{}{}'.format( os.path.basename(os.path.splitext(arg[1])[0]), dim, arg[2])) curr_arr.tofile(curr_filename) - trx = tmm.load(tmpdirname) + trx = tmm.load(tmp_dir_name) tmm.save(trx, out_tractogram) trx.close() From 7c227829b15e11f6c2eb994598bdd1e29f09e4c0 Mon Sep 17 00:00:00 2001 From: frheault Date: Wed, 11 Oct 2023 12:41:14 -0400 Subject: [PATCH 4/5] Extra tests and cleaning --- trx/streamlines_ops.py | 2 +- trx/tests/test_io.py | 46 +++++++++++++++++++++++++++++++++++++++--- trx/trx_file_memmap.py | 4 ++-- trx/utils.py | 2 +- 4 files changed, 47 insertions(+), 7 deletions(-) diff --git a/trx/streamlines_ops.py b/trx/streamlines_ops.py index bb291e1..6057d2d 100644 --- a/trx/streamlines_ops.py +++ b/trx/streamlines_ops.py @@ -29,7 +29,7 @@ def union(left, right): def get_streamline_key(streamline, precision=None): """Produces a key using a hash from a streamline using a few points only and - the desired precision + the desired precision Parameters ---------- diff --git a/trx/tests/test_io.py b/trx/tests/test_io.py index 889bba2..18c4473 100644 --- a/trx/tests/test_io.py +++ b/trx/tests/test_io.py @@ -4,13 +4,13 @@ from copy import deepcopy import os import psutil +import zipfile import pytest import numpy as np from numpy.testing import assert_allclose try: - import dipy from dipy.io.streamline import save_tractogram dipy_available = True except ImportError: @@ -30,11 +30,12 @@ @pytest.mark.parametrize("path", [("gs.trk"), ("gs.tck"), ("gs.vtk")]) @pytest.mark.skipif(not dipy_available, reason='Dipy is not installed.') -def test_load_vox(path): +def test_seq_ops(path): gs_dir = os.path.join(get_home(), 'gold_standard') path = os.path.join(gs_dir, path) - obj = load(path, os.path.join(gs_dir, 'gs.nii')) + obj = load(os.path.join(gs_dir, 'gs.trx'), + os.path.join(gs_dir, 'gs.nii')) sft = obj.to_sft() save_tractogram(sft, path) obj.close() @@ -182,3 +183,42 @@ def test_change_tmp_dir(tmp_path): trx.close() assert not os.path.isdir(tmp_gs_dir) + + +@pytest.mark.parametrize("path", [("gs.trx"), ("gs_fldr.trx")]) +def test_complete_dir_from_trx(path): + gs_dir = os.path.join(get_home(), 'gold_standard') + path = os.path.join(gs_dir, path) + + trx = tmm.load(path) + if trx._uncompressed_folder_handle is None: + dir_to_check = path + else: + dir_to_check = trx._uncompressed_folder_handle.name + + file_paths = [] + for dirpath, _, filenames in os.walk(dir_to_check): + for filename in filenames: + full_path = os.path.join(dirpath, filename) + cut_path = full_path.split(dir_to_check)[1][1:] + file_paths.append(cut_path) + + expected_content = ['offsets.uint32', 'positions.3.float32', + 'header.json', 'dps/random_coord.3.float32', + 'dpv/color_y.float32', 'dpv/color_x.float32', + 'dpv/color_z.float32'] + assert set(file_paths) == set(expected_content) + + +def test_complete_zip_from_trx(): + gs_dir = os.path.join(get_home(), 'gold_standard') + path = os.path.join(gs_dir, 'gs.trx') + + with zipfile.ZipFile(path, mode="r") as zf: + zip_file_list = zf.namelist() + + expected_content = ['offsets.uint32', 'positions.3.float32', + 'header.json', 'dps/random_coord.3.float32', + 'dpv/color_y.float32', 'dpv/color_x.float32', + 'dpv/color_z.float32'] + assert set(zip_file_list) == set(expected_content) diff --git a/trx/trx_file_memmap.py b/trx/trx_file_memmap.py index a238ca8..7515dd6 100644 --- a/trx/trx_file_memmap.py +++ b/trx/trx_file_memmap.py @@ -27,7 +27,7 @@ try: import dipy dipy_available = True -except: +except ImportError: dipy_available = False @@ -1752,7 +1752,7 @@ def close(self) -> None: try: self._uncompressed_folder_handle.cleanup() except PermissionError: - logging.error("Windows PermissionError, temporary directory {}" + + logging.error("Windows PermissionError, temporary directory {}" "was not deleted!".format(self._uncompressed_folder_handle.name)) self.__init__() logging.debug("Deleted memmaps and intialized empty TrxFile.") diff --git a/trx/utils.py b/trx/utils.py index e07db4b..8c1051b 100644 --- a/trx/utils.py +++ b/trx/utils.py @@ -417,7 +417,7 @@ def verify_trx_dtype(trx, dict_dtype): trx : Tractogram Tractogram to verify. dict_dtype : dict - Dictionary containing the dtype to verify. + Dictionary containing all elements dtype to verify. Returns ------- output : bool From 71a1c59da1eb9ffce0a9a5499b25b14801446f68 Mon Sep 17 00:00:00 2001 From: frheault Date: Wed, 11 Oct 2023 12:49:35 -0400 Subject: [PATCH 5/5] Safer save/reload --- trx/tests/test_io.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/trx/tests/test_io.py b/trx/tests/test_io.py index 18c4473..d7fdbb3 100644 --- a/trx/tests/test_io.py +++ b/trx/tests/test_io.py @@ -4,6 +4,7 @@ from copy import deepcopy import os import psutil +from tempfile import TemporaryDirectory import zipfile import pytest @@ -11,7 +12,7 @@ from numpy.testing import assert_allclose try: - from dipy.io.streamline import save_tractogram + from dipy.io.streamline import save_tractogram, load_tractogram dipy_available = True except ImportError: dipy_available = False @@ -31,15 +32,18 @@ ("gs.vtk")]) @pytest.mark.skipif(not dipy_available, reason='Dipy is not installed.') def test_seq_ops(path): - gs_dir = os.path.join(get_home(), 'gold_standard') - path = os.path.join(gs_dir, path) + with TemporaryDirectory() as tmp_dir: + gs_dir = os.path.join(get_home(), 'gold_standard') + path = os.path.join(tmp_dir, path) + + obj = load(os.path.join(gs_dir, 'gs.trx'), + os.path.join(gs_dir, 'gs.nii')) + sft_1 = obj.to_sft() + save_tractogram(sft_1, path) + obj.close() + save_tractogram(sft_1, 'tmp.trx') - obj = load(os.path.join(gs_dir, 'gs.trx'), - os.path.join(gs_dir, 'gs.nii')) - sft = obj.to_sft() - save_tractogram(sft, path) - obj.close() - save_tractogram(sft, path) + sft_2 = load_tractogram('tmp.trx', 'same') @pytest.mark.parametrize("path", [("gs.trx"), ("gs.trk"), ("gs.tck"),