Skip to content

Commit bcdee8c

Browse files
Nic-Mamonai-botwyli
authored
1542 Add RandLambdad transform (#1546)
* [DLMED] add RandLambdad transform Signed-off-by: Nic Ma <nma@nvidia.com> * [DLMED] add doc-strings Signed-off-by: Nic Ma <nma@nvidia.com> * [MONAI] python code formatting Signed-off-by: monai-bot <monai.miccai2019@gmail.com> * [DLMED] update according to comments Signed-off-by: Nic Ma <nma@nvidia.com> * [DLMED] fix typo Signed-off-by: Nic Ma <nma@nvidia.com> * [DLMED] change to rtol=1e-05 Signed-off-by: Nic Ma <nma@nvidia.com> * [MONAI] python code formatting Signed-off-by: monai-bot <monai.miccai2019@gmail.com> * fixes seeds Signed-off-by: Wenqi Li <wenqil@nvidia.com> Co-authored-by: monai-bot <monai.miccai2019@gmail.com> Co-authored-by: Wenqi Li <wenqil@nvidia.com>
1 parent 2021f24 commit bcdee8c

File tree

4 files changed

+84
-2
lines changed

4 files changed

+84
-2
lines changed

docs/source/transforms.rst

+6
Original file line numberDiff line numberDiff line change
@@ -939,6 +939,12 @@ Utility (Dict)
939939
:members:
940940
:special-members: __call__
941941

942+
`RandLambdad`
943+
"""""""""""""
944+
.. autoclass:: RandLambdad
945+
:members:
946+
:special-members: __call__
947+
942948
`LabelToMaskd`
943949
""""""""""""""
944950
.. autoclass:: LabelToMaskd

monai/transforms/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,9 @@
298298
Lambdad,
299299
LambdaD,
300300
LambdaDict,
301+
RandLambdad,
302+
RandLambdaD,
303+
RandLambdaDict,
301304
RepeatChanneld,
302305
RepeatChannelD,
303306
RepeatChannelDict,

monai/transforms/utility/dictionary.py

+27-2
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
import copy
1919
import logging
20-
from typing import Callable, Dict, Hashable, List, Mapping, Optional, Sequence, Tuple, Union
20+
from typing import Any, Callable, Dict, Hashable, List, Mapping, Optional, Sequence, Tuple, Union
2121

2222
import numpy as np
2323
import torch
@@ -64,10 +64,12 @@
6464
"CopyItemsd",
6565
"ConcatItemsd",
6666
"Lambdad",
67+
"RandLambdad",
6768
"LabelToMaskd",
6869
"FgBgToIndicesd",
6970
"ConvertToMultiChannelBasedOnBratsClassesd",
7071
"AddExtremePointsChanneld",
72+
"TorchVisiond",
7173
"IdentityD",
7274
"IdentityDict",
7375
"AsChannelFirstD",
@@ -76,6 +78,8 @@
7678
"AsChannelLastDict",
7779
"AddChannelD",
7880
"AddChannelDict",
81+
"RandLambdaD",
82+
"RandLambdaDict",
7983
"RepeatChannelD",
8084
"RepeatChannelDict",
8185
"SplitChannelD",
@@ -106,7 +110,6 @@
106110
"ConvertToMultiChannelBasedOnBratsClassesDict",
107111
"AddExtremePointsChannelD",
108112
"AddExtremePointsChannelDict",
109-
"TorchVisiond",
110113
"TorchVisionD",
111114
"TorchVisionDict",
112115
]
@@ -621,6 +624,27 @@ def __call__(self, data):
621624
return d
622625

623626

627+
class RandLambdad(Lambdad, Randomizable):
628+
"""
629+
Randomizable version :py:class:`monai.transforms.Lambdad`, the input `func` contains random logic.
630+
It's a randomizable transform so `CacheDataset` will not execute it and cache the results.
631+
632+
Args:
633+
keys: keys of the corresponding items to be transformed.
634+
See also: :py:class:`monai.transforms.compose.MapTransform`
635+
func: Lambda/function to be applied. It also can be a sequence of Callable,
636+
each element corresponds to a key in ``keys``.
637+
overwrite: whether to overwrite the original data in the input dictionary with lamdbda function output.
638+
default to True. it also can be a sequence of bool, each element corresponds to a key in ``keys``.
639+
640+
For more details, please check :py:class:`monai.transforms.Lambdad`.
641+
642+
"""
643+
644+
def randomize(self, data: Any) -> None:
645+
pass
646+
647+
624648
class LabelToMaskd(MapTransform):
625649
"""
626650
Dictionary-based wrapper of :py:class:`monai.transforms.LabelToMask`.
@@ -830,3 +854,4 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torc
830854
) = ConvertToMultiChannelBasedOnBratsClassesd
831855
AddExtremePointsChannelD = AddExtremePointsChannelDict = AddExtremePointsChanneld
832856
TorchVisionD = TorchVisionDict = TorchVisiond
857+
RandLambdaD = RandLambdaDict = RandLambdad

tests/test_rand_lambdad.py

+48
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# Copyright 2020 - 2021 MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
import unittest
13+
14+
import numpy as np
15+
16+
from monai.transforms import Randomizable
17+
from monai.transforms.utility.dictionary import RandLambdad
18+
19+
20+
class RandTest(Randomizable):
21+
"""
22+
randomisable transform for testing.
23+
"""
24+
25+
def randomize(self, data=None):
26+
self._a = self.R.random()
27+
28+
def __call__(self, data):
29+
self.randomize()
30+
return data + self._a
31+
32+
33+
class TestRandLambdad(unittest.TestCase):
34+
def test_rand_lambdad_identity(self):
35+
img = np.zeros((10, 10))
36+
data = {"img": img, "prop": 1.0}
37+
38+
test_func = RandTest()
39+
test_func.set_random_state(seed=134)
40+
expected = {"img": test_func(data["img"]), "prop": 1.0}
41+
test_func.set_random_state(seed=134)
42+
ret = RandLambdad(keys=["img", "prop"], func=test_func, overwrite=[True, False])(data)
43+
np.testing.assert_allclose(expected["img"], ret["img"])
44+
np.testing.assert_allclose(expected["prop"], ret["prop"])
45+
46+
47+
if __name__ == "__main__":
48+
unittest.main()

0 commit comments

Comments
 (0)