diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 8f718f9..9660ba1 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -32,7 +32,7 @@ repos: - id: text-unicode-replacement-char - repo: https://github.com/PyCQA/autoflake - rev: v1.7.0 + rev: v1.7.1 hooks: - id: autoflake args: @@ -116,3 +116,8 @@ repos: - --strict - --ignore-missing-imports - --exclude=/docs/ + + - repo: https://github.com/codespell-project/codespell + rev: v2.2.1 + hooks: + - id: codespell diff --git a/README.rst b/README.rst index eef8abe..5534a1d 100644 --- a/README.rst +++ b/README.rst @@ -54,7 +54,7 @@ Time to check these work: >>> np.concatenate((w1d, w1d)) Wrap1D(x=array([0, 1, 2, 0, 1, 2])) -|ufunc| also have a number of methods: 'at', 'accumulate', etc. The funtion +|ufunc| also have a number of methods: 'at', 'accumulate', etc. The function dispatch mechanism in `NEP13 `_ says that "If one of the input or output arguments implements __array_ufunc__, it is executed instead diff --git a/docs/conf.py b/docs/conf.py index 58b912e..eb3c731 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -55,6 +55,7 @@ def get_authors() -> set[str]: extensions = [ "sphinx.ext.doctest", "sphinx_automodapi.automodapi", + "numpydoc", "pytest_doctestplus.sphinx.doctestplus", ] @@ -73,6 +74,8 @@ def get_authors() -> set[str]: # This is added to the end of RST files - a good place to put substitutions to # be used globally. rst_epilog = """ +.. |overload_numpy| replace:: :mod:`overload_numpy` + .. |TypeConstraint| replace:: :class:`~overload_numpy.constraints.TypeConstraint` .. |Invariant| replace:: :class:`~overload_numpy.constraints.Invariant` .. |Covariant| replace:: :class:`~overload_numpy.constraints.Covariant` @@ -81,8 +84,10 @@ def get_authors() -> set[str]: .. |NumPyOverloader| replace:: :class:`~overload_numpy.overload.NumPyOverloader` - +.. |Numpy| replace:: :mod:`numpy` +.. |numpy| replace:: :mod:`numpy` .. |ufunc| replace:: :class:`~numpy.ufunc` +.. |ndarray| replace:: :class:`~numpy.ndarray` .. |array_function| replace:: ``__array_function__`` .. _array_function: https://numpy.org/doc/stable/reference/arrays.classes.html#numpy.class.__array_function__ .. |array_ufunc| replace:: ``__array_ufunc__`` @@ -91,10 +96,13 @@ def get_authors() -> set[str]: # intersphinx intersphinx_mapping = { - "python": ( - "https://docs.python.org/3/", - (None, "http://data.astropy.org/intersphinx/python3.inv"), + "python": ("https://docs.python.org/3/", (None, "http://data.astropy.org/intersphinx/python3.inv")), + "pythonloc": ( + "http://docs.python.org/", + (None, (pathlib.Path(__file__).parent.parent / "local" / "python3_local_links.inv").resolve()), ), + "numpy": ("https://numpy.org/doc/stable/", (None, "http://data.astropy.org/intersphinx/numpy.inv")), + "scipy": ("https://docs.scipy.org/doc/scipy/reference/", (None, "http://data.astropy.org/intersphinx/scipy.inv")), } # Show / hide TODO blocks @@ -146,6 +154,21 @@ def get_authors() -> set[str]: "mapping": ":term:`python:mapping`", } +# # Report warnings for all validation checks minus specified checsk. +# numpydoc_validation_checks = { +# "all", +# "GL01", # Docstring text (summary) should start in the line immediately after ... +# "GL08", # - The object does not have a docstring # TODO! rm when does docstring inheritance +# "SA01", # - See Also section not found +# } + +# numpydoc_validation_exclude = { +# r"docs\." +# r"NumPyOverloader\.get$", +# r"NumPyOverloader\.items$", +# } + + # -- Project information ------------------------------------------------------ # This does not *have* to match the package name, but typically does @@ -161,6 +184,10 @@ def get_authors() -> set[str]: # The full version, including alpha/beta/rc tags. release = get_version("overload_numpy") +# -- automodapi configuration --------------------------------------------------- + +# automodsumm_inherited_members = True + # -- Options for HTML output --------------------------------------------------- diff --git a/docs/contributing.rst b/docs/contributing.rst index 4130552..feb810d 100644 --- a/docs/contributing.rst +++ b/docs/contributing.rst @@ -1,19 +1,21 @@ -.. _contributing: +.. _overload_numpy-contributing: ***************** -How to contribute +How to Contribute ***************** We welcome contributions from anyone via pull requests on `GitHub `_. If you don't feel comfortable modifying or adding functionality, we also welcome feature requests and bug -reports as `GitHub issues `_. +reports as `GitHub Issues `_. -Developer documentation + +Developer Documentation ======================= .. toctree:: :maxdepth: 1 + install testing docs diff --git a/docs/getting_started.rst b/docs/getting_started.rst new file mode 100644 index 0000000..392d955 --- /dev/null +++ b/docs/getting_started.rst @@ -0,0 +1,17 @@ +############### +Getting started +############### + + +.. automodule:: overload_numpy.mixin + + +Where to go from here? +====================== + +This page is meant to demonstrate a few initial things you may want to do with +|overload_numpy|. There is much more functionality that you can discover by +perusing the :ref:`user guide `. Some other +commonly-used functionality includes: + +* :ref:`Adding type constraints ` diff --git a/docs/index.rst b/docs/index.rst index f3d175d..be1e828 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -1,18 +1,60 @@ -############## -Overload NumPy -############## +.. include:: references.txt +############################ +Overload-NumPy Documentation +############################ -.. automodule:: overload_numpy.mixin +|NumPy| offers powerful methods to allow arguments of |NumPy| functions (and +|ufunc| objects) to define how a given function operates on them. The details +are specified in NEP13_ and NEP18_, but in summary: normally |NumPy| only works +on an |ndarray| but with NEP13_/NEP18_ for a custom object, users can register +overrides for a |NumPy| function and then use that function on that object (a +quick example is outlined below). Plugging into the |NumPy| framework is +convenient both for developers -- to let |NumPy| take care of the actual math -- +and users -- who get many things, not least of which is a familiar API. If all +this sounds great, that's because it is. However, if you've read NEP13_/NEP18_ +then you know that making the |NumPy| bridge to your custom object and +registering overrides is non-trivial. That's where |overload_numpy| comes in. +|overload_numpy| offers convenient base classes for the |NumPy| bridge and +powerful methods to register overrides for functions, |ufunc| objects, and even +|ufunc| methods (e.g. ``.accumulate()``). The library is fully typed and +(almost) fully ``c``-transpiled for speed. +.. code-block:: python + + from dataclasses import dataclass + import numpy as np + + @dataclass + class ArrayWrapper: + x: np.ndarray + + ... # lot's of non-trivial implementation details + + aw = ArrayWrapper(np.arange(10)) + + np.add(aw, aw) # returns ArrayWrapper([0, 2, ...]) + + +This package is being actively developed in a `public repository on GitHub +`_, and we are always looking for +new contributors! No contribution is too small, so if you have any trouble with +this code, find a typo, or have requests for new content (tutorials or +features), `open an issue on GitHub +`_. + + +.. --------------------- +.. Nav bar (top of docs) .. toctree:: :maxdepth: 1 :titlesonly: install - contributing + getting_started src/index + contributing Contributors diff --git a/docs/install.rst b/docs/install.rst index 64a86b0..e1810b4 100644 --- a/docs/install.rst +++ b/docs/install.rst @@ -1,6 +1,6 @@ .. include:: references.txt -.. overload_numpy-install: +.. _overload_numpy-install: ************ Installation @@ -48,12 +48,20 @@ the cloned ``overload_numpy`` directory) python -m pip install [-e] . +To ``c``-transpile and build wheels with ``mypyc``. + +.. code-block:: bash + + python -m pip install [-e] . --install-option='--use-mypyc" + + Python Dependencies =================== This packages has the following dependencies: * `Python`_ >= 3.8 +* ``mypy_extensions`` >= 0.4.3 : for ``c``-transpilation Explicit version requirements are specified in the project `pyproject.toml `_. ``pip`` diff --git a/docs/references.txt b/docs/references.txt index b49146b..ee25df8 100644 --- a/docs/references.txt +++ b/docs/references.txt @@ -1 +1,3 @@ .. _Python: http://www.python.org +.. _NEP13: https://numpy.org/neps/nep-0013-ufunc-overrides.html +.. _NEP18: https://numpy.org/neps/nep-0018-array-function-protocol.html diff --git a/docs/src/constraints.rst b/docs/src/constraints.rst index bc13ae9..9b3c711 100644 --- a/docs/src/constraints.rst +++ b/docs/src/constraints.rst @@ -1,4 +1,7 @@ -.. _constraints: +.. module:: overload_numpy.constraints + :noindex: + +.. _overload_numpy-constraints: ############################################### Type Constraints (`overload_numpy.constraints`) @@ -6,6 +9,7 @@ Type Constraints (`overload_numpy.constraints`) .. automodule:: overload_numpy.constraints + API === diff --git a/docs/src/conventions.rst b/docs/src/conventions.rst new file mode 100644 index 0000000..df64a65 --- /dev/null +++ b/docs/src/conventions.rst @@ -0,0 +1,13 @@ +.. _overload_numpy-conventions: + +########### +Conventions +########### + +Public API +========== + +Some things in this package are public, others are not. This is how to tell: + +* If a module or package defines ``__all__``, that authoritatively defines the public interface. +* If something begins with a leading underscore, it and its contents are private. diff --git a/docs/src/index.rst b/docs/src/index.rst index 949b9d3..2fd321d 100644 --- a/docs/src/index.rst +++ b/docs/src/index.rst @@ -1,18 +1,26 @@ -.. _overload_numpy: +.. _overload_numpy-user-guide: -################################# -Overload NumPy (`overload_numpy`) -################################# +########## +User Guide +########## + +The user guide contains exhaustive descriptions of all of the functions and +classes available in ``overload_numpy``, with some inline narrative descriptions and +demonstrations of functionality. .. toctree:: :maxdepth: 1 + conventions + mixin + register constraints -.. _overload_numpy-api: +Recent additions and changes +============================ -API -=== +.. toctree:: + :maxdepth: 2 -.. automodapi:: overload_numpy + whatsnew/index diff --git a/docs/src/mixin.rst b/docs/src/mixin.rst new file mode 100644 index 0000000..547e006 --- /dev/null +++ b/docs/src/mixin.rst @@ -0,0 +1,18 @@ +.. _overload_numpy-mixin: + +###################################### +Mixin Bridges (`overload_numpy.mixin`) +###################################### + +Mixins for adding |array_ufunc|_ &/or |array_function|_ methods. +Choose the one that's best suited for your needs. + + +API +=== + +.. automodapi:: overload_numpy + :no-main-docstr: + :no-heading: + :noindex: + :skip: NumPyOverloader diff --git a/docs/src/register.rst b/docs/src/register.rst new file mode 100644 index 0000000..1c00498 --- /dev/null +++ b/docs/src/register.rst @@ -0,0 +1,31 @@ +.. module:: overload_numpy.overload + +.. _overload_numpy-overload: + +############################################# +Override Registry (`overload_numpy.overload`) +############################################# + +.. automodule:: overload_numpy.overload + + + +A note about |ufunc| +==================== + +When registering a |ufunc| override a wrapper object is returned instead of the +original function. On these objects only the ``__call__`` and ``register`` +methods are public API. + + + +API +=== + +.. automodapi:: overload_numpy + :no-main-docstr: + :no-heading: + :noindex: + :skip: NPArrayOverloadMixin + :skip: NPArrayFuncOverloadMixin + :skip: NPArrayUFuncOverloadMixin diff --git a/docs/testing.rst b/docs/testing.rst index 561a198..fc68e15 100644 --- a/docs/testing.rst +++ b/docs/testing.rst @@ -40,3 +40,6 @@ Then you can run the tests with: :: pytest overload_numpy + + +== diff --git a/py.typed b/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/pyproject.toml b/pyproject.toml index a6de8bb..7aa0909 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,6 @@ ] dependencies = [ "mypy_extensions>=0.4.3", - "numpy>=1.18", ] [project.optional-dependencies] diff --git a/src/overload_numpy/constraints.py b/src/overload_numpy/constraints.py index bba025c..5250cf4 100644 --- a/src/overload_numpy/constraints.py +++ b/src/overload_numpy/constraints.py @@ -81,7 +81,7 @@ constraint. There are currently two things you need to do: 1. subclass :class:`overload_numpy.constraints.TypeConstraint` - 2. define a method ``validate_type`` + 2. define a method ``__call__`` As an example, let's define a constraint where the argument must be one of 2 types: @@ -123,19 +123,38 @@ @mypyc_attr(allow_interpreted_subclasses=True) class TypeConstraint(metaclass=ABCMeta): - """ABC for constraining an argument type. + r"""ABC for constraining an argument type. .. warning:: This class will be converted to a runtime-checkable `Protocol` when mypyc behaves nicely with runtime_checkable interpreted subclasses (see https://github.com/mypyc/mypyc/issues/909). + + Examples + -------- + It's very easy to define a custom type constraint. + + >>> from dataclasses import dataclass + >>> from overload_numpy.constraints import TypeConstraint + + >>> @dataclass(frozen=True) + ... class ThisOrThat(TypeConstraint): + ... this: type + ... that: type + ... def validate_type(self, arg_type: type, /) -> bool: + ... return arg_type is self.this or arg_type is self.that """ @abstractmethod def validate_type(self, arg_type: type, /) -> bool: """Validate the argument type. + This is used in :class:`overload_numpy.mixin.NPArrayFuncOverloadMixin` + and subclasses like :class:`overload_numpy.mixin.NPArrayOverloadMixin` + to ensure that the input is of the correct set of types to work + with the |array_function|_ override. + Parameters ---------- arg_type : type, positional-only @@ -145,13 +164,58 @@ def validate_type(self, arg_type: type, /) -> bool: ------- bool Whether the type is valid. + + Examples + -------- + The simplest built-in type constraint is + :class:`overload_numpy.constraints.Invariant`. + + >>> from overload_numpy.constraints import Invariant + >>> constraint = Invariant(int) + >>> constraint.validate_type(int) # exact type + True + >>> constraint.validate_type(bool) # subclass + False """ + def validate_object(self, arg: object, /) -> bool: + """Validate an argument. + + This is used in :class:`overload_numpy.mixin.NPArrayFuncOverloadMixin` + and subclasses like :class:`overload_numpy.mixin.NPArrayOverloadMixin` + to ensure that the input is of the correct set of types to work + with the |array_function|_ override. + + Parameters + ---------- + arg : object, positional-only + The argument that's type must fit the type constraint. + + Returns + ------- + bool + Whether the type is valid. + + Examples + -------- + The simplest built-in type constraint is + :class:`overload_numpy.constraints.Invariant`. + + >>> from overload_numpy.constraints import Invariant + >>> constraint = Invariant(int) + >>> constraint.validate_type(int) # exact type + True + >>> constraint.validate_type(bool) # subclass + False + """ + return self.validate_type(type(arg)) + @mypyc_attr(allow_interpreted_subclasses=True) @dataclass(frozen=True) class Invariant(TypeConstraint): - """Type constraint for invariance -- the exact type. + r""" + Type constraint for invariance -- the exact type. This is equivalent to ``arg_type is bound``. @@ -160,6 +224,11 @@ class Invariant(TypeConstraint): bound : type The exact type of the argument. + Notes + ----- + When compiled this class permits interpreted subclasses, see + https://mypyc.readthedocs.io/en/latest/native_classes.html. + Examples -------- Construct the constraint object: @@ -177,7 +246,6 @@ class Invariant(TypeConstraint): """ bound: type - """The exact type of the argument.""" def validate_type(self, arg_type: type, /) -> bool: return arg_type is self.bound @@ -186,7 +254,8 @@ def validate_type(self, arg_type: type, /) -> bool: @mypyc_attr(allow_interpreted_subclasses=True) @dataclass(frozen=True) class Covariant(TypeConstraint): - """A covariant constraint -- permitting subclasses. + r""" + A covariant constraint -- permitting subclasses. This is the most common constraint, equivalent to ``issubclass(arg_type, bound)``. @@ -196,6 +265,11 @@ class Covariant(TypeConstraint): bound : type The parent type of the argument. + Notes + ----- + When compiled this class permits interpreted subclasses, see + https://mypyc.readthedocs.io/en/latest/native_classes.html. + Examples -------- Construct the constraint object: @@ -213,7 +287,6 @@ class Covariant(TypeConstraint): """ bound: type - """The upper bound type of the argument.""" def validate_type(self, arg_type: type, /) -> bool: return issubclass(arg_type, self.bound) @@ -222,7 +295,8 @@ def validate_type(self, arg_type: type, /) -> bool: @mypyc_attr(allow_interpreted_subclasses=True) @dataclass(frozen=True) class Contravariant(TypeConstraint): - """A contravariant constraint -- permitting superclasses. + r""" + A contravariant constraint -- permitting superclasses. An uncommon constraint. See examples for why. @@ -231,6 +305,11 @@ class Contravariant(TypeConstraint): bound : type The child type of the argument. + Notes + ----- + When compiled this class permits interpreted subclasses, see + https://mypyc.readthedocs.io/en/latest/native_classes.html. + Examples -------- Construct the constraint object: @@ -248,7 +327,6 @@ class Contravariant(TypeConstraint): """ bound: type - """The lower bound type of the argument.""" def validate_type(self, arg_type: type, /) -> bool: return issubclass(self.bound, arg_type) @@ -257,7 +335,12 @@ def validate_type(self, arg_type: type, /) -> bool: @mypyc_attr(allow_interpreted_subclasses=True) @dataclass(frozen=True) class Between(TypeConstraint): - """Type constrained between two types. + r""" + Type constrained between two types. + + This combines the functionality of + :class:`~overload_numpy.constraints.Covariant` and + :class:`~overload_numpy.constraints.Contravariant`. Parameters ---------- @@ -266,9 +349,14 @@ class Between(TypeConstraint): upper_bound : type The parent type of the argument. + Notes + ----- + When compiled this class permits interpreted subclasses, see + https://mypyc.readthedocs.io/en/latest/native_classes.html + Examples -------- - For this example we need a type heirarchy: + For this example we need a type hierarchy: >>> class A: pass >>> class B(A): pass @@ -295,14 +383,27 @@ class Between(TypeConstraint): """ lower_bound: type - """The lower bound type of the argument.""" upper_bound: type - """The upper bound type of the argument.""" def validate_type(self, arg_type: type, /) -> bool: return issubclass(self.lower_bound, arg_type) & issubclass(arg_type, self.upper_bound) @property def bounds(self) -> tuple[type, type]: - """Return tuple of (lower, upper) bounds.""" + """Return tuple of (lower, upper) bounds. + + The lower bound is contravariant, the upper bound covariant. + + Examples + -------- + For this example we need a type hierarchy: + + >>> class A: pass + >>> class B(A): pass + >>> class C(B): pass + + >>> constraint = Between(C, B) + >>> constraint.bounds + (C, B) + """ return (self.lower_bound, self.upper_bound) diff --git a/src/overload_numpy/mixin.py b/src/overload_numpy/mixin.py index c96a9b9..ca30f8a 100644 --- a/src/overload_numpy/mixin.py +++ b/src/overload_numpy/mixin.py @@ -46,7 +46,7 @@ Wrap1D(x=array([0, 1, 2, 0, 1, 2])) |ufunc| also have a number of methods: 'at', 'accumulate', 'outer', etc. The -funtion dispatch mechanism in `NEP13 +function dispatch mechanism in `NEP13 `_ says that "If one of the input or output arguments implements __array_ufunc__, it is executed instead of the ufunc." Currently the overloaded `numpy.add` does not work for any of the @@ -220,17 +220,28 @@ class NPArrayFuncOverloadMixin: """Mixin for adding |array_function|_ to a class. + This mixin adds the method ``__array_function__``. Subclasses must define a + class variable ``NP_OVERLOADS`` and optionally ``NP_FUNC_TYPES``. + Attributes ---------- NP_OVERLOADS : |NumPyOverloader| + How the overrides are registered. A class-attribute of an instance of |NumPyOverloader|. NP_FUNC_TYPES : frozenset[type | TypeConstraint] | None, optional + Default type constraints. A class-attribute of `None` or a `frozenset` of `type` or |TypeConstraint|. If `None`, then the ``types`` argument for overloading functions becomes mandatory. If a `frozenset` (including blank) then the self-type + the contents are used in the types check. See |array_function|_ for details of the ``types`` argument. + Notes + ----- + When compiled this class is a ``mypyc`` :func:`~mypy_extensions.trait` and + permits interpreted subclasses (see + https://mypyc.readthedocs.io/en/latest/native_classes.html#inheritance). + Examples -------- First, some imports: @@ -252,8 +263,6 @@ class NPArrayFuncOverloadMixin: ... x: np.ndarray ... NP_OVERLOADS: ClassVar[NumPyOverloader] = W_FUNCS - Implementing an Overload - ^^^^^^^^^^^^^^^^^^^^^^^^ Now :mod:`numpy` functions can be overloaded and registered for ``Wrap1D``. >>> @W_FUNCS.implements(np.concatenate, Wrap1D) @@ -266,8 +275,6 @@ class NPArrayFuncOverloadMixin: >>> np.concatenate((w1d, w1d)) Wrap1D(x=array([0, 1, 2, 0, 1, 2])) - Dispatching Overloads for Subclasses - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ What if we defined a subclass of ``Wrap1D``? >>> @dataclass @@ -317,9 +324,6 @@ class NPArrayFuncOverloadMixin: >>> np.concatenate((w3d, w3d)) Wrap3D(x=array([0, 1, 0, 1]), y=array([3, 4, 3, 4]), z=array([6, 7, 6, 7])) - Assisting Groups of Overloads - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - In the previous examples we wrote implementations for a single NumPy function. Overloading the full set of NumPy functions this way would take a long time. @@ -344,16 +348,8 @@ class NPArrayFuncOverloadMixin: """ NP_OVERLOADS: ClassVar[NumPyOverloader] - """A class-attribute of an instance of |NumPyOverloader|.""" NP_FUNC_TYPES: ClassVar[frozenset[type | TypeConstraint] | None] = frozenset() - """ - A class-attribute of `None` or a `frozenset` of `type` / |TypeConstraint|. - If `None`, then the ``types`` argument for overloading functions becomes - mandatory. If a `frozenset` (including blank) then the self-type + the - contents are used in the types check. See |array_function|_ for details - of the ``types`` argument. - """ def __array_function__( self, func: Callable[..., Any], types: Collection[type], args: tuple[Any, ...], kwargs: dict[str, Any] @@ -413,13 +409,23 @@ def __array_function__( @mypyc_attr(allow_interpreted_subclasses=True) @trait class NPArrayUFuncOverloadMixin: - """Mixin for adding |array_ufunc|_ to a class. + """ + Mixin for adding |array_ufunc|_ to a class. + + This mixin adds the method ``__array_ufunc__``. Subclasses must define a + class variable ``NP_OVERLOADS``. Attributes ---------- NP_OVERLOADS : |NumPyOverloader| A class-attribute of an instance of |NumPyOverloader|. + Notes + ----- + When compiled this class is a ``mypyc`` :func:`~mypy_extensions.trait` and + permits interpreted subclasses (see + https://mypyc.readthedocs.io/en/latest/native_classes.html#inheritance). + Examples -------- First, some imports: @@ -443,8 +449,6 @@ class NPArrayUFuncOverloadMixin: >>> w1d = Wrap1D(np.arange(3)) - Implementing an Overload - ^^^^^^^^^^^^^^^^^^^^^^^^ Now :class:`numpy.ufunc` can be overloaded and registered for ``Wrap1D``. >>> @W_FUNCS.implements(np.add, Wrap1D) @@ -456,7 +460,7 @@ class NPArrayUFuncOverloadMixin: >>> np.add(w1d, w1d) Wrap1D(x=array([0, 2, 4])) - |ufunc| also have a number of methods: 'at', 'accumulate', etc. The funtion + |ufunc| also have a number of methods: 'at', 'accumulate', etc. The function dispatch mechanism in `NEP13 `_ says that "If one of the input or output arguments implements __array_ufunc__, it is executed @@ -477,8 +481,6 @@ class NPArrayUFuncOverloadMixin: >>> np.add.accumulate(w1d) Wrap1D(x=array([0, 1, 3])) - Dispatching Overloads for Subclasses - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ What if we defined a subclass of ``Wrap1D``? >>> @dataclass @@ -527,9 +529,6 @@ class NPArrayUFuncOverloadMixin: >>> np.add(w3d, w3d) Wrap3D(x=array([0, 2]), y=array([6, 8]), z=array([12, 14])) - Assisting Groups of Overloads - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - In the previous examples we wrote implementations for a single NumPy function. Overloading the full set of NumPy functions this way would take a long time. @@ -561,7 +560,6 @@ class NPArrayUFuncOverloadMixin: """ NP_OVERLOADS: ClassVar[NumPyOverloader] - """A class-attribute of an instance of |NumPyOverloader|.""" def __array_ufunc__(self, ufunc: UFuncLike, method: UFMT, *inputs: Any, **kwargs: Any) -> Any: # Check if can be dispatched. @@ -591,7 +589,12 @@ def __array_ufunc__(self, ufunc: UFuncLike, method: UFMT, *inputs: Any, **kwargs @mypyc_attr(allow_interpreted_subclasses=True) class NPArrayOverloadMixin(NPArrayFuncOverloadMixin, NPArrayUFuncOverloadMixin): - """Mixin for adding |array_ufunc|_ and |array_function|_ to a class. + """ + Mixin for adding |array_ufunc|_ and |array_function|_ to a class. + + This mixin adds the methods ``__array_ufunc__`` and ``__array_function__``. + Subclasses must define a class variable ``NP_OVERLOADS`` and optionally + ``NP_FUNC_TYPES``. Attributes ---------- @@ -601,6 +604,11 @@ class NPArrayOverloadMixin(NPArrayFuncOverloadMixin, NPArrayUFuncOverloadMixin): A class-attribute of `None` or a `frozenset` of `type` or |TypeConstraint|. + Notes + ----- + When compiled this class is permits interpreted subclasses (see + https://mypyc.readthedocs.io/en/latest/native_classes.html#inheritance). + Examples -------- First, some imports: @@ -624,8 +632,6 @@ class NPArrayOverloadMixin(NPArrayFuncOverloadMixin, NPArrayUFuncOverloadMixin): >>> w1d = Wrap1D(np.arange(3)) - Implementing an Overload - ^^^^^^^^^^^^^^^^^^^^^^^^ Now both :class:`numpy.ufunc` (e.g. :obj:`numpy.add`) and :mod:`numpy` functions (e.g. :func:`numpy.concatenate`) can be overloaded and registered for ``Wrap1D``. @@ -646,7 +652,7 @@ class NPArrayOverloadMixin(NPArrayFuncOverloadMixin, NPArrayUFuncOverloadMixin): >>> np.concatenate((w1d, w1d)) Wrap1D(x=array([0, 1, 2, 0, 1, 2])) - |ufunc| also have a number of methods: 'at', 'accumulate', etc. The funtion + |ufunc| also have a number of methods: 'at', 'accumulate', etc. The function dispatch mechanism in `NEP13 `_ says that "If one of the input or output arguments implements __array_ufunc__, it is executed @@ -667,8 +673,6 @@ class NPArrayOverloadMixin(NPArrayFuncOverloadMixin, NPArrayUFuncOverloadMixin): >>> np.add.accumulate(w1d) Wrap1D(x=array([0, 1, 3])) - Dispatching Overloads for Subclasses - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ What if we defined a subclass of ``Wrap1D``? >>> @dataclass @@ -740,9 +744,6 @@ class NPArrayOverloadMixin(NPArrayFuncOverloadMixin, NPArrayUFuncOverloadMixin): >>> np.concatenate((w3d, w3d)) Wrap3D(x=array([0, 1, 0, 1]), y=array([3, 4, 3, 4]), z=array([6, 7, 6, 7])) - Assisting Groups of Overloads - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - In the previous examples we wrote implementations for a single NumPy function. Overloading the full set of NumPy functions this way would take a long time. diff --git a/src/overload_numpy/overload.py b/src/overload_numpy/overload.py index 42805bf..795a5b5 100644 --- a/src/overload_numpy/overload.py +++ b/src/overload_numpy/overload.py @@ -56,7 +56,14 @@ # Dispatcher[AssistsFunc] | Dispatcher[AssistsUFunc] # TODO: when py3.10+ # https://bugs.python.org/issue42233 class NumPyOverloader(Mapping[str, Dispatcher[Any]]): - """Overload :mod:`numpy` functions with |array_function|_. + """ + Register :mod:`numpy` function overrides. + + This mapping works in conjunction with a mixin + (:class:`~overload_numpy.NPArrayFuncOverloadMixin`, + :class:`~overload_numpy.NPArrayUFuncOverloadMixin`, + or :class:`~overload_numpy.NPArrayOverloadMixin`) to register and implement + overrides with |array_function|_ and |array_ufunc|_. Examples -------- @@ -95,7 +102,10 @@ class NumPyOverloader(Mapping[str, Dispatcher[Any]]): """ def __init__(self) -> None: - # `_reg` is initialized here for `dataclasses.dataclass` ssubclasses. + self.__post_init__() # initialize this way for `dataclasses.dataclass` subclasses. + + def __post_init__(self) -> None: + # `_reg` is initialized here for `dataclasses.dataclass` subclasses. self._reg: dict[str, All_Dispatchers] object.__setattr__(self, "_reg", {}) # compatible with frozen dataclass # TODO parametrization of Dispatcher. Use All_Dispatchers @@ -147,7 +157,6 @@ def implements( ) -> ImplementsFuncDecorator: ... - # @singledispatchmethod # TODO! has probs with mypyc, so using overloads def implements( self, numpy_func: UFuncLike | Callable[..., Any], @@ -157,7 +166,8 @@ def implements( types: type | TypeConstraint | Collection[type | TypeConstraint] | None = None, methods: UFMsT = "__call__", ) -> ImplementsUFuncDecorator | ImplementsFuncDecorator: - """Register an |array_function|_ implementation object. + """ + Register an |array_function|_ implementation object. This is a decorator factory, returning ``decorator``, which registers the decorated function as an overload method for :mod:`numpy` function @@ -171,8 +181,7 @@ def implements( The class type for which the overload implementation is being registered. - types : type or TypeConstraint or Collection thereof or None, - keyword-only + types : type or TypeConstraint or Collection thereof or None, keyword-only The types of the arguments of `numpy_func`. See |array_function|_. If `None` then ``dispatch_on`` must have class-level attribute ``NP_FUNC_TYPES`` specifying the types. @@ -282,7 +291,6 @@ def assists( ) -> AssistsManyDecorator: ... - # @singledispatchmethod # TODO! having probs with mypyc, so using overloads def assists( self, numpy_funcs: UFuncLike | Callable[..., Any] | set[Callable[..., Any] | UFuncLike], @@ -292,7 +300,8 @@ def assists( types: type | TypeConstraint | Collection[type | TypeConstraint] | None = None, methods: UFMsT = "__call__", ) -> AssistsUFuncDecorator | AssistsFuncDecorator | AssistsManyDecorator: - """Register an |array_function|_ assistance function. + """ + Register an |array_function|_ assistance function. This is a decorator factory, returning ``decorator``, which registers the decorated function as an overload method for :mod:`numpy` function @@ -300,16 +309,21 @@ def assists( Parameters ---------- - numpy_func : callable[..., Any], positional-only + numpy_funcs : callable[..., Any], positional-only The :mod:`numpy` function that is being overloaded. dispatch_on : type The class type for which the overload implementation is being registered. - types : type or TypeConstraint or Collection thereof or None, - keyword-only - The types of the arguments of `numpy_func`. See |array_function|_ + types : type or TypeConstraint or Collection thereof or None, keyword-only + The types of the arguments of ``numpy_func``. + See |array_function|_ for details. + Only used if a function (not |ufunc|) is being overridden. If `None` then ``dispatch_on`` must have class-level attribute ``NP_FUNC_TYPES`` specifying the types. + methods : {'__call__', 'at', 'accumulate', 'outer', 'reduce', 'reduceat'} or set thereof, keyword-only + The |ufunc| methods for which this override applies. + Default is just "__call__". + Only used if a |ufunc| (not function) is being overridden. Returns ------- diff --git a/src/overload_numpy/utils.py b/src/overload_numpy/utils.py index 2a0499a..865972b 100644 --- a/src/overload_numpy/utils.py +++ b/src/overload_numpy/utils.py @@ -53,7 +53,7 @@ class UFuncMethodOverloadMap(TypedDict): ############################################################################## -def _parse_methods(methods: UFMsT) -> set[UFMT]: +def _parse_methods(methods: UFMsT) -> frozenset[UFMT]: """Parse |ufunc| method. Parameters @@ -75,7 +75,7 @@ def _parse_methods(methods: UFMsT) -> set[UFMT]: # validation that each elt is a UFUNC method type if any(m not in VALID_UFUNC_METHODS for m in ms): raise ValueError(f"methods must be an element or subset of {VALID_UFUNC_METHODS}, not {ms}") - return ms + return frozenset(ms) def _get_key(key: str | UFuncLike | Callable[..., Any] | Any) -> str: diff --git a/src/overload_numpy/wrapper/func.py b/src/overload_numpy/wrapper/func.py index f821dc1..e19c5aa 100644 --- a/src/overload_numpy/wrapper/func.py +++ b/src/overload_numpy/wrapper/func.py @@ -6,16 +6,7 @@ # STDLIB from collections.abc import Collection from dataclasses import dataclass -from typing import ( - TYPE_CHECKING, - Any, - Callable, - ClassVar, - Generic, - Mapping, - TypeVar, - cast, -) +from typing import TYPE_CHECKING, Any, Callable, ClassVar, Generic, Mapping, TypeVar # THIRDPARTY from mypy_extensions import trait @@ -82,7 +73,7 @@ def validate_types(self, types: Collection[type], /) -> bool: else: # isinstance(self.types, Collection) valid_types = self.types - # Check that each type is considred valid. e.g. `types` is (ndarray, + # Check that each type is considered valid. e.g. `types` is (ndarray, # bool) and valid_types are (int, ndarray). It passes b/c ndarray <- # ndarray and bool <- int. for t in types: @@ -128,14 +119,13 @@ def __init__( self._dispatch_on = dispatch_on self._numpy_func = numpy_func self._overloader = overloader + self.__post_init__() def __post_init__(self) -> None: # Make single-dispatcher for numpy function key = _get_key(self.numpy_func) - if not self.overloader.__contains__(self.numpy_func): + if key not in self.overloader._reg: self.overloader._reg[key] = Dispatcher[FT]() - else: - cast("Dispatcher[FT]", self.overloader._reg[key]) @property def types(self) -> type | TypeConstraint | Collection[type | TypeConstraint] | None: diff --git a/src/overload_numpy/wrapper/many.py b/src/overload_numpy/wrapper/many.py index a068cc9..3add488 100644 --- a/src/overload_numpy/wrapper/many.py +++ b/src/overload_numpy/wrapper/many.py @@ -4,11 +4,12 @@ from __future__ import annotations # STDLIB +import itertools from dataclasses import dataclass from typing import Any, Callable, TypeVar # LOCAL -from overload_numpy.utils import UFMsT, _parse_methods +from overload_numpy.utils import UFMT, UFMsT, _parse_methods from overload_numpy.wrapper.func import AssistsFuncDecorator from overload_numpy.wrapper.ufunc import ( AssistsUFunc, @@ -34,21 +35,20 @@ class AssistsManyDecorator: """Class for registering `~overload_numpy.NumPyOverloader.assists` (u)funcs. + .. warning:: + + Only the ``__call__`` and ``register`` methods are currently public API. + Instances of this class are created as-needed by |NumPyOverloader| + whenever multiple functions and |ufunc| overrides are made with + `~overload_numpy.NumPyOverloader.assists`. + Parameters ---------- - decorators : tuple[AssistsFuncDecorator | AssistsUFuncDecorator, ...] + _decorators : tuple[AssistsFuncDecorator | AssistsUFuncDecorator, ...] `tuple` of ``AssistsFuncDecorator | AssistsUFuncDecorator``. - - ufunc_wrappers : tuple[ImplementsUFunc | AssistsUFunc, ...] | None - `tuple` of the |ufunc| wrappers. ``__call__`` must be used before this - is not `None`. - - __wrapped__ : Callable[..., Any] | None - The assistance function which this object wraps. ``__call__`` must be - used before this is not `None`. """ - decorators: tuple[AssistsUFuncDecorator | AssistsFuncDecorator, ...] + _decorators: tuple[AssistsUFuncDecorator | AssistsFuncDecorator, ...] """`tuple` of ``AssistsFuncDecorator | AssistsUFuncDecorator``.""" def __post_init__(self) -> None: @@ -73,7 +73,7 @@ def __call__(self: Self, assists_func: Callable[..., Any], /) -> Self: Returns ------- - AssistsManyDecorator + `overload_numpy.wrapper.many.AssistsManyDecorator` """ if self._is_set: raise ValueError("AssistsManyDecorator can only be called once") @@ -84,30 +84,71 @@ def __call__(self: Self, assists_func: Callable[..., Any], /) -> Self: # the decorator is enough to activate it. We separate funcs and ufuncs, # because the latter are kept in attr ``ufunc_wrappers``. # for dec in (d for d in self.decorators if isinstance(d, AssistsFuncDecorator)): # NOTE: mypyc incompatible - for dec in self.decorators: + for dec in self._decorators: if not isinstance(dec, AssistsFuncDecorator): continue dec(assists_func) - ufws = tuple(dec(assists_func) for dec in self.decorators if isinstance(dec, AssistsUFuncDecorator)) + ufws = tuple(dec(assists_func) for dec in self._decorators if isinstance(dec, AssistsUFuncDecorator)) object.__setattr__(self, "ufunc_wrappers", ufws) object.__setattr__(self, "_is_set", True) # prevent re-calling return self - def register(self, method: UFMsT) -> Callable[[C], C]: - ms = _parse_methods(method) + def register(self, methods: UFMsT, /) -> RegisterManyUFuncMethodDecorator: + """Register overload for |ufunc| methods. - def decorator(assist_ufunc_method: C, /) -> C: + Parameters + ---------- + methods : {'__call__', 'at', 'accumulate', 'outer', 'reduce', + 'reduceat'} or set thereof. + The names of the methods to overload. Can be a set the names. - ufws = self.ufunc_wrappers - if ufws is None: - raise ValueError("need to call this decorator first") + Returns + ------- + decorator : `RegisterManyUFuncMethodDecorator` + Decorator to register a function as an overload for a |ufunc| method + (or set thereof). + """ + ufws = self.ufunc_wrappers + if ufws is None: + raise ValueError("need to call this decorator first") + + return RegisterManyUFuncMethodDecorator(ufws, _parse_methods(methods)) + + +@dataclass(frozen=True) +class RegisterManyUFuncMethodDecorator: + """Decorator to register a |ufunc| method implementation. - for ufw in ufws: - for m in ms: - ufw.funcs[m] = assist_ufunc_method + Returned by by `~overload_numpy.wrapper.ufunc.OverrideUfuncBase.register`. - return assist_ufunc_method + .. warning:: - return decorator + Only the ``__call__`` method is public API. Instances of this class are + created as-needed by |NumPyOverloader| if many |ufunc| overrides are + registered. Users should not make an instance of this class. + """ + + _ufunc_wrappers: tuple[ImplementsUFunc | AssistsUFunc, ...] + """`tuple` of ``AssistsFuncDecorator | AssistsUFuncDecorator``.""" + + _applicable_methods: frozenset[UFMT] + """|ufunc| methods for which this decorator will register overrides.""" + + def __call__(self, assist_ufunc_method: C, /) -> C: + """Decorator to register an overload funcction for |ufunc| methods. + + Parameters + ---------- + assist_ufunc_method : Callable[..., Any] + The overload function for specified |ufunc| methods. + + Returns + ------- + func : Callable[..., Any] + Unchanged. + """ + for ufw, m in itertools.product(self._ufunc_wrappers, self._applicable_methods): + ufw._funcs[m] = assist_ufunc_method + return assist_ufunc_method diff --git a/src/overload_numpy/wrapper/ufunc.py b/src/overload_numpy/wrapper/ufunc.py index 416c0db..46fe5c2 100644 --- a/src/overload_numpy/wrapper/ufunc.py +++ b/src/overload_numpy/wrapper/ufunc.py @@ -5,6 +5,7 @@ # STDLIB from dataclasses import dataclass +from types import MappingProxyType from typing import ( TYPE_CHECKING, Any, @@ -14,7 +15,7 @@ KeysView, Mapping, TypeVar, - cast, + final, ) # LOCAL @@ -62,7 +63,7 @@ class OverrideUFuncBase: `~overload_numpy.wrapper.dispatch.Dispatcher`. """ - funcs: UFuncMethodOverloadMap + _funcs: UFuncMethodOverloadMap """The overloading function for each :class:`numpy.ufunc` method.""" implements: UFuncLike @@ -72,10 +73,18 @@ class OverrideUFuncBase: """The type dispatched on. See `~overload_numpy.wrapper.dispatch.Dispatcher`.""" + @property + def funcs(self) -> MappingProxyType[str, object]: + return MappingProxyType(self._funcs) + + @property + def __wrapped__(self) -> Callable[..., Any]: + return self._funcs["__call__"] + @property def methods(self) -> KeysView[str]: """All the |ufunc| methods.""" - return self.funcs.keys() + return self._funcs.keys() def validate_method(self, method: UFMT, /) -> bool: """Validates that the method has an overload. @@ -90,9 +99,9 @@ def validate_method(self, method: UFMT, /) -> bool: bool If the ``method`` is in the attr ``funcs``. """ - return (method in self.funcs) if method != "__call__" else True + return method in self._funcs - def register(self, methods: UFMsT) -> RegisterUFuncMethodDecorator: + def register(self, methods: UFMsT, /) -> RegisterUFuncMethodDecorator: """Register overload for |ufunc| methods. Parameters @@ -142,20 +151,34 @@ def register(self, methods: UFMsT) -> RegisterUFuncMethodDecorator: This is a decorator factory. """ # TODO! validation that the function has the right signature. - return RegisterUFuncMethodDecorator(self.funcs, _parse_methods(methods)) + return RegisterUFuncMethodDecorator(self._funcs, _parse_methods(methods)) +@final @dataclass(frozen=True) class RegisterUFuncMethodDecorator: """Decorator to register a |ufunc| method implementation. - Called by `~overload_numpy.wrapper.ufunc.OverrideUfuncBase.register`. + Returned by `~overload_numpy.wrapper.ufunc.OverrideUfuncBase.register`. + + .. warning:: + + Only the ``__call__`` method is public API. Instances of this class are + created as-needed by |NumPyOverloader| if a |ufunc| override is + registered. Users should not make an instance of this class. + + Methods + ------- + __call__ """ - funcs_dict: UFuncMethodOverloadMap - """dict of the |ufunc| method overloads.""" + _funcs_dict: UFuncMethodOverloadMap + """`dict` of the |ufunc| method overloads. + + This is linked to the map on a `OverrideUFuncBase` instance. + """ - applicable_methods: set[UFMT] + _applicable_methods: frozenset[UFMT] """|ufunc| methods for which this decorator will register overrides.""" def __call__(self, func: C, /) -> C: @@ -173,8 +196,8 @@ def __call__(self, func: C, /) -> C: """ # Iterate through the methods, adding as overloads for specified # methods. - for m in self.applicable_methods: - self.funcs_dict[m] = func + for m in self._applicable_methods: + self._funcs_dict[m] = func return func @@ -211,7 +234,7 @@ class OverloadUFuncDecoratorBase(Generic[UT]): numpy_func: UFuncLike """The :mod:`numpy` function that is being overloaded.""" - methods: set[UFMT] + methods: frozenset[UFMT] """Set of names of |ufunc| methods.""" overloader: NumPyOverloader @@ -220,10 +243,8 @@ class OverloadUFuncDecoratorBase(Generic[UT]): def __post_init__(self) -> None: # Make single-dispatcher for numpy function key = _get_key(self.numpy_func) - if not self.overloader.__contains__(self.numpy_func): + if key not in self.overloader._reg: self.overloader._reg[key] = Dispatcher[UT]() - else: - cast("Dispatcher[UT]", self.overloader._reg[key]) def __call__(self, func: Callable[..., Any], /) -> UT: """Register overload on Dispatcher. @@ -242,7 +263,7 @@ def __call__(self, func: Callable[..., Any], /) -> UT: """ # Adding a new numpy function info: UT = self.OverrideCls( - funcs={m: func for m in self.methods}, # includes __call__ + _funcs={m: func for m in self.methods}, # includes __call__ implements=self.numpy_func, dispatch_on=self.dispatch_on, ) @@ -263,15 +284,26 @@ def __call__(self, func: Callable[..., Any], /) -> UT: class ImplementsUFunc(OverrideUFuncBase): """Implements a |ufunc| override. + .. warning:: + + Only the ``register`` method is currently public API. Instances of this + class are created as-needed by |NumPyOverloader| whenever a |ufunc| + override is made with `~overload_numpy.NumPyOverloader.implements`. + Parameters ---------- funcs : dict[str, Callable], keyword-only - The overloading function for each |ufunc| method, including ``__call__``. + The overloading function for each |ufunc| method, including + ``__call__``. implements: |ufunc|, keyword-only The overloaded |ufunc|. dispatch_on : type, keyword-only The type dispatched on. See `~overload_numpy.wrapper.dispatch.Dispatcher`. + + Methods + ------- + register """ def __call__( @@ -299,14 +331,18 @@ def __call__( """ if not self.validate_method(method): return NotImplemented - return self.funcs[method](*args, **kwargs) + return self._funcs[method](*args, **kwargs) @dataclass(frozen=True) class ImplementsUFuncDecorator(OverloadUFuncDecoratorBase[ImplementsUFunc]): """Decorator for registering with |NumPyOverloader|. - Instances of this class should not be used directly. + .. warning:: + + Only the ``__call__`` methods is currently public API. Instances of this + class are created as-needed by |NumPyOverloader| whenever a |ufunc| + override is made with `~overload_numpy.NumPyOverloader.implements`. Parameters ---------- @@ -331,6 +367,12 @@ class ImplementsUFuncDecorator(OverloadUFuncDecoratorBase[ImplementsUFunc]): class AssistsUFunc(OverrideUFuncBase): """Assists a |ufunc| override. + .. warning:: + + Only the ``register`` methods is currently public API. Instances of this + class are created as-needed by |NumPyOverloader| whenever a |ufunc| + override is made with `~overload_numpy.NumPyOverloader.assists`. + Parameters ---------- funcs : dict[str, Callable], keyword-only @@ -366,13 +408,19 @@ def __call__(self, method: UFMT, calling_type: type, /, args: tuple[Any, ...], k """ if not self.validate_method(method): return NotImplemented - return self.funcs[method](calling_type, getattr(self.implements, method), *args, **kwargs) + return self._funcs[method](calling_type, getattr(self.implements, method), *args, **kwargs) @dataclass(frozen=True) class AssistsUFuncDecorator(OverloadUFuncDecoratorBase[AssistsUFunc]): """Decorator to register a |ufunc| override assist with |NumPyOverloader|. + .. warning:: + + Only the ``__call__`` methods is currently public API. Instances of this + class are created as-needed by |NumPyOverloader| whenever a |ufunc| + override is made with `~overload_numpy.NumPyOverloader.assists`. + Parameters ---------- overloader : |NumPyOverloader|, keyword-only