Skip to content

Commit

Permalink
chore(pre-commit): update pre-commit hooks
Browse files Browse the repository at this point in the history
  • Loading branch information
XuehaiPan committed Dec 25, 2024
1 parent 47e8941 commit f93926f
Show file tree
Hide file tree
Showing 6 changed files with 40 additions and 25 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ repos:
hooks:
- id: clang-format
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.8.3
rev: v0.8.4
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
Expand All @@ -39,7 +39,7 @@ repos:
hooks:
- id: black
- repo: https://github.com/asottile/pyupgrade
rev: v3.19.0
rev: v3.19.1
hooks:
- id: pyupgrade
args: [--py38-plus]
Expand Down
2 changes: 1 addition & 1 deletion optree/integration/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __dir__() -> list[str]:
return [*sorted(SUBMODULES), 'SUBMODULES']


def __getattr__(name: str) -> ModuleType:
def __getattr__(name: str, /) -> ModuleType:
if name in SUBMODULES:
import importlib # pylint: disable=import-outside-toplevel
import sys # pylint: disable=import-outside-toplevel
Expand Down
29 changes: 17 additions & 12 deletions optree/integration/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@

import contextlib
import itertools
import operator
import warnings
from operator import itemgetter
from types import FunctionType
from typing import Any, Callable
from typing_extensions import TypeAlias # Python 3.10+
Expand Down Expand Up @@ -66,7 +66,7 @@ class HashablePartial: # pragma: no cover
args: tuple[Any, ...]
kwargs: dict[str, Any]

def __init__(self, func: FunctionType | HashablePartial, *args: Any, **kwargs: Any) -> None:
def __init__(self, func: FunctionType | HashablePartial, /, *args: Any, **kwargs: Any) -> None:
"""Construct a :class:`HashablePartial` instance."""
if not callable(func):
raise TypeError(f'Expected a callable, got {func!r}.')
Expand All @@ -82,23 +82,23 @@ def __init__(self, func: FunctionType | HashablePartial, *args: Any, **kwargs: A
else:
raise TypeError(f'Expected a function, got {func!r}.')

def __eq__(self, other: object) -> bool:
def __eq__(self, other: object, /) -> bool:
return (
type(other) is HashablePartial # pylint: disable=unidiomatic-typecheck
and self.func.__code__ == other.func.__code__ # type: ignore[attr-defined]
and self.func.__code__ == other.func.__code__
and (self.args, self.kwargs) == (other.args, other.kwargs)
)

def __hash__(self) -> int:
def __hash__(self, /) -> int:
return hash(
(
self.func.__code__, # type: ignore[attr-defined]
self.func.__code__,
self.args,
tuple(total_order_sorted(self.kwargs.items(), key=operator.itemgetter(0))),
tuple(total_order_sorted(self.kwargs.items(), key=itemgetter(0))),
),
)

def __call__(self, *args: Any, **kwargs: Any) -> Any:
def __call__(self, /, *args: Any, **kwargs: Any) -> Any:
kwargs = {**self.kwargs, **kwargs}
return self.func(*self.args, *args, **kwargs)

Expand All @@ -110,6 +110,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any:

def tree_ravel(
tree: ArrayLikeTree,
/,
is_leaf: Callable[[Any], bool] | None = None,
*,
none_is_leaf: bool = False,
Expand Down Expand Up @@ -193,13 +194,15 @@ def _tree_unravel(
treespec: PyTreeSpec,
unravel_flat: Callable[[Array], list[ArrayLike]],
flat: Array,
/,
) -> ArrayTree:
return tree_unflatten(treespec, unravel_flat(flat))


def _ravel_leaves(
leaves: list[ArrayLike],
) -> tuple[Array, Callable[[Array], list[ArrayLike]]]:
def _ravel_leaves(leaves: list[ArrayLike], /) -> tuple[
Array,
Callable[[Array], list[ArrayLike]],
]:
if not leaves:
return (jnp.zeros(0), _unravel_empty)

Expand Down Expand Up @@ -229,7 +232,7 @@ def _ravel_leaves(
)


def _unravel_empty(flat: Array) -> list[ArrayLike]:
def _unravel_empty(flat: Array, /) -> list[ArrayLike]:
if jnp.shape(flat) != (0,):
raise ValueError(
f'The unravel function expected an array of shape {(0,)}, '
Expand All @@ -243,6 +246,7 @@ def _unravel_leaves_single_dtype(
indices: tuple[int, ...],
shapes: tuple[tuple[int, ...], ...],
flat: Array,
/,
) -> list[Array]:
if jnp.shape(flat) != (indices[-1],):
raise ValueError(
Expand All @@ -260,6 +264,7 @@ def _unravel_leaves(
from_dtypes: tuple[jnp.dtype, ...],
to_dtype: jnp.dtype,
flat: Array,
/,
) -> list[Array]:
if jnp.shape(flat) != (indices[-1],):
raise ValueError(
Expand Down
13 changes: 9 additions & 4 deletions optree/integration/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@

def tree_ravel(
tree: ArrayLikeTree,
/,
is_leaf: Callable[[Any], bool] | None = None,
*,
none_is_leaf: bool = False,
Expand Down Expand Up @@ -122,13 +123,15 @@ def _tree_unravel(
treespec: PyTreeSpec,
unravel_flat: Callable[[np.ndarray], list[np.ndarray]],
flat: np.ndarray,
/,
) -> ArrayTree:
return tree_unflatten(treespec, unravel_flat(flat))


def _ravel_leaves(
leaves: list[np.ndarray],
) -> tuple[np.ndarray, Callable[[np.ndarray], list[np.ndarray]]]:
def _ravel_leaves(leaves: list[np.ndarray], /) -> tuple[
np.ndarray,
Callable[[np.ndarray], list[np.ndarray]],
]:
if not leaves:
return (np.zeros(0), _unravel_empty)

Expand All @@ -155,7 +158,7 @@ def _ravel_leaves(
)


def _unravel_empty(flat: np.ndarray) -> list[np.ndarray]:
def _unravel_empty(flat: np.ndarray, /) -> list[np.ndarray]:
if np.shape(flat) != (0,):
raise ValueError(
f'The unravel function expected an array of shape {(0,)}, '
Expand All @@ -168,6 +171,7 @@ def _unravel_leaves_single_dtype(
indices: tuple[int, ...],
shapes: tuple[tuple[int, ...], ...],
flat: np.ndarray,
/,
) -> list[np.ndarray]:
if np.shape(flat) != (indices[-1],):
raise ValueError(
Expand All @@ -185,6 +189,7 @@ def _unravel_leaves(
from_dtypes: tuple[np.dtype, ...],
to_dtype: np.dtype,
flat: np.ndarray,
/,
) -> list[np.ndarray]:
if np.shape(flat) != (indices[-1],):
raise ValueError(
Expand Down
13 changes: 9 additions & 4 deletions optree/integration/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@

def tree_ravel(
tree: TensorTree,
/,
is_leaf: Callable[[Any], bool] | None = None,
*,
none_is_leaf: bool = False,
Expand Down Expand Up @@ -119,13 +120,15 @@ def _tree_unravel(
treespec: PyTreeSpec,
unravel_flat: Callable[[torch.Tensor], list[torch.Tensor]],
flat: torch.Tensor,
/,
) -> TensorTree:
return tree_unflatten(treespec, unravel_flat(flat))


def _ravel_leaves(
leaves: list[torch.Tensor],
) -> tuple[torch.Tensor, Callable[[torch.Tensor], list[torch.Tensor]]]:
def _ravel_leaves(leaves: list[torch.Tensor], /) -> tuple[
torch.Tensor,
Callable[[torch.Tensor], list[torch.Tensor]],
]:
if not leaves:
return (torch.zeros(0), _unravel_empty)
if not all(torch.is_tensor(leaf) for leaf in leaves):
Expand Down Expand Up @@ -155,7 +158,7 @@ def _ravel_leaves(
)


def _unravel_empty(flat: torch.Tensor) -> list[torch.Tensor]:
def _unravel_empty(flat: torch.Tensor, /) -> list[torch.Tensor]:
if not torch.is_tensor(flat):
raise ValueError(f'Expected a tensor to unravel, got {type(flat)!r}.')
if flat.shape != (0,):
Expand All @@ -169,6 +172,7 @@ def _unravel_leaves_single_dtype(
sizes: tuple[int, ...],
shapes: tuple[tuple[int, ...], ...],
flat: torch.Tensor,
/,
) -> list[torch.Tensor]:
if not torch.is_tensor(flat):
raise ValueError(f'Expected a tensor to unravel, got {type(flat)!r}.')
Expand All @@ -188,6 +192,7 @@ def _unravel_leaves(
from_dtypes: tuple[torch.dtype, ...],
to_dtype: torch.dtype,
flat: torch.Tensor,
/,
) -> list[torch.Tensor]:
if not torch.is_tensor(flat):
raise ValueError(f'Expected a tensor to unravel, got {type(flat)!r}.')
Expand Down
4 changes: 2 additions & 2 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

# pylint: disable=missing-function-docstring,invalid-name

import operator
from operator import itemgetter

import pytest

Expand All @@ -32,7 +32,7 @@ def test_total_order_sorted():
assert total_order_sorted([1, 5, 4.5, '20', '3']) == [4.5, 1, 5, '20', '3']
assert total_order_sorted(
{1: 1, 5: 2, 4.5: 3, '20': 4, '3': 5}.items(),
key=operator.itemgetter(0),
key=itemgetter(0),
) == [(4.5, 3), (1, 1), (5, 2), ('20', 4), ('3', 5)]

class NonSortable:
Expand Down

0 comments on commit f93926f

Please sign in to comment.