diff --git a/scripts/tests/test_workflows.py b/scripts/tests/test_workflows.py index 740fdd2..7640886 100644 --- a/scripts/tests/test_workflows.py +++ b/scripts/tests/test_workflows.py @@ -15,7 +15,7 @@ from trx.fetcher import (get_testing_files_dict, fetch_data, get_home) -from trx.io import get_trx_tmpdir +from trx.io import get_trx_tmp_dir import trx.trx_file_memmap as tmm from trx.workflows import (convert_dsi_studio, convert_tractogram, @@ -26,7 +26,7 @@ # If they already exist, this only takes 5 seconds (check md5sum) fetch_data(get_testing_files_dict(), keys=['DSI.zip', 'trx_from_scratch.zip']) -tmp_dir = get_trx_tmpdir() +tmp_dir = get_trx_tmp_dir() def test_help_option_convert_dsi(script_runner): diff --git a/trx/io.py b/trx/io.py index da2e8eb..86ecc27 100644 --- a/trx/io.py +++ b/trx/io.py @@ -15,7 +15,7 @@ from trx.utils import split_name_with_gz -def get_trx_tmpdir(): +def get_trx_tmp_dir(): if os.getenv('TRX_TMPDIR') is not None: if os.getenv('TRX_TMPDIR') == 'use_working_dir': trx_tmp_dir = os.getcwd() diff --git a/trx/streamlines_ops.py b/trx/streamlines_ops.py index bb291e1..6057d2d 100644 --- a/trx/streamlines_ops.py +++ b/trx/streamlines_ops.py @@ -29,7 +29,7 @@ def union(left, right): def get_streamline_key(streamline, precision=None): """Produces a key using a hash from a streamline using a few points only and - the desired precision + the desired precision Parameters ---------- diff --git a/trx/tests/test_io.py b/trx/tests/test_io.py index f9628cf..d7fdbb3 100644 --- a/trx/tests/test_io.py +++ b/trx/tests/test_io.py @@ -4,37 +4,57 @@ from copy import deepcopy import os import psutil +from tempfile import TemporaryDirectory +import zipfile import pytest import numpy as np from numpy.testing import assert_allclose try: - import dipy + from dipy.io.streamline import save_tractogram, load_tractogram dipy_available = True except ImportError: dipy_available = False import trx.trx_file_memmap as tmm from trx.trx_file_memmap import TrxFile -from trx.io import load, save, get_trx_tmpdir +from trx.io import load, save, get_trx_tmp_dir from trx.fetcher import (get_testing_files_dict, fetch_data, get_home) fetch_data(get_testing_files_dict(), keys=['gold_standard.zip']) -tmp_dir = get_trx_tmpdir() +tmp_gs_dir = get_trx_tmp_dir() + + +@pytest.mark.parametrize("path", [("gs.trk"), ("gs.tck"), + ("gs.vtk")]) +@pytest.mark.skipif(not dipy_available, reason='Dipy is not installed.') +def test_seq_ops(path): + with TemporaryDirectory() as tmp_dir: + gs_dir = os.path.join(get_home(), 'gold_standard') + path = os.path.join(tmp_dir, path) + + obj = load(os.path.join(gs_dir, 'gs.trx'), + os.path.join(gs_dir, 'gs.nii')) + sft_1 = obj.to_sft() + save_tractogram(sft_1, path) + obj.close() + save_tractogram(sft_1, 'tmp.trx') + + sft_2 = load_tractogram('tmp.trx', 'same') @pytest.mark.parametrize("path", [("gs.trx"), ("gs.trk"), ("gs.tck"), ("gs.vtk")]) @pytest.mark.skipif(not dipy_available, reason='Dipy is not installed.') def test_load_vox(path): - dir = os.path.join(get_home(), 'gold_standard') - path = os.path.join(dir, path) + gs_dir = os.path.join(get_home(), 'gold_standard') + path = os.path.join(gs_dir, path) coord = np.loadtxt(os.path.join(get_home(), 'gold_standard', 'gs_vox_space.txt')) - obj = load(path, os.path.join(dir, 'gs.nii')) + obj = load(path, os.path.join(gs_dir, 'gs.nii')) sft = obj.to_sft() if isinstance(obj, TrxFile) else obj sft.to_vox() @@ -48,11 +68,11 @@ def test_load_vox(path): ("gs.vtk")]) @pytest.mark.skipif(not dipy_available, reason='Dipy is not installed.') def test_load_voxmm(path): - dir = os.path.join(get_home(), 'gold_standard') - path = os.path.join(dir, path) + gs_dir = os.path.join(get_home(), 'gold_standard') + path = os.path.join(gs_dir, path) coord = np.loadtxt(os.path.join(get_home(), 'gold_standard', 'gs_voxmm_space.txt')) - obj = load(path, os.path.join(dir, 'gs.nii')) + obj = load(path, os.path.join(gs_dir, 'gs.nii')) sft = obj.to_sft() if isinstance(obj, TrxFile) else obj sft.to_voxmm() @@ -65,33 +85,33 @@ def test_load_voxmm(path): @pytest.mark.parametrize("path", [("gs.trk"), ("gs.trx"), ("gs_fldr.trx")]) @pytest.mark.skipif(not dipy_available, reason='Dipy is not installed.') def test_multi_load_save_rasmm(path): - dir = os.path.join(get_home(), 'gold_standard') + gs_dir = os.path.join(get_home(), 'gold_standard') basename, ext = os.path.splitext(path) - out_path = os.path.join(tmp_dir.name, '{}_tmp{}'.format(basename, ext)) - path = os.path.join(dir, path) + out_path = os.path.join(tmp_gs_dir.name, '{}_tmp{}'.format(basename, ext)) + path = os.path.join(gs_dir, path) coord = np.loadtxt(os.path.join(get_home(), 'gold_standard', 'gs_rasmm_space.txt')) - obj = load(path, os.path.join(dir, 'gs.nii')) + obj = load(path, os.path.join(gs_dir, 'gs.nii')) for _ in range(100): save(obj, out_path) if isinstance(obj, TrxFile): obj.close() - obj = load(out_path, os.path.join(dir, 'gs.nii')) + obj = load(out_path, os.path.join(gs_dir, 'gs.nii')) assert_allclose(obj.streamlines._data, coord, rtol=1e-04, atol=1e-06) @pytest.mark.parametrize("path", [("gs.trx"), ("gs_fldr.trx")]) @pytest.mark.skipif(not dipy_available, reason='Dipy is not installed.') -def test_delete_tmp_dir(path): - dir = os.path.join(get_home(), 'gold_standard') - path = os.path.join(dir, path) +def test_delete_tmp_gs_dir(path): + gs_dir = os.path.join(get_home(), 'gold_standard') + path = os.path.join(gs_dir, path) trx1 = tmm.load(path) if os.path.isfile(path): - tmp_dir = deepcopy(trx1._uncompressed_folder_handle.name) - assert os.path.isdir(tmp_dir) + tmp_gs_dir = deepcopy(trx1._uncompressed_folder_handle.name) + assert os.path.isdir(tmp_gs_dir) sft = trx1.to_sft() trx1.close() @@ -102,7 +122,7 @@ def test_delete_tmp_dir(path): # The folder trx representation does not need tmp files if os.path.isfile(path): - assert not os.path.isdir(tmp_dir) + assert not os.path.isdir(tmp_gs_dir) assert_allclose(sft.streamlines._data, coord_rasmm, rtol=1e-04, atol=1e-06) @@ -124,8 +144,8 @@ def test_delete_tmp_dir(path): @pytest.mark.parametrize("path", [("gs.trx")]) @pytest.mark.skipif(not dipy_available, reason='Dipy is not installed.') def test_close_tmp_files(path): - dir = os.path.join(get_home(), 'gold_standard') - path = os.path.join(dir, path) + gs_dir = os.path.join(get_home(), 'gold_standard') + path = os.path.join(gs_dir, path) trx = tmm.load(path) process = psutil.Process(os.getpid()) @@ -149,8 +169,8 @@ def test_close_tmp_files(path): @pytest.mark.parametrize("tmp_path", [("~"), ("use_working_dir")]) def test_change_tmp_dir(tmp_path): - dir = os.path.join(get_home(), 'gold_standard') - path = os.path.join(dir, 'gs.trx') + gs_dir = os.path.join(get_home(), 'gold_standard') + path = os.path.join(gs_dir, 'gs.trx') if tmp_path == 'use_working_dir': os.environ['TRX_TMPDIR'] = 'use_working_dir' @@ -158,12 +178,51 @@ def test_change_tmp_dir(tmp_path): os.environ['TRX_TMPDIR'] = os.path.expanduser(tmp_path) trx = tmm.load(path) - tmp_dir = deepcopy(trx._uncompressed_folder_handle.name) + tmp_gs_dir = deepcopy(trx._uncompressed_folder_handle.name) if tmp_path == 'use_working_dir': - assert os.path.dirname(tmp_dir) == os.getcwd() + assert os.path.dirname(tmp_gs_dir) == os.getcwd() else: - assert os.path.dirname(tmp_dir) == os.path.expanduser(tmp_path) + assert os.path.dirname(tmp_gs_dir) == os.path.expanduser(tmp_path) trx.close() - assert not os.path.isdir(tmp_dir) + assert not os.path.isdir(tmp_gs_dir) + + +@pytest.mark.parametrize("path", [("gs.trx"), ("gs_fldr.trx")]) +def test_complete_dir_from_trx(path): + gs_dir = os.path.join(get_home(), 'gold_standard') + path = os.path.join(gs_dir, path) + + trx = tmm.load(path) + if trx._uncompressed_folder_handle is None: + dir_to_check = path + else: + dir_to_check = trx._uncompressed_folder_handle.name + + file_paths = [] + for dirpath, _, filenames in os.walk(dir_to_check): + for filename in filenames: + full_path = os.path.join(dirpath, filename) + cut_path = full_path.split(dir_to_check)[1][1:] + file_paths.append(cut_path) + + expected_content = ['offsets.uint32', 'positions.3.float32', + 'header.json', 'dps/random_coord.3.float32', + 'dpv/color_y.float32', 'dpv/color_x.float32', + 'dpv/color_z.float32'] + assert set(file_paths) == set(expected_content) + + +def test_complete_zip_from_trx(): + gs_dir = os.path.join(get_home(), 'gold_standard') + path = os.path.join(gs_dir, 'gs.trx') + + with zipfile.ZipFile(path, mode="r") as zf: + zip_file_list = zf.namelist() + + expected_content = ['offsets.uint32', 'positions.3.float32', + 'header.json', 'dps/random_coord.3.float32', + 'dpv/color_y.float32', 'dpv/color_x.float32', + 'dpv/color_z.float32'] + assert set(zip_file_list) == set(expected_content) diff --git a/trx/tests/test_memmap.py b/trx/tests/test_memmap.py index a0f2703..7ae189d 100644 --- a/trx/tests/test_memmap.py +++ b/trx/tests/test_memmap.py @@ -14,14 +14,14 @@ except ImportError: dipy_available = False -from trx.io import get_trx_tmpdir +from trx.io import get_trx_tmp_dir import trx.trx_file_memmap as tmm from trx.fetcher import (get_testing_files_dict, fetch_data, get_home) fetch_data(get_testing_files_dict(), keys=['memmap_test_data.zip']) -tmp_dir = get_trx_tmpdir() +tmp_dir = get_trx_tmp_dir() @pytest.mark.parametrize( @@ -129,7 +129,7 @@ def test__dichotomic_search(arr, l_bound, r_bound, expected): def test__create_memmap(basename, create, expected): if create: # Need to create array before evaluating - with get_trx_tmpdir() as dirname: + with get_trx_tmp_dir() as dirname: filename = os.path.join(dirname, basename) fp = np.memmap(filename, dtype=np.int16, mode="w+", shape=(3, 4)) fp[:] = expected[:] @@ -138,7 +138,7 @@ def test__create_memmap(basename, create, expected): assert np.array_equal(mmarr, expected) else: - with get_trx_tmpdir() as dirname: + with get_trx_tmp_dir() as dirname: filename = os.path.join(dirname, basename) mmarr = tmm._create_memmap(filename=filename, shape=(0,), dtype=np.int16) diff --git a/trx/trx_file_memmap.py b/trx/trx_file_memmap.py index b0775aa..7515dd6 100644 --- a/trx/trx_file_memmap.py +++ b/trx/trx_file_memmap.py @@ -18,7 +18,7 @@ from nibabel.streamlines.tractogram import Tractogram, LazyTractogram import numpy as np -from trx.io import get_trx_tmpdir +from trx.io import get_trx_tmp_dir from trx.utils import (append_generator_to_dict, close_or_delete_mmap, convert_data_dict_to_tractogram, @@ -27,7 +27,7 @@ try: import dipy dipy_available = True -except: +except ImportError: dipy_available = False @@ -229,10 +229,10 @@ def load(input_obj: str, check_dpg: bool = True) -> Type["TrxFile"]: break if was_compressed: with zipfile.ZipFile(input_obj, "r") as zf: - tmpdir = get_trx_tmpdir() - zf.extractall(tmpdir.name) - trx = load_from_directory(tmpdir.name) - trx._uncompressed_folder_handle = tmpdir + tmp_dir = get_trx_tmp_dir() + zf.extractall(tmp_dir.name) + trx = load_from_directory(tmp_dir.name) + trx._uncompressed_folder_handle = tmp_dir logging.info( "File was compressed, call the close() function before" "exiting." @@ -542,20 +542,15 @@ def save( compression_standard -- The compression standard to use, as defined by the ZipFile library """ - if os.path.splitext(filename)[1] and not os.path.splitext(filename)[1] in [ - ".zip", - ".trx", - ]: + _, ext = os.path.splitext(filename) + if ext not in [".zip", ".trx", ""]: raise ValueError("Unsupported extension.") copy_trx = trx.deepcopy() copy_trx.resize() tmp_dir_name = copy_trx._uncompressed_folder_handle.name - if os.path.splitext(filename)[1] and os.path.splitext(filename)[1] in [ - ".zip", - ".trx", - ]: + if ext in [".zip", ".trx"]: zip_from_folder(tmp_dir_name, filename, compression_standard) else: if os.path.isdir(filename): @@ -740,7 +735,7 @@ def deepcopy(self) -> Type["TrxFile"]: Returns A deepcopied TrxFile of the current TrxFile """ - tmp_dir = get_trx_tmpdir() + tmp_dir = get_trx_tmp_dir() out_json = open(os.path.join(tmp_dir.name, "header.json"), "w") tmp_header = deepcopy(self.header) @@ -917,7 +912,7 @@ def _initialize_empty_trx( An empty TrxFile preallocated with a certain size """ trx = TrxFile() - tmp_dir = get_trx_tmpdir() + tmp_dir = get_trx_tmp_dir() logging.info("Temporary folder for memmaps: {}".format(tmp_dir.name)) trx.header["NB_VERTICES"] = nb_vertices @@ -1582,10 +1577,10 @@ def from_sft(sft, dtype_dict={}): dtype_to_use) # For safety and for RAM, convert the whole object to memmaps - tmpdir = get_trx_tmpdir() - save(trx, tmpdir.name) - trx = load_from_directory(tmpdir.name) - trx._uncompressed_folder_handle = tmpdir + tmp_dir = get_trx_tmp_dir() + save(trx, tmp_dir.name) + trx.close() + trx = load_from_directory(tmp_dir.name) sft.to_space(old_space) sft.to_origin(old_origin) @@ -1651,10 +1646,11 @@ def from_tractogram(tractogram, reference, tractogram.data_per_streamline[key].astype(dtype_to_use) # For safety and for RAM, convert the whole object to memmaps - tmpdir = get_trx_tmpdir() - save(trx, tmpdir.name) - trx = load_from_directory(tmpdir.name) - trx._uncompressed_folder_handle = tmpdir + tmp_dir = get_trx_tmp_dir() + save(trx, tmp_dir.name) + trx.close() + + trx = load_from_directory(tmp_dir.name) del tmp_streamlines return trx @@ -1756,7 +1752,7 @@ def close(self) -> None: try: self._uncompressed_folder_handle.cleanup() except PermissionError: - logging.error("Windows PermissionError, temporary directory {}" + + logging.error("Windows PermissionError, temporary directory {}" "was not deleted!".format(self._uncompressed_folder_handle.name)) self.__init__() logging.debug("Deleted memmaps and intialized empty TrxFile.") diff --git a/trx/utils.py b/trx/utils.py index 9d1e378..8c1051b 100644 --- a/trx/utils.py +++ b/trx/utils.py @@ -36,7 +36,7 @@ def close_or_delete_mmap(obj): elif isinstance(obj, np.memmap): del obj else: - logging.warning('Object to be close or deleted must be np.memmap') + logging.debug('Object to be close or deleted must be np.memmap') def split_name_with_gz(filename): @@ -417,7 +417,7 @@ def verify_trx_dtype(trx, dict_dtype): trx : Tractogram Tractogram to verify. dict_dtype : dict - Dictionary containing the dtype to verify. + Dictionary containing all elements dtype to verify. Returns ------- output : bool diff --git a/trx/workflows.py b/trx/workflows.py index 60f06ad..11cfcd6 100644 --- a/trx/workflows.py +++ b/trx/workflows.py @@ -18,7 +18,7 @@ except ImportError: dipy_available = False -from trx.io import get_trx_tmpdir, load, load_sft_with_reference, save +from trx.io import get_trx_tmp_dir, load, load_sft_with_reference, save from trx.streamlines_ops import perform_streamlines_operation, intersection import trx.trx_file_memmap as tmm from trx.viz import display @@ -311,7 +311,7 @@ def generate_trx_from_scratch(reference, out_tractogram, positions_csv=False, space_str='rasmm', origin_str='nifti', verify_invalid=True, dpv=[], dps=[], groups=[], dpg=[]): - with get_trx_tmpdir() as tmpdirname: + with get_trx_tmp_dir() as tmp_dir_name: if positions_csv: with open(positions_csv, newline='') as f: reader = csv.reader(f) @@ -360,20 +360,20 @@ def generate_trx_from_scratch(reference, out_tractogram, positions_csv=False, raise IOError('To use this script, you need at least 2' 'streamlines.') - with open(os.path.join(tmpdirname, "header.json"), "w") as out_json: + with open(os.path.join(tmp_dir_name, "header.json"), "w") as out_json: json.dump(header, out_json) - curr_filename = os.path.join(tmpdirname, 'positions.3.{}'.format( + curr_filename = os.path.join(tmp_dir_name, 'positions.3.{}'.format( positions_dtype)) streamlines._data.astype(positions_dtype).tofile( curr_filename) - curr_filename = os.path.join(tmpdirname, 'offsets.{}'.format( + curr_filename = os.path.join(tmp_dir_name, 'offsets.{}'.format( offsets_dtype)) streamlines._offsets.astype(offsets_dtype).tofile( curr_filename) if dpv: - os.mkdir(os.path.join(tmpdirname, 'dpv')) + os.mkdir(os.path.join(tmp_dir_name, 'dpv')) for arg in dpv: curr_arr = np.squeeze(load_matrix_in_any_format(arg[0]).astype( arg[1])) @@ -383,12 +383,12 @@ def generate_trx_from_scratch(reference, out_tractogram, positions_csv=False, raise IOError('Maximum of 2 dimensions for dpv/dps/dpg.') dim = '' if curr_arr.ndim == 1 else '{}.'.format( curr_arr.shape[-1]) - curr_filename = os.path.join(tmpdirname, 'dpv', '{}.{}{}'.format( + curr_filename = os.path.join(tmp_dir_name, 'dpv', '{}.{}{}'.format( os.path.basename(os.path.splitext(arg[0])[0]), dim, arg[1])) curr_arr.tofile(curr_filename) if dps: - os.mkdir(os.path.join(tmpdirname, 'dps')) + os.mkdir(os.path.join(tmp_dir_name, 'dps')) for arg in dps: curr_arr = np.squeeze(load_matrix_in_any_format(arg[0]).astype( arg[1])) @@ -398,12 +398,12 @@ def generate_trx_from_scratch(reference, out_tractogram, positions_csv=False, raise IOError('Maximum of 2 dimensions for dpv/dps/dpg.') dim = '' if curr_arr.ndim == 1 else '{}.'.format( curr_arr.shape[-1]) - curr_filename = os.path.join(tmpdirname, 'dps', '{}.{}{}'.format( + curr_filename = os.path.join(tmp_dir_name, 'dps', '{}.{}{}'.format( os.path.basename(os.path.splitext(arg[0])[0]), dim, arg[1])) curr_arr.tofile(curr_filename) if groups: - os.mkdir(os.path.join(tmpdirname, 'groups')) + os.mkdir(os.path.join(tmp_dir_name, 'groups')) for arg in groups: curr_arr = load_matrix_in_any_format(arg[0]).astype(arg[1]) if arg[1] == 'bool': @@ -412,15 +412,15 @@ def generate_trx_from_scratch(reference, out_tractogram, positions_csv=False, raise IOError('Maximum of 2 dimensions for dpv/dps/dpg.') dim = '' if curr_arr.ndim == 1 else '{}.'.format( curr_arr.shape[-1]) - curr_filename = os.path.join(tmpdirname, 'groups', '{}.{}{}'.format( + curr_filename = os.path.join(tmp_dir_name, 'groups', '{}.{}{}'.format( os.path.basename(os.path.splitext(arg[0])[0]), dim, arg[1])) curr_arr.tofile(curr_filename) if dpg: - os.mkdir(os.path.join(tmpdirname, 'dpg')) + os.mkdir(os.path.join(tmp_dir_name, 'dpg')) for arg in dpg: - if not os.path.isdir(os.path.join(tmpdirname, 'dpg', arg[0])): - os.mkdir(os.path.join(tmpdirname, 'dpg', arg[0])) + if not os.path.isdir(os.path.join(tmp_dir_name, 'dpg', arg[0])): + os.mkdir(os.path.join(tmp_dir_name, 'dpg', arg[0])) curr_arr = load_matrix_in_any_format(arg[1]).astype(arg[2]) if arg[1] == 'bool': arg[1] = 'bit' @@ -430,11 +430,11 @@ def generate_trx_from_scratch(reference, out_tractogram, positions_csv=False, curr_arr = curr_arr.reshape((1,)) dim = '' if curr_arr.ndim == 1 else '{}.'.format( curr_arr.shape[-1]) - curr_filename = os.path.join(tmpdirname, 'dpg', arg[0], '{}.{}{}'.format( + curr_filename = os.path.join(tmp_dir_name, 'dpg', arg[0], '{}.{}{}'.format( os.path.basename(os.path.splitext(arg[1])[0]), dim, arg[2])) curr_arr.tofile(curr_filename) - trx = tmm.load(tmpdirname) + trx = tmm.load(tmp_dir_name) tmm.save(trx, out_tractogram) trx.close()