diff --git a/optree/ops.py b/optree/ops.py index 69c35536..3b2a4166 100644 --- a/optree/ops.py +++ b/optree/ops.py @@ -701,7 +701,7 @@ def tree_transpose( leaves, treespec = tree_flatten( tree, - is_leaf, + is_leaf=is_leaf, none_is_leaf=outer_treespec.none_is_leaf, namespace=outer_treespec.namespace or inner_treespec.namespace, ) @@ -773,7 +773,7 @@ def tree_broadcast_prefix( def broadcast_leaves(x: T, subtree: PyTree[S]) -> PyTree[T]: subtreespec = tree_structure( subtree, - is_leaf, # type: ignore[arg-type] + is_leaf=is_leaf, # type: ignore[arg-type] none_is_leaf=none_is_leaf, namespace=namespace, ) @@ -781,7 +781,13 @@ def broadcast_leaves(x: T, subtree: PyTree[S]) -> PyTree[T]: # If prefix_tree is not a tree prefix of full_tree, this code can raise a ValueError; # use prefix_errors to find disagreements and raise more precise error messages. - # prefix_errors = prefix_errors(prefix_tree, full_tree, is_leaf, none_is_leaf=none_is_leaf, namespace=namespace) + # errors = prefix_errors( + # prefix_tree, + # full_tree, + # is_leaf=is_leaf, + # none_is_leaf=none_is_leaf, + # namespace=namespace, + # ) return tree_map( broadcast_leaves, # type: ignore[arg-type] prefix_tree, @@ -848,7 +854,7 @@ def broadcast_prefix( def add_leaves(x: T, subtree: PyTree[S]) -> None: subtreespec = tree_structure( subtree, - is_leaf, # type: ignore[arg-type] + is_leaf=is_leaf, # type: ignore[arg-type] none_is_leaf=none_is_leaf, namespace=namespace, ) @@ -856,7 +862,13 @@ def add_leaves(x: T, subtree: PyTree[S]) -> None: # If prefix_tree is not a tree prefix of full_tree, this code can raise a ValueError; # use prefix_errors to find disagreements and raise more precise error messages. - # prefix_errors = prefix_errors(prefix_tree, full_tree, is_leaf, none_is_leaf=none_is_leaf, namespace=namespace) + # errors = prefix_errors( + # prefix_tree, + # full_tree, + # is_leaf=is_leaf, + # none_is_leaf=none_is_leaf, + # namespace=namespace, + # ) tree_map_( add_leaves, prefix_tree, @@ -944,7 +956,7 @@ def tree_reduce( Returns: The result of reducing the leaves of the pytree using ``func``. """ # pylint: disable=line-too-long - leaves = tree_leaves(tree, is_leaf, none_is_leaf=none_is_leaf, namespace=namespace) + leaves = tree_leaves(tree, is_leaf=is_leaf, none_is_leaf=none_is_leaf, namespace=namespace) if initial is __MISSING: return functools.reduce(func, leaves) return functools.reduce(func, leaves, initial) @@ -992,7 +1004,7 @@ def tree_sum( Returns: The total sum of ``start`` and leaf values in ``tree``. """ - leaves = tree_leaves(tree, is_leaf, none_is_leaf=none_is_leaf, namespace=namespace) + leaves = tree_leaves(tree, is_leaf=is_leaf, none_is_leaf=none_is_leaf, namespace=namespace) # sum() rejects string values for `start` parameter if isinstance(start, str): return ''.join([start, *leaves]) # type: ignore[list-item,return-value] @@ -1078,7 +1090,7 @@ def tree_max(tree, *, default=__MISSING, key=None, is_leaf=None, none_is_leaf=Fa Returns: The maximum leaf value in ``tree``. """ - leaves = tree_leaves(tree, is_leaf, none_is_leaf=none_is_leaf, namespace=namespace) + leaves = tree_leaves(tree, is_leaf=is_leaf, none_is_leaf=none_is_leaf, namespace=namespace) if default is __MISSING: if key is None: # special handling for Python 3.7 return max(leaves) @@ -1165,7 +1177,7 @@ def tree_min(tree, *, default=__MISSING, key=None, is_leaf=None, none_is_leaf=Fa Returns: The minimum leaf value in ``tree``. """ - leaves = tree_leaves(tree, is_leaf, none_is_leaf=none_is_leaf, namespace=namespace) + leaves = tree_leaves(tree, is_leaf=is_leaf, none_is_leaf=none_is_leaf, namespace=namespace) if default is __MISSING: if key is None: # special handling for Python 3.7 return min(leaves) @@ -1216,7 +1228,14 @@ def tree_all( :data:`True` if all leaves in ``tree`` are true, or if ``tree`` is empty. Otherwise, :data:`False`. """ - return all(tree_leaves(tree, is_leaf, none_is_leaf=none_is_leaf, namespace=namespace)) # type: ignore[arg-type] + return all( + tree_leaves( + tree, # type: ignore[arg-type] + is_leaf=is_leaf, # type: ignore[arg-type] + none_is_leaf=none_is_leaf, + namespace=namespace, + ), + ) def tree_any( @@ -1260,7 +1279,14 @@ def tree_any( :data:`True` if any leaves in ``tree`` are true, otherwise, :data:`False`. If ``tree`` is empty, return :data:`False`. """ - return any(tree_leaves(tree, is_leaf, none_is_leaf=none_is_leaf, namespace=namespace)) # type: ignore[arg-type] + return any( + tree_leaves( + tree, # type: ignore[arg-type] + is_leaf=is_leaf, # type: ignore[arg-type] + none_is_leaf=none_is_leaf, + namespace=namespace, + ), + ) def treespec_is_prefix( @@ -1571,7 +1597,7 @@ def prefix_errors( KeyPath(), prefix_tree, full_tree, - is_leaf, + is_leaf=is_leaf, none_is_leaf=none_is_leaf, namespace=namespace, ), @@ -1717,8 +1743,18 @@ def _prefix_error( # If the root types and numbers of children agree, there must be an error in a subtree, # so recurse: - keys = _child_keys(prefix_tree, is_leaf, none_is_leaf=none_is_leaf, namespace=namespace) - keys_ = _child_keys(full_tree, is_leaf, none_is_leaf=none_is_leaf, namespace=namespace) # type: ignore[arg-type] + keys = _child_keys( + prefix_tree, + is_leaf=is_leaf, + none_is_leaf=none_is_leaf, + namespace=namespace, + ) + keys_ = _child_keys( + full_tree, + is_leaf=is_leaf, # type: ignore[arg-type] + none_is_leaf=none_is_leaf, + namespace=namespace, + ) assert keys == keys_ or ( # Special handling for directory types already done in the keys check above both_standard_dict @@ -1729,7 +1765,7 @@ def _prefix_error( key_path + k, cast(PyTree[T], t1), cast(PyTree[S], t2), - is_leaf, + is_leaf=is_leaf, none_is_leaf=none_is_leaf, namespace=namespace, ) @@ -1742,7 +1778,7 @@ def _child_keys( none_is_leaf: bool = False, namespace: str = '', ) -> list[KeyPathEntry]: - treespec = tree_structure(tree, is_leaf, none_is_leaf=none_is_leaf, namespace=namespace) + treespec = tree_structure(tree, is_leaf=is_leaf, none_is_leaf=none_is_leaf, namespace=namespace) assert not treespec_is_strict_leaf(treespec), 'treespec must be a non-leaf node' handler = register_keypaths.get(type(tree)) # type: ignore[attr-defined] diff --git a/tests/test_ops.py b/tests/test_ops.py index 56ffdbf2..43fc8cf3 100644 --- a/tests/test_ops.py +++ b/tests/test_ops.py @@ -389,7 +389,7 @@ def test_paths_with_is_leaf(tree, is_leaf, none_is_leaf, namespace): def test_round_trip_is_leaf(tree, is_leaf, none_is_leaf, namespace): subtrees, treespec = optree.tree_flatten( tree, - is_leaf, + is_leaf=is_leaf, none_is_leaf=none_is_leaf, namespace=namespace, ) @@ -430,8 +430,18 @@ def test_all_leaves_with_leaves(leaf, none_is_leaf, namespace): namespace=['', 'undefined', 'namespace'], ) def test_all_leaves_with_is_leaf(tree, is_leaf, none_is_leaf, namespace): - leaves = optree.tree_leaves(tree, is_leaf, none_is_leaf=none_is_leaf, namespace=namespace) - assert optree.all_leaves(leaves, is_leaf, none_is_leaf=none_is_leaf, namespace=namespace) + leaves = optree.tree_leaves( + tree, + is_leaf=is_leaf, + none_is_leaf=none_is_leaf, + namespace=namespace, + ) + assert optree.all_leaves( + leaves, + is_leaf=is_leaf, + none_is_leaf=none_is_leaf, + namespace=namespace, + ) def test_tree_map():