Skip to content

Commit

Permalink
Collect associated files at beginning of workflow (#246)
Browse files Browse the repository at this point in the history
* Use m0scan as available.

* Fix.

* Update bids.py

* Release strict.

* Remove "bold" from collected data.

* Fix collect_run_data.

* Fix.

* Whoops.

* Fix the BIDSDataGrabber.

* Don't find aslcontext in ExtractCBForDeltaM.

* Update base.py
  • Loading branch information
tsalo committed Mar 31, 2023
1 parent 075038d commit ea0d0e1
Show file tree
Hide file tree
Showing 10 changed files with 444 additions and 335 deletions.
68 changes: 66 additions & 2 deletions aslprep/interfaces/bids.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,19 @@
from pathlib import Path

from bids.layout import Config
from nipype import logging
from nipype.interfaces.base import (
BaseInterfaceInputSpec,
OutputMultiObject,
SimpleInterface,
Str,
TraitedSpec,
traits,
)
from niworkflows.interfaces.bids import DerivativesDataSink as BaseDerivativesDataSink
from pkg_resources import resource_filename as pkgrf

from aslprep import config

# NOTE: Modified for aslprep's purposes
aslprep_spec = loads(Path(pkgrf("aslprep", "data/aslprep_bids_config.json")).read_text())
bids_config = Config.load("bids")
Expand All @@ -19,7 +28,62 @@
merged_entities = [{"name": k, "pattern": v} for k, v in merged_entities.items()]
config_entities = frozenset({e["name"] for e in merged_entities})

LOGGER = logging.getLogger("nipype.interface")

class _BIDSDataGrabberInputSpec(BaseInterfaceInputSpec):
subject_data = traits.Dict(Str, traits.Any)
subject_id = Str()


class _BIDSDataGrabberOutputSpec(TraitedSpec):
out_dict = traits.Dict(desc="output data structure")
fmap = OutputMultiObject(desc="output fieldmaps")
bold = OutputMultiObject(desc="output functional images")
sbref = OutputMultiObject(desc="output sbrefs")
t1w = OutputMultiObject(desc="output T1w images")
roi = OutputMultiObject(desc="output ROI images")
t2w = OutputMultiObject(desc="output T2w images")
flair = OutputMultiObject(desc="output FLAIR images")
asl = OutputMultiObject(desc="output ASL images")


class BIDSDataGrabber(SimpleInterface):
"""Collect files from a BIDS directory structure."""

input_spec = _BIDSDataGrabberInputSpec
output_spec = _BIDSDataGrabberOutputSpec
_require_funcs = True

def __init__(self, *args, **kwargs):
anat_only = kwargs.pop("anat_only")
super(BIDSDataGrabber, self).__init__(*args, **kwargs)
if anat_only is not None:
self._require_funcs = not anat_only

def _run_interface(self, runtime):
bids_dict = self.inputs.subject_data

self._results["out_dict"] = bids_dict
self._results.update(bids_dict)

if not bids_dict["t1w"]:
raise FileNotFoundError(
f"No T1w images found for subject sub-{self.inputs.subject_id}"
)

if self._require_funcs and not bids_dict["asl"]:
raise FileNotFoundError(
f"No ASL images found for subject sub-{self.inputs.subject_id}"
)

for imtype in ["t2w", "flair", "fmap", "sbref", "roi", "asl"]:
if not bids_dict[imtype]:
config.loggers.interface.info(
'No "%s" images found for sub-%s',
imtype,
self.inputs.subject_id,
)

return runtime


class DerivativesDataSink(BaseDerivativesDataSink):
Expand Down
59 changes: 29 additions & 30 deletions aslprep/interfaces/cbf_computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,7 @@
from nipype.utils.filemanip import fname_presuffix

from aslprep import config
from aslprep.utils.misc import (
_getcbfscore,
_scrubcbf,
parcellate_cbf,
pcasl_or_pasl,
readjson,
)
from aslprep.utils.misc import _getcbfscore, _scrubcbf, parcellate_cbf, pcasl_or_pasl
from aslprep.utils.qc import (
cbf_qei,
coverage,
Expand Down Expand Up @@ -82,14 +76,28 @@ def _run_interface(self, runtime):
class _ExtractCBFInputSpec(BaseInterfaceInputSpec):
name_source = File(exists=True, mandatory=True, desc="raw asl file")
asl_file = File(exists=True, mandatory=True, desc="preprocessed asl file")
metadata = traits.Dict(mandatory=True, desc="metadata for ASL file")
aslcontext = File(exists=True, mandatory=True, desc="aslcontext TSV file for run.")
m0scan = traits.Either(
File(exists=True),
None,
mandatory=True,
desc="m0scan file associated with the ASL file. Only defined if M0Type is 'Separate'.",
)
m0scan_metadata = traits.Either(
traits.Dict,
None,
mandatory=True,
desc="metadata for M0 scan. Only defined if M0Type is 'Separate'.",
)
in_mask = File(exists=True, mandatory=True, desc="mask")
dummy_vols = traits.Int(
default_value=0, exit=False, mandatory=False, desc="remove first n volumes"
default_value=0,
use_default=True,
mandatory=False,
desc="remove first n volumes",
)
in_metadata = traits.Dict(exists=True, mandatory=True, desc="metadata for asl or deltam ")
bids_dir = traits.Str(exits=True, mandatory=True, desc=" bids directory")
fwhm = traits.Float(default_value=5, exists=True, mandatory=False, desc="fwhm")
fwhm = traits.Float(default_value=5, use_default=True, mandatory=False, desc="fwhm")


class _ExtractCBFOutputSpec(TraitedSpec):
Expand Down Expand Up @@ -125,7 +133,7 @@ class ExtractCBF(SimpleInterface):

def _run_interface(self, runtime):
aslcontext = pd.read_table(self.inputs.aslcontext)
metadata = self.inputs.in_metadata.copy()
metadata = self.inputs.metadata.copy()

mask_data = nb.load(self.inputs.in_mask).get_fdata()

Expand All @@ -143,17 +151,12 @@ def _run_interface(self, runtime):

# extract m0 file and register it to ASL if separate
if metadata["M0Type"] == "Separate":
m0file = self.inputs.in_file.replace("asl.nii.gz", "m0scan.nii.gz")
m0file_metadata = readjson(m0file.replace("nii.gz", "json"))
aslfile_linkedM0 = os.path.abspath(
os.path.join(self.inputs.bids_dir, m0file_metadata["IntendedFor"])
)
if self.inputs.in_file not in aslfile_linkedM0:
raise RuntimeError("there is no separate m0scan for the asl data")
m0file = self.inputs.m0scan
m0file_metadata = self.inputs.m0scan_metadata

newm0 = fname_presuffix(self.inputs.asl_file, suffix="_m0file")
newm0 = regmotoasl(asl=self.inputs.asl_file, m0file=m0file, m02asl=newm0)
m0data_smooth = smooth_image(nb.load(newm0), fwhm=self.inputs.fwhm).get_fdata()
m0_in_asl = fname_presuffix(self.inputs.asl_file, suffix="_m0file")
m0_in_asl = regmotoasl(asl=self.inputs.asl_file, m0file=m0file, m02asl=m0_in_asl)
m0data_smooth = smooth_image(nb.load(m0_in_asl), fwhm=self.inputs.fwhm).get_fdata()
if len(m0data_smooth.shape) > 3:
m0data = mask_data * np.mean(m0data_smooth, axis=3)
else:
Expand Down Expand Up @@ -205,10 +208,6 @@ def _run_interface(self, runtime):
else:
raise RuntimeError("no pathway to m0scan")

if asl_data.ndim == 5:
# XXX: Why specifically check for 5D data? NIFTIs can go up to 8, I think.
raise RuntimeError("Input image (%s) is 5D.", self.inputs.asl_file)

pld = np.array(metadata["PostLabelingDelay"])
multi_pld = pld.size > 1
if multi_pld and pld.size != asl_data.shape[3]:
Expand Down Expand Up @@ -1174,7 +1173,8 @@ def _run_interface(self, runtime):


class _ExtractCBForDeltaMInputSpec(BaseInterfaceInputSpec):
in_asl = File(exists=True, mandatory=True, desc="raw asl file")
asl_file = File(exists=True, mandatory=True, desc="raw asl file")
aslcontext = File(exists=True, mandatory=True, desc="aslcontext TSV file for run.")
in_aslmask = File(exists=True, mandatory=True, desct="asl mask")
file_type = traits.Str(desc="file type, c for cbf, d for deltam", mandatory=True)

Expand All @@ -1195,11 +1195,10 @@ def _run_interface(self, runtime):
suffix="_cbfdeltam",
newpath=runtime.cwd,
)
asl_img = nb.load(self.inputs.in_asl)
asl_img = nb.load(self.inputs.asl_file)
asl_data = asl_img.get_fdata()

# XXX: Not a good way to find the aslcontext file.
aslcontext = pd.read_csv(self.inputs.in_asl.replace("_asl.nii.gz", "_aslcontext.tsv"))
aslcontext = pd.read_table(self.inputs.aslcontext)
vol_types = aslcontext["volume_type"].tolist()
control_volume_idx = [i for i, vol_type in enumerate(vol_types) if vol_type == "control"]
label_volume_idx = [i for i, vol_type in enumerate(vol_types) if vol_type == "label"]
Expand Down
Loading

0 comments on commit ea0d0e1

Please sign in to comment.