Skip to content

Commit

Permalink
chore(pre-commit): update pre-commit hooks
Browse files Browse the repository at this point in the history
  • Loading branch information
XuehaiPan committed Apr 21, 2024
1 parent 47f20b7 commit 055b386
Show file tree
Hide file tree
Showing 10 changed files with 34 additions and 30 deletions.
8 changes: 4 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ ci:
default_stages: [commit, push, manual]
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
rev: v4.6.0
hooks:
- id: check-symlinks
- id: destroyed-symlinks
Expand All @@ -26,11 +26,11 @@ repos:
- id: debug-statements
- id: double-quote-string-fixer
- repo: https://github.com/pre-commit/mirrors-clang-format
rev: v18.1.2
rev: v18.1.4
hooks:
- id: clang-format
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.3.5
rev: v0.4.1
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
Expand All @@ -39,7 +39,7 @@ repos:
hooks:
- id: isort
- repo: https://github.com/psf/black
rev: 24.3.0
rev: 24.4.0
hooks:
- id: black
- repo: https://github.com/asottile/pyupgrade
Expand Down
4 changes: 2 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ if(NOT DEFINED PYBIND11_VERSION AND NOT "$ENV{PYBIND11_VERSION}" STREQUAL "")
set(PYBIND11_VERSION "$ENV{PYBIND11_VERSION}")
endif()
if(NOT PYBIND11_VERSION)
set(PYBIND11_VERSION v2.11.1)
set(PYBIND11_VERSION v2.12.0)
endif()

if(NOT CMAKE_BUILD_TYPE)
Expand Down Expand Up @@ -51,7 +51,7 @@ if(MSVC)
APPEND CMAKE_CXX_FLAGS
" /Zc:preprocessor"
" /experimental:external /external:anglebrackets /external:W0"
" /Wall /Wv:19.34" # Visual Studio 2022 version 17.4
" /Wall /Wv:19.37" # Visual Studio 2022 version 17.7
# Suppress following warnings
" /wd4365" # conversion from 'type_1' to 'type_2', signed/unsigned mismatch
" /wd4514" # unreferenced inline function has been removed
Expand Down
7 changes: 3 additions & 4 deletions benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from __future__ import annotations

import argparse
import operator
import sys
import textwrap
import timeit
Expand Down Expand Up @@ -447,7 +448,7 @@ def check(tree: Any) -> None:

optree_flat = optree.tree_leaves(tree, none_is_leaf=False)
jax_flat = jax.tree_util.tree_leaves(tree)
if len(optree_flat) == len(jax_flat) and all(map(lambda a, b: a is b, optree_flat, jax_flat)):
if len(optree_flat) == len(jax_flat) and all(map(operator.is_, optree_flat, jax_flat)):
cprint(
f'{CMARK} FLATTEN (OpTree vs. JAX XLA): '
f'optree.tree_leaves(tree, none_is_leaf=False)'
Expand All @@ -462,9 +463,7 @@ def check(tree: Any) -> None:

optree_flat = optree.tree_leaves(tree, none_is_leaf=True)
torch_flat = torch_utils_pytree.tree_flatten(tree)[0]
if len(optree_flat) == len(torch_flat) and all(
map(lambda a, b: a is b, optree_flat, torch_flat),
):
if len(optree_flat) == len(torch_flat) and all(map(operator.is_, optree_flat, torch_flat)):
cprint(
f'{CMARK} FLATTEN (OpTree vs. PyTorch): '
f'optree.tree_leaves(tree, none_is_leaf=True)'
Expand Down
10 changes: 5 additions & 5 deletions optree/integration/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@

from __future__ import annotations

import contextlib
import itertools
import operator
import warnings
from types import FunctionType
from typing import Any, Callable
Expand Down Expand Up @@ -92,7 +94,7 @@ def __hash__(self) -> int:
(
self.func.__code__, # type: ignore[attr-defined]
self.args,
tuple(total_order_sorted(self.kwargs.items(), key=lambda kv: kv[0])),
tuple(total_order_sorted(self.kwargs.items(), key=operator.itemgetter(0))),
),
)

Expand All @@ -101,11 +103,9 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any:
return self.func(*self.args, *args, **kwargs)


try: # noqa: SIM105 # pragma: no cover
# pylint: disable=ungrouped-imports
with contextlib.suppress(ImportError): # pragma: no cover
# pylint: disable-next=ungrouped-imports
from jax._src.util import HashablePartial # type: ignore[assignment] # noqa: F811,RUF100
except ImportError: # pragma: no cover
pass


def tree_ravel(
Expand Down
4 changes: 2 additions & 2 deletions optree/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import inspect
import sys
from collections import OrderedDict, defaultdict, deque, namedtuple
from operator import methodcaller
from operator import itemgetter, methodcaller
from threading import Lock
from typing import (
TYPE_CHECKING,
Expand Down Expand Up @@ -413,7 +413,7 @@ def unregister_pytree_node(


def _sorted_items(items: Iterable[tuple[KT, VT]]) -> list[tuple[KT, VT]]:
return total_order_sorted(items, key=lambda kv: kv[0])
return total_order_sorted(items, key=itemgetter(0))


def _sorted_keys(dct: dict[KT, VT]) -> list[KT]:
Expand Down
3 changes: 3 additions & 0 deletions tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,9 @@ def __int__(self):
def __eq__(self, other):
return isinstance(other, Counter) and self.count == other.count

def __hash__(self):
return hash(self.count)

def __repr__(self):
return f'Counter({self.count})'

Expand Down
18 changes: 9 additions & 9 deletions tests/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1655,8 +1655,8 @@ def test_broadcast_common():


def test_tree_reduce():
assert optree.tree_reduce(lambda x, y: x + y, {'x': 1, 'y': (2, 3)}) == 6
assert optree.tree_reduce(lambda x, y: x + y, {'x': 1, 'y': (2, None), 'z': 3}) == 6
assert optree.tree_reduce(operator.add, {'x': 1, 'y': (2, 3)}) == 6
assert optree.tree_reduce(operator.add, {'x': 1, 'y': (2, None), 'z': 3}) == 6
assert optree.tree_reduce(lambda x, y: x and y, {'x': 1, 'y': (2, None), 'z': 3}) == 3
assert (
optree.tree_reduce(
Expand Down Expand Up @@ -1699,7 +1699,7 @@ def test_tree_max():
optree.tree_max({})
assert optree.tree_max({}, default=0) == 0
assert optree.tree_max({'x': 0, 'y': (2, 1)}) == 2
assert optree.tree_max({'x': 0, 'y': (2, 1)}, key=lambda x: -x) == 0
assert optree.tree_max({'x': 0, 'y': (2, 1)}, key=operator.neg) == 0
with pytest.raises(ValueError, match='empty'):
optree.tree_max({'a': None})
assert optree.tree_max({'a': None}, default=0) == 0
Expand All @@ -1708,17 +1708,17 @@ def test_tree_max():
optree.tree_max(None)
assert optree.tree_max(None, default=0) == 0
assert optree.tree_max(None, none_is_leaf=True) is None
assert optree.tree_max(None, default=0, key=lambda x: -x) == 0
assert optree.tree_max(None, default=0, key=operator.neg) == 0
with pytest.raises(TypeError, match=re.escape("bad operand type for unary -: 'NoneType'")):
assert optree.tree_max(None, default=0, key=lambda x: -x, none_is_leaf=True) is None
assert optree.tree_max(None, default=0, key=operator.neg, none_is_leaf=True) is None


def test_tree_min():
with pytest.raises(ValueError, match='empty'):
optree.tree_min({})
assert optree.tree_min({}, default=0) == 0
assert optree.tree_min({'x': 0, 'y': (2, 1)}) == 0
assert optree.tree_min({'x': 0, 'y': (2, 1)}, key=lambda x: -x) == 2
assert optree.tree_min({'x': 0, 'y': (2, 1)}, key=operator.neg) == 2
with pytest.raises(ValueError, match='empty'):
optree.tree_min({'a': None})
assert optree.tree_min({'a': None}, default=0) == 0
Expand All @@ -1727,9 +1727,9 @@ def test_tree_min():
optree.tree_min(None)
assert optree.tree_min(None, default=0) == 0
assert optree.tree_min(None, none_is_leaf=True) is None
assert optree.tree_min(None, default=0, key=lambda x: -x) == 0
assert optree.tree_min(None, default=0, key=operator.neg) == 0
with pytest.raises(TypeError, match=re.escape("bad operand type for unary -: 'NoneType'")):
assert optree.tree_min(None, default=0, key=lambda x: -x, none_is_leaf=True) is None
assert optree.tree_min(None, default=0, key=operator.neg, none_is_leaf=True) is None


def test_tree_all():
Expand Down Expand Up @@ -1785,7 +1785,7 @@ def test_tree_flatten_one_level(tree, none_is_leaf, namespace): # noqa: C901
namespace=namespace,
)
assert children == expected_children
if node_type in (type(None), tuple, list):
if node_type in {type(None), tuple, list}:
assert metadata is None
if node_type is tuple:
assert one_level_treespec.kind == optree.PyTreeKind.TUPLE
Expand Down
4 changes: 2 additions & 2 deletions tests/test_prefix_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,7 @@ def shuffle_dictionary(x):
shuffled_tree = optree.tree_map(
shuffle_dictionary,
tree,
is_leaf=lambda x: type(x) in (dict, OrderedDict, defaultdict),
is_leaf=lambda x: type(x) in {dict, OrderedDict, defaultdict},
none_is_leaf=none_is_leaf,
namespace=namespace,
)
Expand All @@ -453,7 +453,7 @@ def shuffle_dictionary(x):
shuffled_suffix_tree = optree.tree_map(
shuffle_dictionary,
suffix_tree,
is_leaf=lambda x: type(x) in (dict, OrderedDict, defaultdict),
is_leaf=lambda x: type(x) in {dict, OrderedDict, defaultdict},
none_is_leaf=none_is_leaf,
namespace=namespace,
)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_treespec.py
Original file line number Diff line number Diff line change
Expand Up @@ -859,7 +859,7 @@ def test_treespec_constructor(tree, none_is_leaf, namespace): # noqa: C901
== expected_treespec
)

if node_type in (type(None), tuple, list):
if node_type in {type(None), tuple, list}:
if node_type is tuple:
assert (
optree.treespec_tuple(
Expand Down
4 changes: 3 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

# pylint: disable=missing-function-docstring,invalid-name

import operator

import pytest

from optree.utils import safe_zip, total_order_sorted, unzip2
Expand All @@ -30,7 +32,7 @@ def test_total_order_sorted():
assert total_order_sorted([1, 5, 4.5, '20', '3']) == [4.5, 1, 5, '20', '3']
assert total_order_sorted(
{1: 1, 5: 2, 4.5: 3, '20': 4, '3': 5}.items(),
key=lambda kv: kv[0],
key=operator.itemgetter(0),
) == [(4.5, 3), (1, 1), (5, 2), ('20', 4), ('3', 5)]

class NonSortable:
Expand Down

0 comments on commit 055b386

Please sign in to comment.