Skip to content

Commit

Permalink
Backport recent improvements to the implementation of Protocol (#324)
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexWaygood authored Jan 20, 2024
1 parent f84880d commit 004b893
Show file tree
Hide file tree
Showing 3 changed files with 146 additions and 27 deletions.
11 changes: 11 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,14 @@
# Unreleased

- Speedup `issubclass()` checks against simple runtime-checkable protocols by
around 6% (backporting https://github.com/python/cpython/pull/112717, by Alex
Waygood).
- Fix a regression in the implementation of protocols where `typing.Protocol`
classes that were not marked as `@runtime_checkable` would be unnecessarily
introspected, potentially causing exceptions to be raised if the protocol had
problematic members. Patch by Alex Waygood, backporting
https://github.com/python/cpython/pull/113401.

# Release 4.9.0 (December 9, 2023)

This feature release adds `typing_extensions.ReadOnly`, as specified
Expand Down
61 changes: 55 additions & 6 deletions src/test_typing_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2817,8 +2817,8 @@ def meth(self): pass # noqa: B027

self.assertNotIn("__protocol_attrs__", vars(NonP))
self.assertNotIn("__protocol_attrs__", vars(NonPR))
self.assertNotIn("__callable_proto_members_only__", vars(NonP))
self.assertNotIn("__callable_proto_members_only__", vars(NonPR))
self.assertNotIn("__non_callable_proto_members__", vars(NonP))
self.assertNotIn("__non_callable_proto_members__", vars(NonPR))

acceptable_extra_attrs = {
'_is_protocol', '_is_runtime_protocol', '__parameters__',
Expand Down Expand Up @@ -2891,11 +2891,26 @@ def __subclasshook__(cls, other):
@skip_if_py312b1
def test_issubclass_fails_correctly(self):
@runtime_checkable
class P(Protocol):
class NonCallableMembers(Protocol):
x = 1

class NotRuntimeCheckable(Protocol):
def callable_member(self) -> int: ...

@runtime_checkable
class RuntimeCheckable(Protocol):
def callable_member(self) -> int: ...

class C: pass
with self.assertRaisesRegex(TypeError, r"issubclass\(\) arg 1 must be a class"):
issubclass(C(), P)

# These three all exercise different code paths,
# but should result in the same error message:
for protocol in NonCallableMembers, NotRuntimeCheckable, RuntimeCheckable:
with self.subTest(proto_name=protocol.__name__):
with self.assertRaisesRegex(
TypeError, r"issubclass\(\) arg 1 must be a class"
):
issubclass(C(), protocol)

def test_defining_generic_protocols(self):
T = TypeVar('T')
Expand Down Expand Up @@ -3456,6 +3471,7 @@ def method(self) -> None: ...

@skip_if_early_py313_alpha
def test_protocol_issubclass_error_message(self):
@runtime_checkable
class Vec2D(Protocol):
x: float
y: float
Expand All @@ -3471,6 +3487,39 @@ def square_norm(self) -> float:
with self.assertRaisesRegex(TypeError, re.escape(expected_error_message)):
issubclass(int, Vec2D)

def test_nonruntime_protocol_interaction_with_evil_classproperty(self):
class classproperty:
def __get__(self, instance, type):
raise RuntimeError("NO")

class Commentable(Protocol):
evil = classproperty()

# recognised as a protocol attr,
# but not actually accessed by the protocol metaclass
# (which would raise RuntimeError) for non-runtime protocols.
# See gh-113320
self.assertEqual(get_protocol_members(Commentable), {"evil"})

def test_runtime_protocol_interaction_with_evil_classproperty(self):
class CustomError(Exception): pass

class classproperty:
def __get__(self, instance, type):
raise CustomError

with self.assertRaises(TypeError) as cm:
@runtime_checkable
class Commentable(Protocol):
evil = classproperty()

exc = cm.exception
self.assertEqual(
exc.args[0],
"Failed to determine whether protocol member 'evil' is a method member"
)
self.assertIs(type(exc.__cause__), CustomError)


class Point2DGeneric(Generic[T], TypedDict):
a: T
Expand Down Expand Up @@ -5263,7 +5312,7 @@ def test_typing_extensions_defers_when_possible(self):
'SupportsRound', 'Unpack',
}
if sys.version_info < (3, 13):
exclude |= {'NamedTuple', 'Protocol'}
exclude |= {'NamedTuple', 'Protocol', 'runtime_checkable'}
if not hasattr(typing, 'ReadOnly'):
exclude |= {'TypedDict', 'is_typeddict'}
for item in typing_extensions.__all__:
Expand Down
101 changes: 80 additions & 21 deletions src/typing_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,7 @@ def clear_overloads():
"_is_runtime_protocol", "__dict__", "__slots__", "__parameters__",
"__orig_bases__", "__module__", "_MutableMapping__marker", "__doc__",
"__subclasshook__", "__orig_class__", "__init__", "__new__",
"__protocol_attrs__", "__callable_proto_members_only__",
"__protocol_attrs__", "__non_callable_proto_members__",
"__match_args__",
}

Expand Down Expand Up @@ -521,6 +521,22 @@ def _no_init(self, *args, **kwargs):
if type(self)._is_protocol:
raise TypeError('Protocols cannot be instantiated')

def _type_check_issubclass_arg_1(arg):
"""Raise TypeError if `arg` is not an instance of `type`
in `issubclass(arg, <protocol>)`.
In most cases, this is verified by type.__subclasscheck__.
Checking it again unnecessarily would slow down issubclass() checks,
so, we don't perform this check unless we absolutely have to.
For various error paths, however,
we want to ensure that *this* error message is shown to the user
where relevant, rather than a typing.py-specific error message.
"""
if not isinstance(arg, type):
# Same error message as for issubclass(1, int).
raise TypeError('issubclass() arg 1 must be a class')

# Inheriting from typing._ProtocolMeta isn't actually desirable,
# but is necessary to allow typing.Protocol and typing_extensions.Protocol
# to mix without getting TypeErrors about "metaclass conflict"
Expand Down Expand Up @@ -551,11 +567,6 @@ def __init__(cls, *args, **kwargs):
abc.ABCMeta.__init__(cls, *args, **kwargs)
if getattr(cls, "_is_protocol", False):
cls.__protocol_attrs__ = _get_protocol_attrs(cls)
# PEP 544 prohibits using issubclass()
# with protocols that have non-method members.
cls.__callable_proto_members_only__ = all(
callable(getattr(cls, attr, None)) for attr in cls.__protocol_attrs__
)

def __subclasscheck__(cls, other):
if cls is Protocol:
Expand All @@ -564,26 +575,23 @@ def __subclasscheck__(cls, other):
getattr(cls, '_is_protocol', False)
and not _allow_reckless_class_checks()
):
if not isinstance(other, type):
# Same error message as for issubclass(1, int).
raise TypeError('issubclass() arg 1 must be a class')
if not getattr(cls, '_is_runtime_protocol', False):
_type_check_issubclass_arg_1(other)
raise TypeError(
"Instance and class checks can only be used with "
"@runtime_checkable protocols"
)
if (
not cls.__callable_proto_members_only__
# this attribute is set by @runtime_checkable:
cls.__non_callable_proto_members__
and cls.__dict__.get("__subclasshook__") is _proto_hook
):
non_method_attrs = sorted(
attr for attr in cls.__protocol_attrs__
if not callable(getattr(cls, attr, None))
)
_type_check_issubclass_arg_1(other)
non_method_attrs = sorted(cls.__non_callable_proto_members__)
raise TypeError(
"Protocols with non-method members don't support issubclass()."
f" Non-method members: {str(non_method_attrs)[1:-1]}."
)
if not getattr(cls, '_is_runtime_protocol', False):
raise TypeError(
"Instance and class checks can only be used with "
"@runtime_checkable protocols"
)
return abc.ABCMeta.__subclasscheck__(cls, other)

def __instancecheck__(cls, instance):
Expand All @@ -610,7 +618,8 @@ def __instancecheck__(cls, instance):
val = inspect.getattr_static(instance, attr)
except AttributeError:
break
if val is None and callable(getattr(cls, attr, None)):
# this attribute is set by @runtime_checkable:
if val is None and attr not in cls.__non_callable_proto_members__:
break
else:
return True
Expand Down Expand Up @@ -678,8 +687,58 @@ def __init_subclass__(cls, *args, **kwargs):
cls.__init__ = _no_init


if sys.version_info >= (3, 13):
runtime_checkable = typing.runtime_checkable
else:
def runtime_checkable(cls):
"""Mark a protocol class as a runtime protocol.
Such protocol can be used with isinstance() and issubclass().
Raise TypeError if applied to a non-protocol class.
This allows a simple-minded structural check very similar to
one trick ponies in collections.abc such as Iterable.
For example::
@runtime_checkable
class Closable(Protocol):
def close(self): ...
assert isinstance(open('/some/file'), Closable)
Warning: this will check only the presence of the required methods,
not their type signatures!
"""
if not issubclass(cls, typing.Generic) or not getattr(cls, '_is_protocol', False):
raise TypeError('@runtime_checkable can be only applied to protocol classes,'
' got %r' % cls)
cls._is_runtime_protocol = True

# Only execute the following block if it's a typing_extensions.Protocol class.
# typing.Protocol classes don't need it.
if isinstance(cls, _ProtocolMeta):
# PEP 544 prohibits using issubclass()
# with protocols that have non-method members.
# See gh-113320 for why we compute this attribute here,
# rather than in `_ProtocolMeta.__init__`
cls.__non_callable_proto_members__ = set()
for attr in cls.__protocol_attrs__:
try:
is_callable = callable(getattr(cls, attr, None))
except Exception as e:
raise TypeError(
f"Failed to determine whether protocol member {attr!r} "
"is a method member"
) from e
else:
if not is_callable:
cls.__non_callable_proto_members__.add(attr)

return cls


# The "runtime" alias exists for backwards compatibility.
runtime = runtime_checkable = typing.runtime_checkable
runtime = runtime_checkable


# Our version of runtime-checkable protocols is faster on Python 3.8-3.11
Expand Down

0 comments on commit 004b893

Please sign in to comment.