Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Jun 12, 2024
1 parent 5bb568a commit 6955a13
Show file tree
Hide file tree
Showing 22 changed files with 4,403 additions and 2,538 deletions.
20 changes: 11 additions & 9 deletions nobrainer/ext/SynthSeg/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from . import brain_generator
from . import estimate_priors
from . import evaluate
from . import labels_to_image_model
from . import metrics_model
from . import model_inputs
from . import predict
from . import training_supervised
from . import training
from . import (
brain_generator,
estimate_priors,
evaluate,
labels_to_image_model,
metrics_model,
model_inputs,
predict,
training,
training_supervised,
)
127 changes: 84 additions & 43 deletions nobrainer/ext/SynthSeg/estimate_priors.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,24 @@
License.
"""


# python imports
import os

import numpy as np

try:
from scipy.stats import median_absolute_deviation
except ImportError:
from scipy.stats import median_abs_deviation as median_absolute_deviation


# third-party imports
from ext.lab2im import utils
from ext.lab2im import edit_volumes
from ext.lab2im import edit_volumes, utils


def sample_intensity_stats_from_image(image, segmentation, labels_list, classes_list=None, keep_strictly_positive=True):
def sample_intensity_stats_from_image(
image, segmentation, labels_list, classes_list=None, keep_strictly_positive=True
):
"""This function takes an image and corresponding segmentation as inputs. It estimates the mean and std intensity
for all specified label values. Labels can share the same statistics by being regrouped into K classes.
:param image: image from which to evaluate mean intensity and std deviation.
Expand All @@ -48,19 +50,27 @@ def sample_intensity_stats_from_image(image, segmentation, labels_list, classes_
"""

# reformat labels and classes
labels_list = np.array(utils.reformat_to_list(labels_list, load_as_numpy=True, dtype='int'))
labels_list = np.array(
utils.reformat_to_list(labels_list, load_as_numpy=True, dtype="int")
)
if classes_list is not None:
classes_list = np.array(utils.reformat_to_list(classes_list, load_as_numpy=True, dtype='int'))
classes_list = np.array(
utils.reformat_to_list(classes_list, load_as_numpy=True, dtype="int")
)
else:
classes_list = np.arange(labels_list.shape[0])
assert len(classes_list) == len(labels_list), 'labels and classes lists should have the same length'
assert len(classes_list) == len(
labels_list
), "labels and classes lists should have the same length"

# get unique classes
unique_classes, unique_indices = np.unique(classes_list, return_index=True)
n_classes = len(unique_classes)
if not np.array_equal(unique_classes, np.arange(n_classes)):
raise ValueError('classes_list should only contain values between 0 and K-1, '
'where K is the total number of classes. Here K = %d' % n_classes)
raise ValueError(
"classes_list should only contain values between 0 and K-1, "
"where K is the total number of classes. Here K = %d" % n_classes
)

# compute mean/std of specified classes
means = np.zeros(n_classes)
Expand All @@ -80,13 +90,14 @@ def sample_intensity_stats_from_image(image, segmentation, labels_list, classes_
# compute stats for class and put them to the location of corresponding label values
if len(intensities) != 0:
means[idx] = np.nanmedian(intensities)
stds[idx] = median_absolute_deviation(intensities, nan_policy='omit')
stds[idx] = median_absolute_deviation(intensities, nan_policy="omit")

return np.stack([means, stds])


def sample_intensity_stats_from_single_dataset(image_dir, labels_dir, labels_list, classes_list=None, max_channel=3,
rescale=True):
def sample_intensity_stats_from_single_dataset(
image_dir, labels_dir, labels_list, classes_list=None, max_channel=3, rescale=True
):
"""This function aims at estimating the intensity distributions of K different structure types from a set of images.
The distribution of each structure type is modelled as a Gaussian, parametrised by a mean and a standard deviation.
Because the intensity distribution of structures can vary across images, we additionally use Gaussian priors for the
Expand Down Expand Up @@ -116,30 +127,42 @@ def sample_intensity_stats_from_single_dataset(image_dir, labels_dir, labels_lis
# list files
path_images = utils.list_images_in_folder(image_dir)
path_labels = utils.list_images_in_folder(labels_dir)
assert len(path_images) == len(path_labels), 'image and labels folders do not have the same number of files'
assert len(path_images) == len(
path_labels
), "image and labels folders do not have the same number of files"

# reformat list labels and classes
labels_list = np.array(utils.reformat_to_list(labels_list, load_as_numpy=True, dtype='int'))
labels_list = np.array(
utils.reformat_to_list(labels_list, load_as_numpy=True, dtype="int")
)
if classes_list is not None:
classes_list = np.array(utils.reformat_to_list(classes_list, load_as_numpy=True, dtype='int'))
classes_list = np.array(
utils.reformat_to_list(classes_list, load_as_numpy=True, dtype="int")
)
else:
classes_list = np.arange(labels_list.shape[0])
assert len(classes_list) == len(labels_list), 'labels and classes lists should have the same length'
assert len(classes_list) == len(
labels_list
), "labels and classes lists should have the same length"

# get unique classes
unique_classes, unique_indices = np.unique(classes_list, return_index=True)
n_classes = len(unique_classes)
if not np.array_equal(unique_classes, np.arange(n_classes)):
raise ValueError('classes_list should only contain values between 0 and K-1, '
'where K is the total number of classes. Here K = %d' % n_classes)
raise ValueError(
"classes_list should only contain values between 0 and K-1, "
"where K is the total number of classes. Here K = %d" % n_classes
)

# initialise result arrays
n_dims, n_channels = utils.get_dims(utils.load_volume(path_images[0]).shape, max_channels=max_channel)
n_dims, n_channels = utils.get_dims(
utils.load_volume(path_images[0]).shape, max_channels=max_channel
)
means = np.zeros((len(path_images), n_classes, n_channels))
stds = np.zeros((len(path_images), n_classes, n_channels))

# loop over images
loop_info = utils.LoopInfo(len(path_images), 10, 'estimating', print_time=True)
loop_info = utils.LoopInfo(len(path_images), 10, "estimating", print_time=True)
for idx, (path_im, path_la) in enumerate(zip(path_images, path_labels)):
loop_info.update(idx)

Expand All @@ -154,7 +177,9 @@ def sample_intensity_stats_from_single_dataset(image_dir, labels_dir, labels_lis
im = image[..., channel]
if rescale:
im = edit_volumes.rescale_volume(im)
stats = sample_intensity_stats_from_image(im, la, labels_list, classes_list=classes_list)
stats = sample_intensity_stats_from_image(
im, la, labels_list, classes_list=classes_list
)
means[idx, :, channel] = stats[0, :]
stds[idx, :, channel] = stats[1, :]

Expand All @@ -176,13 +201,15 @@ def sample_intensity_stats_from_single_dataset(image_dir, labels_dir, labels_lis
return prior_means, prior_stds


def build_intensity_stats(list_image_dir,
list_labels_dir,
result_dir,
estimation_labels,
estimation_classes=None,
max_channel=3,
rescale=True):
def build_intensity_stats(
list_image_dir,
list_labels_dir,
result_dir,
estimation_labels,
estimation_classes=None,
max_channel=3,
rescale=True,
):
"""This function aims at estimating the intensity distributions of K different structure types from a set of images.
The distribution of each structure type is modelled as a Gaussian, parametrised by a mean and a standard deviation.
Because the intensity distribution of structures can vary across images, we additionally use Gaussian priors for the
Expand Down Expand Up @@ -219,35 +246,49 @@ def build_intensity_stats(list_image_dir,

# reformat image/labels dir into lists
list_image_dir = utils.reformat_to_list(list_image_dir)
list_labels_dir = utils.reformat_to_list(list_labels_dir, length=len(list_image_dir))
list_labels_dir = utils.reformat_to_list(
list_labels_dir, length=len(list_image_dir)
)

# reformat list estimation labels and classes
estimation_labels = np.array(utils.reformat_to_list(estimation_labels, load_as_numpy=True, dtype='int'))
estimation_labels = np.array(
utils.reformat_to_list(estimation_labels, load_as_numpy=True, dtype="int")
)
if estimation_classes is not None:
estimation_classes = np.array(utils.reformat_to_list(estimation_classes, load_as_numpy=True, dtype='int'))
estimation_classes = np.array(
utils.reformat_to_list(estimation_classes, load_as_numpy=True, dtype="int")
)
else:
estimation_classes = np.arange(estimation_labels.shape[0])
assert len(estimation_classes) == len(estimation_labels), 'estimation labels and classes should be of same length'
assert len(estimation_classes) == len(
estimation_labels
), "estimation labels and classes should be of same length"

# get unique classes
unique_estimation_classes, unique_indices = np.unique(estimation_classes, return_index=True)
unique_estimation_classes, unique_indices = np.unique(
estimation_classes, return_index=True
)
n_classes = len(unique_estimation_classes)
if not np.array_equal(unique_estimation_classes, np.arange(n_classes)):
raise ValueError('estimation_classes should only contain values between 0 and N-1, '
'where K is the total number of classes. Here N = %d' % n_classes)
raise ValueError(
"estimation_classes should only contain values between 0 and N-1, "
"where K is the total number of classes. Here N = %d" % n_classes
)

# loop over dataset
list_datasets_prior_means = list()
list_datasets_prior_stds = list()
for image_dir, labels_dir in zip(list_image_dir, list_labels_dir):

# get prior stats for dataset
tmp_prior_means, tmp_prior_stds = sample_intensity_stats_from_single_dataset(image_dir,
labels_dir,
estimation_labels,
estimation_classes,
max_channel=max_channel,
rescale=rescale)
tmp_prior_means, tmp_prior_stds = sample_intensity_stats_from_single_dataset(
image_dir,
labels_dir,
estimation_labels,
estimation_classes,
max_channel=max_channel,
rescale=rescale,
)

# add stats arrays to list of datasets-wise statistics
list_datasets_prior_means.append(tmp_prior_means)
Expand All @@ -258,7 +299,7 @@ def build_intensity_stats(list_image_dir,
prior_stds = np.concatenate(list_datasets_prior_stds, axis=0)

# save files
np.save(os.path.join(result_dir, 'prior_means.npy'), prior_means)
np.save(os.path.join(result_dir, 'prior_stds.npy'), prior_stds)
np.save(os.path.join(result_dir, "prior_means.npy"), prior_means)
np.save(os.path.join(result_dir, "prior_stds.npy"), prior_stds)

return prior_means, prior_stds
Loading

0 comments on commit 6955a13

Please sign in to comment.