diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1f50558c..51e34355 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -30,7 +30,7 @@ repos: hooks: - id: clang-format - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.8.3 + rev: v0.8.4 hooks: - id: ruff args: [--fix, --exit-non-zero-on-fix] @@ -39,7 +39,7 @@ repos: hooks: - id: black - repo: https://github.com/asottile/pyupgrade - rev: v3.19.0 + rev: v3.19.1 hooks: - id: pyupgrade args: [--py38-plus] diff --git a/optree/integration/__init__.py b/optree/integration/__init__.py index afc4753e..f844214d 100644 --- a/optree/integration/__init__.py +++ b/optree/integration/__init__.py @@ -32,7 +32,7 @@ def __dir__() -> list[str]: return [*sorted(SUBMODULES), 'SUBMODULES'] -def __getattr__(name: str) -> ModuleType: +def __getattr__(name: str, /) -> ModuleType: if name in SUBMODULES: import importlib # pylint: disable=import-outside-toplevel import sys # pylint: disable=import-outside-toplevel diff --git a/optree/integration/jax.py b/optree/integration/jax.py index c57ebd7d..abe8336b 100644 --- a/optree/integration/jax.py +++ b/optree/integration/jax.py @@ -35,8 +35,8 @@ import contextlib import itertools -import operator import warnings +from operator import itemgetter from types import FunctionType from typing import Any, Callable from typing_extensions import TypeAlias # Python 3.10+ @@ -66,7 +66,7 @@ class HashablePartial: # pragma: no cover args: tuple[Any, ...] kwargs: dict[str, Any] - def __init__(self, func: FunctionType | HashablePartial, *args: Any, **kwargs: Any) -> None: + def __init__(self, func: FunctionType | HashablePartial, /, *args: Any, **kwargs: Any) -> None: """Construct a :class:`HashablePartial` instance.""" if not callable(func): raise TypeError(f'Expected a callable, got {func!r}.') @@ -82,23 +82,23 @@ def __init__(self, func: FunctionType | HashablePartial, *args: Any, **kwargs: A else: raise TypeError(f'Expected a function, got {func!r}.') - def __eq__(self, other: object) -> bool: + def __eq__(self, other: object, /) -> bool: return ( type(other) is HashablePartial # pylint: disable=unidiomatic-typecheck - and self.func.__code__ == other.func.__code__ # type: ignore[attr-defined] + and self.func.__code__ == other.func.__code__ and (self.args, self.kwargs) == (other.args, other.kwargs) ) - def __hash__(self) -> int: + def __hash__(self, /) -> int: return hash( ( - self.func.__code__, # type: ignore[attr-defined] + self.func.__code__, self.args, - tuple(total_order_sorted(self.kwargs.items(), key=operator.itemgetter(0))), + tuple(total_order_sorted(self.kwargs.items(), key=itemgetter(0))), ), ) - def __call__(self, *args: Any, **kwargs: Any) -> Any: + def __call__(self, /, *args: Any, **kwargs: Any) -> Any: kwargs = {**self.kwargs, **kwargs} return self.func(*self.args, *args, **kwargs) @@ -110,6 +110,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any: def tree_ravel( tree: ArrayLikeTree, + /, is_leaf: Callable[[Any], bool] | None = None, *, none_is_leaf: bool = False, @@ -193,13 +194,15 @@ def _tree_unravel( treespec: PyTreeSpec, unravel_flat: Callable[[Array], list[ArrayLike]], flat: Array, + /, ) -> ArrayTree: return tree_unflatten(treespec, unravel_flat(flat)) -def _ravel_leaves( - leaves: list[ArrayLike], -) -> tuple[Array, Callable[[Array], list[ArrayLike]]]: +def _ravel_leaves(leaves: list[ArrayLike], /) -> tuple[ + Array, + Callable[[Array], list[ArrayLike]], +]: if not leaves: return (jnp.zeros(0), _unravel_empty) @@ -229,7 +232,7 @@ def _ravel_leaves( ) -def _unravel_empty(flat: Array) -> list[ArrayLike]: +def _unravel_empty(flat: Array, /) -> list[ArrayLike]: if jnp.shape(flat) != (0,): raise ValueError( f'The unravel function expected an array of shape {(0,)}, ' @@ -243,6 +246,7 @@ def _unravel_leaves_single_dtype( indices: tuple[int, ...], shapes: tuple[tuple[int, ...], ...], flat: Array, + /, ) -> list[Array]: if jnp.shape(flat) != (indices[-1],): raise ValueError( @@ -260,6 +264,7 @@ def _unravel_leaves( from_dtypes: tuple[jnp.dtype, ...], to_dtype: jnp.dtype, flat: Array, + /, ) -> list[Array]: if jnp.shape(flat) != (indices[-1],): raise ValueError( diff --git a/optree/integration/numpy.py b/optree/integration/numpy.py index b9609969..f2b29191 100644 --- a/optree/integration/numpy.py +++ b/optree/integration/numpy.py @@ -39,6 +39,7 @@ def tree_ravel( tree: ArrayLikeTree, + /, is_leaf: Callable[[Any], bool] | None = None, *, none_is_leaf: bool = False, @@ -122,13 +123,15 @@ def _tree_unravel( treespec: PyTreeSpec, unravel_flat: Callable[[np.ndarray], list[np.ndarray]], flat: np.ndarray, + /, ) -> ArrayTree: return tree_unflatten(treespec, unravel_flat(flat)) -def _ravel_leaves( - leaves: list[np.ndarray], -) -> tuple[np.ndarray, Callable[[np.ndarray], list[np.ndarray]]]: +def _ravel_leaves(leaves: list[np.ndarray], /) -> tuple[ + np.ndarray, + Callable[[np.ndarray], list[np.ndarray]], +]: if not leaves: return (np.zeros(0), _unravel_empty) @@ -155,7 +158,7 @@ def _ravel_leaves( ) -def _unravel_empty(flat: np.ndarray) -> list[np.ndarray]: +def _unravel_empty(flat: np.ndarray, /) -> list[np.ndarray]: if np.shape(flat) != (0,): raise ValueError( f'The unravel function expected an array of shape {(0,)}, ' @@ -168,6 +171,7 @@ def _unravel_leaves_single_dtype( indices: tuple[int, ...], shapes: tuple[tuple[int, ...], ...], flat: np.ndarray, + /, ) -> list[np.ndarray]: if np.shape(flat) != (indices[-1],): raise ValueError( @@ -185,6 +189,7 @@ def _unravel_leaves( from_dtypes: tuple[np.dtype, ...], to_dtype: np.dtype, flat: np.ndarray, + /, ) -> list[np.ndarray]: if np.shape(flat) != (indices[-1],): raise ValueError( diff --git a/optree/integration/torch.py b/optree/integration/torch.py index 6af98c22..cd4c078a 100644 --- a/optree/integration/torch.py +++ b/optree/integration/torch.py @@ -36,6 +36,7 @@ def tree_ravel( tree: TensorTree, + /, is_leaf: Callable[[Any], bool] | None = None, *, none_is_leaf: bool = False, @@ -119,13 +120,15 @@ def _tree_unravel( treespec: PyTreeSpec, unravel_flat: Callable[[torch.Tensor], list[torch.Tensor]], flat: torch.Tensor, + /, ) -> TensorTree: return tree_unflatten(treespec, unravel_flat(flat)) -def _ravel_leaves( - leaves: list[torch.Tensor], -) -> tuple[torch.Tensor, Callable[[torch.Tensor], list[torch.Tensor]]]: +def _ravel_leaves(leaves: list[torch.Tensor], /) -> tuple[ + torch.Tensor, + Callable[[torch.Tensor], list[torch.Tensor]], +]: if not leaves: return (torch.zeros(0), _unravel_empty) if not all(torch.is_tensor(leaf) for leaf in leaves): @@ -155,7 +158,7 @@ def _ravel_leaves( ) -def _unravel_empty(flat: torch.Tensor) -> list[torch.Tensor]: +def _unravel_empty(flat: torch.Tensor, /) -> list[torch.Tensor]: if not torch.is_tensor(flat): raise ValueError(f'Expected a tensor to unravel, got {type(flat)!r}.') if flat.shape != (0,): @@ -169,6 +172,7 @@ def _unravel_leaves_single_dtype( sizes: tuple[int, ...], shapes: tuple[tuple[int, ...], ...], flat: torch.Tensor, + /, ) -> list[torch.Tensor]: if not torch.is_tensor(flat): raise ValueError(f'Expected a tensor to unravel, got {type(flat)!r}.') @@ -188,6 +192,7 @@ def _unravel_leaves( from_dtypes: tuple[torch.dtype, ...], to_dtype: torch.dtype, flat: torch.Tensor, + /, ) -> list[torch.Tensor]: if not torch.is_tensor(flat): raise ValueError(f'Expected a tensor to unravel, got {type(flat)!r}.') diff --git a/tests/test_utils.py b/tests/test_utils.py index 62cc2a61..eb72b344 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -15,7 +15,7 @@ # pylint: disable=missing-function-docstring,invalid-name -import operator +from operator import itemgetter import pytest @@ -32,7 +32,7 @@ def test_total_order_sorted(): assert total_order_sorted([1, 5, 4.5, '20', '3']) == [4.5, 1, 5, '20', '3'] assert total_order_sorted( {1: 1, 5: 2, 4.5: 3, '20': 4, '3': 5}.items(), - key=operator.itemgetter(0), + key=itemgetter(0), ) == [(4.5, 3), (1, 1), (5, 2), ('20', 4), ('3', 5)] class NonSortable: