diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 90d960a6b9..228b73b6c2 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -939,6 +939,12 @@ Utility (Dict) :members: :special-members: __call__ +`RandLambdad` +""""""""""""" +.. autoclass:: RandLambdad + :members: + :special-members: __call__ + `LabelToMaskd` """""""""""""" .. autoclass:: LabelToMaskd diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 9eaedd6b15..ebd21a1c45 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -298,6 +298,9 @@ Lambdad, LambdaD, LambdaDict, + RandLambdad, + RandLambdaD, + RandLambdaDict, RepeatChanneld, RepeatChannelD, RepeatChannelDict, diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index 951c9dd459..f374b82d76 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -17,7 +17,7 @@ import copy import logging -from typing import Callable, Dict, Hashable, List, Mapping, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Dict, Hashable, List, Mapping, Optional, Sequence, Tuple, Union import numpy as np import torch @@ -64,10 +64,12 @@ "CopyItemsd", "ConcatItemsd", "Lambdad", + "RandLambdad", "LabelToMaskd", "FgBgToIndicesd", "ConvertToMultiChannelBasedOnBratsClassesd", "AddExtremePointsChanneld", + "TorchVisiond", "IdentityD", "IdentityDict", "AsChannelFirstD", @@ -76,6 +78,8 @@ "AsChannelLastDict", "AddChannelD", "AddChannelDict", + "RandLambdaD", + "RandLambdaDict", "RepeatChannelD", "RepeatChannelDict", "SplitChannelD", @@ -106,7 +110,6 @@ "ConvertToMultiChannelBasedOnBratsClassesDict", "AddExtremePointsChannelD", "AddExtremePointsChannelDict", - "TorchVisiond", "TorchVisionD", "TorchVisionDict", ] @@ -621,6 +624,27 @@ def __call__(self, data): return d +class RandLambdad(Lambdad, Randomizable): + """ + Randomizable version :py:class:`monai.transforms.Lambdad`, the input `func` contains random logic. + It's a randomizable transform so `CacheDataset` will not execute it and cache the results. + + Args: + keys: keys of the corresponding items to be transformed. + See also: :py:class:`monai.transforms.compose.MapTransform` + func: Lambda/function to be applied. It also can be a sequence of Callable, + each element corresponds to a key in ``keys``. + overwrite: whether to overwrite the original data in the input dictionary with lamdbda function output. + default to True. it also can be a sequence of bool, each element corresponds to a key in ``keys``. + + For more details, please check :py:class:`monai.transforms.Lambdad`. + + """ + + def randomize(self, data: Any) -> None: + pass + + class LabelToMaskd(MapTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.LabelToMask`. @@ -830,3 +854,4 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torc ) = ConvertToMultiChannelBasedOnBratsClassesd AddExtremePointsChannelD = AddExtremePointsChannelDict = AddExtremePointsChanneld TorchVisionD = TorchVisionDict = TorchVisiond +RandLambdaD = RandLambdaDict = RandLambdad diff --git a/tests/test_rand_lambdad.py b/tests/test_rand_lambdad.py new file mode 100644 index 0000000000..359da8857a --- /dev/null +++ b/tests/test_rand_lambdad.py @@ -0,0 +1,48 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +from monai.transforms import Randomizable +from monai.transforms.utility.dictionary import RandLambdad + + +class RandTest(Randomizable): + """ + randomisable transform for testing. + """ + + def randomize(self, data=None): + self._a = self.R.random() + + def __call__(self, data): + self.randomize() + return data + self._a + + +class TestRandLambdad(unittest.TestCase): + def test_rand_lambdad_identity(self): + img = np.zeros((10, 10)) + data = {"img": img, "prop": 1.0} + + test_func = RandTest() + test_func.set_random_state(seed=134) + expected = {"img": test_func(data["img"]), "prop": 1.0} + test_func.set_random_state(seed=134) + ret = RandLambdad(keys=["img", "prop"], func=test_func, overwrite=[True, False])(data) + np.testing.assert_allclose(expected["img"], ret["img"]) + np.testing.assert_allclose(expected["prop"], ret["prop"]) + + +if __name__ == "__main__": + unittest.main()