Skip to content

Commit

Permalink
Merge branch 'master' into typeguard_ignore_annotation
Browse files Browse the repository at this point in the history
  • Loading branch information
agronholm authored Sep 21, 2024
2 parents d6827dc + 604b08d commit 0b47f77
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 14 deletions.
5 changes: 5 additions & 0 deletions docs/versionhistory.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@ Version history
This library adheres to
`Semantic Versioning 2.0 <https://semver.org/#semantic-versioning-200>`_.

**UNRELEASED**

- Fixed basic support for intersection protocols
(`#490 <https://github.com/agronholm/typeguard/pull/490>`_; PR by @antonagestam)

**4.3.0** (2024-05-27)

- Added support for checking against static protocols
Expand Down
22 changes: 8 additions & 14 deletions src/typeguard/_checkers.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,19 +654,13 @@ def check_protocol(
else:
return

# Collect a set of methods and non-method attributes present in the protocol
ignored_attrs = set(dir(typing.Protocol)) | {
"__annotations__",
"__non_callable_proto_members__",
}
expected_methods: dict[str, tuple[Any, Any]] = {}
expected_noncallable_members: dict[str, Any] = {}
for attrname in dir(origin_type):
# Skip attributes present in typing.Protocol
if attrname in ignored_attrs:
continue
origin_annotations = typing.get_type_hints(origin_type)

for attrname in typing_extensions.get_protocol_members(origin_type):
member = getattr(origin_type, attrname, None)

member = getattr(origin_type, attrname)
if callable(member):
signature = inspect.signature(member)
argtypes = [
Expand All @@ -681,10 +675,10 @@ def check_protocol(
)
expected_methods[attrname] = argtypes, return_annotation
else:
expected_noncallable_members[attrname] = member

for attrname, annotation in typing.get_type_hints(origin_type).items():
expected_noncallable_members[attrname] = annotation
try:
expected_noncallable_members[attrname] = origin_annotations[attrname]
except KeyError:
expected_noncallable_members[attrname] = member

subject_annotations = typing.get_type_hints(subject)

Expand Down
83 changes: 83 additions & 0 deletions tests/test_checkers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,17 @@
Dict,
ForwardRef,
FrozenSet,
Iterable,
Iterator,
List,
Literal,
Mapping,
MutableMapping,
Optional,
Protocol,
Sequence,
Set,
Sized,
TextIO,
Tuple,
Type,
Expand Down Expand Up @@ -995,6 +998,86 @@ def test_text_real_file(self, tmp_path: Path):
check_type(f, TextIO)


class TestIntersectingProtocol:
SIT = TypeVar("SIT", covariant=True)

class SizedIterable(
Sized,
Iterable[SIT],
Protocol[SIT],
): ...

@pytest.mark.parametrize(
"subject, predicate_type",
(
pytest.param(
(),
SizedIterable,
id="empty_tuple_unspecialized",
),
pytest.param(
range(2),
SizedIterable,
id="range",
),
pytest.param(
(),
SizedIterable[int],
id="empty_tuple_int_specialized",
),
pytest.param(
(1, 2, 3),
SizedIterable[int],
id="tuple_int_specialized",
),
pytest.param(
("1", "2", "3"),
SizedIterable[str],
id="tuple_str_specialized",
),
),
)
def test_valid_member_passes(self, subject: object, predicate_type: type) -> None:
for _ in range(2): # Makes sure that the cache is also exercised
check_type(subject, predicate_type)

xfail_nested_protocol_checks = pytest.mark.xfail(
reason="false negative due to missing support for nested protocol checks",
)

@pytest.mark.parametrize(
"subject, predicate_type",
(
pytest.param(
(1 for _ in ()),
SizedIterable,
id="generator",
),
pytest.param(
range(2),
SizedIterable[str],
marks=xfail_nested_protocol_checks,
id="range_str_specialized",
),
pytest.param(
(1, 2, 3),
SizedIterable[str],
marks=xfail_nested_protocol_checks,
id="int_tuple_str_specialized",
),
pytest.param(
("1", "2", "3"),
SizedIterable[int],
marks=xfail_nested_protocol_checks,
id="str_tuple_int_specialized",
),
),
)
def test_raises_for_non_member(self, subject: object, predicate_type: type) -> None:
with pytest.raises(TypeCheckError):
check_type(subject, predicate_type)


@pytest.mark.parametrize(
"instantiate, annotation",
[
Expand Down

0 comments on commit 0b47f77

Please sign in to comment.