Skip to content

Commit

Permalink
Update preprocess.py
Browse files Browse the repository at this point in the history
Code changes to satisfy codacy (white spaces, variable name, overindentation, etc.)
  • Loading branch information
hamzake authored Oct 30, 2020
1 parent d973fa6 commit e1456c1
Showing 1 changed file with 42 additions and 41 deletions.
83 changes: 42 additions & 41 deletions pymialsrtk/interfaces/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
#
# This software is distributed under the open-source license Modified BSD.

""" PyMIALSRTK preprocessing functions including BTK Non-local-mean denoising, slice intensity correction
"""
PyMIALSRTK preprocessing functions including BTK Non-local-mean denoising, slice intensity correction
slice N4 bias field correction, slice-by-slice correct bias field, intensity standardization,
histogram normalization and both manual or deep learning based automatic brain extraction.
"""

import os
Expand Down Expand Up @@ -58,6 +60,7 @@ class BtkNLMDenoisingOutputSpec(TraitedSpec):
out_file = File(desc='Output denoised image')

class BtkNLMDenoising(BaseInterface):

"""
Runs the non-local mean denoising module: implementation by Rousseau et al. [1]_ of the method proposed by Coupé et al. [2]_.
Expand Down Expand Up @@ -93,7 +96,6 @@ class BtkNLMDenoising(BaseInterface):
>>> nlmDenoise.inputs.weight = 0.2
>>> nlmDenoise.run() # doctest: +SKIP
"""

input_spec = BtkNLMDenoisingInputSpec
output_spec = BtkNLMDenoisingOutputSpec

Expand Down Expand Up @@ -128,12 +130,13 @@ class MultipleBtkNLMDenoisingInputSpec(BaseInterfaceInputSpec):
input_masks = InputMultiPath(File(desc='Input mask filenames', mandatory=False))
weight = traits.Float(0.1, desc='NLM smoothing parameter (0.1 by default)', usedefault=True)
out_postfix = traits.Str("_nlm", desc='Suffix to be added to input image filenames to construst denoised output filenames',usedefault=True)
stacksOrder = traits.List(desc='Order of images index. To ensure images are processed with their correct corresponding mask', mandatory=False)
stacks_order = traits.List(desc='Order of images index. To ensure images are processed with their correct corresponding mask', mandatory=False)

class MultipleBtkNLMDenoisingOutputSpec(TraitedSpec):
output_images = OutputMultiPath(File(desc='Output denoised images'))

class MultipleBtkNLMDenoising(BaseInterface):

"""
Apply the non-local mean (NLM) denoising module on multiple inputs.
NLM denoising implementation by Rousseau et al. [1]_ of the method proposed by Coupé et al. [2]_.
Expand All @@ -155,7 +158,7 @@ class MultipleBtkNLMDenoising(BaseInterface):
weight <float>
smoothing parameter (high beta produces smoother result, default is 0.1)
stacksOrder <list<int>>
stacks_order <list<int>>
order of images index. To ensure images are processed with their correct corresponding mask.
References
Expand All @@ -177,7 +180,6 @@ class MultipleBtkNLMDenoising(BaseInterface):
--------
pymialsrtk.interfaces.preprocess.BtkNLMDenoising
"""

input_spec = MultipleBtkNLMDenoisingInputSpec
output_spec = MultipleBtkNLMDenoisingOutputSpec

Expand Down Expand Up @@ -239,6 +241,7 @@ class MialsrtkCorrectSliceIntensityOutputSpec(TraitedSpec):
out_file = File(desc='Output image with corrected slice intensities')

class MialsrtkCorrectSliceIntensity(BaseInterface):

"""
Runs the MIAL SRTK mean slice intensity correction module [to_be_cited].
Expand All @@ -264,9 +267,7 @@ class MialsrtkCorrectSliceIntensity(BaseInterface):
>>> sliceIntensityCorr.inputs.in_file = 'my_image.nii.gz'
>>> sliceIntensityCorr.inputs.in_mask = 'my_mask.nii.gz'
>>> sliceIntensityCorr.run() # doctest: +SKIP
"""

"""
input_spec = MialsrtkCorrectSliceIntensityInputSpec
output_spec = MialsrtkCorrectSliceIntensityOutputSpec

Expand Down Expand Up @@ -301,6 +302,7 @@ class MultipleMialsrtkCorrectSliceIntensityOutputSpec(TraitedSpec):
output_images = OutputMultiPath(File())

class MultipleMialsrtkCorrectSliceIntensity(BaseInterface):

"""
Apply the MIAL SRTK slice intensity correction module [to_be_cited] on multiple images.
Calls MialsrtkCorrectSliceIntensity interface with a list of images/masks.
Expand Down Expand Up @@ -335,7 +337,6 @@ class MultipleMialsrtkCorrectSliceIntensity(BaseInterface):
------------
pymialsrtk.interfaces.preprocess.MialsrtkCorrectSliceIntensity
"""

input_spec = MultipleMialsrtkCorrectSliceIntensityInputSpec
output_spec = MultipleMialsrtkCorrectSliceIntensityOutputSpec

Expand Down Expand Up @@ -395,6 +396,7 @@ class MialsrtkSliceBySliceN4BiasFieldCorrectionOutputSpec(TraitedSpec):
out_fld_file = File(desc='Filename bias field extracted slice by slice from input image.')

class MialsrtkSliceBySliceN4BiasFieldCorrection(BaseInterface):

"""
Runs the MIAL SRTK slice by slice N4 bias field correction module that implements the method proposed by Tustison et al. [1]_.
Expand Down Expand Up @@ -428,7 +430,6 @@ class MialsrtkSliceBySliceN4BiasFieldCorrection(BaseInterface):
>>> N4biasFieldCorr.inputs.in_mask = 'my_mask.nii.gz'
>>> N4biasFieldCorr.run() # doctest: +SKIP
"""

input_spec = MialsrtkSliceBySliceN4BiasFieldCorrectionInputSpec
output_spec = MialsrtkSliceBySliceN4BiasFieldCorrectionOutputSpec

Expand Down Expand Up @@ -475,6 +476,7 @@ class MultipleMialsrtkSliceBySliceN4BiasFieldCorrectionOutputSpec(TraitedSpec):
output_fields = OutputMultiPath(File())

class MultipleMialsrtkSliceBySliceN4BiasFieldCorrection(BaseInterface):

"""
Runs on multiple images the MIAL SRTK slice by slice N4 bias field correction module that implements the method proposed by Tustison et al. [1]_.
Calls MialsrtkSliceBySliceN4BiasFieldCorrection interface with a list of images/masks.
Expand Down Expand Up @@ -518,7 +520,6 @@ class MultipleMialsrtkSliceBySliceN4BiasFieldCorrection(BaseInterface):
.. [1] Tustison et al.; Medical Imaging, IEEE Transactions, 2010. `(link to paper) <https://www.ncbi.nlm.nih.gov/pmc/articles/PMC3071855>`_
"""

input_spec = MultipleMialsrtkSliceBySliceN4BiasFieldCorrectionInputSpec
output_spec = MultipleMialsrtkSliceBySliceN4BiasFieldCorrectionOutputSpec

Expand Down Expand Up @@ -574,6 +575,7 @@ class MialsrtkSliceBySliceCorrectBiasFieldOutputSpec(TraitedSpec):
out_im_file = File(desc='Bias field corrected image')

class MialsrtkSliceBySliceCorrectBiasField(BaseInterface):

"""
Runs the MIAL SRTK independant slice by slice bias field correction module [To_be_cited].
Expand Down Expand Up @@ -607,7 +609,6 @@ class MialsrtkSliceBySliceCorrectBiasField(BaseInterface):
>>> biasFieldCorr.inputs.in_field = 'my_field.nii.gz'
>>> biasFieldCorr.run() # doctest: +SKIP
"""

input_spec = MialsrtkSliceBySliceCorrectBiasFieldInputSpec
output_spec = MialsrtkSliceBySliceCorrectBiasFieldOutputSpec

Expand Down Expand Up @@ -643,6 +644,7 @@ class MultipleMialsrtkSliceBySliceCorrectBiasFieldOutputSpec(TraitedSpec):


class MultipleMialsrtkSliceBySliceCorrectBiasField(BaseInterface):

"""
Runs the MIAL SRTK slice by slice bias field correction module [to_be_cited] on multiple images.
Calls MialsrtkSliceBySliceCorrectBiasField interface with a list of images/masks/fields.
Expand Down Expand Up @@ -680,9 +682,7 @@ class MultipleMialsrtkSliceBySliceCorrectBiasField(BaseInterface):
See also
------------
pymialsrtk.interfaces.preprocess.MialsrtkCorrectSliceIntensity
"""

input_spec = MultipleMialsrtkSliceBySliceCorrectBiasFieldInputSpec
output_spec = MultipleMialsrtkSliceBySliceCorrectBiasFieldOutputSpec

Expand Down Expand Up @@ -737,12 +737,13 @@ class MialsrtkIntensityStandardizationInputSpec(BaseInterfaceInputSpec):
input_images = InputMultiPath(File(desc='files to be corrected for intensity', mandatory=True))
out_postfix = traits.Str("", desc='Suffix to be added to intensity corrected input_images', usedefault=True)
in_max = traits.Float(desc='Maximal intensity',usedefault=False)
stacksOrder = traits.List(desc='Order of images index. To ensure images are processed with their correct corresponding mask', mandatory=False) # ToDo: Can be removed -> Also in pymialsrtk.pipelines.anatomical.srr.AnatomicalPipeline !!!
stacks_order = traits.List(desc='Order of images index. To ensure images are processed with their correct corresponding mask', mandatory=False) # ToDo: Can be removed -> Also in pymialsrtk.pipelines.anatomical.srr.AnatomicalPipeline !!!

class MialsrtkIntensityStandardizationOutputSpec(TraitedSpec):
output_images = OutputMultiPath(File())

class MialsrtkIntensityStandardization(BaseInterface):

"""
Runs the MIAL SRTK intensity standardization module [To_be_cited].
Rescale image intensity by linear transformation
Expand All @@ -769,7 +770,6 @@ class MialsrtkIntensityStandardization(BaseInterface):
>>> intensityStandardization.inputs.input_images = ['image1.nii.gz','image2.nii.gz']
>>> intensityStandardization.run() # doctest: +SKIP
"""

input_spec = MialsrtkIntensityStandardizationInputSpec
output_spec = MialsrtkIntensityStandardizationOutputSpec

Expand Down Expand Up @@ -806,7 +806,8 @@ class MialsrtkHistogramNormalizationInputSpec(BaseInterfaceInputSpec):
bids_dir = Directory(desc='BIDS root directory', mandatory=True, exists=True)
input_images = InputMultiPath(File(desc='Input image filenames to be normalized', mandatory=True))
input_masks = InputMultiPath(File(desc='Input mask filenames', mandatory=False))
out_postfix = traits.Str("_histnorm", desc='Suffix to be added to normalized input image filenames to construct ouptut normalized image filenames', usedefault=True)
out_postfix = traits.Str("_histnorm", desc='Suffix to be added to normalized input image filenames to construct ouptut normalized image filenames',
usedefault=True)
stacksOrder = traits.List(desc='Order of images index. To ensure images are processed with their correct corresponding mask', mandatory=False)


Expand All @@ -815,6 +816,7 @@ class MialsrtkHistogramNormalizationOutputSpec(TraitedSpec):


class MialsrtkHistogramNormalization(BaseInterface):

"""
Runs the MIAL SRTK histogram normalizaton module that implements the method proposed by Nyúl et al. [1]_.
Expand Down Expand Up @@ -849,8 +851,7 @@ class MialsrtkHistogramNormalization(BaseInterface):
>>> histNorm.inputs.out_postfix = '_histnorm'
>>> histNorm.inputs.stacksOrder = [0,1]
>>> histNorm.run() # doctest: +SKIP
"""

"""
input_spec = MialsrtkHistogramNormalizationInputSpec
output_spec = MialsrtkHistogramNormalizationOutputSpec

Expand Down Expand Up @@ -913,6 +914,7 @@ class MialsrtkMaskImageOutputSpec(TraitedSpec):
out_im_file = File(desc='Masked image')

class MialsrtkMaskImage(BaseInterface):

"""
Runs the MIAL SRTK mask image module.
Expand Down Expand Up @@ -940,7 +942,6 @@ class MialsrtkMaskImage(BaseInterface):
>>> maskImg.inputs.out_im_postfix = '_masked'
>>> maskImg.run() # doctest: +SKIP
"""

input_spec = MialsrtkMaskImageInputSpec
output_spec = MialsrtkMaskImageOutputSpec

Expand Down Expand Up @@ -975,6 +976,7 @@ class MultipleMialsrtkMaskImageOutputSpec(TraitedSpec):
output_images = OutputMultiPath(File(desc='Output masked image filenames'))

class MultipleMialsrtkMaskImage(BaseInterface):

"""
Runs the MIAL SRTK mask image module on multiple images.
Calls MialsrtkMaskImage interface with a list of images/masks.
Expand Down Expand Up @@ -1013,8 +1015,7 @@ class MultipleMialsrtkMaskImage(BaseInterface):
See also
------------
pymialsrtk.interfaces.preprocess.MialsrtkMaskImage
"""

"""
input_spec = MultipleMialsrtkMaskImageInputSpec
output_spec = MultipleMialsrtkMaskImageOutputSpec

Expand Down Expand Up @@ -1069,6 +1070,7 @@ class BrainExtractionOutputSpec(TraitedSpec):


class BrainExtraction(BaseInterface):

"""
Runs the automatic brain extraction module based on a 2D U-Net (Ronneberger et al. [1]_)
using the pre-trained weights from Salehi et al. [2]_.
Expand Down Expand Up @@ -1114,7 +1116,6 @@ class BrainExtraction(BaseInterface):
>>> brainmask.inputs.out_postfix = '_brainMask.nii.gz'
>>> brainmask.run() # doctest: +SKIP
"""

input_spec = BrainExtractionInputSpec
output_spec = BrainExtractionOutputSpec

Expand All @@ -1141,7 +1142,7 @@ def _extractBrain(self, dataPath, modelCkptLoc, thresholdLoc, modelCkptSeg, thre
"""

# Step1: Main part brain localization
##### Step2: Brain localization #####
normalize = "local_max"
width = 128
height = 128
Expand All @@ -1160,16 +1161,16 @@ def _extractBrain(self, dataPath, modelCkptLoc, thresholdLoc, modelCkptSeg, thre
fy=height)

if normalize:
if normalize == "local_max":
images[slice_counter, :, :, 0] = img_patch / np.max(img_patch)
elif normalize == "global_max":
images[slice_counter, :, :, 0] = img_patch / max_val
elif normalize == "mean_std":
images[slice_counter, :, :, 0] = (img_patch-np.mean(img_patch))/np.std(img_patch)
else:
raise ValueError('Please select a valid normalization')
if normalize == "local_max":
images[slice_counter, :, :, 0] = img_patch / np.max(img_patch)
elif normalize == "global_max":
images[slice_counter, :, :, 0] = img_patch / max_val
elif normalize == "mean_std":
images[slice_counter, :, :, 0] = (img_patch-np.mean(img_patch))/np.std(img_patch)
else:
raise ValueError('Please select a valid normalization')
else:
images[slice_counter, :, :, 0] = img_patch
images[slice_counter, :, :, 0] = img_patch

slice_counter += 1

Expand Down Expand Up @@ -1276,7 +1277,7 @@ def _extractBrain(self, dataPath, modelCkptLoc, thresholdLoc, modelCkptSeg, thre
y_beg = med_y-half_max_y-border_y
y_end = med_y+half_max_y+border_y

# Step2: Brain segmentation
##### Step2: Brain segmentation #####
width = 96
height = 96

Expand All @@ -1287,12 +1288,12 @@ def _extractBrain(self, dataPath, modelCkptLoc, thresholdLoc, modelCkptSeg, thre
img_patch = cv2.resize(image_data[x_beg:x_end, y_beg:y_end, ii], dsize=(width, height))

if normalize:
if normalize == "local_max":
images[slice_counter, :, :, 0] = img_patch / np.max(img_patch)
elif normalize == "mean_std":
images[slice_counter, :, :, 0] = (img_patch-np.mean(img_patch))/np.std(img_patch)
else:
raise ValueError('Please select a valid normalization')
if normalize == "local_max":
images[slice_counter, :, :, 0] = img_patch / np.max(img_patch)
elif normalize == "mean_std":
images[slice_counter, :, :, 0] = (img_patch-np.mean(img_patch))/np.std(img_patch)
else:
raise ValueError('Please select a valid normalization')
else:
images[slice_counter, :, :, 0] = img_patch

Expand Down Expand Up @@ -1346,7 +1347,7 @@ def _extractBrain(self, dataPath, modelCkptLoc, thresholdLoc, modelCkptSeg, thre
pred = conv_2d(conv9, 2, 1, activation='linear', padding='valid')

with tf.Session(graph=g) as sess_test_seg:
# Restore the model
# Restore the model
tf_saver = tf.train.Saver()
tf_saver.restore(sess_test_seg, modelCkptSeg)

Expand Down

0 comments on commit e1456c1

Please sign in to comment.