From b405e43c5fef54f169e7d695d5e6fc243aa96475 Mon Sep 17 00:00:00 2001 From: "Jason R. Coombs" Date: Sun, 27 Aug 2023 16:07:57 -0400 Subject: [PATCH] Initial work toward refreshing the implementation (broken). --- singledispatch/__init__.py | 129 ++++--- test_singledispatch.py | 692 +++++++++++++++++++++++++++---------- 2 files changed, 589 insertions(+), 232 deletions(-) diff --git a/singledispatch/__init__.py b/singledispatch/__init__.py index f0bd42d..3096b95 100644 --- a/singledispatch/__init__.py +++ b/singledispatch/__init__.py @@ -1,8 +1,8 @@ __all__ = ['singledispatch', 'singledispatchmethod'] -from weakref import WeakKeyDictionary +from types import GenericAlias -from .helpers import MappingProxyType, get_cache_token, get_type_hints, update_wrapper +from .helpers import get_cache_token, update_wrapper ################################################################################ # singledispatch() - single-dispatch generic function decorator @@ -12,7 +12,7 @@ def _c3_merge(sequences): """Merges MROs in *sequences* to a single MRO using the C3 algorithm. - Adapted from http://www.python.org/download/releases/2.3/mro/. + Adapted from https://www.python.org/download/releases/2.3/mro/. """ result = [] @@ -87,7 +87,7 @@ def _c3_mro(cls, abcs=None): ) -def _compose_mro(cls, types): # noqa: C901 +def _compose_mro(cls, types): """Calculates the method resolution order for a given class *cls*. Includes relevant abstract base classes (with their respective bases) from @@ -98,7 +98,12 @@ def _compose_mro(cls, types): # noqa: C901 # Remove entries which are already present in the __mro__ or unrelated. def is_related(typ): - return typ not in bases and hasattr(typ, '__mro__') and issubclass(cls, typ) + return ( + typ not in bases + and hasattr(typ, '__mro__') + and not isinstance(typ, GenericAlias) + and issubclass(cls, typ) + ) types = [n for n in types if is_related(n)] @@ -117,7 +122,7 @@ def is_strict_base(typ): mro = [] for typ in types: found = [] - for sub in filter(_safe, typ.__subclasses__()): + for sub in typ.__subclasses__(): if sub not in bases and issubclass(cls, sub): found.append([s for s in sub.__mro__ if s in type_set]) if not found: @@ -132,13 +137,6 @@ def is_strict_base(typ): return _c3_mro(cls, abcs=mro) -def _safe(class_): - """ - Return if the class is safe for testing as subclass. Ref #2. - """ - return not getattr(class_, '__origin__', None) - - def _find_impl(cls, registry): """Returns the best matching implementation from *registry* for type *cls*. @@ -161,33 +159,14 @@ def _find_impl(cls, registry): and match not in cls.__mro__ and not issubclass(match, t) ): - raise RuntimeError(f"Ambiguous dispatch: {match} or {t}") + raise RuntimeError("Ambiguous dispatch: {} or {}".format(match, t)) break if t in registry: match = t return registry.get(match) -def _validate_annotation(annotation): - """Determine if an annotation is valid for registration. - - An annotation is considered valid for use in registration if it is an - instance of ``type`` and not a generic type from ``typing``. - """ - try: - # In Python earlier than 3.7, the classes in typing are considered - # instances of type, but they invalid for registering single dispatch - # functions so check against GenericMeta instead. - from typing import GenericMeta - - valid = not isinstance(annotation, GenericMeta) - except ImportError: - # In Python 3.7+, classes in typing are not instances of type. - valid = isinstance(annotation, type) - return valid - - -def singledispatch(func): # noqa: C901 +def singledispatch(func): """Single-dispatch generic function decorator. Transforms a function into a generic function, which can have different @@ -196,13 +175,14 @@ def singledispatch(func): # noqa: C901 implementations can be registered using the register() attribute of the generic function. """ - registry = {} - dispatch_cache = WeakKeyDictionary() - - def ns(): - pass + # There are many programs that use functools without singledispatch, so we + # trade-off making singledispatch marginally slower for the benefit of + # making start-up of such applications slightly faster. + import types, weakref # noqa: E401 - ns.cache_token = None + registry = {} + dispatch_cache = weakref.WeakKeyDictionary() + cache_token = None def dispatch(cls): """generic_func.dispatch(cls) -> @@ -211,11 +191,12 @@ def dispatch(cls): for the given *cls* registered on *generic_func*. """ - if ns.cache_token is not None: + nonlocal cache_token + if cache_token is not None: current_token = get_cache_token() - if ns.cache_token != current_token: + if cache_token != current_token: dispatch_cache.clear() - ns.cache_token = current_token + cache_token = current_token try: impl = dispatch_cache[cls] except KeyError: @@ -226,15 +207,36 @@ def dispatch(cls): dispatch_cache[cls] = impl return impl + def _is_union_type(cls): + from typing import get_origin, Union + + return get_origin(cls) in {Union, types.UnionType} + + def _is_valid_dispatch_type(cls): + if isinstance(cls, type): + return True + from typing import get_args + + return _is_union_type(cls) and all( + isinstance(arg, type) for arg in get_args(cls) + ) + def register(cls, func=None): """generic_func.register(cls, func) -> func Registers a new implementation for the given *cls* on a *generic_func*. """ - if func is None: - if isinstance(cls, type): + nonlocal cache_token + if _is_valid_dispatch_type(cls): + if func is None: return lambda f: register(cls, f) + else: + if func is not None: + raise TypeError( + f"Invalid first argument to `register()`. " + f"{cls!r} is not a class or union type." + ) ann = getattr(cls, '__annotations__', {}) if not ann: raise TypeError( @@ -244,22 +246,37 @@ def register(cls, func=None): ) func = cls + # only import typing if annotation parsing is necessary + from typing import get_type_hints + argname, cls = next(iter(get_type_hints(func).items())) - if not _validate_annotation(cls): - raise TypeError( - f"Invalid annotation for {argname!r}. " f"{cls!r} is not a class." - ) - registry[cls] = func - if ns.cache_token is None and hasattr(cls, '__abstractmethods__'): - ns.cache_token = get_cache_token() + if not _is_valid_dispatch_type(cls): + if _is_union_type(cls): + raise TypeError( + f"Invalid annotation for {argname!r}. " + f"{cls!r} not all arguments are classes." + ) + else: + raise TypeError( + f"Invalid annotation for {argname!r}. " + f"{cls!r} is not a class." + ) + + if _is_union_type(cls): + from typing import get_args + + for arg in get_args(cls): + registry[arg] = func + else: + registry[cls] = func + if cache_token is None and hasattr(cls, '__abstractmethods__'): + cache_token = get_cache_token() dispatch_cache.clear() return func def wrapper(*args, **kw): if not args: - raise TypeError( - '{} requires at least ' '1 positional argument'.format(funcname) - ) + raise TypeError(f'{funcname} requires at least ' '1 positional argument') return dispatch(args[0].__class__)(*args, **kw) @@ -267,7 +284,7 @@ def wrapper(*args, **kw): registry[object] = func wrapper.register = register wrapper.dispatch = dispatch - wrapper.registry = MappingProxyType(registry) + wrapper.registry = types.MappingProxyType(registry) wrapper._clear_cache = dispatch_cache.clear update_wrapper(wrapper, func) return wrapper diff --git a/test_singledispatch.py b/test_singledispatch.py index 86e3f16..a82cfce 100644 --- a/test_singledispatch.py +++ b/test_singledispatch.py @@ -3,10 +3,12 @@ import collections import decimal from itertools import permutations -import singledispatch as functools +import singledispatch +import functools as functools_orig from singledispatch.helpers import Support import typing import unittest +import contextlib coll_abc = getattr(collections, 'abc', collections) @@ -21,6 +23,21 @@ del _prefix +class MultiModule: + def __init__(self, *modules): + self.modules = modules + + def __getattr__(self, name): + return next( + getattr(mod, name) + for mod in self.modules + if name in mod.__all__ or mod is functools_orig + ) + + +functools = MultiModule(singledispatch, functools_orig) + + class TestSingleDispatch(unittest.TestCase): def test_simple_overloads(self): @functools.singledispatch @@ -72,7 +89,7 @@ def g(obj): @g.register(int) def g_int(i): - return "int {}".format(i) + return "int %s" % (i,) self.assertEqual(g(""), "base") self.assertEqual(g(12), "int 12") @@ -88,7 +105,8 @@ def g(obj): return "Test" self.assertEqual(g.__name__, "g") - self.assertEqual(g.__doc__, "Simple test") + if sys.flags.optimize < 2: + self.assertEqual(g.__doc__, "Simple test") @unittest.skipUnless(decimal, 'requires _decimal') @support.cpython_only @@ -115,38 +133,40 @@ def _(obj): def test_compose_mro(self): # None of the examples in this test depend on haystack ordering. - c = coll_abc + c = collections.abc mro = functools._compose_mro bases = [c.Sequence, c.MutableMapping, c.Mapping, c.Set] for haystack in permutations(bases): m = mro(dict, haystack) - expected = _mro_compat( + self.assertEqual( + m, [ dict, c.MutableMapping, c.Mapping, + c.Collection, c.Sized, c.Iterable, c.Container, object, - ] + ], ) - self.assertEqual(m, expected) bases = [c.Container, c.Mapping, c.MutableMapping, collections.OrderedDict] for haystack in permutations(bases): m = mro(collections.ChainMap, haystack) - expected = _mro_compat( + self.assertEqual( + m, [ collections.ChainMap, c.MutableMapping, c.Mapping, + c.Collection, c.Sized, c.Iterable, c.Container, object, - ] + ], ) - self.assertEqual(m, expected) # If there's a generic function with implementations registered for # both Sized and Container, passing a defaultdict to it results in an @@ -160,7 +180,7 @@ def test_compose_mro(self): ) # MutableSequence below is registered directly on D. In other words, it - # preceeds MutableMapping which means single dispatch will always + # precedes MutableMapping which means single dispatch will always # choose MutableSequence here. class D(collections.defaultdict): pass @@ -168,23 +188,25 @@ class D(collections.defaultdict): c.MutableSequence.register(D) bases = [c.MutableSequence, c.MutableMapping] for haystack in permutations(bases): - m = mro(D, bases) - expected = _mro_compat( + m = mro(D, haystack) + self.assertEqual( + m, [ D, c.MutableSequence, c.Sequence, + c.Reversible, collections.defaultdict, dict, c.MutableMapping, c.Mapping, + c.Collection, c.Sized, c.Iterable, c.Container, object, - ] + ], ) - self.assertEqual(m, expected) # Container and Callable are registered on different base classes and # a generic function supporting both should always pick the Callable @@ -196,25 +218,26 @@ def __call__(self): bases = [c.Sized, c.Callable, c.Container, c.Mapping] for haystack in permutations(bases): m = mro(C, haystack) - expected = _mro_compat( + self.assertEqual( + m, [ C, c.Callable, collections.defaultdict, dict, c.Mapping, + c.Collection, c.Sized, c.Iterable, c.Container, object, - ] + ], ) - self.assertEqual(m, expected) def test_register_abc(self): - c = coll_abc + c = collections.abc d = {"a": "b"} - ls = [1, 2, 3] + l = [1, 2, 3] s = {object(), None} f = frozenset(s) t = (1, 2, 3) @@ -224,107 +247,105 @@ def g(obj): return "base" self.assertEqual(g(d), "base") - self.assertEqual(g(ls), "base") + self.assertEqual(g(l), "base") self.assertEqual(g(s), "base") self.assertEqual(g(f), "base") self.assertEqual(g(t), "base") g.register(c.Sized, lambda obj: "sized") self.assertEqual(g(d), "sized") - self.assertEqual(g(ls), "sized") + self.assertEqual(g(l), "sized") self.assertEqual(g(s), "sized") self.assertEqual(g(f), "sized") self.assertEqual(g(t), "sized") g.register(c.MutableMapping, lambda obj: "mutablemapping") self.assertEqual(g(d), "mutablemapping") - self.assertEqual(g(ls), "sized") + self.assertEqual(g(l), "sized") self.assertEqual(g(s), "sized") self.assertEqual(g(f), "sized") self.assertEqual(g(t), "sized") g.register(collections.ChainMap, lambda obj: "chainmap") self.assertEqual(g(d), "mutablemapping") # irrelevant ABCs registered - self.assertEqual(g(ls), "sized") + self.assertEqual(g(l), "sized") self.assertEqual(g(s), "sized") self.assertEqual(g(f), "sized") self.assertEqual(g(t), "sized") g.register(c.MutableSequence, lambda obj: "mutablesequence") self.assertEqual(g(d), "mutablemapping") - self.assertEqual(g(ls), "mutablesequence") + self.assertEqual(g(l), "mutablesequence") self.assertEqual(g(s), "sized") self.assertEqual(g(f), "sized") self.assertEqual(g(t), "sized") g.register(c.MutableSet, lambda obj: "mutableset") self.assertEqual(g(d), "mutablemapping") - self.assertEqual(g(ls), "mutablesequence") + self.assertEqual(g(l), "mutablesequence") self.assertEqual(g(s), "mutableset") self.assertEqual(g(f), "sized") self.assertEqual(g(t), "sized") g.register(c.Mapping, lambda obj: "mapping") self.assertEqual(g(d), "mutablemapping") # not specific enough - self.assertEqual(g(ls), "mutablesequence") + self.assertEqual(g(l), "mutablesequence") self.assertEqual(g(s), "mutableset") self.assertEqual(g(f), "sized") self.assertEqual(g(t), "sized") g.register(c.Sequence, lambda obj: "sequence") self.assertEqual(g(d), "mutablemapping") - self.assertEqual(g(ls), "mutablesequence") + self.assertEqual(g(l), "mutablesequence") self.assertEqual(g(s), "mutableset") self.assertEqual(g(f), "sized") self.assertEqual(g(t), "sequence") g.register(c.Set, lambda obj: "set") self.assertEqual(g(d), "mutablemapping") - self.assertEqual(g(ls), "mutablesequence") + self.assertEqual(g(l), "mutablesequence") self.assertEqual(g(s), "mutableset") self.assertEqual(g(f), "set") self.assertEqual(g(t), "sequence") g.register(dict, lambda obj: "dict") self.assertEqual(g(d), "dict") - self.assertEqual(g(ls), "mutablesequence") + self.assertEqual(g(l), "mutablesequence") self.assertEqual(g(s), "mutableset") self.assertEqual(g(f), "set") self.assertEqual(g(t), "sequence") g.register(list, lambda obj: "list") self.assertEqual(g(d), "dict") - self.assertEqual(g(ls), "list") + self.assertEqual(g(l), "list") self.assertEqual(g(s), "mutableset") self.assertEqual(g(f), "set") self.assertEqual(g(t), "sequence") g.register(set, lambda obj: "concrete-set") self.assertEqual(g(d), "dict") - self.assertEqual(g(ls), "list") + self.assertEqual(g(l), "list") self.assertEqual(g(s), "concrete-set") self.assertEqual(g(f), "set") self.assertEqual(g(t), "sequence") g.register(frozenset, lambda obj: "frozen-set") self.assertEqual(g(d), "dict") - self.assertEqual(g(ls), "list") + self.assertEqual(g(l), "list") self.assertEqual(g(s), "concrete-set") self.assertEqual(g(f), "frozen-set") self.assertEqual(g(t), "sequence") g.register(tuple, lambda obj: "tuple") self.assertEqual(g(d), "dict") - self.assertEqual(g(ls), "list") + self.assertEqual(g(l), "list") self.assertEqual(g(s), "concrete-set") self.assertEqual(g(f), "frozen-set") self.assertEqual(g(t), "tuple") def test_c3_abc(self): - c = coll_abc + c = collections.abc mro = functools._c3_mro - class A: + class A(object): pass class B(A): def __len__(self): return 0 # implies Sized - # @c.Container.register - class C: + @c.Container.register + class C(object): pass - c.Container.register(C) - - class D: + class D(object): pass # unrelated class X(D, C, B): @@ -344,11 +365,8 @@ class MetaA(type): def __len__(self): return 0 - """ - class A(metaclass=MetaA): - pass - """ - A = MetaA('A', (), {}) + class A(metaclass=MetaA): + pass class AA(A): pass @@ -364,32 +382,32 @@ def _(a): aa = AA() self.assertEqual(fun(aa), 'fun A') - def test_mro_conflicts(self): # noqa: C901 - c = coll_abc + def test_mro_conflicts(self): + c = collections.abc @functools.singledispatch def g(arg): return "base" - class Oh(c.Sized): + class O(c.Sized): def __len__(self): return 0 - o = Oh() + o = O() self.assertEqual(g(o), "base") g.register(c.Iterable, lambda arg: "iterable") g.register(c.Container, lambda arg: "container") g.register(c.Sized, lambda arg: "sized") g.register(c.Set, lambda arg: "set") self.assertEqual(g(o), "sized") - c.Iterable.register(Oh) + c.Iterable.register(O) self.assertEqual(g(o), "sized") # because it's explicitly in __mro__ - c.Container.register(Oh) + c.Container.register(O) self.assertEqual(g(o), "sized") # see above: Sized is in __mro__ - c.Set.register(Oh) + c.Set.register(O) self.assertEqual(g(o), "set") # because c.Set is a subclass of - # c.Sized and c.Container + # c.Sized and c.Container class P: pass @@ -404,13 +422,13 @@ class P: str(re_one.exception), ( ( - "Ambiguous dispatch: " - "or " - ).format(prefix=abcoll_prefix), + "Ambiguous dispatch: " + "or " + ), ( - "Ambiguous dispatch: " - "or " - ).format(prefix=abcoll_prefix), + "Ambiguous dispatch: " + "or " + ), ), ) @@ -424,8 +442,8 @@ def __len__(self): self.assertEqual(g(q), "sized") # because it's explicitly in __mro__ c.Set.register(Q) self.assertEqual(g(q), "set") # because c.Set is a subclass of - # c.Sized and c.Iterable + # c.Sized and c.Iterable @functools.singledispatch def h(arg): return "base" @@ -448,13 +466,13 @@ def _(arg): str(re_two.exception), ( ( - "Ambiguous dispatch: " - "or " - ).format(prefix=abcoll_prefix), + "Ambiguous dispatch: " + "or " + ), ( - "Ambiguous dispatch: " - "or " - ).format(prefix=abcoll_prefix), + "Ambiguous dispatch: " + "or " + ), ), ) @@ -505,13 +523,13 @@ def __len__(self): str(re_three.exception), ( ( - "Ambiguous dispatch: " - "or " - ).format(prefix=abcoll_prefix), + "Ambiguous dispatch: " + "or " + ), ( - "Ambiguous dispatch: " - "or " - ).format(prefix=abcoll_prefix), + "Ambiguous dispatch: " + "or " + ), ), ) @@ -538,14 +556,12 @@ def _(arg): # Sized in the MRO def test_cache_invalidation(self): - try: - from collections import UserDict - except ImportError: - from UserDict import UserDict + from collections import UserDict + import weakref class TracingDict(UserDict): def __init__(self, *args, **kwargs): - UserDict.__init__(self, *args, **kwargs) + super(TracingDict, self).__init__(*args, **kwargs) self.set_ops = [] self.get_ops = [] @@ -561,91 +577,91 @@ def __setitem__(self, key, value): def clear(self): self.data.clear() - _orig_wkd = functools.WeakKeyDictionary td = TracingDict() - functools.WeakKeyDictionary = lambda: td - c = coll_abc + with support.swap_attr(weakref, "WeakKeyDictionary", lambda: td): + c = collections.abc - @functools.singledispatch - def g(arg): - return "base" + @functools.singledispatch + def g(arg): + return "base" - d = {} - ls = [] - self.assertEqual(len(td), 0) - self.assertEqual(g(d), "base") - self.assertEqual(len(td), 1) - self.assertEqual(td.get_ops, []) - self.assertEqual(td.set_ops, [dict]) - self.assertEqual(td.data[dict], g.registry[object]) - self.assertEqual(g(ls), "base") - self.assertEqual(len(td), 2) - self.assertEqual(td.get_ops, []) - self.assertEqual(td.set_ops, [dict, list]) - self.assertEqual(td.data[dict], g.registry[object]) - self.assertEqual(td.data[list], g.registry[object]) - self.assertEqual(td.data[dict], td.data[list]) - self.assertEqual(g(ls), "base") - self.assertEqual(g(d), "base") - self.assertEqual(td.get_ops, [list, dict]) - self.assertEqual(td.set_ops, [dict, list]) - g.register(list, lambda arg: "list") - self.assertEqual(td.get_ops, [list, dict]) - self.assertEqual(len(td), 0) - self.assertEqual(g(d), "base") - self.assertEqual(len(td), 1) - self.assertEqual(td.get_ops, [list, dict]) - self.assertEqual(td.set_ops, [dict, list, dict]) - self.assertEqual(td.data[dict], functools._find_impl(dict, g.registry)) - self.assertEqual(g(ls), "list") - self.assertEqual(len(td), 2) - self.assertEqual(td.get_ops, [list, dict]) - self.assertEqual(td.set_ops, [dict, list, dict, list]) - self.assertEqual(td.data[list], functools._find_impl(list, g.registry)) - - class X: - pass + d = {} + l = [] + self.assertEqual(len(td), 0) + self.assertEqual(g(d), "base") + self.assertEqual(len(td), 1) + self.assertEqual(td.get_ops, []) + self.assertEqual(td.set_ops, [dict]) + self.assertEqual(td.data[dict], g.registry[object]) + self.assertEqual(g(l), "base") + self.assertEqual(len(td), 2) + self.assertEqual(td.get_ops, []) + self.assertEqual(td.set_ops, [dict, list]) + self.assertEqual(td.data[dict], g.registry[object]) + self.assertEqual(td.data[list], g.registry[object]) + self.assertEqual(td.data[dict], td.data[list]) + self.assertEqual(g(l), "base") + self.assertEqual(g(d), "base") + self.assertEqual(td.get_ops, [list, dict]) + self.assertEqual(td.set_ops, [dict, list]) + g.register(list, lambda arg: "list") + self.assertEqual(td.get_ops, [list, dict]) + self.assertEqual(len(td), 0) + self.assertEqual(g(d), "base") + self.assertEqual(len(td), 1) + self.assertEqual(td.get_ops, [list, dict]) + self.assertEqual(td.set_ops, [dict, list, dict]) + self.assertEqual(td.data[dict], functools._find_impl(dict, g.registry)) + self.assertEqual(g(l), "list") + self.assertEqual(len(td), 2) + self.assertEqual(td.get_ops, [list, dict]) + self.assertEqual(td.set_ops, [dict, list, dict, list]) + self.assertEqual(td.data[list], functools._find_impl(list, g.registry)) + + class X: + pass - c.MutableMapping.register(X) # Will not invalidate the cache, - # not using ABCs yet. - self.assertEqual(g(d), "base") - self.assertEqual(g(ls), "list") - self.assertEqual(td.get_ops, [list, dict, dict, list]) - self.assertEqual(td.set_ops, [dict, list, dict, list]) - g.register(c.Sized, lambda arg: "sized") - self.assertEqual(len(td), 0) - self.assertEqual(g(d), "sized") - self.assertEqual(len(td), 1) - self.assertEqual(td.get_ops, [list, dict, dict, list]) - self.assertEqual(td.set_ops, [dict, list, dict, list, dict]) - self.assertEqual(g(ls), "list") - self.assertEqual(len(td), 2) - self.assertEqual(td.get_ops, [list, dict, dict, list]) - self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list]) - self.assertEqual(g(ls), "list") - self.assertEqual(g(d), "sized") - self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict]) - self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list]) - g.dispatch(list) - g.dispatch(dict) - self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict, list, dict]) - self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list]) - c.MutableSet.register(X) # Will invalidate the cache. - self.assertEqual(len(td), 2) # Stale cache. - self.assertEqual(g(ls), "list") - self.assertEqual(len(td), 1) - g.register(c.MutableMapping, lambda arg: "mutablemapping") - self.assertEqual(len(td), 0) - self.assertEqual(g(d), "mutablemapping") - self.assertEqual(len(td), 1) - self.assertEqual(g(ls), "list") - self.assertEqual(len(td), 2) - g.register(dict, lambda arg: "dict") - self.assertEqual(g(d), "dict") - self.assertEqual(g(ls), "list") - g._clear_cache() - self.assertEqual(len(td), 0) - functools.WeakKeyDictionary = _orig_wkd + c.MutableMapping.register(X) # Will not invalidate the cache, + # not using ABCs yet. + self.assertEqual(g(d), "base") + self.assertEqual(g(l), "list") + self.assertEqual(td.get_ops, [list, dict, dict, list]) + self.assertEqual(td.set_ops, [dict, list, dict, list]) + g.register(c.Sized, lambda arg: "sized") + self.assertEqual(len(td), 0) + self.assertEqual(g(d), "sized") + self.assertEqual(len(td), 1) + self.assertEqual(td.get_ops, [list, dict, dict, list]) + self.assertEqual(td.set_ops, [dict, list, dict, list, dict]) + self.assertEqual(g(l), "list") + self.assertEqual(len(td), 2) + self.assertEqual(td.get_ops, [list, dict, dict, list]) + self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list]) + self.assertEqual(g(l), "list") + self.assertEqual(g(d), "sized") + self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict]) + self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list]) + g.dispatch(list) + g.dispatch(dict) + self.assertEqual( + td.get_ops, [list, dict, dict, list, list, dict, list, dict] + ) + self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list]) + c.MutableSet.register(X) # Will invalidate the cache. + self.assertEqual(len(td), 2) # Stale cache. + self.assertEqual(g(l), "list") + self.assertEqual(len(td), 1) + g.register(c.MutableMapping, lambda arg: "mutablemapping") + self.assertEqual(len(td), 0) + self.assertEqual(g(d), "mutablemapping") + self.assertEqual(len(td), 1) + self.assertEqual(g(l), "list") + self.assertEqual(len(td), 2) + g.register(dict, lambda arg: "dict") + self.assertEqual(g(d), "dict") + self.assertEqual(g(l), "list") + g._clear_cache() + self.assertEqual(len(td), 0) def test_annotations(self): @functools.singledispatch @@ -666,10 +682,6 @@ def _(arg: "collections.abc.Sequence"): self.assertEqual(i((1, 2, 3)), "sequence") self.assertEqual(i("str"), "sequence") - if sys.version_info < (3,): - # the rest of this test fails on Python 2 - return - # Registering classes as callables doesn't work with annotations, # you need to pass the type explicitly. @i.register(str) @@ -728,7 +740,7 @@ def _(arg): def _(arg): return isinstance(arg, str) - A() + a = A() self.assertTrue(A.t(0)) self.assertTrue(A.t('')) @@ -783,13 +795,17 @@ def _(cls, arg): self.assertEqual(A.t(0.0).arg, "base") def test_abstractmethod_register(self): - class Abstract(abc.ABCMeta): + class Abstract(metaclass=abc.ABCMeta): @functools.singledispatchmethod @abc.abstractmethod def add(self, x, y): pass self.assertTrue(Abstract.add.__isabstractmethod__) + self.assertTrue(Abstract.__dict__['add'].__isabstractmethod__) + + with self.assertRaises(TypeError): + Abstract() def test_type_ann_register(self): class A: @@ -805,15 +821,194 @@ def _(self, arg: int): def _(self, arg: str): return "str" - _.__annotations__ = dict(arg=str) - t.register(_) - a = A() self.assertEqual(a.t(0), "int") self.assertEqual(a.t(''), "str") self.assertEqual(a.t(0.0), "base") + def test_staticmethod_type_ann_register(self): + class A: + @functools.singledispatchmethod + @staticmethod + def t(arg): + return arg + + @t.register + @staticmethod + def _(arg: int): + return isinstance(arg, int) + + @t.register + @staticmethod + def _(arg: str): + return isinstance(arg, str) + + a = A() + + self.assertTrue(A.t(0)) + self.assertTrue(A.t('')) + self.assertEqual(A.t(0.0), 0.0) + + def test_classmethod_type_ann_register(self): + class A: + def __init__(self, arg): + self.arg = arg + + @functools.singledispatchmethod + @classmethod + def t(cls, arg): + return cls("base") + + @t.register + @classmethod + def _(cls, arg: int): + return cls("int") + + @t.register + @classmethod + def _(cls, arg: str): + return cls("str") + + self.assertEqual(A.t(0).arg, "int") + self.assertEqual(A.t('').arg, "str") + self.assertEqual(A.t(0.0).arg, "base") + + def test_method_wrapping_attributes(self): + class A: + @functools.singledispatchmethod + def func(self, arg: int) -> str: + """My function docstring""" + return str(arg) + + @functools.singledispatchmethod + @classmethod + def cls_func(cls, arg: int) -> str: + """My function docstring""" + return str(arg) + + @functools.singledispatchmethod + @staticmethod + def static_func(arg: int) -> str: + """My function docstring""" + return str(arg) + + for meth in ( + A.func, + A().func, + A.cls_func, + A().cls_func, + A.static_func, + A().static_func, + ): + with self.subTest(meth=meth): + self.assertEqual(meth.__doc__, 'My function docstring') + self.assertEqual(meth.__annotations__['arg'], int) + + self.assertEqual(A.func.__name__, 'func') + self.assertEqual(A().func.__name__, 'func') + self.assertEqual(A.cls_func.__name__, 'cls_func') + self.assertEqual(A().cls_func.__name__, 'cls_func') + self.assertEqual(A.static_func.__name__, 'static_func') + self.assertEqual(A().static_func.__name__, 'static_func') + + def test_double_wrapped_methods(self): + def classmethod_friendly_decorator(func): + wrapped = func.__func__ + + @classmethod + @functools.wraps(wrapped) + def wrapper(*args, **kwargs): + return wrapped(*args, **kwargs) + + return wrapper + + class WithoutSingleDispatch: + @classmethod + @contextlib.contextmanager + def cls_context_manager(cls, arg: int) -> str: + try: + yield str(arg) + finally: + return 'Done' + + @classmethod_friendly_decorator + @classmethod + def decorated_classmethod(cls, arg: int) -> str: + return str(arg) + + class WithSingleDispatch: + @functools.singledispatchmethod + @classmethod + @contextlib.contextmanager + def cls_context_manager(cls, arg: int) -> str: + """My function docstring""" + try: + yield str(arg) + finally: + return 'Done' + + @functools.singledispatchmethod + @classmethod_friendly_decorator + @classmethod + def decorated_classmethod(cls, arg: int) -> str: + """My function docstring""" + return str(arg) + + # These are sanity checks + # to test the test itself is working as expected + with WithoutSingleDispatch.cls_context_manager(5) as foo: + without_single_dispatch_foo = foo + + with WithSingleDispatch.cls_context_manager(5) as foo: + single_dispatch_foo = foo + + self.assertEqual(without_single_dispatch_foo, single_dispatch_foo) + self.assertEqual(single_dispatch_foo, '5') + + self.assertEqual( + WithoutSingleDispatch.decorated_classmethod(5), + WithSingleDispatch.decorated_classmethod(5), + ) + + self.assertEqual(WithSingleDispatch.decorated_classmethod(5), '5') + + # Behavioural checks now follow + for method_name in ('cls_context_manager', 'decorated_classmethod'): + with self.subTest(method=method_name): + self.assertEqual( + getattr(WithSingleDispatch, method_name).__name__, + getattr(WithoutSingleDispatch, method_name).__name__, + ) + + self.assertEqual( + getattr(WithSingleDispatch(), method_name).__name__, + getattr(WithoutSingleDispatch(), method_name).__name__, + ) + + for meth in ( + WithSingleDispatch.cls_context_manager, + WithSingleDispatch().cls_context_manager, + WithSingleDispatch.decorated_classmethod, + WithSingleDispatch().decorated_classmethod, + ): + with self.subTest(meth=meth): + self.assertEqual(meth.__doc__, 'My function docstring') + self.assertEqual(meth.__annotations__['arg'], int) + + self.assertEqual( + WithSingleDispatch.cls_context_manager.__name__, 'cls_context_manager' + ) + self.assertEqual( + WithSingleDispatch().cls_context_manager.__name__, 'cls_context_manager' + ) + self.assertEqual( + WithSingleDispatch.decorated_classmethod.__name__, 'decorated_classmethod' + ) + self.assertEqual( + WithSingleDispatch().decorated_classmethod.__name__, 'decorated_classmethod' + ) + def test_invalid_registrations(self): msg_prefix = "Invalid first argument to `register()`: " msg_suffix = ( @@ -839,28 +1034,39 @@ def _(arg): def _(arg): return "I forgot to annotate" - scope = "TestSingleDispatch.test_invalid_registrations.." self.assertTrue( - str(exc.exception).startswith(msg_prefix + "._" + ) ) self.assertTrue(str(exc.exception).endswith(msg_suffix)) with self.assertRaises(TypeError) as exc: - # @i.register - # def _(arg: typing.Iterable[str]): - def _(arg): + + @i.register + def _(arg: typing.Iterable[str]): # At runtime, dispatching on generics is impossible. # When registering implementations with singledispatch, avoid # types from `typing`. Instead, annotate with regular types # or ABCs. return "I annotated with a generic collection" - _.__annotations__ = dict(arg=typing.Iterable[str]) - i.register(_) + self.assertTrue(str(exc.exception).startswith("Invalid annotation for 'arg'.")) + self.assertTrue( + str(exc.exception).endswith('typing.Iterable[str] is not a class.') + ) + + with self.assertRaises(TypeError) as exc: + + @i.register + def _(arg: typing.Union[int, typing.Iterable[str]]): + return "Invalid Union" + self.assertTrue(str(exc.exception).startswith("Invalid annotation for 'arg'.")) self.assertTrue( str(exc.exception).endswith( - 'typing.Iterable[' + str.__name__ + '] is not a class.' + 'typing.Union[int, typing.Iterable[str]] not all arguments are classes.' ) ) @@ -873,6 +1079,140 @@ def f(*args): with self.assertRaisesRegex(TypeError, msg): f() + def test_union(self): + @functools.singledispatch + def f(arg): + return "default" + + @f.register + def _(arg: typing.Union[str, bytes]): + return "typing.Union" + + @f.register + def _(arg: int | float): + return "types.UnionType" + + self.assertEqual(f([]), "default") + self.assertEqual(f(""), "typing.Union") + self.assertEqual(f(b""), "typing.Union") + self.assertEqual(f(1), "types.UnionType") + self.assertEqual(f(1.0), "types.UnionType") + + def test_union_conflict(self): + @functools.singledispatch + def f(arg): + return "default" + + @f.register + def _(arg: typing.Union[str, bytes]): + return "typing.Union" + + @f.register + def _(arg: int | str): + return "types.UnionType" + + self.assertEqual(f([]), "default") + self.assertEqual(f(""), "types.UnionType") # last one wins + self.assertEqual(f(b""), "typing.Union") + self.assertEqual(f(1), "types.UnionType") + + def test_union_None(self): + @functools.singledispatch + def typing_union(arg): + return "default" + + @typing_union.register + def _(arg: typing.Union[str, None]): + return "typing.Union" + + self.assertEqual(typing_union(1), "default") + self.assertEqual(typing_union(""), "typing.Union") + self.assertEqual(typing_union(None), "typing.Union") + + @functools.singledispatch + def types_union(arg): + return "default" + + @types_union.register + def _(arg: int | None): + return "types.UnionType" + + self.assertEqual(types_union(""), "default") + self.assertEqual(types_union(1), "types.UnionType") + self.assertEqual(types_union(None), "types.UnionType") + + def test_register_genericalias(self): + @functools.singledispatch + def f(arg): + return "default" + + with self.assertRaisesRegex(TypeError, "Invalid first argument to "): + f.register(list[int], lambda arg: "types.GenericAlias") + with self.assertRaisesRegex(TypeError, "Invalid first argument to "): + f.register(typing.List[int], lambda arg: "typing.GenericAlias") + with self.assertRaisesRegex(TypeError, "Invalid first argument to "): + f.register( + list[int] | str, lambda arg: "types.UnionTypes(types.GenericAlias)" + ) + with self.assertRaisesRegex(TypeError, "Invalid first argument to "): + f.register( + typing.List[float] | bytes, + lambda arg: "typing.Union[typing.GenericAlias]", + ) + + self.assertEqual(f([1]), "default") + self.assertEqual(f([1.0]), "default") + self.assertEqual(f(""), "default") + self.assertEqual(f(b""), "default") + + def test_register_genericalias_decorator(self): + @functools.singledispatch + def f(arg): + return "default" + + with self.assertRaisesRegex(TypeError, "Invalid first argument to "): + f.register(list[int]) + with self.assertRaisesRegex(TypeError, "Invalid first argument to "): + f.register(typing.List[int]) + with self.assertRaisesRegex(TypeError, "Invalid first argument to "): + f.register(list[int] | str) + with self.assertRaisesRegex(TypeError, "Invalid first argument to "): + f.register(typing.List[int] | str) + + def test_register_genericalias_annotation(self): + @functools.singledispatch + def f(arg): + return "default" + + with self.assertRaisesRegex(TypeError, "Invalid annotation for 'arg'"): + + @f.register + def _(arg: list[int]): + return "types.GenericAlias" + + with self.assertRaisesRegex(TypeError, "Invalid annotation for 'arg'"): + + @f.register + def _(arg: typing.List[float]): + return "typing.GenericAlias" + + with self.assertRaisesRegex(TypeError, "Invalid annotation for 'arg'"): + + @f.register + def _(arg: list[int] | str): + return "types.UnionType(types.GenericAlias)" + + with self.assertRaisesRegex(TypeError, "Invalid annotation for 'arg'"): + + @f.register + def _(arg: typing.List[float] | bytes): + return "typing.Union[typing.GenericAlias]" + + self.assertEqual(f([1]), "default") + self.assertEqual(f([1.0]), "default") + self.assertEqual(f(""), "default") + self.assertEqual(f(b""), "default") + def _mro_compat(classes): if sys.version_info < (3, 6):