diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 8f219fcb..e24435c4 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -9,7 +9,7 @@ ci: default_stages: [commit, push, manual] repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.5.0 + rev: v4.6.0 hooks: - id: check-symlinks - id: destroyed-symlinks @@ -26,11 +26,11 @@ repos: - id: debug-statements - id: double-quote-string-fixer - repo: https://github.com/pre-commit/mirrors-clang-format - rev: v18.1.2 + rev: v18.1.4 hooks: - id: clang-format - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.3.5 + rev: v0.4.1 hooks: - id: ruff args: [--fix, --exit-non-zero-on-fix] @@ -39,7 +39,7 @@ repos: hooks: - id: isort - repo: https://github.com/psf/black - rev: 24.3.0 + rev: 24.4.0 hooks: - id: black - repo: https://github.com/asottile/pyupgrade diff --git a/CMakeLists.txt b/CMakeLists.txt index c970a9a4..dd19a8aa 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -23,7 +23,7 @@ if(NOT DEFINED PYBIND11_VERSION AND NOT "$ENV{PYBIND11_VERSION}" STREQUAL "") set(PYBIND11_VERSION "$ENV{PYBIND11_VERSION}") endif() if(NOT PYBIND11_VERSION) - set(PYBIND11_VERSION v2.11.1) + set(PYBIND11_VERSION v2.12.0) endif() if(NOT CMAKE_BUILD_TYPE) @@ -51,7 +51,7 @@ if(MSVC) APPEND CMAKE_CXX_FLAGS " /Zc:preprocessor" " /experimental:external /external:anglebrackets /external:W0" - " /Wall /Wv:19.34" # Visual Studio 2022 version 17.4 + " /Wall /Wv:19.37" # Visual Studio 2022 version 17.7 # Suppress following warnings " /wd4365" # conversion from 'type_1' to 'type_2', signed/unsigned mismatch " /wd4514" # unreferenced inline function has been removed diff --git a/benchmark.py b/benchmark.py index 1243af75..f94cfce7 100755 --- a/benchmark.py +++ b/benchmark.py @@ -21,6 +21,7 @@ from __future__ import annotations import argparse +import operator import sys import textwrap import timeit @@ -447,7 +448,7 @@ def check(tree: Any) -> None: optree_flat = optree.tree_leaves(tree, none_is_leaf=False) jax_flat = jax.tree_util.tree_leaves(tree) - if len(optree_flat) == len(jax_flat) and all(map(lambda a, b: a is b, optree_flat, jax_flat)): + if len(optree_flat) == len(jax_flat) and all(map(operator.is_, optree_flat, jax_flat)): cprint( f'{CMARK} FLATTEN (OpTree vs. JAX XLA): ' f'optree.tree_leaves(tree, none_is_leaf=False)' @@ -462,9 +463,7 @@ def check(tree: Any) -> None: optree_flat = optree.tree_leaves(tree, none_is_leaf=True) torch_flat = torch_utils_pytree.tree_flatten(tree)[0] - if len(optree_flat) == len(torch_flat) and all( - map(lambda a, b: a is b, optree_flat, torch_flat), - ): + if len(optree_flat) == len(torch_flat) and all(map(operator.is_, optree_flat, torch_flat)): cprint( f'{CMARK} FLATTEN (OpTree vs. PyTorch): ' f'optree.tree_leaves(tree, none_is_leaf=True)' diff --git a/optree/integration/jax.py b/optree/integration/jax.py index 7ff34e57..d658c5f6 100644 --- a/optree/integration/jax.py +++ b/optree/integration/jax.py @@ -33,7 +33,9 @@ from __future__ import annotations +import contextlib import itertools +import operator import warnings from types import FunctionType from typing import Any, Callable @@ -92,7 +94,7 @@ def __hash__(self) -> int: ( self.func.__code__, # type: ignore[attr-defined] self.args, - tuple(total_order_sorted(self.kwargs.items(), key=lambda kv: kv[0])), + tuple(total_order_sorted(self.kwargs.items(), key=operator.itemgetter(0))), ), ) @@ -101,11 +103,9 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any: return self.func(*self.args, *args, **kwargs) -try: # noqa: SIM105 # pragma: no cover - # pylint: disable=ungrouped-imports +with contextlib.suppress(ImportError): # pragma: no cover + # pylint: disable-next=ungrouped-imports from jax._src.util import HashablePartial # type: ignore[assignment] # noqa: F811,RUF100 -except ImportError: # pragma: no cover - pass def tree_ravel( diff --git a/optree/registry.py b/optree/registry.py index 1fc5b929..e4129f40 100644 --- a/optree/registry.py +++ b/optree/registry.py @@ -21,7 +21,7 @@ import inspect import sys from collections import OrderedDict, defaultdict, deque, namedtuple -from operator import methodcaller +from operator import itemgetter, methodcaller from threading import Lock from typing import ( TYPE_CHECKING, @@ -413,7 +413,7 @@ def unregister_pytree_node( def _sorted_items(items: Iterable[tuple[KT, VT]]) -> list[tuple[KT, VT]]: - return total_order_sorted(items, key=lambda kv: kv[0]) + return total_order_sorted(items, key=itemgetter(0)) def _sorted_keys(dct: dict[KT, VT]) -> list[KT]: diff --git a/tests/helpers.py b/tests/helpers.py index e2b23a00..ecc6ca47 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -228,6 +228,9 @@ def __int__(self): def __eq__(self, other): return isinstance(other, Counter) and self.count == other.count + def __hash__(self): + return hash(self.count) + def __repr__(self): return f'Counter({self.count})' diff --git a/tests/test_ops.py b/tests/test_ops.py index 4afbcc93..ac1e4610 100644 --- a/tests/test_ops.py +++ b/tests/test_ops.py @@ -1655,8 +1655,8 @@ def test_broadcast_common(): def test_tree_reduce(): - assert optree.tree_reduce(lambda x, y: x + y, {'x': 1, 'y': (2, 3)}) == 6 - assert optree.tree_reduce(lambda x, y: x + y, {'x': 1, 'y': (2, None), 'z': 3}) == 6 + assert optree.tree_reduce(operator.add, {'x': 1, 'y': (2, 3)}) == 6 + assert optree.tree_reduce(operator.add, {'x': 1, 'y': (2, None), 'z': 3}) == 6 assert optree.tree_reduce(lambda x, y: x and y, {'x': 1, 'y': (2, None), 'z': 3}) == 3 assert ( optree.tree_reduce( @@ -1699,7 +1699,7 @@ def test_tree_max(): optree.tree_max({}) assert optree.tree_max({}, default=0) == 0 assert optree.tree_max({'x': 0, 'y': (2, 1)}) == 2 - assert optree.tree_max({'x': 0, 'y': (2, 1)}, key=lambda x: -x) == 0 + assert optree.tree_max({'x': 0, 'y': (2, 1)}, key=operator.neg) == 0 with pytest.raises(ValueError, match='empty'): optree.tree_max({'a': None}) assert optree.tree_max({'a': None}, default=0) == 0 @@ -1708,9 +1708,9 @@ def test_tree_max(): optree.tree_max(None) assert optree.tree_max(None, default=0) == 0 assert optree.tree_max(None, none_is_leaf=True) is None - assert optree.tree_max(None, default=0, key=lambda x: -x) == 0 + assert optree.tree_max(None, default=0, key=operator.neg) == 0 with pytest.raises(TypeError, match=re.escape("bad operand type for unary -: 'NoneType'")): - assert optree.tree_max(None, default=0, key=lambda x: -x, none_is_leaf=True) is None + assert optree.tree_max(None, default=0, key=operator.neg, none_is_leaf=True) is None def test_tree_min(): @@ -1718,7 +1718,7 @@ def test_tree_min(): optree.tree_min({}) assert optree.tree_min({}, default=0) == 0 assert optree.tree_min({'x': 0, 'y': (2, 1)}) == 0 - assert optree.tree_min({'x': 0, 'y': (2, 1)}, key=lambda x: -x) == 2 + assert optree.tree_min({'x': 0, 'y': (2, 1)}, key=operator.neg) == 2 with pytest.raises(ValueError, match='empty'): optree.tree_min({'a': None}) assert optree.tree_min({'a': None}, default=0) == 0 @@ -1727,9 +1727,9 @@ def test_tree_min(): optree.tree_min(None) assert optree.tree_min(None, default=0) == 0 assert optree.tree_min(None, none_is_leaf=True) is None - assert optree.tree_min(None, default=0, key=lambda x: -x) == 0 + assert optree.tree_min(None, default=0, key=operator.neg) == 0 with pytest.raises(TypeError, match=re.escape("bad operand type for unary -: 'NoneType'")): - assert optree.tree_min(None, default=0, key=lambda x: -x, none_is_leaf=True) is None + assert optree.tree_min(None, default=0, key=operator.neg, none_is_leaf=True) is None def test_tree_all(): @@ -1785,7 +1785,7 @@ def test_tree_flatten_one_level(tree, none_is_leaf, namespace): # noqa: C901 namespace=namespace, ) assert children == expected_children - if node_type in (type(None), tuple, list): + if node_type in {type(None), tuple, list}: assert metadata is None if node_type is tuple: assert one_level_treespec.kind == optree.PyTreeKind.TUPLE diff --git a/tests/test_prefix_errors.py b/tests/test_prefix_errors.py index df2a74d2..8bdf73b3 100644 --- a/tests/test_prefix_errors.py +++ b/tests/test_prefix_errors.py @@ -441,7 +441,7 @@ def shuffle_dictionary(x): shuffled_tree = optree.tree_map( shuffle_dictionary, tree, - is_leaf=lambda x: type(x) in (dict, OrderedDict, defaultdict), + is_leaf=lambda x: type(x) in {dict, OrderedDict, defaultdict}, none_is_leaf=none_is_leaf, namespace=namespace, ) @@ -453,7 +453,7 @@ def shuffle_dictionary(x): shuffled_suffix_tree = optree.tree_map( shuffle_dictionary, suffix_tree, - is_leaf=lambda x: type(x) in (dict, OrderedDict, defaultdict), + is_leaf=lambda x: type(x) in {dict, OrderedDict, defaultdict}, none_is_leaf=none_is_leaf, namespace=namespace, ) diff --git a/tests/test_treespec.py b/tests/test_treespec.py index 413b3c4a..6adec221 100644 --- a/tests/test_treespec.py +++ b/tests/test_treespec.py @@ -859,7 +859,7 @@ def test_treespec_constructor(tree, none_is_leaf, namespace): # noqa: C901 == expected_treespec ) - if node_type in (type(None), tuple, list): + if node_type in {type(None), tuple, list}: if node_type is tuple: assert ( optree.treespec_tuple( diff --git a/tests/test_utils.py b/tests/test_utils.py index fe35335c..62cc2a61 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -15,6 +15,8 @@ # pylint: disable=missing-function-docstring,invalid-name +import operator + import pytest from optree.utils import safe_zip, total_order_sorted, unzip2 @@ -30,7 +32,7 @@ def test_total_order_sorted(): assert total_order_sorted([1, 5, 4.5, '20', '3']) == [4.5, 1, 5, '20', '3'] assert total_order_sorted( {1: 1, 5: 2, 4.5: 3, '20': 4, '3': 5}.items(), - key=lambda kv: kv[0], + key=operator.itemgetter(0), ) == [(4.5, 3), (1, 1), (5, 2), ('20', 4), ('3', 5)] class NonSortable: