Skip to content

Commit

Permalink
feat(functools): add submodule optree.functools (#134)
Browse files Browse the repository at this point in the history
  • Loading branch information
XuehaiPan authored Apr 6, 2024
1 parent 516f99e commit 9f4f568
Show file tree
Hide file tree
Showing 26 changed files with 601 additions and 342 deletions.
6 changes: 3 additions & 3 deletions .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@ ignore =
# E203: whitespace before ':'
# W503: line break before binary operator
# W504: line break after binary operator
# E704: multiple statements on one line (def)
# format by black
E203,W503,W504,
E203,W503,W504,E704
# E501: line too long
# W505: doc line too long
# too long docstring due to long example blocks
Expand All @@ -25,9 +26,8 @@ per-file-ignores =
# E302: expected 2 blank lines
# E305: expected 2 blank lines after class or function definition
# E701: multiple statements on one line (colon)
# E704: multiple statements on one line (def)
# format by black
*.pyi: E301,E302,E305,E701,E704
*.pyi: E301,E302,E305,E701
exclude =
.git,
.vscode,
Expand Down
10 changes: 10 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,13 @@ jobs:
- name: Test with pytest
run: |
make pytest
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v4
if: ${{ matrix.os == 'ubuntu-latest' }}
with:
token: ${{ secrets.CODECOV_TOKEN }}
file: ./tests/coverage.xml
flags: unittests
name: codecov-umbrella
fail_ci_if_error: false
5 changes: 3 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

-
- Add submodule `optree.functools` by [@XuehaiPan](https://github.com/XuehaiPan) in [#134](https://github.com/metaopt/optree/pull/134).

### Changed

- Update minimal version of `typing-extensions` to 4.5.0 for `typing_extensions.deprecated` by [@XuehaiPan](https://github.com/XuehaiPan) in [#134](https://github.com/metaopt/optree/pull/134).
- Update string representation for `OrderedDict` by [@XuehaiPan](https://github.com/XuehaiPan) in [#133](https://github.com/metaopt/optree/pull/133).

### Fixed
Expand All @@ -25,7 +26,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Removed

-
- Deprecate `optree.Partial` and replace with `optree.functools.partial` by [@XuehaiPan](https://github.com/XuehaiPan) in [#134](https://github.com/metaopt/optree/pull/134).

------

Expand Down
2 changes: 1 addition & 1 deletion conda-recipe.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ dependencies:
- pip

# Dependency
- typing-extensions >= 4.0.0
- typing-extensions >= 4.5.0

# Build toolchain
- cmake >= 3.11
Expand Down
2 changes: 1 addition & 1 deletion docs/conda-recipe.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ dependencies:
- pip

# Dependency
- typing-extensions >= 4.0.0
- typing-extensions >= 4.5.0

# Build toolchain
- cmake >= 3.11
Expand Down
12 changes: 12 additions & 0 deletions docs/source/functools.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
Integration with :mod:`functools`
=================================

.. currentmodule:: optree.functools

.. autosummary::

partial
reduce

.. autoclass:: partial
.. autofunction:: reduce
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ OpTree: Optimized PyTree Utilities
:maxdepth: 1

registry.rst
functools.rst
typing.rst
api.rst

Expand Down
2 changes: 0 additions & 2 deletions docs/source/registry.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,13 @@ PyTree Node Registry
register_pytree_node
register_pytree_node_class
unregister_pytree_node
Partial
register_keypaths
AttributeKeyPathEntry
GetitemKeyPathEntry

.. autofunction:: register_pytree_node
.. autofunction:: register_pytree_node_class
.. autofunction:: unregister_pytree_node
.. autofunction:: Partial
.. autofunction:: register_keypaths
.. autofunction:: AttributeKeyPathEntry
.. autofunction:: GetitemKeyPathEntry
2 changes: 2 additions & 0 deletions docs/source/spelling_wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,9 @@ jax
numpy
torch
dtype
cuda
getattr
setattr
delattr
typecheck
subclassed
5 changes: 2 additions & 3 deletions optree/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
# ==============================================================================
"""OpTree: Optimized PyTree Utilities."""

from optree import integration, typing
from optree import functools, integration, typing
from optree.functools import Partial
from optree.ops import (
MAX_RECURSION_DEPTH,
NONE_IS_LEAF,
Expand Down Expand Up @@ -74,7 +75,6 @@
from optree.registry import (
AttributeKeyPathEntry,
GetitemKeyPathEntry,
Partial,
register_keypaths,
register_pytree_node,
register_pytree_node_class,
Expand Down Expand Up @@ -161,7 +161,6 @@
'register_pytree_node',
'register_pytree_node_class',
'unregister_pytree_node',
'Partial',
'register_keypaths',
'AttributeKeyPathEntry',
'GetitemKeyPathEntry',
Expand Down
171 changes: 171 additions & 0 deletions optree/functools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
# Copyright 2022-2024 MetaOPT Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""PyTree integration with :mod:`functools.partial`."""

from __future__ import annotations

import functools
from typing import Any, Callable, ClassVar
from typing_extensions import Self # Python 3.11+
from typing_extensions import deprecated # Python 3.13+

from optree import registry
from optree.ops import tree_reduce as reduce
from optree.typing import CustomTreeNode, T


__all__ = [
'partial',
'reduce',
]


class _HashablePartialShim:
"""Object that delegates :meth:`__call__`, :meth:`__eq__`, and :meth:`__hash__` to another object."""

__slots__: ClassVar[tuple[str, ...]] = ('partial_func', 'func', 'args', 'keywords')

func: Callable[..., Any]
args: tuple[Any, ...]
keywords: dict[str, Any]

def __init__(self, partial_func: functools.partial) -> None:
self.partial_func: functools.partial = partial_func

def __call__(self, *args: Any, **kwargs: Any) -> Any:
return self.partial_func(*args, **kwargs)

def __eq__(self, other: object) -> bool:
if isinstance(other, _HashablePartialShim):
return self.partial_func == other.partial_func
return self.partial_func == other

def __hash__(self) -> int:
return hash(self.partial_func)

def __repr__(self) -> str:
return repr(self.partial_func)


# pylint: disable-next=protected-access
@registry.register_pytree_node_class(namespace=registry.__GLOBAL_NAMESPACE)
class partial( # noqa: N801 # pylint: disable=invalid-name,too-few-public-methods
functools.partial,
CustomTreeNode[T],
):
"""A version of :func:`functools.partial` that works in pytrees.
Use it for partial function evaluation in a way that is compatible with transformations,
e.g., ``partial(func, *args, **kwargs)``.
(You need to explicitly opt-in to this behavior because we did not want to give
:func:`functools.partial` different semantics than normal function closures.)
For example, here is a basic usage of :class:`partial` in a manner similar to
:func:`functools.partial`:
>>> import operator
>>> import torch
>>> add_one = partial(operator.add, torch.ones(()))
>>> add_one(torch.tensor([[1, 2], [3, 4]]))
tensor([[2., 3.],
[4., 5.]])
Pytree compatibility means that the resulting partial function can be passed as an argument
within tree-map functions, which is not possible with a standard :func:`functools.partial`
function:
>>> def call_func_on_cuda(f, *args, **kwargs):
... f, args, kwargs = tree_map(lambda t: t.cuda(), (f, args, kwargs))
... return f(*args, **kwargs)
...
>>> # doctest: +SKIP
>>> tree_map(lambda t: t.cuda(), add_one)
optree.functools.partial(<built-in function add>, tensor(1., device='cuda:0'))
>>> call_func_on_cuda(add_one, torch.tensor([[1, 2], [3, 4]]))
tensor([[2., 3.],
[4., 5.]], device='cuda:0')
Passing zero arguments to :class:`partial` effectively wraps the original function, making it a
valid argument in tree-map functions:
>>> # doctest: +SKIP
>>> call_func_on_cuda(partial(torch.add), torch.tensor(1), torch.tensor(2))
tensor(3, device='cuda:0')
Had we passed :func:`operator.add` to ``call_func_on_cuda`` directly, it would have resulted in
a :class:`TypeError` or :class:`AttributeError`.
"""

__slots__: ClassVar[tuple[()]] = ()

func: Callable[..., Any]
args: tuple[T, ...]
keywords: dict[str, T]

def __new__(cls, func: Callable[..., Any], *args: T, **keywords: T) -> Self:
"""Create a new :class:`partial` instance."""
# In Python 3.10+, if func is itself a functools.partial instance, functools.partial.__new__
# would merge the arguments of this partial instance with the arguments of the func. We box
# func in a class that does not (yet) have a `func` attribute to defeat this optimization,
# since we care exactly which arguments are considered part of the pytree.
if isinstance(func, functools.partial):
original_func = func
func = _HashablePartialShim(original_func)
assert not hasattr(func, 'func'), 'shimmed function should not have a `func` attribute'
out = super().__new__(cls, func, *args, **keywords)
func.func = original_func.func
func.args = original_func.args
func.keywords = original_func.keywords
return out

return super().__new__(cls, func, *args, **keywords)

def __repr__(self) -> str:
"""Return a string representation of the :class:`partial` instance."""
args = [repr(self.func)]
args.extend(repr(x) for x in self.args)
args.extend(f'{k}={v!r}' for (k, v) in self.keywords.items())
return f'{self.__class__.__module__}.{self.__class__.__qualname__}({", ".join(args)})'

def tree_flatten(self) -> tuple[ # type: ignore[override]
tuple[tuple[T, ...], dict[str, T]],
Callable[..., Any],
tuple[str, str],
]:
"""Flatten the :class:`partial` instance to children and auxiliary data."""
return (self.args, self.keywords), self.func, ('args', 'keywords')

@classmethod
def tree_unflatten( # type: ignore[override]
cls,
metadata: Callable[..., Any],
children: tuple[tuple[T, ...], dict[str, T]],
) -> Self:
"""Unflatten the children and auxiliary data into a :class:`partial` instance."""
args, keywords = children
return cls(metadata, *args, **keywords)


# pylint: disable-next=protected-access
@registry.register_pytree_node_class(namespace=registry.__GLOBAL_NAMESPACE)
@deprecated(
'The class `optree.Partial` is deprecated and will be removed in a future version. '
'Please use `optree.functools.partial` instead.',
)
class Partial(partial):
"""Deprecated alias for :class:`partial`."""

__slots__: ClassVar[tuple[()]] = ()
35 changes: 17 additions & 18 deletions optree/integration/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,31 +14,30 @@
# ==============================================================================
"""Integration with third-party libraries."""

import sys
from typing import Any
from __future__ import annotations

from typing import TYPE_CHECKING

current_module = sys.modules[__name__]

if TYPE_CHECKING:
from types import ModuleType

SUBMODULES = frozenset({'jax', 'numpy', 'torch'})

SUBMODULES: frozenset[str] = frozenset({'jax', 'numpy', 'torch'})

# pylint: disable-next=too-few-public-methods
class _LazyModule(type(current_module)): # type: ignore[misc]
def __getattribute__(self, name: str) -> Any: # noqa: N804
try:
return super().__getattribute__(name)
except AttributeError:
if name in SUBMODULES:
import importlib # pylint: disable=import-outside-toplevel

submodule = importlib.import_module(f'{__name__}.{name}')
setattr(self, name, submodule)
return submodule
raise
def __getattr__(name: str) -> ModuleType:
if name in SUBMODULES:
import importlib # pylint: disable=import-outside-toplevel
import sys # pylint: disable=import-outside-toplevel

module = sys.modules[__name__]

current_module.__class__ = _LazyModule
submodule = importlib.import_module(f'{__name__}.{name}')
setattr(module, name, submodule)
return submodule

del sys, Any, current_module, _LazyModule
raise AttributeError(f'module {__name__!r} has no attribute {name!r}')


del TYPE_CHECKING
6 changes: 3 additions & 3 deletions optree/integration/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,7 @@ def __eq__(self, other: object) -> bool:
return (
type(other) is HashablePartial # pylint: disable=unidiomatic-typecheck
and self.func.__code__ == other.func.__code__ # type: ignore[attr-defined]
and self.args == other.args
and self.kwargs == other.kwargs
and (self.args, self.kwargs) == (other.args, other.kwargs)
)

def __hash__(self) -> int:
Expand All @@ -98,7 +97,8 @@ def __hash__(self) -> int:
)

def __call__(self, *args: Any, **kwargs: Any) -> Any:
return self.func(*self.args, *args, **self.kwargs, **kwargs)
kwargs = {**self.kwargs, **kwargs}
return self.func(*self.args, *args, **kwargs)


try: # noqa: SIM105 # pragma: no cover
Expand Down
Loading

0 comments on commit 9f4f568

Please sign in to comment.