Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New test to match Dipy potential issue with sft copy #70

Merged
merged 5 commits into from
Oct 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions scripts/tests/test_workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from trx.fetcher import (get_testing_files_dict,
fetch_data, get_home)
from trx.io import get_trx_tmpdir
from trx.io import get_trx_tmp_dir
import trx.trx_file_memmap as tmm
from trx.workflows import (convert_dsi_studio,
convert_tractogram,
Expand All @@ -26,7 +26,7 @@

# If they already exist, this only takes 5 seconds (check md5sum)
fetch_data(get_testing_files_dict(), keys=['DSI.zip', 'trx_from_scratch.zip'])
tmp_dir = get_trx_tmpdir()
tmp_dir = get_trx_tmp_dir()


def test_help_option_convert_dsi(script_runner):
Expand Down
2 changes: 1 addition & 1 deletion trx/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from trx.utils import split_name_with_gz


def get_trx_tmpdir():
def get_trx_tmp_dir():
if os.getenv('TRX_TMPDIR') is not None:
if os.getenv('TRX_TMPDIR') == 'use_working_dir':
trx_tmp_dir = os.getcwd()
Expand Down
2 changes: 1 addition & 1 deletion trx/streamlines_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def union(left, right):

def get_streamline_key(streamline, precision=None):
"""Produces a key using a hash from a streamline using a few points only and
the desired precision
the desired precision

Parameters
----------
Expand Down
115 changes: 87 additions & 28 deletions trx/tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,37 +4,57 @@
from copy import deepcopy
import os
import psutil
from tempfile import TemporaryDirectory
import zipfile

import pytest
import numpy as np
from numpy.testing import assert_allclose

try:
import dipy
from dipy.io.streamline import save_tractogram, load_tractogram
dipy_available = True
except ImportError:
dipy_available = False

import trx.trx_file_memmap as tmm
from trx.trx_file_memmap import TrxFile
from trx.io import load, save, get_trx_tmpdir
from trx.io import load, save, get_trx_tmp_dir
from trx.fetcher import (get_testing_files_dict,
fetch_data, get_home)


fetch_data(get_testing_files_dict(), keys=['gold_standard.zip'])
tmp_dir = get_trx_tmpdir()
tmp_gs_dir = get_trx_tmp_dir()


@pytest.mark.parametrize("path", [("gs.trk"), ("gs.tck"),
("gs.vtk")])
@pytest.mark.skipif(not dipy_available, reason='Dipy is not installed.')
def test_seq_ops(path):
with TemporaryDirectory() as tmp_dir:
gs_dir = os.path.join(get_home(), 'gold_standard')
path = os.path.join(tmp_dir, path)

obj = load(os.path.join(gs_dir, 'gs.trx'),
os.path.join(gs_dir, 'gs.nii'))
sft_1 = obj.to_sft()
save_tractogram(sft_1, path)
obj.close()
save_tractogram(sft_1, 'tmp.trx')

sft_2 = load_tractogram('tmp.trx', 'same')


@pytest.mark.parametrize("path", [("gs.trx"), ("gs.trk"), ("gs.tck"),
("gs.vtk")])
@pytest.mark.skipif(not dipy_available, reason='Dipy is not installed.')
def test_load_vox(path):
dir = os.path.join(get_home(), 'gold_standard')
path = os.path.join(dir, path)
gs_dir = os.path.join(get_home(), 'gold_standard')
path = os.path.join(gs_dir, path)
coord = np.loadtxt(os.path.join(get_home(), 'gold_standard',
'gs_vox_space.txt'))
obj = load(path, os.path.join(dir, 'gs.nii'))
obj = load(path, os.path.join(gs_dir, 'gs.nii'))

sft = obj.to_sft() if isinstance(obj, TrxFile) else obj
sft.to_vox()
Expand All @@ -48,11 +68,11 @@ def test_load_vox(path):
("gs.vtk")])
@pytest.mark.skipif(not dipy_available, reason='Dipy is not installed.')
def test_load_voxmm(path):
dir = os.path.join(get_home(), 'gold_standard')
path = os.path.join(dir, path)
gs_dir = os.path.join(get_home(), 'gold_standard')
path = os.path.join(gs_dir, path)
coord = np.loadtxt(os.path.join(get_home(), 'gold_standard',
'gs_voxmm_space.txt'))
obj = load(path, os.path.join(dir, 'gs.nii'))
obj = load(path, os.path.join(gs_dir, 'gs.nii'))

sft = obj.to_sft() if isinstance(obj, TrxFile) else obj
sft.to_voxmm()
Expand All @@ -65,33 +85,33 @@ def test_load_voxmm(path):
@pytest.mark.parametrize("path", [("gs.trk"), ("gs.trx"), ("gs_fldr.trx")])
@pytest.mark.skipif(not dipy_available, reason='Dipy is not installed.')
def test_multi_load_save_rasmm(path):
dir = os.path.join(get_home(), 'gold_standard')
gs_dir = os.path.join(get_home(), 'gold_standard')
basename, ext = os.path.splitext(path)
out_path = os.path.join(tmp_dir.name, '{}_tmp{}'.format(basename, ext))
path = os.path.join(dir, path)
out_path = os.path.join(tmp_gs_dir.name, '{}_tmp{}'.format(basename, ext))
path = os.path.join(gs_dir, path)
coord = np.loadtxt(os.path.join(get_home(), 'gold_standard',
'gs_rasmm_space.txt'))

obj = load(path, os.path.join(dir, 'gs.nii'))
obj = load(path, os.path.join(gs_dir, 'gs.nii'))
for _ in range(100):
save(obj, out_path)
if isinstance(obj, TrxFile):
obj.close()
obj = load(out_path, os.path.join(dir, 'gs.nii'))
obj = load(out_path, os.path.join(gs_dir, 'gs.nii'))

assert_allclose(obj.streamlines._data, coord, rtol=1e-04, atol=1e-06)


@pytest.mark.parametrize("path", [("gs.trx"), ("gs_fldr.trx")])
@pytest.mark.skipif(not dipy_available, reason='Dipy is not installed.')
def test_delete_tmp_dir(path):
dir = os.path.join(get_home(), 'gold_standard')
path = os.path.join(dir, path)
def test_delete_tmp_gs_dir(path):
gs_dir = os.path.join(get_home(), 'gold_standard')
path = os.path.join(gs_dir, path)

trx1 = tmm.load(path)
if os.path.isfile(path):
tmp_dir = deepcopy(trx1._uncompressed_folder_handle.name)
assert os.path.isdir(tmp_dir)
tmp_gs_dir = deepcopy(trx1._uncompressed_folder_handle.name)
assert os.path.isdir(tmp_gs_dir)
sft = trx1.to_sft()
trx1.close()

Expand All @@ -102,7 +122,7 @@ def test_delete_tmp_dir(path):

# The folder trx representation does not need tmp files
if os.path.isfile(path):
assert not os.path.isdir(tmp_dir)
assert not os.path.isdir(tmp_gs_dir)

assert_allclose(sft.streamlines._data, coord_rasmm, rtol=1e-04, atol=1e-06)

Expand All @@ -124,8 +144,8 @@ def test_delete_tmp_dir(path):
@pytest.mark.parametrize("path", [("gs.trx")])
@pytest.mark.skipif(not dipy_available, reason='Dipy is not installed.')
def test_close_tmp_files(path):
dir = os.path.join(get_home(), 'gold_standard')
path = os.path.join(dir, path)
gs_dir = os.path.join(get_home(), 'gold_standard')
path = os.path.join(gs_dir, path)

trx = tmm.load(path)
process = psutil.Process(os.getpid())
Expand All @@ -149,21 +169,60 @@ def test_close_tmp_files(path):

@pytest.mark.parametrize("tmp_path", [("~"), ("use_working_dir")])
def test_change_tmp_dir(tmp_path):
dir = os.path.join(get_home(), 'gold_standard')
path = os.path.join(dir, 'gs.trx')
gs_dir = os.path.join(get_home(), 'gold_standard')
path = os.path.join(gs_dir, 'gs.trx')

if tmp_path == 'use_working_dir':
os.environ['TRX_TMPDIR'] = 'use_working_dir'
else:
os.environ['TRX_TMPDIR'] = os.path.expanduser(tmp_path)

trx = tmm.load(path)
tmp_dir = deepcopy(trx._uncompressed_folder_handle.name)
tmp_gs_dir = deepcopy(trx._uncompressed_folder_handle.name)

if tmp_path == 'use_working_dir':
assert os.path.dirname(tmp_dir) == os.getcwd()
assert os.path.dirname(tmp_gs_dir) == os.getcwd()
else:
assert os.path.dirname(tmp_dir) == os.path.expanduser(tmp_path)
assert os.path.dirname(tmp_gs_dir) == os.path.expanduser(tmp_path)

trx.close()
assert not os.path.isdir(tmp_dir)
assert not os.path.isdir(tmp_gs_dir)


@pytest.mark.parametrize("path", [("gs.trx"), ("gs_fldr.trx")])
def test_complete_dir_from_trx(path):
gs_dir = os.path.join(get_home(), 'gold_standard')
path = os.path.join(gs_dir, path)

trx = tmm.load(path)
if trx._uncompressed_folder_handle is None:
dir_to_check = path
else:
dir_to_check = trx._uncompressed_folder_handle.name

file_paths = []
for dirpath, _, filenames in os.walk(dir_to_check):
for filename in filenames:
full_path = os.path.join(dirpath, filename)
cut_path = full_path.split(dir_to_check)[1][1:]
file_paths.append(cut_path)

expected_content = ['offsets.uint32', 'positions.3.float32',
'header.json', 'dps/random_coord.3.float32',
'dpv/color_y.float32', 'dpv/color_x.float32',
'dpv/color_z.float32']
assert set(file_paths) == set(expected_content)


def test_complete_zip_from_trx():
gs_dir = os.path.join(get_home(), 'gold_standard')
path = os.path.join(gs_dir, 'gs.trx')

with zipfile.ZipFile(path, mode="r") as zf:
zip_file_list = zf.namelist()

expected_content = ['offsets.uint32', 'positions.3.float32',
'header.json', 'dps/random_coord.3.float32',
'dpv/color_y.float32', 'dpv/color_x.float32',
'dpv/color_z.float32']
assert set(zip_file_list) == set(expected_content)
8 changes: 4 additions & 4 deletions trx/tests/test_memmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@
except ImportError:
dipy_available = False

from trx.io import get_trx_tmpdir
from trx.io import get_trx_tmp_dir
import trx.trx_file_memmap as tmm
from trx.fetcher import (get_testing_files_dict,
fetch_data, get_home)


fetch_data(get_testing_files_dict(), keys=['memmap_test_data.zip'])
tmp_dir = get_trx_tmpdir()
tmp_dir = get_trx_tmp_dir()


@pytest.mark.parametrize(
Expand Down Expand Up @@ -129,7 +129,7 @@ def test__dichotomic_search(arr, l_bound, r_bound, expected):
def test__create_memmap(basename, create, expected):
if create:
# Need to create array before evaluating
with get_trx_tmpdir() as dirname:
with get_trx_tmp_dir() as dirname:
filename = os.path.join(dirname, basename)
fp = np.memmap(filename, dtype=np.int16, mode="w+", shape=(3, 4))
fp[:] = expected[:]
Expand All @@ -138,7 +138,7 @@ def test__create_memmap(basename, create, expected):
assert np.array_equal(mmarr, expected)

else:
with get_trx_tmpdir() as dirname:
with get_trx_tmp_dir() as dirname:
filename = os.path.join(dirname, basename)
mmarr = tmm._create_memmap(filename=filename, shape=(0,),
dtype=np.int16)
Expand Down
46 changes: 21 additions & 25 deletions trx/trx_file_memmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from nibabel.streamlines.tractogram import Tractogram, LazyTractogram
import numpy as np

from trx.io import get_trx_tmpdir
from trx.io import get_trx_tmp_dir
from trx.utils import (append_generator_to_dict,
close_or_delete_mmap,
convert_data_dict_to_tractogram,
Expand All @@ -27,7 +27,7 @@
try:
import dipy
dipy_available = True
except:
except ImportError:
dipy_available = False


Expand Down Expand Up @@ -229,10 +229,10 @@ def load(input_obj: str, check_dpg: bool = True) -> Type["TrxFile"]:
break
if was_compressed:
with zipfile.ZipFile(input_obj, "r") as zf:
tmpdir = get_trx_tmpdir()
zf.extractall(tmpdir.name)
trx = load_from_directory(tmpdir.name)
trx._uncompressed_folder_handle = tmpdir
tmp_dir = get_trx_tmp_dir()
zf.extractall(tmp_dir.name)
trx = load_from_directory(tmp_dir.name)
trx._uncompressed_folder_handle = tmp_dir
logging.info(
"File was compressed, call the close() function before"
"exiting."
Expand Down Expand Up @@ -542,20 +542,15 @@ def save(
compression_standard -- The compression standard to use, as defined by
the ZipFile library
"""
if os.path.splitext(filename)[1] and not os.path.splitext(filename)[1] in [
".zip",
".trx",
]:
_, ext = os.path.splitext(filename)
if ext not in [".zip", ".trx", ""]:
raise ValueError("Unsupported extension.")

copy_trx = trx.deepcopy()
copy_trx.resize()

tmp_dir_name = copy_trx._uncompressed_folder_handle.name
if os.path.splitext(filename)[1] and os.path.splitext(filename)[1] in [
".zip",
".trx",
]:
if ext in [".zip", ".trx"]:
zip_from_folder(tmp_dir_name, filename, compression_standard)
else:
if os.path.isdir(filename):
Expand Down Expand Up @@ -740,7 +735,7 @@ def deepcopy(self) -> Type["TrxFile"]:
Returns
A deepcopied TrxFile of the current TrxFile
"""
tmp_dir = get_trx_tmpdir()
tmp_dir = get_trx_tmp_dir()
out_json = open(os.path.join(tmp_dir.name, "header.json"), "w")
tmp_header = deepcopy(self.header)

Expand Down Expand Up @@ -917,7 +912,7 @@ def _initialize_empty_trx(
An empty TrxFile preallocated with a certain size
"""
trx = TrxFile()
tmp_dir = get_trx_tmpdir()
tmp_dir = get_trx_tmp_dir()
logging.info("Temporary folder for memmaps: {}".format(tmp_dir.name))

trx.header["NB_VERTICES"] = nb_vertices
Expand Down Expand Up @@ -1582,10 +1577,10 @@ def from_sft(sft, dtype_dict={}):
dtype_to_use)

# For safety and for RAM, convert the whole object to memmaps
tmpdir = get_trx_tmpdir()
save(trx, tmpdir.name)
trx = load_from_directory(tmpdir.name)
trx._uncompressed_folder_handle = tmpdir
tmp_dir = get_trx_tmp_dir()
save(trx, tmp_dir.name)
trx.close()
trx = load_from_directory(tmp_dir.name)

sft.to_space(old_space)
sft.to_origin(old_origin)
Expand Down Expand Up @@ -1651,10 +1646,11 @@ def from_tractogram(tractogram, reference,
tractogram.data_per_streamline[key].astype(dtype_to_use)

# For safety and for RAM, convert the whole object to memmaps
tmpdir = get_trx_tmpdir()
save(trx, tmpdir.name)
trx = load_from_directory(tmpdir.name)
trx._uncompressed_folder_handle = tmpdir
tmp_dir = get_trx_tmp_dir()
save(trx, tmp_dir.name)
trx.close()

trx = load_from_directory(tmp_dir.name)
del tmp_streamlines

return trx
Expand Down Expand Up @@ -1756,7 +1752,7 @@ def close(self) -> None:
try:
self._uncompressed_folder_handle.cleanup()
except PermissionError:
logging.error("Windows PermissionError, temporary directory {}" +
logging.error("Windows PermissionError, temporary directory {}"
"was not deleted!".format(self._uncompressed_folder_handle.name))
self.__init__()
logging.debug("Deleted memmaps and intialized empty TrxFile.")
Loading
Loading