From 9f4f568be8dbfe1c5a5011dc58368d338ce684ef Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Sun, 7 Apr 2024 00:53:50 +0800 Subject: [PATCH] feat(functools): add submodule `optree.functools` (#134) --- .flake8 | 6 +- .github/workflows/tests.yml | 10 ++ CHANGELOG.md | 5 +- conda-recipe.yaml | 2 +- docs/conda-recipe.yaml | 2 +- docs/source/functools.rst | 12 ++ docs/source/index.rst | 1 + docs/source/registry.rst | 2 - docs/source/spelling_wordlist.txt | 2 + optree/__init__.py | 5 +- optree/functools.py | 171 +++++++++++++++++++++++++ optree/integration/__init__.py | 35 +++--- optree/integration/jax.py | 6 +- optree/integration/torch.py | 6 +- optree/ops.py | 47 +++---- optree/registry.py | 200 ++++++++++-------------------- optree/typing.py | 10 +- optree/utils.py | 12 +- pyproject.toml | 9 +- requirements.txt | 2 +- tests/.coveragerc | 2 + tests/helpers.py | 83 +++++++++++-- tests/test_functools.py | 108 ++++++++++++++++ tests/test_ops.py | 131 ++++++------------- tests/test_prefix_errors.py | 11 +- tests/test_treespec.py | 63 +++++++--- 26 files changed, 601 insertions(+), 342 deletions(-) create mode 100644 docs/source/functools.rst create mode 100644 optree/functools.py create mode 100644 tests/test_functools.py diff --git a/.flake8 b/.flake8 index 75570420..36f6b948 100644 --- a/.flake8 +++ b/.flake8 @@ -6,8 +6,9 @@ ignore = # E203: whitespace before ':' # W503: line break before binary operator # W504: line break after binary operator + # E704: multiple statements on one line (def) # format by black - E203,W503,W504, + E203,W503,W504,E704 # E501: line too long # W505: doc line too long # too long docstring due to long example blocks @@ -25,9 +26,8 @@ per-file-ignores = # E302: expected 2 blank lines # E305: expected 2 blank lines after class or function definition # E701: multiple statements on one line (colon) - # E704: multiple statements on one line (def) # format by black - *.pyi: E301,E302,E305,E701,E704 + *.pyi: E301,E302,E305,E701 exclude = .git, .vscode, diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 0efcfd90..099829c1 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -116,3 +116,13 @@ jobs: - name: Test with pytest run: | make pytest + + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v4 + if: ${{ matrix.os == 'ubuntu-latest' }} + with: + token: ${{ secrets.CODECOV_TOKEN }} + file: ./tests/coverage.xml + flags: unittests + name: codecov-umbrella + fail_ci_if_error: false diff --git a/CHANGELOG.md b/CHANGELOG.md index dcad9123..0c7a5aaf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,10 +13,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added -- +- Add submodule `optree.functools` by [@XuehaiPan](https://github.com/XuehaiPan) in [#134](https://github.com/metaopt/optree/pull/134). ### Changed +- Update minimal version of `typing-extensions` to 4.5.0 for `typing_extensions.deprecated` by [@XuehaiPan](https://github.com/XuehaiPan) in [#134](https://github.com/metaopt/optree/pull/134). - Update string representation for `OrderedDict` by [@XuehaiPan](https://github.com/XuehaiPan) in [#133](https://github.com/metaopt/optree/pull/133). ### Fixed @@ -25,7 +26,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Removed -- +- Deprecate `optree.Partial` and replace with `optree.functools.partial` by [@XuehaiPan](https://github.com/XuehaiPan) in [#134](https://github.com/metaopt/optree/pull/134). ------ diff --git a/conda-recipe.yaml b/conda-recipe.yaml index 46bff7ac..533808bf 100644 --- a/conda-recipe.yaml +++ b/conda-recipe.yaml @@ -30,7 +30,7 @@ dependencies: - pip # Dependency - - typing-extensions >= 4.0.0 + - typing-extensions >= 4.5.0 # Build toolchain - cmake >= 3.11 diff --git a/docs/conda-recipe.yaml b/docs/conda-recipe.yaml index fbec54ec..f858e189 100644 --- a/docs/conda-recipe.yaml +++ b/docs/conda-recipe.yaml @@ -29,7 +29,7 @@ dependencies: - pip # Dependency - - typing-extensions >= 4.0.0 + - typing-extensions >= 4.5.0 # Build toolchain - cmake >= 3.11 diff --git a/docs/source/functools.rst b/docs/source/functools.rst new file mode 100644 index 00000000..f8dc1f54 --- /dev/null +++ b/docs/source/functools.rst @@ -0,0 +1,12 @@ +Integration with :mod:`functools` +================================= + +.. currentmodule:: optree.functools + +.. autosummary:: + + partial + reduce + +.. autoclass:: partial +.. autofunction:: reduce diff --git a/docs/source/index.rst b/docs/source/index.rst index e8489228..7978d7b4 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -12,6 +12,7 @@ OpTree: Optimized PyTree Utilities :maxdepth: 1 registry.rst + functools.rst typing.rst api.rst diff --git a/docs/source/registry.rst b/docs/source/registry.rst index 9d080bdc..59bbd565 100644 --- a/docs/source/registry.rst +++ b/docs/source/registry.rst @@ -8,7 +8,6 @@ PyTree Node Registry register_pytree_node register_pytree_node_class unregister_pytree_node - Partial register_keypaths AttributeKeyPathEntry GetitemKeyPathEntry @@ -16,7 +15,6 @@ PyTree Node Registry .. autofunction:: register_pytree_node .. autofunction:: register_pytree_node_class .. autofunction:: unregister_pytree_node -.. autofunction:: Partial .. autofunction:: register_keypaths .. autofunction:: AttributeKeyPathEntry .. autofunction:: GetitemKeyPathEntry diff --git a/docs/source/spelling_wordlist.txt b/docs/source/spelling_wordlist.txt index b0db3597..78693300 100644 --- a/docs/source/spelling_wordlist.txt +++ b/docs/source/spelling_wordlist.txt @@ -75,7 +75,9 @@ jax numpy torch dtype +cuda getattr setattr delattr typecheck +subclassed diff --git a/optree/__init__.py b/optree/__init__.py index aafbc609..7ae7d3ce 100644 --- a/optree/__init__.py +++ b/optree/__init__.py @@ -14,7 +14,8 @@ # ============================================================================== """OpTree: Optimized PyTree Utilities.""" -from optree import integration, typing +from optree import functools, integration, typing +from optree.functools import Partial from optree.ops import ( MAX_RECURSION_DEPTH, NONE_IS_LEAF, @@ -74,7 +75,6 @@ from optree.registry import ( AttributeKeyPathEntry, GetitemKeyPathEntry, - Partial, register_keypaths, register_pytree_node, register_pytree_node_class, @@ -161,7 +161,6 @@ 'register_pytree_node', 'register_pytree_node_class', 'unregister_pytree_node', - 'Partial', 'register_keypaths', 'AttributeKeyPathEntry', 'GetitemKeyPathEntry', diff --git a/optree/functools.py b/optree/functools.py new file mode 100644 index 00000000..c40951d9 --- /dev/null +++ b/optree/functools.py @@ -0,0 +1,171 @@ +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""PyTree integration with :mod:`functools.partial`.""" + +from __future__ import annotations + +import functools +from typing import Any, Callable, ClassVar +from typing_extensions import Self # Python 3.11+ +from typing_extensions import deprecated # Python 3.13+ + +from optree import registry +from optree.ops import tree_reduce as reduce +from optree.typing import CustomTreeNode, T + + +__all__ = [ + 'partial', + 'reduce', +] + + +class _HashablePartialShim: + """Object that delegates :meth:`__call__`, :meth:`__eq__`, and :meth:`__hash__` to another object.""" + + __slots__: ClassVar[tuple[str, ...]] = ('partial_func', 'func', 'args', 'keywords') + + func: Callable[..., Any] + args: tuple[Any, ...] + keywords: dict[str, Any] + + def __init__(self, partial_func: functools.partial) -> None: + self.partial_func: functools.partial = partial_func + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + return self.partial_func(*args, **kwargs) + + def __eq__(self, other: object) -> bool: + if isinstance(other, _HashablePartialShim): + return self.partial_func == other.partial_func + return self.partial_func == other + + def __hash__(self) -> int: + return hash(self.partial_func) + + def __repr__(self) -> str: + return repr(self.partial_func) + + +# pylint: disable-next=protected-access +@registry.register_pytree_node_class(namespace=registry.__GLOBAL_NAMESPACE) +class partial( # noqa: N801 # pylint: disable=invalid-name,too-few-public-methods + functools.partial, + CustomTreeNode[T], +): + """A version of :func:`functools.partial` that works in pytrees. + + Use it for partial function evaluation in a way that is compatible with transformations, + e.g., ``partial(func, *args, **kwargs)``. + + (You need to explicitly opt-in to this behavior because we did not want to give + :func:`functools.partial` different semantics than normal function closures.) + + For example, here is a basic usage of :class:`partial` in a manner similar to + :func:`functools.partial`: + + >>> import operator + >>> import torch + >>> add_one = partial(operator.add, torch.ones(())) + >>> add_one(torch.tensor([[1, 2], [3, 4]])) + tensor([[2., 3.], + [4., 5.]]) + + Pytree compatibility means that the resulting partial function can be passed as an argument + within tree-map functions, which is not possible with a standard :func:`functools.partial` + function: + + >>> def call_func_on_cuda(f, *args, **kwargs): + ... f, args, kwargs = tree_map(lambda t: t.cuda(), (f, args, kwargs)) + ... return f(*args, **kwargs) + ... + >>> # doctest: +SKIP + >>> tree_map(lambda t: t.cuda(), add_one) + optree.functools.partial(, tensor(1., device='cuda:0')) + >>> call_func_on_cuda(add_one, torch.tensor([[1, 2], [3, 4]])) + tensor([[2., 3.], + [4., 5.]], device='cuda:0') + + Passing zero arguments to :class:`partial` effectively wraps the original function, making it a + valid argument in tree-map functions: + + >>> # doctest: +SKIP + >>> call_func_on_cuda(partial(torch.add), torch.tensor(1), torch.tensor(2)) + tensor(3, device='cuda:0') + + Had we passed :func:`operator.add` to ``call_func_on_cuda`` directly, it would have resulted in + a :class:`TypeError` or :class:`AttributeError`. + """ + + __slots__: ClassVar[tuple[()]] = () + + func: Callable[..., Any] + args: tuple[T, ...] + keywords: dict[str, T] + + def __new__(cls, func: Callable[..., Any], *args: T, **keywords: T) -> Self: + """Create a new :class:`partial` instance.""" + # In Python 3.10+, if func is itself a functools.partial instance, functools.partial.__new__ + # would merge the arguments of this partial instance with the arguments of the func. We box + # func in a class that does not (yet) have a `func` attribute to defeat this optimization, + # since we care exactly which arguments are considered part of the pytree. + if isinstance(func, functools.partial): + original_func = func + func = _HashablePartialShim(original_func) + assert not hasattr(func, 'func'), 'shimmed function should not have a `func` attribute' + out = super().__new__(cls, func, *args, **keywords) + func.func = original_func.func + func.args = original_func.args + func.keywords = original_func.keywords + return out + + return super().__new__(cls, func, *args, **keywords) + + def __repr__(self) -> str: + """Return a string representation of the :class:`partial` instance.""" + args = [repr(self.func)] + args.extend(repr(x) for x in self.args) + args.extend(f'{k}={v!r}' for (k, v) in self.keywords.items()) + return f'{self.__class__.__module__}.{self.__class__.__qualname__}({", ".join(args)})' + + def tree_flatten(self) -> tuple[ # type: ignore[override] + tuple[tuple[T, ...], dict[str, T]], + Callable[..., Any], + tuple[str, str], + ]: + """Flatten the :class:`partial` instance to children and auxiliary data.""" + return (self.args, self.keywords), self.func, ('args', 'keywords') + + @classmethod + def tree_unflatten( # type: ignore[override] + cls, + metadata: Callable[..., Any], + children: tuple[tuple[T, ...], dict[str, T]], + ) -> Self: + """Unflatten the children and auxiliary data into a :class:`partial` instance.""" + args, keywords = children + return cls(metadata, *args, **keywords) + + +# pylint: disable-next=protected-access +@registry.register_pytree_node_class(namespace=registry.__GLOBAL_NAMESPACE) +@deprecated( + 'The class `optree.Partial` is deprecated and will be removed in a future version. ' + 'Please use `optree.functools.partial` instead.', +) +class Partial(partial): + """Deprecated alias for :class:`partial`.""" + + __slots__: ClassVar[tuple[()]] = () diff --git a/optree/integration/__init__.py b/optree/integration/__init__.py index d00549bd..609109f0 100644 --- a/optree/integration/__init__.py +++ b/optree/integration/__init__.py @@ -14,31 +14,30 @@ # ============================================================================== """Integration with third-party libraries.""" -import sys -from typing import Any +from __future__ import annotations +from typing import TYPE_CHECKING -current_module = sys.modules[__name__] +if TYPE_CHECKING: + from types import ModuleType -SUBMODULES = frozenset({'jax', 'numpy', 'torch'}) +SUBMODULES: frozenset[str] = frozenset({'jax', 'numpy', 'torch'}) -# pylint: disable-next=too-few-public-methods -class _LazyModule(type(current_module)): # type: ignore[misc] - def __getattribute__(self, name: str) -> Any: # noqa: N804 - try: - return super().__getattribute__(name) - except AttributeError: - if name in SUBMODULES: - import importlib # pylint: disable=import-outside-toplevel - submodule = importlib.import_module(f'{__name__}.{name}') - setattr(self, name, submodule) - return submodule - raise +def __getattr__(name: str) -> ModuleType: + if name in SUBMODULES: + import importlib # pylint: disable=import-outside-toplevel + import sys # pylint: disable=import-outside-toplevel + module = sys.modules[__name__] -current_module.__class__ = _LazyModule + submodule = importlib.import_module(f'{__name__}.{name}') + setattr(module, name, submodule) + return submodule -del sys, Any, current_module, _LazyModule + raise AttributeError(f'module {__name__!r} has no attribute {name!r}') + + +del TYPE_CHECKING diff --git a/optree/integration/jax.py b/optree/integration/jax.py index 1b6d0cd7..7ff34e57 100644 --- a/optree/integration/jax.py +++ b/optree/integration/jax.py @@ -84,8 +84,7 @@ def __eq__(self, other: object) -> bool: return ( type(other) is HashablePartial # pylint: disable=unidiomatic-typecheck and self.func.__code__ == other.func.__code__ # type: ignore[attr-defined] - and self.args == other.args - and self.kwargs == other.kwargs + and (self.args, self.kwargs) == (other.args, other.kwargs) ) def __hash__(self) -> int: @@ -98,7 +97,8 @@ def __hash__(self) -> int: ) def __call__(self, *args: Any, **kwargs: Any) -> Any: - return self.func(*self.args, *args, **self.kwargs, **kwargs) + kwargs = {**self.kwargs, **kwargs} + return self.func(*self.args, *args, **kwargs) try: # noqa: SIM105 # pragma: no cover diff --git a/optree/integration/torch.py b/optree/integration/torch.py index eaaa24ff..d228fab2 100644 --- a/optree/integration/torch.py +++ b/optree/integration/torch.py @@ -157,7 +157,7 @@ def _ravel_leaves( def _unravel_empty(flat: torch.Tensor) -> list[torch.Tensor]: if not torch.is_tensor(flat): - raise ValueError(f'Expected a tensor to unravel, got {type(flat)}.') + raise ValueError(f'Expected a tensor to unravel, got {type(flat)!r}.') if flat.shape != (0,): raise ValueError( f'The unravel function expected a tensor of shape {(0,)}, got shape {flat.shape}.', @@ -171,7 +171,7 @@ def _unravel_leaves_single_dtype( flat: torch.Tensor, ) -> list[torch.Tensor]: if not torch.is_tensor(flat): - raise ValueError(f'Expected a tensor to unravel, got {type(flat)}.') + raise ValueError(f'Expected a tensor to unravel, got {type(flat)!r}.') if flat.shape != (sum(sizes),): raise ValueError( f'The unravel function expected a tensor of shape {(sum(sizes),)}, ' @@ -190,7 +190,7 @@ def _unravel_leaves( flat: torch.Tensor, ) -> list[torch.Tensor]: if not torch.is_tensor(flat): - raise ValueError(f'Expected a tensor to unravel, got {type(flat)}.') + raise ValueError(f'Expected a tensor to unravel, got {type(flat)!r}.') if flat.shape != (sum(sizes),): raise ValueError( f'The unravel function expected a tensor of shape {(sum(sizes),)}, ' diff --git a/optree/ops.py b/optree/ops.py index 787ae63a..0566b7aa 100644 --- a/optree/ops.py +++ b/optree/ops.py @@ -23,7 +23,7 @@ import itertools import textwrap from collections import OrderedDict, defaultdict, deque -from typing import Any, Callable, Iterable, Mapping, overload +from typing import Any, Callable, ClassVar, Iterable, Mapping, overload from optree import _C from optree.registry import ( @@ -954,7 +954,11 @@ def tree_transpose_map_with_path( ... inner_treespec=tree_structure({'path': 0, 'value': 0}), ... ) { - 'path': {'b': (('b', 0), [('b', 1, 0), ('b', 1, 1)]), 'a': ('a',), 'c': (('c', 0), ('c', 1))}, + 'path': { + 'b': (('b', 0), [('b', 1, 0), ('b', 1, 1)]), + 'a': ('a',), + 'c': (('c', 0), ('c', 1)) + }, 'value': {'b': (2, [3, 4]), 'a': 1, 'c': (5, 6)} } @@ -1460,10 +1464,16 @@ def tree_broadcast_map_with_path( If multiple inputs are given, all input trees will be broadcasted to the common suffix structure of all inputs: - >>> tree_broadcast_map_with_path(lambda p, x, y: (p, x * y), [5, 6, (3, 4)], [{'a': 7, 'b': 9}, [1, 2], 8]) - [{'a': ((0, 'a'), 35), 'b': ((0, 'b'), 45)}, - [((1, 0), 6), ((1, 1), 12)], - (((2, 0), 24), ((2, 1), 32))] + >>> tree_broadcast_map_with_path( # doctest: +IGNORE_WHITESPACE + ... lambda p, x, y: (p, x * y), + ... [5, 6, (3, 4)], + ... [{'a': 7, 'b': 9}, [1, 2], 8], + ... ) + [ + {'a': ((0, 'a'), 35), 'b': ((0, 'b'), 45)}, + [((1, 0), 6), ((1, 1), 12)], + (((2, 0), 24), ((2, 1), 32)) + ] Args: func (callable): A function that takes ``2 + len(rests)`` arguments, to be applied at the @@ -1522,6 +1532,8 @@ def tree_broadcast_map_with_path( # pylint: disable-next=missing-class-docstring,too-few-public-methods class MissingSentinel: # pragma: no cover + __slots__: ClassVar[tuple[()]] = () + def __repr__(self) -> str: return '' @@ -1538,8 +1550,7 @@ def tree_reduce( is_leaf: Callable[[T], bool] | None = None, none_is_leaf: bool = False, namespace: str = '', -) -> T: # pragma: no cover - ... +) -> T: ... @overload @@ -1551,8 +1562,7 @@ def tree_reduce( is_leaf: Callable[[S], bool] | None = None, none_is_leaf: bool = False, namespace: str = '', -) -> T: # pragma: no cover - ... +) -> T: ... def tree_reduce( @@ -1661,8 +1671,7 @@ def tree_max( key: Callable[[T], Any] | None = None, none_is_leaf: bool = False, namespace: str = '', -) -> T: # pragma: no cover - ... +) -> T: ... @overload @@ -1674,8 +1683,7 @@ def tree_max( is_leaf: Callable[[T], bool] | None = None, none_is_leaf: bool = False, namespace: str = '', -) -> T: # pragma: no cover - ... +) -> T: ... def tree_max( @@ -1756,8 +1764,7 @@ def tree_min( is_leaf: Callable[[T], bool] | None = None, none_is_leaf: bool = False, namespace: str = '', -) -> T: # pragma: no cover - ... +) -> T: ... @overload @@ -1769,8 +1776,7 @@ def tree_min( is_leaf: Callable[[T], bool] | None = None, none_is_leaf: bool = False, namespace: str = '', -) -> T: # pragma: no cover - ... +) -> T: ... def tree_min( @@ -2689,7 +2695,7 @@ def prefix_errors( ) -STANDARD_DICT_TYPES = frozenset({dict, OrderedDict, defaultdict}) +STANDARD_DICT_TYPES: frozenset[type] = frozenset({dict, OrderedDict, defaultdict}) # pylint: disable-next=too-many-locals @@ -2710,8 +2716,7 @@ def _prefix_error( prefix_tree_type = type(prefix_tree) full_tree_type = type(full_tree) both_standard_dict = ( - prefix_tree_type in STANDARD_DICT_TYPES # type: ignore[comparison-overlap] - and full_tree_type in STANDARD_DICT_TYPES # type: ignore[comparison-overlap] + prefix_tree_type in STANDARD_DICT_TYPES and full_tree_type in STANDARD_DICT_TYPES ) both_deque = prefix_tree_type is deque and full_tree_type is deque # type: ignore[comparison-overlap] if prefix_tree_type is not full_tree_type and ( diff --git a/optree/registry.py b/optree/registry.py index e557430d..1fc5b929 100644 --- a/optree/registry.py +++ b/optree/registry.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""OpTree: Optimized PyTree Utilities.""" +"""Registry for custom pytree node types.""" from __future__ import annotations @@ -23,8 +23,18 @@ from collections import OrderedDict, defaultdict, deque, namedtuple from operator import methodcaller from threading import Lock -from typing import TYPE_CHECKING, Any, Callable, Iterable, NamedTuple, Sequence, overload -from typing_extensions import Self # Python 3.11+ +from typing import ( + TYPE_CHECKING, + Any, + Callable, + ClassVar, + Iterable, + NamedTuple, + Sequence, + Type, + overload, +) +from typing_extensions import TypeAlias # Python 3.10+ from optree import _C from optree.typing import ( @@ -50,7 +60,6 @@ 'register_pytree_node', 'register_pytree_node_class', 'unregister_pytree_node', - 'Partial', 'register_keypaths', 'AttributeKeyPathEntry', 'GetitemKeyPathEntry', @@ -73,6 +82,8 @@ class PyTreeNodeRegistryEntry: # pylint: disable-next=missing-class-docstring,too-few-public-methods class GlobalNamespace: # pragma: no cover + __slots__: ClassVar[tuple[()]] = () + def __repr__(self) -> str: return '' @@ -82,12 +93,16 @@ def __repr__(self) -> str: del GlobalNamespace +CustomTreeNodeT: TypeAlias = Type[CustomTreeNode[T]] + + def register_pytree_node( - cls: type[CustomTreeNode[T]], + cls: CustomTreeNodeT, flatten_func: FlattenFunc, unflatten_func: UnflattenFunc, + *, namespace: str, -) -> type[CustomTreeNode[T]]: +) -> CustomTreeNodeT: """Extend the set of types that are considered internal nodes in pytrees. See also :func:`register_pytree_node_class` and :func:`unregister_pytree_node`. @@ -203,9 +218,9 @@ def register_pytree_node( ) """ if not inspect.isclass(cls): - raise TypeError(f'Expected a class, got {cls}.') + raise TypeError(f'Expected a class, got {cls!r}.') if namespace is not __GLOBAL_NAMESPACE and not isinstance(namespace, str): - raise TypeError(f'The namespace must be a string, got {namespace}.') + raise TypeError(f'The namespace must be a string, got {namespace!r}.') if namespace == '': raise ValueError('The namespace cannot be an empty string.') @@ -217,7 +232,12 @@ def register_pytree_node( registration_key = (namespace, cls) with __REGISTRY_LOCK: - _C.register_node(cls, flatten_func, unflatten_func, namespace) + _C.register_node( + cls, + flatten_func, + unflatten_func, + namespace, + ) _NODETYPE_REGISTRY[registration_key] = PyTreeNodeRegistryEntry( cls, flatten_func, @@ -232,24 +252,22 @@ def register_pytree_node_class( cls: str | None = None, *, namespace: str | None = None, -) -> Callable[[type[CustomTreeNode[T]]], type[CustomTreeNode[T]]]: # pragma: no cover - ... +) -> Callable[[CustomTreeNodeT], CustomTreeNodeT]: ... @overload def register_pytree_node_class( - cls: type[CustomTreeNode[T]], + cls: CustomTreeNodeT, *, namespace: str, -) -> type[CustomTreeNode[T]]: # pragma: no cover - ... +) -> CustomTreeNodeT: ... def register_pytree_node_class( - cls: type[CustomTreeNode[T]] | str | None = None, + cls: CustomTreeNodeT | str | None = None, *, namespace: str | None = None, -) -> type[CustomTreeNode[T]] | Callable[[type[CustomTreeNode[T]]], type[CustomTreeNode[T]]]: +) -> CustomTreeNodeT | Callable[[CustomTreeNodeT], CustomTreeNodeT]: """Extend the set of types that are considered internal nodes in pytrees. See also :func:`register_pytree_node` and :func:`unregister_pytree_node`. @@ -309,20 +327,31 @@ def tree_unflatten(cls, metadata, children): raise ValueError('Cannot specify `namespace` when the first argument is a string.') if cls == '': raise ValueError('The namespace cannot be an empty string.') - return functools.partial(register_pytree_node_class, namespace=cls) # type: ignore[return-value] + return functools.partial( + register_pytree_node_class, + namespace=cls, + ) # type: ignore[return-value] if namespace is None: raise ValueError('Must specify `namespace` when the first argument is a class.') if namespace is not __GLOBAL_NAMESPACE and not isinstance(namespace, str): - raise TypeError(f'The namespace must be a string, got {namespace}') + raise TypeError(f'The namespace must be a string, got {namespace!r}') if namespace == '': raise ValueError('The namespace cannot be an empty string.') if cls is None: - return functools.partial(register_pytree_node_class, namespace=namespace) # type: ignore[return-value] + return functools.partial( + register_pytree_node_class, + namespace=namespace, + ) # type: ignore[return-value] if not inspect.isclass(cls): - raise TypeError(f'Expected a class, got {cls}.') - register_pytree_node(cls, methodcaller('tree_flatten'), cls.tree_unflatten, namespace) + raise TypeError(f'Expected a class, got {cls!r}.') + register_pytree_node( + cls, + methodcaller('tree_flatten'), + cls.tree_unflatten, + namespace=namespace, + ) return cls @@ -365,9 +394,9 @@ def unregister_pytree_node( >>> unregister_pytree_node(set, namespace='temp') """ if not inspect.isclass(cls): - raise TypeError(f'Expected a class, got {cls}.') + raise TypeError(f'Expected a class, got {cls!r}.') if namespace is not __GLOBAL_NAMESPACE and not isinstance(namespace, str): - raise TypeError(f'The namespace must be a string, got {namespace}.') + raise TypeError(f'The namespace must be a string, got {namespace!r}.') if namespace == '': raise ValueError('The namespace cannot be an empty string.') @@ -442,7 +471,7 @@ def _defaultdict_flatten( dct: defaultdict[KT, VT], ) -> tuple[tuple[VT, ...], tuple[Callable[[], VT] | None, list[KT]], tuple[KT, ...]]: values, keys, entries = _dict_flatten(dct) - return values, (dct.default_factory, list(keys)), entries + return values, (dct.default_factory, keys), entries def _defaultdict_unflatten( @@ -450,7 +479,7 @@ def _defaultdict_unflatten( values: Iterable[VT], ) -> defaultdict[KT, VT]: default_factory, keys = metadata - return defaultdict(default_factory, safe_zip(keys, values)) + return defaultdict(default_factory, _dict_unflatten(keys, values)) def _deque_flatten(deq: deque[T]) -> tuple[deque[T], int | None]: @@ -495,12 +524,14 @@ def _structseq_unflatten(cls: type[structseq[T]], children: Iterable[T]) -> stru def _pytree_node_registry_get( cls: type, *, - namespace: str = __GLOBAL_NAMESPACE, + namespace: str = '', ) -> PyTreeNodeRegistryEntry | None: - handler: PyTreeNodeRegistryEntry | None = _NODETYPE_REGISTRY.get(cls) - if handler is not None: - return handler - handler = _NODETYPE_REGISTRY.get((namespace, cls)) + handler: PyTreeNodeRegistryEntry | None = None + if namespace is not __GLOBAL_NAMESPACE and namespace != '': + handler = _NODETYPE_REGISTRY.get((namespace, cls)) + if handler is not None: + return handler + handler = _NODETYPE_REGISTRY.get(cls) if handler is not None: return handler if is_structseq_class(cls): @@ -514,111 +545,6 @@ def _pytree_node_registry_get( del _pytree_node_registry_get -class _HashablePartialShim: - """Object that delegates :meth:`__call__`, :meth:`__hash__`, and :meth:`__eq__` to another object.""" - - func: Callable[..., Any] - args: tuple[Any, ...] - keywords: dict[str, Any] - - def __init__(self, partial_func: functools.partial) -> None: - self.partial_func: functools.partial = partial_func - - def __call__(self, *args: Any, **kwargs: Any) -> Any: - return self.partial_func(*args, **kwargs) - - def __hash__(self) -> int: - return hash(self.partial_func) - - def __eq__(self, other: object) -> bool: - if isinstance(other, _HashablePartialShim): - return self.partial_func == other.partial_func - return self.partial_func == other - - -@register_pytree_node_class(namespace=__GLOBAL_NAMESPACE) -class Partial(functools.partial, CustomTreeNode[Any]): # pylint: disable=too-few-public-methods - """A version of :func:`functools.partial` that works in pytrees. - - Use it for partial function evaluation in a way that is compatible with transformations, - e.g., ``Partial(func, *args, **kwargs)``. - - (You need to explicitly opt-in to this behavior because we did not want to give - :func:`functools.partial` different semantics than normal function closures.) - - For example, here is a basic usage of :class:`Partial` in a manner similar to - :func:`functools.partial`: - - >>> import operator - >>> import torch - >>> add_one = Partial(operator.add, torch.ones(())) - >>> add_one(torch.tensor([[1, 2], [3, 4]])) - tensor([[2., 3.], - [4., 5.]]) - - Pytree compatibility means that the resulting partial function can be passed as an argument - within tree-map functions, which is not possible with a standard :func:`functools.partial` - function: - - >>> def call_func_on_cuda(f, *args, **kwargs): - ... f, args, kwargs = tree_map(lambda t: t.cuda(), (f, args, kwargs)) - ... return f(*args, **kwargs) - ... - >>> # doctest: +SKIP - >>> tree_map(lambda t: t.cuda(), add_one) - Partial(, tensor(1., device='cuda:0')) - >>> call_func_on_cuda(add_one, torch.tensor([[1, 2], [3, 4]])) - tensor([[2., 3.], - [4., 5.]], device='cuda:0') - - Passing zero arguments to :class:`Partial` effectively wraps the original function, making it a - valid argument in tree-map functions: - - >>> # doctest: +SKIP - >>> call_func_on_cuda(Partial(torch.add), torch.tensor(1), torch.tensor(2)) - tensor(3, device='cuda:0') - - Had we passed :func:`operator.add` to ``call_func_on_cuda`` directly, it would have resulted in - a :class:`TypeError` or :class:`AttributeError`. - """ - - func: Callable[..., Any] - args: tuple[Any, ...] - keywords: dict[str, Any] - - def __new__(cls, func: Callable[..., Any], *args: Any, **keywords: Any) -> Self: - """Create a new :class:`Partial` instance.""" - # In Python 3.10+, if func is itself a functools.partial instance, functools.partial.__new__ - # would merge the arguments of this Partial instance with the arguments of the func. We box - # func in a class that does not (yet) have a `func` attribute to defeat this optimization, - # since we care exactly which arguments are considered part of the pytree. - if isinstance(func, functools.partial): - original_func = func - func = _HashablePartialShim(original_func) - assert not hasattr(func, 'func'), 'shimmed function should not have a `func` attribute' - out = super().__new__(cls, func, *args, **keywords) - func.func = original_func.func - func.args = original_func.args - func.keywords = original_func.keywords - return out - - return super().__new__(cls, func, *args, **keywords) - - def tree_flatten(self) -> tuple[tuple[tuple[Any, ...], dict[str, Any]], Callable[..., Any]]: - """Flatten the :class:`Partial` instance to children and auxiliary data.""" - return (self.args, self.keywords), self.func - - @classmethod - def tree_unflatten( # type: ignore[override] - cls, - metadata: Callable[..., Any], - children: tuple[tuple[Any, ...], dict[str, Any]], - ) -> Self: - """Unflatten the children and auxiliary data into a :class:`Partial` instance.""" - args, keywords = children - return cls(metadata, *args, **keywords) - - class KeyPathEntry(NamedTuple): key: Any @@ -691,9 +617,9 @@ def register_keypaths( ) -> KeyPathHandler: """Register a key path handler for a custom pytree node type.""" if not inspect.isclass(cls): - raise TypeError(f'Expected a class, got {cls}.') + raise TypeError(f'Expected a class, got {cls!r}.') if cls in _KEYPATH_REGISTRY: - raise ValueError(f'Key path handler for {cls} has already been registered.') + raise ValueError(f'Key path handler for {cls!r} has already been registered.') _KEYPATH_REGISTRY[cls] = handler return handler diff --git a/optree/typing.py b/optree/typing.py index ec13907a..1f25c31f 100644 --- a/optree/typing.py +++ b/optree/typing.py @@ -90,9 +90,9 @@ VT = TypeVar('VT') -Children = Iterable[T] +Children: TypeAlias = Iterable[T] _MetaData = TypeVar('_MetaData', bound=Hashable) -MetaData = Optional[_MetaData] +MetaData: TypeAlias = Optional[_MetaData] @runtime_checkable @@ -255,14 +255,14 @@ def __deepcopy__(self, memo: dict[int, Any]) -> TypeAlias: return self -FlattenFunc = Callable[ +FlattenFunc: TypeAlias = Callable[ [CustomTreeNode[T]], Union[ Tuple[Children[T], MetaData], Tuple[Children[T], MetaData, Optional[Iterable[Any]]], ], ] -UnflattenFunc = Callable[[MetaData, Children[T]], CustomTreeNode[T]] +UnflattenFunc: TypeAlias = Callable[[MetaData, Children[T]], CustomTreeNode[T]] def is_namedtuple(obj: object | type) -> bool: @@ -370,7 +370,7 @@ def is_structseq_instance(obj: object) -> bool: # Set if the type allows subclassing (see CPython's Include/object.h) -Py_TPFLAGS_BASETYPE = _C.Py_TPFLAGS_BASETYPE # (1UL << 10) +Py_TPFLAGS_BASETYPE: int = _C.Py_TPFLAGS_BASETYPE # (1UL << 10) def is_structseq_class(cls: type) -> bool: diff --git a/optree/utils.py b/optree/utils.py index ad96457c..0e877295 100644 --- a/optree/utils.py +++ b/optree/utils.py @@ -62,16 +62,14 @@ def key_fn(x: T) -> tuple[str, Any]: @overload def safe_zip( __iter1: Iterable[T], -) -> zip[tuple[T]]: # pragma: no cover - ... +) -> zip[tuple[T]]: ... @overload def safe_zip( __iter1: Iterable[T], __iter2: Iterable[S], -) -> zip[tuple[T, S]]: # pragma: no cover - ... +) -> zip[tuple[T, S]]: ... @overload @@ -79,8 +77,7 @@ def safe_zip( __iter1: Iterable[T], __iter2: Iterable[S], __iter3: Iterable[U], -) -> zip[tuple[T, S, U]]: # pragma: no cover - ... +) -> zip[tuple[T, S, U]]: ... @overload @@ -90,8 +87,7 @@ def safe_zip( __iter3: Iterable[Any], __iter4: Iterable[Any], *__iters: Iterable[Any], -) -> zip[tuple[Any, ...]]: # pragma: no cover - ... +) -> zip[tuple[Any, ...]]: ... def safe_zip(*args: Iterable[Any]) -> zip[tuple[Any, ...]]: diff --git a/pyproject.toml b/pyproject.toml index 69037ee1..ea04789f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,7 +41,7 @@ classifiers = [ "Intended Audience :: Science/Research", "Topic :: Utilities", ] -dependencies = ["typing-extensions >= 4.0.0"] +dependencies = ["typing-extensions >= 4.5.0"] dynamic = ["version"] [project.urls] @@ -143,7 +143,7 @@ show_traceback = true allow_redefinition = true check_untyped_defs = true disallow_incomplete_defs = true -disallow_untyped_defs = true +disallow_untyped_defs = true ignore_missing_imports = true no_implicit_optional = true strict_equality = true @@ -260,4 +260,7 @@ inline-quotes = "single" ban-relative-imports = "all" [tool.pytest.ini_options] -filterwarnings = ["error"] +filterwarnings = [ + "error", + "ignore:The class `optree.Partial` is deprecated and will be removed in a future version. Please use `optree.functools.partial` instead.:DeprecationWarning" +] diff --git a/requirements.txt b/requirements.txt index 255bf2ab..a007f75b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1 @@ -typing-extensions >= 4.0.0 +typing-extensions >= 4.5.0 diff --git a/tests/.coveragerc b/tests/.coveragerc index 810ca85e..0b6fec35 100644 --- a/tests/.coveragerc +++ b/tests/.coveragerc @@ -9,5 +9,7 @@ exclude_lines = raise NotImplementedError class .*\bProtocol\): @(abc\.)?abstractmethod + @(typing\.)?overload + @(warnings\.)?deprecated\(.* if __name__ == ('__main__'|"__main__"): if TYPE_CHECKING: diff --git a/tests/helpers.py b/tests/helpers.py index e024ffde..e2b23a00 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -20,6 +20,7 @@ import sys import time from collections import OrderedDict, UserDict, defaultdict, deque, namedtuple +from typing import NamedTuple import pytest @@ -43,6 +44,54 @@ def parametrize(**argvalues) -> pytest.mark.parametrize: return pytest.mark.parametrize(arguments, argvalues, ids=ids) +def is_tuple(tup): + return isinstance(tup, tuple) + + +def is_list(lst): + return isinstance(lst, list) + + +def is_dict(dct): + return isinstance(dct, dict) + + +def is_primitive_collection(obj): + if type(obj) in {tuple, list, deque}: + return all(isinstance(item, (int, float, str, bool, type(None))) for item in obj) + if type(obj) in {dict, OrderedDict, defaultdict}: + return all(isinstance(value, (int, float, str, bool, type(None))) for value in obj.values()) + return False + + +def is_none(none): + return none is None + + +def is_not_none(none): + return none is not None + + +def always(obj): # pylint: disable=unused-argument + return True + + +def never(obj): # pylint: disable=unused-argument + return False + + +IS_LEAF_FUNCTIONS = ( + is_tuple, + is_list, + is_dict, + is_primitive_collection, + is_none, + is_not_none, + always, + never, +) + + CustomTuple = namedtuple('CustomTuple', ('foo', 'bar')) # noqa: PYI024 @@ -50,10 +99,14 @@ class CustomNamedTupleSubclass(CustomTuple): pass +class EmptyTuple(NamedTuple): + pass + + # sys.float_info(max=*, max_exp=*, max_10_exp=*, min=*, min_exp=*, min_10_exp=*, dig=*, mant_dig=*, epsilon=*, radix=*, rounds=*) -SysFloatInfo = type(sys.float_info) +SysFloatInfoType = type(sys.float_info) # time.struct_time(tm_year=*, tm_mon=*, tm_mday=*, tm_hour=*, tm_min=*, tm_sec=*, tm_wday=*, tm_yday=*, tm_isdst=*) -TimeStructTime = time.struct_time +TimeStructTimeType = time.struct_time class Vector3D: @@ -88,6 +141,12 @@ def __init__(self, x, y): self.x = x self.y = y + def __eq__(self, other): + return isinstance(other, Vector2D) and (self.x, self.y) == (other.x, other.y) + + def __hash__(self): + return hash((self.x, self.y)) + def __repr__(self): return f'{self.__class__.__name__}(x={self.x}, y={self.y})' @@ -98,9 +157,6 @@ def tree_flatten(self): def tree_unflatten(cls, metadata, children): # pylint: disable=unused-argument return cls(*children) - def __eq__(self, other): - return isinstance(other, Vector2D) and (self.x, self.y) == (other.x, other.y) - # pylint: disable-next=protected-access @optree.register_pytree_node_class(namespace=optree.registry.__GLOBAL_NAMESPACE) @@ -112,12 +168,12 @@ def __init__(self, structured, *, leaves=None, treespec=None): self.treespec = treespec self.leaves = leaves - def __hash__(self): - return hash(self.structured) - def __eq__(self, other): return isinstance(other, FlatCache) and self.structured == other.structured + def __hash__(self): + return hash(self.structured) + def __repr__(self): return f'{self.__class__.__name__}({self.structured!r})' @@ -194,9 +250,12 @@ def __next__(self): (1, 2), ((1, 'foo'), ['bar', (3, None, 7)]), [3], + EmptyTuple(), [3, CustomTuple(foo=(3, CustomTuple(foo=3, bar=None)), bar={'baz': 34})], - TimeStructTime((*range(1, 3), None, *range(3, 9))), - SysFloatInfo((*range(1, 10), None, TimeStructTime((*range(10, 15), None, *range(15, 20))))), + TimeStructTimeType((*range(1, 3), None, *range(3, 9))), + SysFloatInfoType( + (*range(1, 10), None, TimeStructTimeType((*range(10, 15), None, *range(15, 20)))), + ), [Vector3D(3, None, [4, 'foo'])], Vector2D(2, 3.0), {}, @@ -234,6 +293,7 @@ def __next__(self): [(0,), (1,)], [(0, 0), (0, 1), (1, 0), (1, 1, 0), (1, 1, 2)], [(0,)], + [], [(0,), (1, 0, 0), (1, 0, 1, 0), (1, 1, 'baz')], [(0,), (1,), (3,), (4,), (5,), (6,), (7,), (8,)], [ @@ -291,6 +351,7 @@ def __next__(self): [(0,), (1,)], [(0, 0), (0, 1), (1, 0), (1, 1, 0), (1, 1, 1), (1, 1, 2)], [(0,)], + [], [(0,), (1, 0, 0), (1, 0, 1, 0), (1, 0, 1, 1), (1, 1, 'baz')], [(0,), (1,), (2,), (3,), (4,), (5,), (6,), (7,), (8,)], [ @@ -356,6 +417,7 @@ def __next__(self): 'PyTreeSpec((*, *))', 'PyTreeSpec(((*, *), [*, (*, None, *)]))', 'PyTreeSpec([*])', + 'PyTreeSpec(EmptyTuple())', "PyTreeSpec([*, CustomTuple(foo=(*, CustomTuple(foo=*, bar=None)), bar={'baz': *})])", 'PyTreeSpec(time.struct_time(tm_year=*, tm_mon=*, tm_mday=None, tm_hour=*, tm_min=*, tm_sec=*, tm_wday=*, tm_yday=*, tm_isdst=*))', 'PyTreeSpec(sys.float_info(max=*, max_exp=*, max_10_exp=*, min=*, min_exp=*, min_10_exp=*, dig=*, mant_dig=*, epsilon=*, radix=None, rounds=time.struct_time(tm_year=*, tm_mon=*, tm_mday=*, tm_hour=*, tm_min=*, tm_sec=None, tm_wday=*, tm_yday=*, tm_isdst=*)))', @@ -395,6 +457,7 @@ def __next__(self): 'PyTreeSpec((*, *), NoneIsLeaf)', 'PyTreeSpec(((*, *), [*, (*, *, *)]), NoneIsLeaf)', 'PyTreeSpec([*], NoneIsLeaf)', + 'PyTreeSpec(EmptyTuple(), NoneIsLeaf)', "PyTreeSpec([*, CustomTuple(foo=(*, CustomTuple(foo=*, bar=*)), bar={'baz': *})], NoneIsLeaf)", 'PyTreeSpec(time.struct_time(tm_year=*, tm_mon=*, tm_mday=*, tm_hour=*, tm_min=*, tm_sec=*, tm_wday=*, tm_yday=*, tm_isdst=*), NoneIsLeaf)', 'PyTreeSpec(sys.float_info(max=*, max_exp=*, max_10_exp=*, min=*, min_exp=*, min_10_exp=*, dig=*, mant_dig=*, epsilon=*, radix=*, rounds=time.struct_time(tm_year=*, tm_mon=*, tm_mday=*, tm_hour=*, tm_min=*, tm_sec=*, tm_wday=*, tm_yday=*, tm_isdst=*)), NoneIsLeaf)', diff --git a/tests/test_functools.py b/tests/test_functools.py new file mode 100644 index 00000000..6e79e956 --- /dev/null +++ b/tests/test_functools.py @@ -0,0 +1,108 @@ +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +# pylint: disable=missing-function-docstring,invalid-name,wrong-import-order + +import functools + +import optree +from helpers import parametrize + + +def dummy_func(*args, **kwargs): # pylint: disable=unused-argument + return + + +dummy_partial_func = functools.partial(dummy_func, a=1) + + +@parametrize( + tree=[ + optree.functools.partial(dummy_func), + optree.functools.partial(dummy_func, 1, 2), + optree.functools.partial(dummy_func, x='a'), + optree.functools.partial(dummy_func, 1, 2, 3, x=4, y=5), + optree.functools.partial(dummy_func, 1, None, x=4, y=5, z=None), + optree.functools.partial(dummy_partial_func, 1, 2, 3, x=4, y=5), + ], + none_is_leaf=[False, True], +) +def test_partial_round_trip(tree, none_is_leaf): + leaves, treespec = optree.tree_flatten(tree, none_is_leaf=none_is_leaf) + actual = optree.tree_unflatten(treespec, leaves) + assert actual.func == tree.func + assert actual.args == tree.args + assert actual.keywords == tree.keywords + + +def test_partial_does_not_merge_with_other_partials(): + def f(a=None, b=None, c=None): + return a, b, c + + g = functools.partial(f, 2) + h = optree.functools.partial(g, 3) + assert h.args == (3,) + assert g() == (2, None, None) + assert h() == (2, 3, None) + + +def test_partial_func_attribute_has_stable_hash(): + fn = functools.partial(print, 1) + p1 = optree.functools.partial(fn, 2) + p2 = optree.functools.partial(fn, 2) + assert p1.func == fn # pylint: disable=comparison-with-callable + assert fn == p1.func # pylint: disable=comparison-with-callable + assert p1.func == p2.func + assert hash(p1.func) == hash(p2.func) + + +@parametrize( + tree=[ + optree.Partial(dummy_func), + optree.Partial(dummy_func, 1, 2), + optree.Partial(dummy_func, x='a'), + optree.Partial(dummy_func, 1, 2, 3, x=4, y=5), + optree.Partial(dummy_func, 1, None, x=4, y=5, z=None), + optree.Partial(dummy_partial_func, 1, 2, 3, x=4, y=5), + ], + none_is_leaf=[False, True], +) +def test_Partial_round_trip(tree, none_is_leaf): # noqa: N802 + leaves, treespec = optree.tree_flatten(tree, none_is_leaf=none_is_leaf) + actual = optree.tree_unflatten(treespec, leaves) + assert actual.func == tree.func + assert actual.args == tree.args + assert actual.keywords == tree.keywords + + +def test_Partial_does_not_merge_with_other_partials(): # noqa: N802 + def f(a=None, b=None, c=None): + return a, b, c + + g = functools.partial(f, 2) + h = optree.Partial(g, 3) + assert h.args == (3,) + assert g() == (2, None, None) + assert h() == (2, 3, None) + + +def test_Partial_func_attribute_has_stable_hash(): # noqa: N802 + fn = functools.partial(print, 1) + p1 = optree.Partial(fn, 2) + p2 = optree.Partial(fn, 2) + assert p1.func == fn # pylint: disable=comparison-with-callable + assert fn == p1.func # pylint: disable=comparison-with-callable + assert p1.func == p2.func + assert hash(p1.func) == hash(p2.func) diff --git a/tests/test_ops.py b/tests/test_ops.py index 7bfe5d66..4afbcc93 100644 --- a/tests/test_ops.py +++ b/tests/test_ops.py @@ -13,11 +13,12 @@ # limitations under the License. # ============================================================================== -# pylint: disable=missing-function-docstring,invalid-name +# pylint: disable=missing-function-docstring,invalid-name,wrong-import-order import copy import functools import itertools +import operator import pickle import re from collections import OrderedDict, defaultdict, deque @@ -25,9 +26,8 @@ import pytest import optree - -# pylint: disable-next=wrong-import-order from helpers import ( + IS_LEAF_FUNCTIONS, LEAVES, TREE_PATHS, TREES, @@ -39,33 +39,6 @@ ) -def dummy_func(*args, **kwargs): # pylint: disable=unused-argument - return - - -dummy_partial_func = functools.partial(dummy_func, a=1) - - -def is_tuple(tup): - return isinstance(tup, tuple) - - -def is_list(lst): - return isinstance(lst, list) - - -def is_none(none): - return none is None - - -def always(obj): # pylint: disable=unused-argument - return True - - -def never(obj): # pylint: disable=unused-argument - return False - - def test_max_depth(): lst = [1] for _ in range(optree.MAX_RECURSION_DEPTH - 1): @@ -360,13 +333,7 @@ def test_paths(data): @parametrize( tree=TREES, - is_leaf=[ - is_tuple, - is_list, - is_none, - always, - never, - ], + is_leaf=IS_LEAF_FUNCTIONS, none_is_leaf=[False, True], namespace=['', 'undefined', 'namespace'], ) @@ -396,13 +363,7 @@ def test_paths_with_is_leaf(tree, is_leaf, none_is_leaf, namespace): @parametrize( tree=TREES, - is_leaf=[ - is_tuple, - is_list, - is_none, - always, - never, - ], + is_leaf=IS_LEAF_FUNCTIONS, none_is_leaf=[False, True], namespace=['', 'undefined', 'namespace'], ) @@ -443,12 +404,7 @@ def test_tree_is_leaf_with_leaves(leaf, none_is_leaf, namespace): @parametrize( tree=TREES, - is_leaf=[ - is_tuple, - is_none, - always, - never, - ], + is_leaf=IS_LEAF_FUNCTIONS, none_is_leaf=[False, True], namespace=['', 'undefined', 'namespace'], ) @@ -505,12 +461,7 @@ def test_all_leaves_with_leaves(leaf, none_is_leaf, namespace): @parametrize( tree=TREES, - is_leaf=[ - is_tuple, - is_none, - always, - never, - ], + is_leaf=IS_LEAF_FUNCTIONS, none_is_leaf=[False, True], namespace=['', 'undefined', 'namespace'], ) @@ -573,6 +524,34 @@ def test_tree_broadcast_map(): ({'foo': (3, 'bar')}, ((4, 9), [(5, 9)]), [(6, 10), (6, 11)]), ) + tree1 = [(1, 2, 3), 4, 5, OrderedDict([('y', 7), ('x', 6)])] + tree2 = [8, [9, 10, 11], 12, {'x': 13, 'y': 14}] + tree3 = 15 + tree4 = [16, 17, {'a': 18, 'b': 19, 'c': 20}, 21] + out = optree.tree_broadcast_map( + lambda *args: functools.reduce(operator.mul, args, 1), + tree1, + tree2, + tree3, + tree4, + ) + assert out == [ + (1920, 3840, 5760), + [9180, 10200, 11220], + {'a': 16200, 'b': 17100, 'c': 18000}, + OrderedDict([('y', 30870), ('x', 24570)]), + ] + for trees in itertools.permutations([tree1, tree2, tree3, tree4], 4): + new_out = optree.tree_broadcast_map( + lambda *args: functools.reduce(operator.mul, args, 1), + *trees, + ) + assert new_out == out + if trees.index(tree1) < trees.index(tree2): + assert type(new_out[-1]) is OrderedDict + else: + assert type(new_out[-1]) is dict # noqa: E721 + def test_tree_broadcast_map_with_path(): x = ((1, 2, None), (3, (4, [5]), 6)) @@ -1852,43 +1831,3 @@ def test_tree_flatten_one_level(tree, none_is_leaf, namespace): # noqa: C901 stack.extend(reversed(children)) assert actual_leaves == expected_leaves - - -@parametrize( - tree=[ - optree.Partial(dummy_func), - optree.Partial(dummy_func, 1, 2), - optree.Partial(dummy_func, x='a'), - optree.Partial(dummy_func, 1, 2, 3, x=4, y=5), - optree.Partial(dummy_func, 1, None, x=4, y=5, z=None), - optree.Partial(dummy_partial_func, 1, 2, 3, x=4, y=5), - ], - none_is_leaf=[False, True], -) -def test_partial_round_trip(tree, none_is_leaf): - leaves, treespec = optree.tree_flatten(tree, none_is_leaf=none_is_leaf) - actual = optree.tree_unflatten(treespec, leaves) - assert actual.func == tree.func - assert actual.args == tree.args - assert actual.keywords == tree.keywords - - -def test_partial_does_not_merge_with_other_partials(): - def f(a=None, b=None, c=None): - return a, b, c - - g = functools.partial(f, 2) - h = optree.Partial(g, 3) - assert h.args == (3,) - assert g() == (2, None, None) - assert h() == (2, 3, None) - - -def test_partial_func_attribute_has_stable_hash(): - fun = functools.partial(print, 1) - p1 = optree.Partial(fun, 2) - p2 = optree.Partial(fun, 2) - assert p1.func == fun # pylint: disable=comparison-with-callable - assert fun == p1.func # pylint: disable=comparison-with-callable - assert p1.func == p2.func - assert hash(p1.func) == hash(p2.func) diff --git a/tests/test_prefix_errors.py b/tests/test_prefix_errors.py index ce93620d..df2a74d2 100644 --- a/tests/test_prefix_errors.py +++ b/tests/test_prefix_errors.py @@ -13,7 +13,7 @@ # limitations under the License. # ============================================================================== -# pylint: disable=missing-function-docstring,invalid-name,implicit-str-concat +# pylint: disable=missing-function-docstring,invalid-name,wrong-import-order import random import re @@ -23,9 +23,7 @@ import pytest import optree - -# pylint: disable-next=wrong-import-order -from helpers import TREES, CustomTuple, FlatCache, TimeStructTime, Vector2D, parametrize +from helpers import TREES, CustomTuple, FlatCache, TimeStructTimeType, Vector2D, parametrize from optree.registry import ( AttributeKeyPathEntry, FlattenedKeyPathEntry, @@ -503,7 +501,10 @@ def test_namedtuple(): def test_structseq(): - lhs, rhs = TimeStructTime((1, [2, [3]], *range(7))), TimeStructTime((4, [5, 6], *range(7))) + lhs, rhs = ( + TimeStructTimeType((1, [2, [3]], *range(7))), + TimeStructTimeType((4, [5, 6], *range(7))), + ) lhs_treespec, rhs_treespec = optree.tree_structure(lhs), optree.tree_structure(rhs) with pytest.raises(ValueError): optree.tree_map_(lambda x, y: None, lhs, rhs) diff --git a/tests/test_treespec.py b/tests/test_treespec.py index e3efb84f..413b3c4a 100644 --- a/tests/test_treespec.py +++ b/tests/test_treespec.py @@ -13,7 +13,7 @@ # limitations under the License. # ============================================================================== -# pylint: disable=missing-function-docstring,invalid-name +# pylint: disable=missing-function-docstring,invalid-name,wrong-import-order import itertools import pickle @@ -25,9 +25,8 @@ import pytest +import helpers import optree - -# pylint: disable-next=wrong-import-order from helpers import NAMESPACED_TREE, TREE_STRINGS, TREES, parametrize @@ -122,6 +121,46 @@ def test_treespec_string_representation(data): assert str(treespec) == expected_string assert repr(treespec) == expected_string + assert expected_string.startswith('PyTreeSpec(') + assert expected_string.endswith(')') + if none_is_leaf: + assert expected_string.endswith(', NoneIsLeaf)') + representation = expected_string[len('PyTreeSpec(') : -len(', NoneIsLeaf)')] + else: + representation = expected_string[len('PyTreeSpec(') : -len(')')] + + if ( + 'CustomTreeNode' not in representation + and 'sys.float_info' not in representation + and 'time.struct_time' not in representation + ): + representation = re.sub( + r"", + lambda match: match.group(1), + representation, + ) + counter = itertools.count() + representation = re.sub(r'\*', lambda _: str(next(counter)), representation) + new_tree = optree.tree_unflatten(treespec, range(treespec.num_leaves)) + reconstructed_tree = eval(representation, helpers.__dict__.copy()) + assert new_tree == reconstructed_tree + + +def test_treespec_with_empty_tuple_string_representation(): + assert str(optree.tree_structure(())) == r'PyTreeSpec(())' + + +def test_treespec_with_single_element_tuple_string_representation(): + assert str(optree.tree_structure((1,))) == r'PyTreeSpec((*,))' + + +def test_treespec_with_empty_list_string_representation(): + assert str(optree.tree_structure([])) == r'PyTreeSpec([])' + + +def test_treespec_with_empty_dict_string_representation(): + assert str(optree.tree_structure({})) == r'PyTreeSpec({})' + def test_treespec_self_referential(): class Holder: @@ -181,7 +220,7 @@ def __repr__(self): assert treespec != other -def test_with_namespace(): +def test_treespec_with_namespace(): tree = NAMESPACED_TREE for namespace in ('', 'undefined'): @@ -326,22 +365,6 @@ def test_treespec_type(tree, none_is_leaf, namespace): assert type(tree) is treespec.type -def test_treespec_with_empty_tuple_string_representation(): - assert str(optree.tree_structure(())) == r'PyTreeSpec(())' - - -def test_treespec_with_single_element_tuple_string_representation(): - assert str(optree.tree_structure((1,))) == r'PyTreeSpec((*,))' - - -def test_treespec_with_empty_list_string_representation(): - assert str(optree.tree_structure([])) == r'PyTreeSpec([])' - - -def test_treespec_with_empty_dict_string_representation(): - assert str(optree.tree_structure({})) == r'PyTreeSpec({})' - - @parametrize( tree=TREES, inner_tree=[