Skip to content

Commit

Permalink
[DLMED] update according to comments
Browse files Browse the repository at this point in the history
Signed-off-by: Nic Ma <nma@nvidia.com>
  • Loading branch information
Nic-Ma committed Dec 21, 2021
1 parent 649a7c5 commit c6c3a35
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 23 deletions.
10 changes: 6 additions & 4 deletions monai/transforms/compose.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
"""

import warnings
from copy import deepcopy
from typing import Any, Callable, Mapping, Optional, Sequence, Union

import numpy as np
Expand Down Expand Up @@ -245,16 +246,17 @@ def inverse(self, data):
if not isinstance(data, Mapping):
raise RuntimeError("Inverse only implemented for Mapping (dictionary) data")

d = deepcopy(dict(data))
# loop until we get an index and then break (since they'll all be the same)
key = self.__class__.__name__
if self.trace_key(key) not in data:
if self.trace_key(key) not in d:
raise RuntimeError("can not find the index of transform have been applied.")

# get the index of the applied OneOf transform
index = self.get_most_recent_transform(data, key)[TraceKeys.EXTRA_INFO]["index"]
index = self.get_most_recent_transform(d, key)[TraceKeys.EXTRA_INFO]["index"]
# and then remove the OneOf transform
self.pop_transform(data, key)
self.pop_transform(d, key)

_transform = self.transforms[index]
# apply the inverse
return _transform.inverse(data) if isinstance(_transform, InvertibleTransform) else data
return _transform.inverse(d) if isinstance(_transform, InvertibleTransform) else d
36 changes: 17 additions & 19 deletions tests/test_one_of.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,31 +150,29 @@ def _match(a, b):
@parameterized.expand(TEST_INVERSES)
def test_inverse(self, transform, invertible):
data = {k: (i + 1) * 10.0 for i, k in enumerate(KEYS)}
key = OneOf.__name__
fwd_data = transform(data)

if invertible:
for k in KEYS:
t = fwd_data[TraceableTransform.trace_key(k)][-1]
# make sure the OneOf index was stored
self.assertEqual(t[TraceKeys.CLASS_NAME], OneOf.__name__)
# make sure index exists and is in bounds
self.assertTrue(0 <= t[TraceKeys.EXTRA_INFO]["index"] < len(transform))
t = fwd_data[TraceableTransform.trace_key(key)][-1]
# make sure the OneOf index was stored
self.assertEqual(t[TraceKeys.CLASS_NAME], key)
# make sure index exists and is in bounds
self.assertTrue(0 <= t[TraceKeys.EXTRA_INFO]["index"] < len(transform))

# call the inverse
fwd_inv_data = transform.inverse(fwd_data)

if invertible:
for k in KEYS:
# check transform was removed
self.assertTrue(
len(fwd_inv_data[TraceableTransform.trace_key(k)]) < len(fwd_data[TraceableTransform.trace_key(k)])
)
# check data is same as original (and different from forward)
self.assertEqual(fwd_inv_data[k], data[k])
# check transform was removed
self.assertTrue(
len(fwd_inv_data[TraceableTransform.trace_key(key)]) < len(fwd_data[TraceableTransform.trace_key(key)])
)
# check data is same as original (and different from forward)
for k, v in data.items():
if invertible:
self.assertEqual(fwd_inv_data[k], v)
self.assertNotEqual(fwd_inv_data[k], fwd_data[k])
else:
# if not invertible, should not change the data
self.assertDictEqual(fwd_data, fwd_inv_data)
else:
# if not invertible, should not change the data
self.assertEqual(fwd_inv_data[k], fwd_data[k])

def test_inverse_compose(self):
transform = Compose(
Expand Down

0 comments on commit c6c3a35

Please sign in to comment.