diff --git a/monai/data/transforms/intensity_normalizer.py b/monai/data/transforms/intensity_normalizer.py new file mode 100644 index 00000000000..751c3368e50 --- /dev/null +++ b/monai/data/transforms/intensity_normalizer.py @@ -0,0 +1,59 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +from .multi_format_transformer import MultiFormatTransformer +from .shape_format import get_shape_format +from .shape_format import get_channel_axis + + +class IntensityNormalizer(MultiFormatTransformer): + """Normalize input based on provided args, using calculated mean and std if not provided + (shape of subtrahend and divisor must match. if 0, entire volume uses same subtrahend and + divisor, otherwise the shape can have dimension 1 for channels). + + Args: + img - the MedicalImage to be processed + subtrahend (ndarray): the amount to subtract by (usually the mean) + divisor (ndarray): the amount to divide by (usually the standard deviation) + """ + + def __init__(self, dtype=np.float32): + MultiFormatTransformer.__init__(self) + self._dtype = dtype + + def _handle_any(self, img, subtrahend=None, divisor=None): + if subtrahend is not None and divisor is not None: + assert isinstance(subtrahend, np.ndarray) + assert isinstance(divisor, np.ndarray) + if subtrahend.ndim == 0 and divisor.ndim == 0: + img -= subtrahend + img /= divisor + else: # we have array or matrix: current implementation, just handle array for channels + shape_format = get_shape_format(img) + assert shape_format is not None, 'can not support this shape format.' + channel_axis = get_channel_axis(shape_format) + assert len(subtrahend.shape) == 1 + assert len(divisor.shape) == 1 + assert subtrahend.shape[0] == img.shape[channel_axis] + assert divisor.shape[0] == img.shape[channel_axis] + img = np.moveaxis(img, channel_axis, -1) + img -= subtrahend + img /= divisor + img = np.moveaxis(img, -1, channel_axis) + else: + img -= np.mean(img) + img /= np.std(img) + + if self._dtype != img.dtype: + img = img.astype(self._dtype) + + return img diff --git a/monai/data/transforms/shape_format.py b/monai/data/transforms/shape_format.py index 2e374b9757d..a5d1b738614 100644 --- a/monai/data/transforms/shape_format.py +++ b/monai/data/transforms/shape_format.py @@ -43,3 +43,19 @@ def get_shape_format(img: np.ndarray): return ShapeFormat.CHWD else: return None + + +def get_channel_axis(fmt): + """Get the channel axis number + + Args: + fmt: a shape format to analyze channel information + + Returns: the channel axis if the format is channeled, or None if not. + + """ + assert type(fmt) == str, 'format must be string.' + for i in range(len(fmt)): + if fmt[i] == 'C': + return i + return None