Skip to content

Commit

Permalink
Merge pull request #124 from ANTsX/ShivaPvs
Browse files Browse the repository at this point in the history
PVS segmentation
  • Loading branch information
ntustison authored Aug 13, 2024
2 parents 196cf39 + 93efce1 commit 70330c7
Show file tree
Hide file tree
Showing 3 changed files with 193 additions and 0 deletions.
2 changes: 2 additions & 0 deletions antspynet/utilities/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@
from .white_matter_hyperintensity_segmentation import sysu_media_wmh_segmentation
from .white_matter_hyperintensity_segmentation import hypermapp3r_segmentation
from .white_matter_hyperintensity_segmentation import wmh_segmentation
from .white_matter_hyperintensity_segmentation import shiva_pvs_segmentation

from .claustrum_segmentation import claustrum_segmentation
from .hypothalamus_segmentation import hypothalamus_segmentation
from .hippmapp3r_segmentation import hippmapp3r_segmentation
Expand Down
22 changes: 22 additions & 0 deletions antspynet/utilities/get_pretrained_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,17 @@ def switch_networks(argument):
"mouseT2wBrainParcellation3DNick" : "https://figshare.com/ndownloader/files/44714944",
"mouseT2wBrainParcellation3DTct" : "https://figshare.com/ndownloader/files/47214538",
"mouseSTPTBrainParcellation3DJay" : "https://figshare.com/ndownloader/files/46710592",
"pvs_shiva_t1_0" : "https://figshare.com/ndownloader/files/48363799",
"pvs_shiva_t1_1" : "https://figshare.com/ndownloader/files/48363832",
"pvs_shiva_t1_2" : "https://figshare.com/ndownloader/files/48363814",
"pvs_shiva_t1_3" : "https://figshare.com/ndownloader/files/48363790",
"pvs_shiva_t1_4" : "https://figshare.com/ndownloader/files/48363829",
"pvs_shiva_t1_5" : "https://figshare.com/ndownloader/files/48363823",
"pvs_shiva_t1_flair_0" : "https://figshare.com/ndownloader/files/48363784",
"pvs_shiva_t1_flair_1" : "https://figshare.com/ndownloader/files/48363820",
"pvs_shiva_t1_flair_2" : "https://figshare.com/ndownloader/files/48363796",
"pvs_shiva_t1_flair_3" : "https://figshare.com/ndownloader/files/48363793",
"pvs_shiva_t1_flair_4" : "https://figshare.com/ndownloader/files/48363826",
"functionalLungMri": "https://ndownloader.figshare.com/files/13824167",
"hippMapp3rInitial": "https://ndownloader.figshare.com/files/18068408",
"hippMapp3rRefine": "https://ndownloader.figshare.com/files/18068411",
Expand Down Expand Up @@ -234,6 +245,17 @@ def switch_networks(argument):
"mouseT2wBrainParcellation3DNick",
"mouseT2wBrainParcellation3DTct",
"mouseSTPTBrainParcellation3DJay",
"pvs_shiva_t1_0",
"pvs_shiva_t1_1",
"pvs_shiva_t1_2",
"pvs_shiva_t1_3",
"pvs_shiva_t1_4",
"pvs_shiva_t1_5",
"pvs_shiva_t1_flair_0",
"pvs_shiva_t1_flair_1",
"pvs_shiva_t1_flair_2",
"pvs_shiva_t1_flair_3",
"pvs_shiva_t1_flair_4",
"elBicho",
"functionalLungMri",
"hippMapp3rInitial",
Expand Down
169 changes: 169 additions & 0 deletions antspynet/utilities/white_matter_hyperintensity_segmentation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@

import ants
import numpy as np
import tensorflow as tf
from tensorflow import keras

def sysu_media_wmh_segmentation(flair,
Expand Down Expand Up @@ -633,3 +634,171 @@ def wmh_segmentation(flair,
domain_image_is_mask=True)

return(wmh_probability_image)


def shiva_pvs_segmentation(t1,
flair=None,
which_model="all",
do_preprocessing=True,
antsxnet_cache_directory=None,
verbose=False):

"""
Perform segmentation of perivascular (PVS) or Vircho-Robin spaces (VRS).
https://pubmed.ncbi.nlm.nih.gov/34262443/
with the original implementation available here:
https://github.com/pboutinaud/SHIVA_PVS
Arguments
---------
t1 : ANTsImage
input 3-D T1 brain image (not skull-stripped).
flair : ANTsImage
(Optional) input 3-D FLAIR brain image (not skull-stripped) aligned to the T1 image.
which_model : integer or string
Several models were trained for the case of T1-only or T1/FLAIR image
pairs. One can use a specific single trained model or the average of
the entire ensemble. I.e., options are:
* For T1-only: 0, 1, 2, 3, 4, 5.
* For T1/FLAIR: 0, 1, 2, 3, 4.
* Or "all" for using the entire ensemble.
do_preprocessing : boolean
perform n4 bias correction, intensity truncation, brain extraction.
antsxnet_cache_directory : string
Destination directory for storing the downloaded template and model weights.
Since these can be reused, if is None, these data will be downloaded to a
~/.keras/ANTsXNet/.
verbose : boolean
Print progress to the screen.
Returns
-------
PVS or VRS segmentation probability image
Example
-------
>>> image = ants.image_read("flair.nii.gz")
>>> probability_mask = shiva_pvs_segmentation(image)
"""

from ..utilities import get_pretrained_network
from ..utilities import preprocess_brain_image

################################
#
# Preprocess images
#
################################

t1_preprocessed = None
flair_preprocessed = None

if do_preprocessing:
if verbose:
print("Preprocess image(s).")

t1_preprocessing = preprocess_brain_image(t1,
truncate_intensity=(0.0, 0.99),
brain_extraction_modality="t1",
do_bias_correction=True,
do_denoising=False,
intensity_normalization_type="01",
antsxnet_cache_directory=antsxnet_cache_directory,
verbose=verbose)
brain_mask = ants.threshold_image(t1_preprocessing["brain_mask"], 0.5, 1, 1, 0)
t1_preprocessed = t1_preprocessing["preprocessed_image"] * brain_mask

if flair is not None:
flair_preprocessing = preprocess_brain_image(flair,
truncate_intensity=(0.0, 0.99),
brain_extraction_modality=None,
do_bias_correction=True,
do_denoising=False,
intensity_normalization_type="01",
antsxnet_cache_directory=antsxnet_cache_directory,
verbose=verbose)
flair_preprocessed = flair_preprocessing["preprocessed_image"] * brain_mask

else:
t1_preprocessed = ants.image_clone(t1)
if flair is not None:
flair_preprocessed = ants.image_clone(flair)

image_shape = (160, 214, 176)
reorient_template = ants.from_numpy(np.ones(image_shape), origin=(0, 0, 0),
spacing=(1, 1, 1), direction=np.eye(3))

center_of_mass_template = ants.get_center_of_mass(reorient_template)
center_of_mass_image = ants.get_center_of_mass(t1_preprocessed * 0 + 1)
translation = np.round(np.asarray(center_of_mass_image) - np.asarray(center_of_mass_template))
xfrm = ants.create_ants_transform(transform_type="Euler3DTransform",
center=np.asarray(center_of_mass_template), translation=translation)

t1_preprocessed = ants.apply_ants_transform_to_image(xfrm, t1_preprocessed, reorient_template)
if flair is not None:
flair_preprocessed = ants.apply_ants_transform_to_image(xfrm, flair_preprocessed, reorient_template)

################################
#
# Load models and predict
#
################################

batchY = None
if flair is None:
batchX = np.zeros((1, *image_shape, 1))
batchX[0,:,:,:,0] = t1_preprocessed.numpy()

model_ids = [which_model,]
if which_model == "all":
model_ids = [0, 1, 2, 3, 4, 5]

for i in range(len(model_ids)):
model_file = get_pretrained_network("pvs_shiva_t1_" + str(model_ids[i]),
antsxnet_cache_directory=antsxnet_cache_directory)
if verbose:
print("Loading", model_file)
model = tf.keras.models.load_model(model_file, compile=False, custom_objects={"tf": tf})
if i == 0:
batchY = model.predict(batchX, verbose=verbose)
else:
batchY += model.predict(batchX, verbose=verbose)

batchY /= len(model_ids)

else:
batchX = np.zeros((1, *image_shape, 2))
batchX[0,:,:,:,0] = t1_preprocessed.numpy()
batchX[0,:,:,:,1] = flair_preprocessed.numpy()

model_ids = [which_model,]
if which_model == "all":
model_ids = [0, 1, 2, 3, 4]

for i in range(len(model_ids)):
model_file = get_pretrained_network("pvs_shiva_t1_flair_" + str(model_ids[i]),
antsxnet_cache_directory=antsxnet_cache_directory)
if verbose:
print("Loading", model_file)
model = tf.keras.models.load_model(model_file, compile=False, custom_objects={"tf": tf})
if i == 0:
batchY = model.predict(batchX, verbose=verbose)
else:
batchY += model.predict(batchX, verbose=verbose)

batchY /= len(model_ids)

pvs = ants.from_numpy(np.squeeze(batchY), origin=reorient_template.origin,
spacing=reorient_template.spacing,
direction=reorient_template.direction)
pvs = ants.apply_ants_transform_to_image(xfrm.invert(), pvs, t1)
return pvs

0 comments on commit 70330c7

Please sign in to comment.