From 94d25f7c6d9de8b8c9119a7d0854670cba4a69c3 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Thu, 4 Feb 2021 20:43:24 +0800 Subject: [PATCH 1/8] [DLMED] add RandLambdad transform Signed-off-by: Nic Ma --- docs/source/transforms.rst | 6 +++++ monai/transforms/utility/dictionary.py | 13 ++++++++- tests/test_rand_lambdad.py | 37 ++++++++++++++++++++++++++ 3 files changed, 55 insertions(+), 1 deletion(-) create mode 100644 tests/test_rand_lambdad.py diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 90d960a6b9..228b73b6c2 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -939,6 +939,12 @@ Utility (Dict) :members: :special-members: __call__ +`RandLambdad` +""""""""""""" +.. autoclass:: RandLambdad + :members: + :special-members: __call__ + `LabelToMaskd` """""""""""""" .. autoclass:: LabelToMaskd diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index 951c9dd459..4c74f934bf 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -17,7 +17,7 @@ import copy import logging -from typing import Callable, Dict, Hashable, List, Mapping, Optional, Sequence, Tuple, Union +from typing import Callable, Dict, Hashable, List, Mapping, Optional, Sequence, Tuple, Union, Any import numpy as np import torch @@ -621,6 +621,16 @@ def __call__(self, data): return d +class RandLambdad(Lambdad, Randomizable): + """ + Randomizable version :py:class:`monai.transforms.Lambdad`, the input `func` contains random logic. + It's a randomizable transform so `CacheDataset` will not execute it and cache the results. + + """ + def randomize(self, data: Any) -> None: + pass + + class LabelToMaskd(MapTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.LabelToMask`. @@ -830,3 +840,4 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torc ) = ConvertToMultiChannelBasedOnBratsClassesd AddExtremePointsChannelD = AddExtremePointsChannelDict = AddExtremePointsChanneld TorchVisionD = TorchVisionDict = TorchVisiond +RandLambdaD = RandLambdaDict = RandLambdad diff --git a/tests/test_rand_lambdad.py b/tests/test_rand_lambdad.py new file mode 100644 index 0000000000..9808965e59 --- /dev/null +++ b/tests/test_rand_lambdad.py @@ -0,0 +1,37 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +from monai.transforms.utility.dictionary import RandLambdad +from tests.utils import NumpyImageTestCase2D + + +class TestRandLambdad(NumpyImageTestCase2D): + def test_rand_lambdad_identity(self): + img = self.imt + data = {"img": img, "prop": 1.0} + + def noise_func(x): + np.random.seed(123) + return x + np.random.randint(0, 10) + np.random.seed(None) + + expected = {"img": noise_func(data["img"]), "prop": 1.0} + ret = RandLambdad(keys=["img", "prop"], func=noise_func, overwrite=[True, False])(data) + self.assertTrue(np.allclose(expected["img"], ret["img"])) + self.assertTrue(np.allclose(expected["prop"], ret["prop"])) + + +if __name__ == "__main__": + unittest.main() From 8394882ebd0ec7e6448641ecfa2b502b0fb8ee25 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Thu, 4 Feb 2021 20:47:08 +0800 Subject: [PATCH 2/8] [DLMED] add doc-strings Signed-off-by: Nic Ma --- monai/transforms/utility/dictionary.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index 4c74f934bf..b56a6a0c92 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -626,6 +626,16 @@ class RandLambdad(Lambdad, Randomizable): Randomizable version :py:class:`monai.transforms.Lambdad`, the input `func` contains random logic. It's a randomizable transform so `CacheDataset` will not execute it and cache the results. + Args: + keys: keys of the corresponding items to be transformed. + See also: :py:class:`monai.transforms.compose.MapTransform` + func: Lambda/function to be applied. It also can be a sequence of Callable, + each element corresponds to a key in ``keys``. + overwrite: whether to overwrite the original data in the input dictionary with lamdbda function output. + default to True. it also can be a sequence of bool, each element corresponds to a key in ``keys``. + + For more details, please check :py:class:`monai.transforms.Lambdad`. + """ def randomize(self, data: Any) -> None: pass From b65fe931cbbb4f39e229f5ec159a68c258d8c4c7 Mon Sep 17 00:00:00 2001 From: monai-bot Date: Thu, 4 Feb 2021 12:51:27 +0000 Subject: [PATCH 3/8] [MONAI] python code formatting Signed-off-by: monai-bot --- monai/transforms/utility/dictionary.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index b56a6a0c92..2247a1d1c0 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -17,7 +17,7 @@ import copy import logging -from typing import Callable, Dict, Hashable, List, Mapping, Optional, Sequence, Tuple, Union, Any +from typing import Any, Callable, Dict, Hashable, List, Mapping, Optional, Sequence, Tuple, Union import numpy as np import torch @@ -637,6 +637,7 @@ class RandLambdad(Lambdad, Randomizable): For more details, please check :py:class:`monai.transforms.Lambdad`. """ + def randomize(self, data: Any) -> None: pass From d418a211bd87461fb7506e0140b696173e80a429 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Thu, 4 Feb 2021 21:54:34 +0800 Subject: [PATCH 4/8] [DLMED] update according to comments Signed-off-by: Nic Ma --- monai/transforms/__init__.py | 3 +++ monai/transforms/utility/dictionary.py | 5 ++++- tests/test_rand_lambdad.py | 25 +++++++++++++++++++------ 3 files changed, 26 insertions(+), 7 deletions(-) diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 9eaedd6b15..ebd21a1c45 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -298,6 +298,9 @@ Lambdad, LambdaD, LambdaDict, + RandLambdad, + RandLambdaD, + RandLambdaDict, RepeatChanneld, RepeatChannelD, RepeatChannelDict, diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index 2247a1d1c0..f374b82d76 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -64,10 +64,12 @@ "CopyItemsd", "ConcatItemsd", "Lambdad", + "RandLambdad", "LabelToMaskd", "FgBgToIndicesd", "ConvertToMultiChannelBasedOnBratsClassesd", "AddExtremePointsChanneld", + "TorchVisiond", "IdentityD", "IdentityDict", "AsChannelFirstD", @@ -76,6 +78,8 @@ "AsChannelLastDict", "AddChannelD", "AddChannelDict", + "RandLambdaD", + "RandLambdaDict", "RepeatChannelD", "RepeatChannelDict", "SplitChannelD", @@ -106,7 +110,6 @@ "ConvertToMultiChannelBasedOnBratsClassesDict", "AddExtremePointsChannelD", "AddExtremePointsChannelDict", - "TorchVisiond", "TorchVisionD", "TorchVisionDict", ] diff --git a/tests/test_rand_lambdad.py b/tests/test_rand_lambdad.py index 9808965e59..47ad1f73a8 100644 --- a/tests/test_rand_lambdad.py +++ b/tests/test_rand_lambdad.py @@ -14,21 +14,34 @@ import numpy as np from monai.transforms.utility.dictionary import RandLambdad +from monai.transforms import Randomizable + from tests.utils import NumpyImageTestCase2D +class RandTest(Randomizable): + """ + randomisable transform for testing. + """ + + def randomize(self, data=None): + self.set_random_state(seed=134) + self._a = self.R.random() + + def __call__(self, data): + self.randomize() + return data + self._a + + class TestRandLambdad(NumpyImageTestCase2D): def test_rand_lambdad_identity(self): img = self.imt data = {"img": img, "prop": 1.0} - def noise_func(x): - np.random.seed(123) - return x + np.random.randint(0, 10) - np.random.seed(None) + test_func = RandTest() - expected = {"img": noise_func(data["img"]), "prop": 1.0} - ret = RandLambdad(keys=["img", "prop"], func=noise_func, overwrite=[True, False])(data) + expected = {"img": test_func(data["img"]), "prop": 1.0} + ret = RandLambdad(keys=["img", "prop"], func=test_func, overwrite=[True, False])(data) self.assertTrue(np.allclose(expected["img"], ret["img"])) self.assertTrue(np.allclose(expected["prop"], ret["prop"])) From 3c0a2979d8e2216ceae94beb0c80ec57d4a52c30 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Thu, 4 Feb 2021 22:31:11 +0800 Subject: [PATCH 5/8] [DLMED] fix typo Signed-off-by: Nic Ma --- tests/test_rand_lambdad.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_rand_lambdad.py b/tests/test_rand_lambdad.py index 47ad1f73a8..417b767e0e 100644 --- a/tests/test_rand_lambdad.py +++ b/tests/test_rand_lambdad.py @@ -27,6 +27,7 @@ class RandTest(Randomizable): def randomize(self, data=None): self.set_random_state(seed=134) self._a = self.R.random() + self.set_random_state(seed=None) def __call__(self, data): self.randomize() From 2c2c0e14c9f67c06e11ecf51fe31c64176e7da35 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Thu, 4 Feb 2021 23:15:18 +0800 Subject: [PATCH 6/8] [DLMED] change to rtol=1e-05 Signed-off-by: Nic Ma --- tests/test_rand_rotate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_rand_rotate.py b/tests/test_rand_rotate.py index 79f3036454..6a72173055 100644 --- a/tests/test_rand_rotate.py +++ b/tests/test_rand_rotate.py @@ -52,7 +52,7 @@ def test_correct_results(self, degrees, keep_size, mode, padding_mode, align_cor self.imt[0, 0], -np.rad2deg(angle), (0, 1), not keep_size, order=_order, mode=_mode, prefilter=False ) expected = np.stack(expected).astype(np.float32) - np.testing.assert_allclose(expected, rotated[0]) + np.testing.assert_allclose(expected, rotated[0], rtol=1e-05) class TestRandRotate3D(NumpyImageTestCase3D): From b98f016ff02ac4fe08575b7f67fc048fe57352c5 Mon Sep 17 00:00:00 2001 From: monai-bot Date: Thu, 4 Feb 2021 15:19:33 +0000 Subject: [PATCH 7/8] [MONAI] python code formatting Signed-off-by: monai-bot --- tests/test_rand_lambdad.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/test_rand_lambdad.py b/tests/test_rand_lambdad.py index 417b767e0e..6a921fa86b 100644 --- a/tests/test_rand_lambdad.py +++ b/tests/test_rand_lambdad.py @@ -13,9 +13,8 @@ import numpy as np -from monai.transforms.utility.dictionary import RandLambdad from monai.transforms import Randomizable - +from monai.transforms.utility.dictionary import RandLambdad from tests.utils import NumpyImageTestCase2D From 93a08d52212ac53acc8b484bfae9aa915438d503 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Thu, 4 Feb 2021 17:30:44 +0000 Subject: [PATCH 8/8] fixes seeds Signed-off-by: Wenqi Li --- tests/test_rand_lambdad.py | 14 ++++++-------- tests/test_rand_rotate.py | 2 +- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/tests/test_rand_lambdad.py b/tests/test_rand_lambdad.py index 6a921fa86b..359da8857a 100644 --- a/tests/test_rand_lambdad.py +++ b/tests/test_rand_lambdad.py @@ -15,7 +15,6 @@ from monai.transforms import Randomizable from monai.transforms.utility.dictionary import RandLambdad -from tests.utils import NumpyImageTestCase2D class RandTest(Randomizable): @@ -24,26 +23,25 @@ class RandTest(Randomizable): """ def randomize(self, data=None): - self.set_random_state(seed=134) self._a = self.R.random() - self.set_random_state(seed=None) def __call__(self, data): self.randomize() return data + self._a -class TestRandLambdad(NumpyImageTestCase2D): +class TestRandLambdad(unittest.TestCase): def test_rand_lambdad_identity(self): - img = self.imt + img = np.zeros((10, 10)) data = {"img": img, "prop": 1.0} test_func = RandTest() - + test_func.set_random_state(seed=134) expected = {"img": test_func(data["img"]), "prop": 1.0} + test_func.set_random_state(seed=134) ret = RandLambdad(keys=["img", "prop"], func=test_func, overwrite=[True, False])(data) - self.assertTrue(np.allclose(expected["img"], ret["img"])) - self.assertTrue(np.allclose(expected["prop"], ret["prop"])) + np.testing.assert_allclose(expected["img"], ret["img"]) + np.testing.assert_allclose(expected["prop"], ret["prop"]) if __name__ == "__main__": diff --git a/tests/test_rand_rotate.py b/tests/test_rand_rotate.py index 6a72173055..79f3036454 100644 --- a/tests/test_rand_rotate.py +++ b/tests/test_rand_rotate.py @@ -52,7 +52,7 @@ def test_correct_results(self, degrees, keep_size, mode, padding_mode, align_cor self.imt[0, 0], -np.rad2deg(angle), (0, 1), not keep_size, order=_order, mode=_mode, prefilter=False ) expected = np.stack(expected).astype(np.float32) - np.testing.assert_allclose(expected, rotated[0], rtol=1e-05) + np.testing.assert_allclose(expected, rotated[0]) class TestRandRotate3D(NumpyImageTestCase3D):