From 2d9c037ad6c26036cf93e723731786ed56279d2a Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Wed, 22 Jan 2020 09:01:00 +0800 Subject: [PATCH] [DLMED] simplify intensity normalization transform for MVP --- monai/data/transforms/intensity_normalizer.py | 45 +++++++------------ monai/data/transforms/transform.py | 27 ----------- 2 files changed, 16 insertions(+), 56 deletions(-) delete mode 100644 monai/data/transforms/transform.py diff --git a/monai/data/transforms/intensity_normalizer.py b/monai/data/transforms/intensity_normalizer.py index 5b66994972d..953498ab3de 100644 --- a/monai/data/transforms/intensity_normalizer.py +++ b/monai/data/transforms/intensity_normalizer.py @@ -9,34 +9,26 @@ # See the License for the specific language governing permissions and # limitations under the License. -from collections.abc import Hashable - import numpy as np +import monai -from .transform import Transform +export = monai.utils.export("monai.data.transforms") -class IntensityNormalizer(Transform): +@export +class IntensityNormalizer: """Normalize input based on provided args, using calculated mean and std if not provided (shape of subtrahend and divisor must match. if 0, entire volume uses same subtrahend and divisor, otherwise the shape can have dimension 1 for channels). Current implementation can only support 'channel_last' format data. Args: - apply_keys (a hashable key or a tuple/list of hashable keys): run transform on which field of the input data subtrahend (ndarray): the amount to subtract by (usually the mean) divisor (ndarray): the amount to divide by (usually the standard deviation) dtype: output data format """ - def __init__(self, apply_keys, subtrahend=None, divisor=None, dtype=np.float32): - _apply_keys = apply_keys if isinstance(apply_keys, (list, tuple)) else (apply_keys,) - if not _apply_keys: - raise ValueError('must set apply_keys for this transform.') - for key in _apply_keys: - if not isinstance(key, Hashable): - raise ValueError('apply_keys should be a hashable or a sequence of hashables used by data[key]') - self.apply_keys = _apply_keys + def __init__(self, subtrahend=None, divisor=None, dtype=np.float32): if subtrahend is not None or divisor is not None: assert isinstance(subtrahend, np.ndarray) and isinstance(divisor, np.ndarray), \ 'subtrahend and divisor must be set in pair and in numpy array.' @@ -44,19 +36,14 @@ def __init__(self, apply_keys, subtrahend=None, divisor=None, dtype=np.float32): self.divisor = divisor self.dtype = dtype - def __call__(self, data): - assert data is not None and isinstance(data, dict), 'data must be in dict format with keys.' - for key in self.apply_keys: - img = data[key] - assert key in data, 'can not find expected key={} in data.'.format(key) - if self.subtrahend is not None and self.divisor is not None: - img -= self.subtrahend - img /= self.divisor - else: - img -= np.mean(img) - img /= np.std(img) - - if self.dtype != img.dtype: - img = img.astype(self.dtype) - data[key] = img - return data + def __call__(self, img): + if self.subtrahend is not None and self.divisor is not None: + img -= self.subtrahend + img /= self.divisor + else: + img -= np.mean(img) + img /= np.std(img) + + if self.dtype != img.dtype: + img = img.astype(self.dtype) + return img diff --git a/monai/data/transforms/transform.py b/monai/data/transforms/transform.py deleted file mode 100644 index 75cc90926b5..00000000000 --- a/monai/data/transforms/transform.py +++ /dev/null @@ -1,27 +0,0 @@ -# Copyright 2020 MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -class Transform(object): - """An abstract class of a ``Transform`` - A transform is callable that maps data into output data. - """ - - def __call__(self, data): - """This method should return an updated version of ``data``. - One useful case is to create multiple instances of this class and - chain them together to form a more powerful transform: - for transform in transforms: - data = transform(data) - Args: - data (dict): an element which often comes from an iteration over an iterable, - such as``torch.utils.data.Dataset`` - """ - raise NotImplementedError