Skip to content

Commit

Permalink
471- fixes deprecated args (#3447)
Browse files Browse the repository at this point in the history
* fixes deprecated args

Signed-off-by: Wenqi Li <wenqil@nvidia.com>

* update based on comments

Signed-off-by: Wenqi Li <wenqil@nvidia.com>
  • Loading branch information
wyli authored Dec 7, 2021
1 parent 98c1c43 commit 29e9ab3
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 7 deletions.
15 changes: 11 additions & 4 deletions monai/utils/deprecate_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
# limitations under the License.

import inspect
import sys
import warnings
from functools import wraps
from types import FunctionType
Expand Down Expand Up @@ -62,7 +63,7 @@ def deprecated(

# if version_val.startswith("0+"):
# # version unknown, set version_val to a large value (assuming the latest version)
# version_val = "100"
# version_val = f"{sys.maxsize}"
if since is not None and removed is not None and not version_leq(since, removed):
raise ValueError(f"since must be less or equal to removed, got since={since}, removed={removed}.")
is_not_yet_deprecated = since is not None and version_val != since and version_leq(version_val, since)
Expand Down Expand Up @@ -144,14 +145,16 @@ def deprecated_arg(
msg_suffix: message appended to warning/exception detailing reasons for deprecation and what to use instead.
version_val: (used for testing) version to compare since and removed against, default is MONAI version.
new_name: name of position or keyword argument to replace the deprecated argument.
if it is specified and the signature of the decorated function has a `kwargs`, the value to the
deprecated argument `name` will be removed.
Returns:
Decorated callable which warns or raises exception when deprecated argument used.
"""

if version_val.startswith("0+") or not f"{version_val}".strip()[0].isdigit():
# version unknown, set version_val to a large value (assuming the latest version)
version_val = "100"
version_val = f"{sys.maxsize}"
if since is not None and removed is not None and not version_leq(since, removed):
raise ValueError(f"since must be less or equal to removed, got since={since}, removed={removed}.")
is_not_yet_deprecated = since is not None and version_val != since and version_leq(version_val, since)
Expand Down Expand Up @@ -197,9 +200,13 @@ def _wrapper(*args, **kwargs):
# multiple values for new_name using both args and kwargs
kwargs.pop(new_name, None)
binding = sig.bind(*args, **kwargs).arguments

positional_found = name in binding
kw_found = "kwargs" in binding and name in binding["kwargs"]
kw_found = False
for k, param in sig.parameters.items():
if param.kind == inspect.Parameter.VAR_KEYWORD and k in binding and name in binding[k]:
kw_found = True
# if the deprecated arg is found in the **kwargs, it should be removed
kwargs.pop(name, None)

if positional_found or kw_found:
if is_removed:
Expand Down
45 changes: 42 additions & 3 deletions tests/test_deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def test_warning(self):
def foo2():
pass

print(foo2())
foo2() # should not raise any warnings

def test_warning_milestone(self):
"""Test deprecated decorator with `since` and `removed` set for a milestone version"""
Expand Down Expand Up @@ -172,7 +172,7 @@ def test_arg_warn2(self):
"""Test deprecated_arg decorator with just `since` set."""

@deprecated_arg("b", since=self.prev_version, version_val=self.test_version)
def afoo2(a, **kwargs):
def afoo2(a, **kw):
pass

afoo2(1) # ok when no b provided
Expand Down Expand Up @@ -235,6 +235,19 @@ def afoo4(a, b=None):

self.assertRaises(DeprecatedError, lambda: afoo4(1, b=2))

def test_arg_except3_unknown(self):
"""
Test deprecated_arg decorator raises exception with `removed` set in the past.
with unknown version and kwargs
"""

@deprecated_arg("b", removed=self.prev_version, version_val="0+untagged.1.g3131155")
def afoo4(a, b=None, **kwargs):
pass

self.assertRaises(DeprecatedError, lambda: afoo4(1, b=2))
self.assertRaises(DeprecatedError, lambda: afoo4(1, b=2, c=3))

def test_replacement_arg(self):
"""
Test deprecated arg being replaced.
Expand All @@ -245,10 +258,36 @@ def afoo4(a, b=None):
return a

self.assertEqual(afoo4(b=2), 2)
# self.assertRaises(DeprecatedError, lambda: afoo4(1, b=2))
self.assertEqual(afoo4(1, b=2), 1) # new name is in use
self.assertEqual(afoo4(a=1, b=2), 1) # prefers the new arg

def test_replacement_arg1(self):
"""
Test deprecated arg being replaced with kwargs.
"""

@deprecated_arg("b", new_name="a", since=self.prev_version, version_val=self.test_version)
def afoo4(a, *args, **kwargs):
return a

self.assertEqual(afoo4(b=2), 2)
self.assertEqual(afoo4(1, b=2, c=3), 1) # new name is in use
self.assertEqual(afoo4(a=1, b=2, c=3), 1) # prefers the new arg

def test_replacement_arg2(self):
"""
Test deprecated arg (with a default value) being replaced.
"""

@deprecated_arg("b", new_name="a", since=self.prev_version, version_val=self.test_version)
def afoo4(a, b=None, **kwargs):
return a, kwargs

self.assertEqual(afoo4(b=2, c=3), (2, {"c": 3}))
self.assertEqual(afoo4(1, b=2, c=3), (1, {"c": 3})) # new name is in use
self.assertEqual(afoo4(a=1, b=2, c=3), (1, {"c": 3})) # prefers the new arg
self.assertEqual(afoo4(1, 2, c=3), (1, {"c": 3})) # prefers the new positional arg


if __name__ == "__main__":
unittest.main()

0 comments on commit 29e9ab3

Please sign in to comment.