diff --git a/antspynet/utilities/__init__.py b/antspynet/utilities/__init__.py index bdb231e..78633cf 100644 --- a/antspynet/utilities/__init__.py +++ b/antspynet/utilities/__init__.py @@ -94,6 +94,8 @@ from .white_matter_hyperintensity_segmentation import sysu_media_wmh_segmentation from .white_matter_hyperintensity_segmentation import hypermapp3r_segmentation from .white_matter_hyperintensity_segmentation import wmh_segmentation +from .white_matter_hyperintensity_segmentation import shiva_pvs_segmentation + from .claustrum_segmentation import claustrum_segmentation from .hypothalamus_segmentation import hypothalamus_segmentation from .hippmapp3r_segmentation import hippmapp3r_segmentation diff --git a/antspynet/utilities/get_pretrained_network.py b/antspynet/utilities/get_pretrained_network.py index c6f328b..c95b1f0 100644 --- a/antspynet/utilities/get_pretrained_network.py +++ b/antspynet/utilities/get_pretrained_network.py @@ -113,6 +113,17 @@ def switch_networks(argument): "mouseT2wBrainParcellation3DNick" : "https://figshare.com/ndownloader/files/44714944", "mouseT2wBrainParcellation3DTct" : "https://figshare.com/ndownloader/files/47214538", "mouseSTPTBrainParcellation3DJay" : "https://figshare.com/ndownloader/files/46710592", + "pvs_shiva_t1_0" : "https://figshare.com/ndownloader/files/48363799", + "pvs_shiva_t1_1" : "https://figshare.com/ndownloader/files/48363832", + "pvs_shiva_t1_2" : "https://figshare.com/ndownloader/files/48363814", + "pvs_shiva_t1_3" : "https://figshare.com/ndownloader/files/48363790", + "pvs_shiva_t1_4" : "https://figshare.com/ndownloader/files/48363829", + "pvs_shiva_t1_5" : "https://figshare.com/ndownloader/files/48363823", + "pvs_shiva_t1_flair_0" : "https://figshare.com/ndownloader/files/48363784", + "pvs_shiva_t1_flair_1" : "https://figshare.com/ndownloader/files/48363820", + "pvs_shiva_t1_flair_2" : "https://figshare.com/ndownloader/files/48363796", + "pvs_shiva_t1_flair_3" : "https://figshare.com/ndownloader/files/48363793", + "pvs_shiva_t1_flair_4" : "https://figshare.com/ndownloader/files/48363826", "functionalLungMri": "https://ndownloader.figshare.com/files/13824167", "hippMapp3rInitial": "https://ndownloader.figshare.com/files/18068408", "hippMapp3rRefine": "https://ndownloader.figshare.com/files/18068411", @@ -234,6 +245,17 @@ def switch_networks(argument): "mouseT2wBrainParcellation3DNick", "mouseT2wBrainParcellation3DTct", "mouseSTPTBrainParcellation3DJay", + "pvs_shiva_t1_0", + "pvs_shiva_t1_1", + "pvs_shiva_t1_2", + "pvs_shiva_t1_3", + "pvs_shiva_t1_4", + "pvs_shiva_t1_5", + "pvs_shiva_t1_flair_0", + "pvs_shiva_t1_flair_1", + "pvs_shiva_t1_flair_2", + "pvs_shiva_t1_flair_3", + "pvs_shiva_t1_flair_4", "elBicho", "functionalLungMri", "hippMapp3rInitial", diff --git a/antspynet/utilities/white_matter_hyperintensity_segmentation.py b/antspynet/utilities/white_matter_hyperintensity_segmentation.py index a49ce9c..12bd9d8 100644 --- a/antspynet/utilities/white_matter_hyperintensity_segmentation.py +++ b/antspynet/utilities/white_matter_hyperintensity_segmentation.py @@ -1,6 +1,7 @@ import ants import numpy as np +import tensorflow as tf from tensorflow import keras def sysu_media_wmh_segmentation(flair, @@ -633,3 +634,171 @@ def wmh_segmentation(flair, domain_image_is_mask=True) return(wmh_probability_image) + + +def shiva_pvs_segmentation(t1, + flair=None, + which_model="all", + do_preprocessing=True, + antsxnet_cache_directory=None, + verbose=False): + + """ + Perform segmentation of perivascular (PVS) or Vircho-Robin spaces (VRS). + + https://pubmed.ncbi.nlm.nih.gov/34262443/ + + with the original implementation available here: + + https://github.com/pboutinaud/SHIVA_PVS + + Arguments + --------- + t1 : ANTsImage + input 3-D T1 brain image (not skull-stripped). + + flair : ANTsImage + (Optional) input 3-D FLAIR brain image (not skull-stripped) aligned to the T1 image. + + which_model : integer or string + Several models were trained for the case of T1-only or T1/FLAIR image + pairs. One can use a specific single trained model or the average of + the entire ensemble. I.e., options are: + * For T1-only: 0, 1, 2, 3, 4, 5. + * For T1/FLAIR: 0, 1, 2, 3, 4. + * Or "all" for using the entire ensemble. + + do_preprocessing : boolean + perform n4 bias correction, intensity truncation, brain extraction. + + antsxnet_cache_directory : string + Destination directory for storing the downloaded template and model weights. + Since these can be reused, if is None, these data will be downloaded to a + ~/.keras/ANTsXNet/. + + verbose : boolean + Print progress to the screen. + + Returns + ------- + PVS or VRS segmentation probability image + + Example + ------- + >>> image = ants.image_read("flair.nii.gz") + >>> probability_mask = shiva_pvs_segmentation(image) + """ + + from ..utilities import get_pretrained_network + from ..utilities import preprocess_brain_image + + ################################ + # + # Preprocess images + # + ################################ + + t1_preprocessed = None + flair_preprocessed = None + + if do_preprocessing: + if verbose: + print("Preprocess image(s).") + + t1_preprocessing = preprocess_brain_image(t1, + truncate_intensity=(0.0, 0.99), + brain_extraction_modality="t1", + do_bias_correction=True, + do_denoising=False, + intensity_normalization_type="01", + antsxnet_cache_directory=antsxnet_cache_directory, + verbose=verbose) + brain_mask = ants.threshold_image(t1_preprocessing["brain_mask"], 0.5, 1, 1, 0) + t1_preprocessed = t1_preprocessing["preprocessed_image"] * brain_mask + + if flair is not None: + flair_preprocessing = preprocess_brain_image(flair, + truncate_intensity=(0.0, 0.99), + brain_extraction_modality=None, + do_bias_correction=True, + do_denoising=False, + intensity_normalization_type="01", + antsxnet_cache_directory=antsxnet_cache_directory, + verbose=verbose) + flair_preprocessed = flair_preprocessing["preprocessed_image"] * brain_mask + + else: + t1_preprocessed = ants.image_clone(t1) + if flair is not None: + flair_preprocessed = ants.image_clone(flair) + + image_shape = (160, 214, 176) + reorient_template = ants.from_numpy(np.ones(image_shape), origin=(0, 0, 0), + spacing=(1, 1, 1), direction=np.eye(3)) + + center_of_mass_template = ants.get_center_of_mass(reorient_template) + center_of_mass_image = ants.get_center_of_mass(t1_preprocessed * 0 + 1) + translation = np.round(np.asarray(center_of_mass_image) - np.asarray(center_of_mass_template)) + xfrm = ants.create_ants_transform(transform_type="Euler3DTransform", + center=np.asarray(center_of_mass_template), translation=translation) + + t1_preprocessed = ants.apply_ants_transform_to_image(xfrm, t1_preprocessed, reorient_template) + if flair is not None: + flair_preprocessed = ants.apply_ants_transform_to_image(xfrm, flair_preprocessed, reorient_template) + + ################################ + # + # Load models and predict + # + ################################ + + batchY = None + if flair is None: + batchX = np.zeros((1, *image_shape, 1)) + batchX[0,:,:,:,0] = t1_preprocessed.numpy() + + model_ids = [which_model,] + if which_model == "all": + model_ids = [0, 1, 2, 3, 4, 5] + + for i in range(len(model_ids)): + model_file = get_pretrained_network("pvs_shiva_t1_" + str(model_ids[i]), + antsxnet_cache_directory=antsxnet_cache_directory) + if verbose: + print("Loading", model_file) + model = tf.keras.models.load_model(model_file, compile=False, custom_objects={"tf": tf}) + if i == 0: + batchY = model.predict(batchX, verbose=verbose) + else: + batchY += model.predict(batchX, verbose=verbose) + + batchY /= len(model_ids) + + else: + batchX = np.zeros((1, *image_shape, 2)) + batchX[0,:,:,:,0] = t1_preprocessed.numpy() + batchX[0,:,:,:,1] = flair_preprocessed.numpy() + + model_ids = [which_model,] + if which_model == "all": + model_ids = [0, 1, 2, 3, 4] + + for i in range(len(model_ids)): + model_file = get_pretrained_network("pvs_shiva_t1_flair_" + str(model_ids[i]), + antsxnet_cache_directory=antsxnet_cache_directory) + if verbose: + print("Loading", model_file) + model = tf.keras.models.load_model(model_file, compile=False, custom_objects={"tf": tf}) + if i == 0: + batchY = model.predict(batchX, verbose=verbose) + else: + batchY += model.predict(batchX, verbose=verbose) + + batchY /= len(model_ids) + + pvs = ants.from_numpy(np.squeeze(batchY), origin=reorient_template.origin, + spacing=reorient_template.spacing, + direction=reorient_template.direction) + pvs = ants.apply_ants_transform_to_image(xfrm.invert(), pvs, t1) + return pvs + \ No newline at end of file