Skip to content

Commit

Permalink
Initial work toward refreshing the implementation (broken).
Browse files Browse the repository at this point in the history
  • Loading branch information
jaraco committed Aug 27, 2023
1 parent 9d15d16 commit b405e43
Show file tree
Hide file tree
Showing 2 changed files with 589 additions and 232 deletions.
129 changes: 73 additions & 56 deletions singledispatch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
__all__ = ['singledispatch', 'singledispatchmethod']

from weakref import WeakKeyDictionary
from types import GenericAlias

from .helpers import MappingProxyType, get_cache_token, get_type_hints, update_wrapper
from .helpers import get_cache_token, update_wrapper

################################################################################
# singledispatch() - single-dispatch generic function decorator
Expand All @@ -12,7 +12,7 @@
def _c3_merge(sequences):
"""Merges MROs in *sequences* to a single MRO using the C3 algorithm.
Adapted from http://www.python.org/download/releases/2.3/mro/.
Adapted from https://www.python.org/download/releases/2.3/mro/.
"""
result = []
Expand Down Expand Up @@ -87,7 +87,7 @@ def _c3_mro(cls, abcs=None):
)


def _compose_mro(cls, types): # noqa: C901
def _compose_mro(cls, types):
"""Calculates the method resolution order for a given class *cls*.
Includes relevant abstract base classes (with their respective bases) from
Expand All @@ -98,7 +98,12 @@ def _compose_mro(cls, types): # noqa: C901

# Remove entries which are already present in the __mro__ or unrelated.
def is_related(typ):
return typ not in bases and hasattr(typ, '__mro__') and issubclass(cls, typ)
return (
typ not in bases
and hasattr(typ, '__mro__')
and not isinstance(typ, GenericAlias)
and issubclass(cls, typ)
)

types = [n for n in types if is_related(n)]

Expand All @@ -117,7 +122,7 @@ def is_strict_base(typ):
mro = []
for typ in types:
found = []
for sub in filter(_safe, typ.__subclasses__()):
for sub in typ.__subclasses__():
if sub not in bases and issubclass(cls, sub):
found.append([s for s in sub.__mro__ if s in type_set])
if not found:
Expand All @@ -132,13 +137,6 @@ def is_strict_base(typ):
return _c3_mro(cls, abcs=mro)


def _safe(class_):
"""
Return if the class is safe for testing as subclass. Ref #2.
"""
return not getattr(class_, '__origin__', None)


def _find_impl(cls, registry):
"""Returns the best matching implementation from *registry* for type *cls*.
Expand All @@ -161,33 +159,14 @@ def _find_impl(cls, registry):
and match not in cls.__mro__
and not issubclass(match, t)
):
raise RuntimeError(f"Ambiguous dispatch: {match} or {t}")
raise RuntimeError("Ambiguous dispatch: {} or {}".format(match, t))
break
if t in registry:
match = t
return registry.get(match)


def _validate_annotation(annotation):
"""Determine if an annotation is valid for registration.
An annotation is considered valid for use in registration if it is an
instance of ``type`` and not a generic type from ``typing``.
"""
try:
# In Python earlier than 3.7, the classes in typing are considered
# instances of type, but they invalid for registering single dispatch
# functions so check against GenericMeta instead.
from typing import GenericMeta

valid = not isinstance(annotation, GenericMeta)
except ImportError:
# In Python 3.7+, classes in typing are not instances of type.
valid = isinstance(annotation, type)
return valid


def singledispatch(func): # noqa: C901
def singledispatch(func):
"""Single-dispatch generic function decorator.
Transforms a function into a generic function, which can have different
Expand All @@ -196,13 +175,14 @@ def singledispatch(func): # noqa: C901
implementations can be registered using the register() attribute of the
generic function.
"""
registry = {}
dispatch_cache = WeakKeyDictionary()

def ns():
pass
# There are many programs that use functools without singledispatch, so we
# trade-off making singledispatch marginally slower for the benefit of
# making start-up of such applications slightly faster.
import types, weakref # noqa: E401

ns.cache_token = None
registry = {}
dispatch_cache = weakref.WeakKeyDictionary()
cache_token = None

def dispatch(cls):
"""generic_func.dispatch(cls) -> <function implementation>
Expand All @@ -211,11 +191,12 @@ def dispatch(cls):
for the given *cls* registered on *generic_func*.
"""
if ns.cache_token is not None:
nonlocal cache_token
if cache_token is not None:
current_token = get_cache_token()
if ns.cache_token != current_token:
if cache_token != current_token:
dispatch_cache.clear()
ns.cache_token = current_token
cache_token = current_token
try:
impl = dispatch_cache[cls]
except KeyError:
Expand All @@ -226,15 +207,36 @@ def dispatch(cls):
dispatch_cache[cls] = impl
return impl

def _is_union_type(cls):
from typing import get_origin, Union

return get_origin(cls) in {Union, types.UnionType}

def _is_valid_dispatch_type(cls):
if isinstance(cls, type):
return True
from typing import get_args

return _is_union_type(cls) and all(
isinstance(arg, type) for arg in get_args(cls)
)

def register(cls, func=None):
"""generic_func.register(cls, func) -> func
Registers a new implementation for the given *cls* on a *generic_func*.
"""
if func is None:
if isinstance(cls, type):
nonlocal cache_token
if _is_valid_dispatch_type(cls):
if func is None:
return lambda f: register(cls, f)
else:
if func is not None:
raise TypeError(
f"Invalid first argument to `register()`. "
f"{cls!r} is not a class or union type."
)
ann = getattr(cls, '__annotations__', {})
if not ann:
raise TypeError(
Expand All @@ -244,30 +246,45 @@ def register(cls, func=None):
)
func = cls

# only import typing if annotation parsing is necessary
from typing import get_type_hints

argname, cls = next(iter(get_type_hints(func).items()))
if not _validate_annotation(cls):
raise TypeError(
f"Invalid annotation for {argname!r}. " f"{cls!r} is not a class."
)
registry[cls] = func
if ns.cache_token is None and hasattr(cls, '__abstractmethods__'):
ns.cache_token = get_cache_token()
if not _is_valid_dispatch_type(cls):
if _is_union_type(cls):
raise TypeError(
f"Invalid annotation for {argname!r}. "
f"{cls!r} not all arguments are classes."
)
else:
raise TypeError(
f"Invalid annotation for {argname!r}. "
f"{cls!r} is not a class."
)

if _is_union_type(cls):
from typing import get_args

for arg in get_args(cls):
registry[arg] = func
else:
registry[cls] = func
if cache_token is None and hasattr(cls, '__abstractmethods__'):
cache_token = get_cache_token()
dispatch_cache.clear()
return func

def wrapper(*args, **kw):
if not args:
raise TypeError(
'{} requires at least ' '1 positional argument'.format(funcname)
)
raise TypeError(f'{funcname} requires at least ' '1 positional argument')

return dispatch(args[0].__class__)(*args, **kw)

funcname = getattr(func, '__name__', 'singledispatch function')
registry[object] = func
wrapper.register = register
wrapper.dispatch = dispatch
wrapper.registry = MappingProxyType(registry)
wrapper.registry = types.MappingProxyType(registry)
wrapper._clear_cache = dispatch_cache.clear
update_wrapper(wrapper, func)
return wrapper
Expand Down
Loading

0 comments on commit b405e43

Please sign in to comment.