Skip to content

Commit

Permalink
Fix callable protocols (#599)
Browse files Browse the repository at this point in the history
  • Loading branch information
MiguelMonteiro authored Nov 7, 2024
1 parent a4b9b3e commit 006d9c4
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 2 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ Fixed
<https://github.com/omni-us/jsonargparse/pull/608>`__).
- Failure when resolving forward references from dataclass parameter types
(`#611 <https://github.com/omni-us/jsonargparse/pull/611>`__).
- Fix callable protocol inheritance.
(`#599 <https://github.com/omni-us/jsonargparse/pull/599>`__).

Changed
^^^^^^^
Expand Down
21 changes: 19 additions & 2 deletions jsonargparse/_typehints.py
Original file line number Diff line number Diff line change
Expand Up @@ -1087,15 +1087,32 @@ def adapt_typehints(
return val


protocol_irrelevant_dunder_methods = {
"__init__",
"__new__",
"__del__",
"__getattr__",
"__getattribute__",
"__setattr__",
"__delattr__",
"__reduce__",
"__reduce_ex__",
"__getstate__",
"__setstate__",
"__subclasshook__",
}


def implements_protocol(value, protocol) -> bool:
from jsonargparse._parameter_resolvers import get_signature_parameters
from jsonargparse._postponed_annotations import get_return_type

if not inspect.isclass(value):
if not inspect.isclass(value) or value is object:
return False
members = 0
for name, _ in inspect.getmembers(protocol, predicate=inspect.isfunction):
if name.startswith("_"):
is_dunder = name.startswith("__") and name.endswith("__")
if (not is_dunder and name.startswith("_")) or (is_dunder and name in protocol_irrelevant_dunder_methods):
continue
if not hasattr(value, name):
return False
Expand Down
45 changes: 45 additions & 0 deletions jsonargparse_tests/test_subclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -1500,6 +1500,51 @@ def test_parse_implements_protocol(parser):
ctx.match("Not a valid subclass of Interface")


# callable protocol tests


class CallableInterface(Protocol):
def __call__(self, items: List[float]) -> List[float]: ...


class ImplementsCallableInterface1:
def __init__(self, batch_size: int):
self.batch_size = batch_size

def __call__(self, items: List[float]) -> List[float]:
return items


class NotImplementsCallableInterface1:
def __call__(self, items: str) -> List[float]:
return []


class NotImplementsCallableInterface2:
def __call__(self, items: List[float], extra: int) -> List[float]:
return items


class NotImplementsCallableInterface3:
def __call__(self, items: List[float]) -> None:
return


@pytest.mark.parametrize(
"expected, value",
[
(True, ImplementsCallableInterface1),
(False, ImplementsCallableInterface1(1)),
(False, NotImplementsCallableInterface1),
(False, NotImplementsCallableInterface2),
(False, NotImplementsCallableInterface3),
(False, object),
],
)
def test_implements_callable_protocol(expected, value):
assert implements_protocol(value, CallableInterface) is expected


# parameter skip tests


Expand Down

0 comments on commit 006d9c4

Please sign in to comment.