Skip to content

Commit

Permalink
Fix interaction between typing_extensions and collections.abc (#503)
Browse files Browse the repository at this point in the history
Fixes #501

The idea is straightforward: special classes in typing_extensions
that have __extra__ should use a metaclass that fixes the problem in
GenericMeta.__subclasscheck__ on older versions of typing.

Note that overriding __subclasscheck__ tries to mimic the behaviour
in the new versions of typing. (I can't just use super().__subclasscheck__
on unaffected versions, since this changes call stack depth and therefore
breaks a sys._getframe hack on some other versions of typing.)
  • Loading branch information
ilevkivskyi committed Dec 1, 2017
1 parent 2613161 commit 5911b7e
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 3 deletions.
10 changes: 10 additions & 0 deletions typing_extensions/src_py2/test_typing_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,16 @@ def test_no_isinstance(self):

class CollectionsAbcTests(BaseTestCase):

def test_isinstance_collections(self):
self.assertNotIsInstance(1, collections.Mapping)
self.assertNotIsInstance(1, collections.Iterable)
self.assertNotIsInstance(1, collections.Container)
self.assertNotIsInstance(1, collections.Sized)
with self.assertRaises(TypeError):
isinstance(collections.deque(), typing_extensions.Deque[int])
with self.assertRaises(TypeError):
issubclass(collections.Counter, typing_extensions.Counter[str])

def test_contextmanager(self):
@contextlib.contextmanager
def manager():
Expand Down
11 changes: 11 additions & 0 deletions typing_extensions/src_py3/test_typing_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,17 @@ def test_get_type_hints_ClassVar(self):

class CollectionsAbcTests(BaseTestCase):

def test_isinstance_collections(self):
self.assertNotIsInstance(1, collections_abc.Mapping)
self.assertNotIsInstance(1, collections_abc.Iterable)
self.assertNotIsInstance(1, collections_abc.Container)
self.assertNotIsInstance(1, collections_abc.Sized)
if SUBCLASS_CHECK_FORBIDDEN:
with self.assertRaises(TypeError):
isinstance(collections.deque(), typing_extensions.Deque[int])
with self.assertRaises(TypeError):
issubclass(collections.Counter, typing_extensions.Counter[str])

@skipUnless(ASYNCIO, 'Python 3.5 and multithreading required')
def test_awaitable(self):
ns = {}
Expand Down
47 changes: 44 additions & 3 deletions typing_extensions/src_py3/typing_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,25 +412,57 @@ def _define_guard(type_name):
return False


class _ExtensionsGenericMeta(GenericMeta):
def __subclasscheck__(self, subclass):
"""This mimics a more modern GenericMeta.__subclasscheck__() logic
(that does not have problems with recursion) to work around interactions
between collections, typing, and typing_extensions on older
versions of Python, see https://github.com/python/typing/issues/501.
"""
if sys.version_info[:3] >= (3, 5, 3) or sys.version_info[:3] < (3, 5, 0):
if self.__origin__ is not None:
if sys._getframe(1).f_globals['__name__'] not in ['abc', 'functools']:
raise TypeError("Parameterized generics cannot be used with class "
"or instance checks")
return False
if not self.__extra__:
return super().__subclasscheck__(subclass)
res = self.__extra__.__subclasshook__(subclass)
if res is not NotImplemented:
return res
if self.__extra__ in subclass.__mro__:
return True
for scls in self.__extra__.__subclasses__():
if isinstance(scls, GenericMeta):
continue
if issubclass(subclass, scls):
return True
return False


if _define_guard('Awaitable'):
class Awaitable(typing.Generic[T_co], extra=collections_abc.Awaitable):
class Awaitable(typing.Generic[T_co], metaclass=_ExtensionsGenericMeta,
extra=collections_abc.Awaitable):
__slots__ = ()


if _define_guard('Coroutine'):
class Coroutine(Awaitable[V_co], typing.Generic[T_co, T_contra, V_co],
metaclass=_ExtensionsGenericMeta,
extra=collections_abc.Coroutine):
__slots__ = ()


if _define_guard('AsyncIterable'):
class AsyncIterable(typing.Generic[T_co],
metaclass=_ExtensionsGenericMeta,
extra=collections_abc.AsyncIterable):
__slots__ = ()


if _define_guard('AsyncIterator'):
class AsyncIterator(AsyncIterable[T_co],
metaclass=_ExtensionsGenericMeta,
extra=collections_abc.AsyncIterator):
__slots__ = ()

Expand All @@ -439,6 +471,7 @@ class AsyncIterator(AsyncIterable[T_co],
Deque = typing.Deque
elif _geqv_defined:
class Deque(collections.deque, typing.MutableSequence[T],
metaclass=_ExtensionsGenericMeta,
extra=collections.deque):
__slots__ = ()

Expand All @@ -448,6 +481,7 @@ def __new__(cls, *args, **kwds):
return _generic_new(collections.deque, cls, *args, **kwds)
else:
class Deque(collections.deque, typing.MutableSequence[T],
metaclass=_ExtensionsGenericMeta,
extra=collections.deque):
__slots__ = ()

Expand All @@ -461,6 +495,7 @@ def __new__(cls, *args, **kwds):
ContextManager = typing.ContextManager
elif hasattr(contextlib, 'AbstractContextManager'):
class ContextManager(typing.Generic[T_co],
metaclass=_ExtensionsGenericMeta,
extra=contextlib.AbstractContextManager):
__slots__ = ()
else:
Expand Down Expand Up @@ -493,6 +528,7 @@ def __subclasshook__(cls, C):
__all__.append('AsyncContextManager')
elif hasattr(contextlib, 'AbstractAsyncContextManager'):
class AsyncContextManager(typing.Generic[T_co],
metaclass=_ExtensionsGenericMeta,
extra=contextlib.AbstractAsyncContextManager):
__slots__ = ()

Expand Down Expand Up @@ -523,6 +559,7 @@ def __subclasshook__(cls, C):
DefaultDict = typing.DefaultDict
elif _geqv_defined:
class DefaultDict(collections.defaultdict, typing.MutableMapping[KT, VT],
metaclass=_ExtensionsGenericMeta,
extra=collections.defaultdict):

__slots__ = ()
Expand All @@ -533,6 +570,7 @@ def __new__(cls, *args, **kwds):
return _generic_new(collections.defaultdict, cls, *args, **kwds)
else:
class DefaultDict(collections.defaultdict, typing.MutableMapping[KT, VT],
metaclass=_ExtensionsGenericMeta,
extra=collections.defaultdict):

__slots__ = ()
Expand Down Expand Up @@ -569,7 +607,7 @@ def __new__(cls, *args, **kwds):
elif _geqv_defined:
class Counter(collections.Counter,
typing.Dict[T, int],
extra=collections.Counter):
metaclass=_ExtensionsGenericMeta, extra=collections.Counter):

__slots__ = ()

Expand All @@ -581,7 +619,7 @@ def __new__(cls, *args, **kwds):
else:
class Counter(collections.Counter,
typing.Dict[T, int],
extra=collections.Counter):
metaclass=_ExtensionsGenericMeta, extra=collections.Counter):

__slots__ = ()

Expand All @@ -598,6 +636,7 @@ def __new__(cls, *args, **kwds):
# ChainMap only exists in 3.3+
if _geqv_defined:
class ChainMap(collections.ChainMap, typing.MutableMapping[KT, VT],
metaclass=_ExtensionsGenericMeta,
extra=collections.ChainMap):

__slots__ = ()
Expand All @@ -608,6 +647,7 @@ def __new__(cls, *args, **kwds):
return _generic_new(collections.ChainMap, cls, *args, **kwds)
else:
class ChainMap(collections.ChainMap, typing.MutableMapping[KT, VT],
metaclass=_ExtensionsGenericMeta,
extra=collections.ChainMap):

__slots__ = ()
Expand All @@ -622,6 +662,7 @@ def __new__(cls, *args, **kwds):

if _define_guard('AsyncGenerator'):
class AsyncGenerator(AsyncIterator[T_co], typing.Generic[T_co, T_contra],
metaclass=_ExtensionsGenericMeta,
extra=collections_abc.AsyncGenerator):
__slots__ = ()

Expand Down

0 comments on commit 5911b7e

Please sign in to comment.