Skip to content

Commit

Permalink
replace key with str
Browse files Browse the repository at this point in the history
Signed-off-by: nstarman <nstarkman@protonmail.com>
  • Loading branch information
nstarman committed Sep 25, 2022
1 parent 6963320 commit 4bf27fd
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 17 deletions.
2 changes: 1 addition & 1 deletion src/overload_numpy/mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def __array_function__(
The result of calling the overloaded functions.
"""
# Check if can be dispatched.
if func not in self.NP_FUNC_OVERLOADS:
if not self.NP_FUNC_OVERLOADS.__contains__(func):
return NotImplemented

# Get _NumPyFuncOverloadInfo on function, given type of self.
Expand Down
62 changes: 46 additions & 16 deletions src/overload_numpy/overload.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,12 @@
##############################################################################


class NumPyOverloader(Mapping[Callable[..., Any], _Dispatcher]):
class NumPyOverloader(Mapping[str, _Dispatcher]):
"""Overload :mod:`numpy` functions with |__array_function__|.
Parameters
----------
_reg : dict[Callable, |`~overload_numpy.dispatch._Dispatcher`|], optional
_reg : dict[str, |`~overload_numpy.dispatch._Dispatcher`|], optional
Registry of overloaded functions. You probably don't want to pass this
as a parameter.
Expand Down Expand Up @@ -84,7 +84,7 @@ class NumPyOverloader(Mapping[Callable[..., Any], _Dispatcher]):
Vector1D(x=array([0, 1, 2, 0, 1, 2]))
"""

_reg: Final[dict[Callable[..., Any], _Dispatcher]] = {}
_reg: Final[dict[str, _Dispatcher]] = {}
"""Registry of overloaded functions.
You probably don't want to pass this as a parameter.
Expand All @@ -93,19 +93,43 @@ class NumPyOverloader(Mapping[Callable[..., Any], _Dispatcher]):
# ===============================================================
# Mapping

def __getitem__(self, key: Callable[..., Any], /) -> _Dispatcher:
return self._reg[key]
def _get_key(self, key: str | Callable[..., Any] | Any) -> str:
"""Get the key.
Parameters
----------
key : str or Callable[..., Any]
The key.
Returns
-------
str
Raises
------
ValueError
If the key is not one of the known types.
"""
if isinstance(key, str):
return key
elif callable(key):
return f"{key.__module__}.{key.__name__}"
else:
raise ValueError

def __getitem__(self, key: str | Callable[..., Any], /) -> _Dispatcher:
return self._reg[self._get_key(key)]

def __contains__(self, o: object, /) -> bool:
return o in self._reg
return self._get_key(o) in self._reg

def __iter__(self) -> Iterator[Callable[..., Any]]:
def __iter__(self) -> Iterator[str]:
return iter(self._reg)

def __len__(self) -> int:
return len(self._reg)

def keys(self) -> KeysView[Callable[..., Any]]:
def keys(self) -> KeysView[str]:
return self._reg.keys()

def values(self) -> ValuesView[_Dispatcher]:
Expand Down Expand Up @@ -326,7 +350,7 @@ class _OverloadDecoratorBase(metaclass=ABCMeta):
Parameters
----------
overloader : `~overload_numpy.overload.NumPyOverloader`
overloader : |NumPyOverloader|
Overloader instance.
numpy_func : Callable[..., Any]
The :mod:`numpy` function that is being overloaded.
Expand All @@ -345,13 +369,19 @@ class _OverloadDecoratorBase(metaclass=ABCMeta):
dispatch_on: type

def __post_init__(self) -> None:
# Add

# Make single-dispatcher for numpy function
if self.numpy_func not in self.overloader._reg:
self.overloader._reg[self.numpy_func] = _Dispatcher()
if not self.overloader.__contains__(self.numpy_func):
self.overloader._reg[self._reg_key] = _Dispatcher()

# Turn ``types`` into only TypeConstraint
self.overloader._parse_types(self.types, self.dispatch_on)

@property
def _reg_key(self) -> str:
return self.overloader._get_key(self.numpy_func)

# @abstractmethod # TODO: fix when https://github.com/python/mypy/issues/5374 released
def func_hook(self, func: Callable[..., Any], /) -> Callable[..., Any]:
"""Function hook.
Expand Down Expand Up @@ -420,19 +450,19 @@ def __call__(self, func: C, /) -> C:
func=self.func_hook(func), types=tinfo, implements=self.numpy_func, dispatch_on=self.dispatch_on
)
# Register the function
self.overloader._reg[self.numpy_func]._dispatcher.register(self.dispatch_on, info)
self.overloader._reg[self._reg_key]._dispatcher.register(self.dispatch_on, info)
return func


@dataclass(frozen=True)
class _ImplementsDecorator(_OverloadDecoratorBase):
"""Decorator for registering with `~overload_numpy.NumPyOverloader`.
"""Decorator for registering with |NumPyOverloader|.
Instances of this class should not be used directly.
Parameters
----------
overloader : `~overload_numpy.overload.NumPyOverloader`
overloader : |NumPyOverloader|
Overloader instance.
numpy_func : Callable[..., Any]
The :mod:`numpy` function that is being overloaded.
Expand All @@ -456,13 +486,13 @@ def func_hook(self, func: C, /) -> C:

@dataclass(frozen=True)
class _AssistsDecorator(_OverloadDecoratorBase):
"""Decorator for registering with `~overload_numpy.NumPyOverloader`.
"""Decorator for registering with |NumPyOverloader|.
Instances of this class should not be used directly.
Parameters
----------
overloader : `~overload_numpy.overload.NumPyOverloader`
overloader : |NumPyOverloader|
Overloader instance.
numpy_func : Callable[..., Any]
The :mod:`numpy` function that is being overloaded.
Expand Down

0 comments on commit 4bf27fd

Please sign in to comment.