diff --git a/snakedwi/config/snakebids.yml b/snakedwi/config/snakebids.yml index df72fdb..83ff6ef 100644 --- a/snakedwi/config/snakebids.yml +++ b/snakedwi/config/snakebids.yml @@ -13,11 +13,6 @@ derivatives: False analysis_levels: &analysis_levels - participant -no_topup: False -use_eddy_s2v: False -eddy_no_quad: False -no_bedpost: False -use_syn_sdc: False targets_by_analysis_level: participant: @@ -92,10 +87,16 @@ parse_args: action: 'store_true' default: False - --no_bedpost: - help: 'Disable bedpost (default: %(default)s)' - action: 'store_true' - default: False + --use_bedpost: + help: 'Enable bedpost (disabled by default)' + action: 'store_false' + dest: 'no_bedpost' + default: True + + # --no_bedpost: + # help: 'Disable bedpost (default: %(default)s)' + # action: 'store_true' + # default: True --use_eddy_s2v: @@ -156,6 +157,16 @@ parse_args: action: 'store_true' default: False + # --synthsr_sdc: + #help: "Enable SynthSR+SyN fieldmap-less distortion-correction. This uses SynthSR to transform T1w and b0 images to the same 1mm T1w contrast prior to non-linear registration to obtain a field-map. (default: %(default)s)" + #action: 'store_true' + #default: True + + --no_synthsr_sdc: + help: "Disable SynthSR+SyN fieldmap-less distortion-correction (enabled by default)" + action: 'store_false' + dest: 'use_synthsr_sdc' + default: True #---- to update below this @@ -173,6 +184,9 @@ singularity: python: 'docker://khanlab/pythondeps-snakedwi:v0.2.0' synthstrip: 'docker://freesurfer/synthstrip:1.3' sdcflows: 'docker://nipreps/fmriprep:22.1.1' #can't currently just use sdcflows docker as it is missing freesurfer's mri_robust_template + synthsr: 'docker://akhanf/synthsr:main' + synthmorph: 'docker://freesurfer/synthmorph:1' + @@ -281,6 +295,7 @@ eddy: default_effective_echo_spacing: 0.0001 #if not defined in JSON files #for test data +root: 'results' participant_label: exclude_participant_label: masking_method: b0_synthstrip @@ -292,4 +307,7 @@ use_eddy_gpu: False use_bedpost_gpu: False rigid_dwi_t1_init: 'identity' rigid_dwi_t1_iters: '50x50' -root: 'results' +eddy_no_quad: False +no_topup: False +use_syn_sdc: False +use_synthsr_sdc: True diff --git a/snakedwi/workflow/rules/eddy.smk b/snakedwi/workflow/rules/eddy.smk index 93ee568..4ed7712 100644 --- a/snakedwi/workflow/rules/eddy.smk +++ b/snakedwi/workflow/rules/eddy.smk @@ -101,6 +101,17 @@ def get_eddy_topup_fmap_input(wildcards): **subj_wildcards, ).format(**wildcards), } + elif method == "synthsr": + return { + "fmap": bids( + root=work, + datatype="dwi", + suffix="fmap.nii.gz", + desc="b0", + method="synthSRsdc", + **subj_wildcards, + ).format(**wildcards) + } elif method == "syn": return { "fmap": bids( @@ -130,6 +141,17 @@ def get_eddy_topup_fmap_opt(wildcards, input): root=work, suffix="topup", datatype="dwi", **subj_wildcards ).format(**wildcards) return f"--topup={topup_prefix}" + elif method == "synthsr": + fmap_prefix = bids( + root=work, + datatype="dwi", + suffix="fmap", + desc="b0", + method="synthSRsdc", + **subj_wildcards, + ).format(**wildcards) + return f"--field={fmap_prefix}" + elif method == "syn": fmap_prefix = bids( root=work, diff --git a/snakedwi/workflow/rules/masking_b0_synthstrip.smk b/snakedwi/workflow/rules/masking_b0_synthstrip.smk index 88b21b5..14dee0b 100644 --- a/snakedwi/workflow/rules/masking_b0_synthstrip.smk +++ b/snakedwi/workflow/rules/masking_b0_synthstrip.smk @@ -22,6 +22,8 @@ rule synthstrip_b0: container: config["singularity"]["synthstrip"] threads: 8 + shadow: + "minimal" group: "subj" shell: diff --git a/snakedwi/workflow/rules/reg_dwi_to_t1.smk b/snakedwi/workflow/rules/reg_dwi_to_t1.smk index a92380c..4a8556c 100644 --- a/snakedwi/workflow/rules/reg_dwi_to_t1.smk +++ b/snakedwi/workflow/rules/reg_dwi_to_t1.smk @@ -30,6 +30,8 @@ rule synthstrip_t1: container: config["singularity"]["synthstrip"] threads: 8 + shadow: + "minimal" shell: "python3 /freesurfer/mri_synthstrip -i {input.t1} -m {output.mask}" diff --git a/snakedwi/workflow/rules/reg_t1_to_template.smk b/snakedwi/workflow/rules/reg_t1_to_template.smk index 5a518d6..4d274ac 100644 --- a/snakedwi/workflow/rules/reg_t1_to_template.smk +++ b/snakedwi/workflow/rules/reg_t1_to_template.smk @@ -54,8 +54,8 @@ rule convert_template_xfm_ras2itk: datatype="transforms", **subj_wildcards, suffix="xfm.txt", - from_="subject", - to="{template}", + from_="{from}", + to="{to}", desc="{desc}", type_="ras" ), @@ -65,8 +65,8 @@ rule convert_template_xfm_ras2itk: datatype="transforms", **subj_wildcards, suffix="xfm.txt", - from_="subject", - to="{template}", + from_="{from}", + to="{to}", desc="{desc}", type_="itk" ), diff --git a/snakedwi/workflow/rules/sdc.smk b/snakedwi/workflow/rules/sdc.smk index 72014b1..4905fbe 100644 --- a/snakedwi/workflow/rules/sdc.smk +++ b/snakedwi/workflow/rules/sdc.smk @@ -350,6 +350,252 @@ rule syn_sdc: "../scripts/sdcflows_syn.py" +rule run_synthSR: + input: + "{prefix}.nii.gz", + output: + "{prefix}SynthSR.nii.gz", + threads: 8 + group: + "subj" + container: + config["singularity"]["synthsr"] + shadow: + "minimal" + shell: + "python /SynthSR/scripts/predict_command_line.py --cpu --threads {threads} {input} {output}" + + +rule reslice_synthSR_b0: + input: + ref=bids( + root=work, + suffix="b0.nii.gz", + datatype="dwi", + desc="moco", + **subj_wildcards + ), + synthsr=bids( + root=work, + suffix="b0SynthSR.nii.gz", + datatype="dwi", + desc="moco", + **subj_wildcards + ), + output: + synthsr=bids( + root=work, + suffix="b0SynthSRresliced.nii.gz", + datatype="dwi", + desc="moco", + **subj_wildcards + ), + container: + config["singularity"]["itksnap"] + group: + "subj" + shell: + "c3d {input.ref} {input.synthsr} -reslice-identity -o {output.synthsr}" + + +rule rigid_reg_t1_to_b0_synthsr: + input: + flo=bids( + root=root, + datatype="anat", + **subj_wildcards, + desc="preproc", + suffix="T1wSynthSR.nii.gz" + ), + ref=bids( + root=work, + suffix="b0SynthSRresliced.nii.gz", + datatype="dwi", + desc="moco", + **subj_wildcards + ), + output: + warped_subj=bids( + root=work, + suffix="T1wSynthSRreg.nii.gz", + space="rigidb0", + datatype="dwi", + **subj_wildcards + ), + xfm_ras=bids( + root=work, + suffix="xfm.txt", + from_="T1wSynthSR", + to="b0SynthSR", + type_="ras", + desc="rigid", + datatype="transforms", + **subj_wildcards + ), + container: + config["singularity"]["prepdwi"] + log: + bids(root="logs", suffix="rigid_reg_t1_to_b0_synthsr.txt", **subj_wildcards), + group: + "subj" + shell: + "reg_aladin -flo {input.flo} -ref {input.ref} -res {output.warped_subj} -aff {output.xfm_ras} > {log}" + + +def get_restrict_deformation(wildcards, input, output): + checkpoint_output = checkpoints.check_subj_dwi_metadata.get(**wildcards).output[0] + ([peaxis],) = glob_wildcards(os.path.join(checkpoint_output, "PEaxis-{axis}")) + if peaxis == "i": + return "1x0x0" + elif peaxis == "j": + return "0x1x0" + elif peaxis == "k": + return "0x0x1" # unlikely.. + + +rule reg_b0_to_t1_synthsr: + input: + rules.check_subj_dwi_metadata.output, + t1synth=bids( + root=work, + suffix="T1wSynthSRreg.nii.gz", + space="rigidb0", + datatype="dwi", + **subj_wildcards + ), + b0synth=bids( + root=work, + suffix="b0SynthSRresliced.nii.gz", + datatype="dwi", + desc="moco", + **subj_wildcards + ), + params: + general_opts="-d 3 -v", + metric=lambda wildcards, input: f"MeanSquares[{input.t1synth},{input.b0synth}]", + transform="SyN[0.1]", + convergence="100x50x20", + smoothing_sigmas="2x1x0vox", + shrink_factors="4x2x1", + restrict_deformation=get_restrict_deformation, + # "0x1x0", #should be set to the phase encode dir - can read json for this.. + output: + unwarped=bids( + root=work, + suffix="b0SynthSRunwarped.nii.gz", + desc="Syn", + datatype="dwi", + **subj_wildcards + ), + fwd_xfm=bids( + root=root, + suffix="xfm.nii.gz", + from_="b0SynthSR", + to="T1wSynthSR", + type_="itk", + desc="SyN", + datatype="transforms", + **subj_wildcards + ), + shadow: + "minimal" + container: + config["singularity"]["ants"] + log: + bids(root="logs", suffix="reg_b0_to_t1_synthsr.txt", **subj_wildcards), + group: + "subj" + shell: + "antsRegistration {params.general_opts} --metric {params.metric} --convergence {params.convergence} " + " --smoothing-sigmas {params.smoothing_sigmas} --shrink-factors {params.shrink_factors} " + " --transform {params.transform} " + " --output [ants_,{output.unwarped}] --restrict-deformation {params.restrict_deformation} > {log} && " + " mv ants_0Warp.nii.gz {output.fwd_xfm} " + + +rule displacement_field_to_fmap: + input: + disp_field=bids( + root=root, + suffix="xfm.nii.gz", + from_="b0SynthSR", + to="T1wSynthSR", + type_="itk", + desc="SyN", + datatype="transforms", + **subj_wildcards + ), + dwi_nii=lambda wildcards: expand( + bids( + root=work, + suffix="dwi.nii.gz", + datatype="dwi", + **input_wildcards["dwi"] + ), + zip, + **filter_list(input_zip_lists["dwi"], wildcards) + )[0], + dwi_json=lambda wildcards: expand( + bids( + root=work, suffix="dwi.json", datatype="dwi", **input_wildcards["dwi"] + ), + zip, + **filter_list(input_zip_lists["dwi"], wildcards) + )[0], + params: + demean=True, + output: + fmap=bids( + root=work, + datatype="dwi", + suffix="fmap.nii.gz", + desc="b0", + method="synthSRsdc", + **subj_wildcards + ), + container: + config["singularity"]["python"] + script: + "../scripts/displacement_field_to_fmap.py" + + +rule apply_unwarp_synthsr: + input: + b0=bids( + root=work, + suffix="b0.nii.gz", + datatype="dwi", + desc="moco", + **subj_wildcards + ), + xfm=bids( + root=root, + suffix="xfm.nii.gz", + from_="b0SynthSR", + to="T1wSynthSR", + type_="itk", + desc="SyN", + datatype="transforms", + **subj_wildcards + ), + output: + unwarped=bids( + root=work, + suffix="b0.nii.gz", + datatype="dwi", + desc="unwarped", + method="synthSRsdc", + **subj_wildcards + ), + container: + config["singularity"]["ants"] + group: + "subj" + shell: + "antsApplyTransforms -d 3 -i {input.b0} -o {output.unwarped} " + " -r {input.b0} -t {input.xfm} -e 0 -n NearestNeighbor" + + def get_dwi_ref(wildcards): checkpoint_output = checkpoints.check_subj_dwi_metadata.get(**wildcards).output[0] ([method],) = glob_wildcards(os.path.join(checkpoint_output, "sdc-{method}")) @@ -363,6 +609,16 @@ def get_dwi_ref(wildcards): datatype="dwi", **subj_wildcards ) + elif method == "synthsr": + return bids( + root=work, + datatype="dwi", + suffix="b0.nii.gz", + desc="unwarped", + method="synthSRsdc", + **subj_wildcards + ) + elif method == "syn": return bids( root=work, diff --git a/snakedwi/workflow/scripts/check_subj_dwi_metadata.py b/snakedwi/workflow/scripts/check_subj_dwi_metadata.py index 249f032..446bb3a 100644 --- a/snakedwi/workflow/scripts/check_subj_dwi_metadata.py +++ b/snakedwi/workflow/scripts/check_subj_dwi_metadata.py @@ -65,6 +65,13 @@ f"Opposing phase encoding directions not available, {phase_encoding_directions}, using syn for sdc" ) shell("touch {snakemake.output}/sdc-syn") + elif snakemake.config["use_synthsr_sdc"]: + + print( + f"Opposing phase encoding directions not available, {phase_encoding_directions}, using synthSR+Syn for sdc" + ) + shell("touch {snakemake.output}/sdc-synthsr") + else: print( f"Opposing phase encoding directions not available, {phase_encoding_directions}, skipping sdc" @@ -82,3 +89,6 @@ else: print("Disabling eddy s2v in the workflow") shell("touch {snakemake.output}/eddys2v-no") + +print("Writing phase encoding axis") +shell("touch {snakemake.output}/PEaxis-{phase_encoding_axes[0]}") diff --git a/snakedwi/workflow/scripts/displacement_field_to_fmap.py b/snakedwi/workflow/scripts/displacement_field_to_fmap.py new file mode 100644 index 0000000..7a58656 --- /dev/null +++ b/snakedwi/workflow/scripts/displacement_field_to_fmap.py @@ -0,0 +1,32 @@ +import nibabel as nib +import json +from sdcflows.transform import disp_to_fmap +import numpy as np + + +# adapted from sdcflows DisplacementsField2Fieldmap Interface +with open(snakemake.input.dwi_json, "r") as f: + json_dwi = json.load(f) + +# get phenc dir, and calculate required readout time +pe_dir = json_dwi["PhaseEncodingDirection"] +shape_dwi = nib.load(snakemake.input.dwi_nii).get_fdata().shape +n_pe = shape_dwi["ijk".index(pe_dir[0])] +ro_time = json_dwi["EffectiveEchoSpacing"] * (n_pe - 1) + +# pass this on to the sdcflows function +nib_fmap = disp_to_fmap(nib.load(snakemake.input.disp_field), ro_time, pe_dir) + +# demean the fieldmap +if snakemake.params.demean: + data = np.asanyarray(nib_fmap.dataobj) + data -= np.median(data) + + fmapnii = nib_fmap.__class__( + data.astype("float32"), + nib_fmap.affine, + nib_fmap.header, + ) + +# save it +nib_fmap.to_filename(snakemake.output.fmap)