Skip to content

Commit

Permalink
style: update keyword argument passing
Browse files Browse the repository at this point in the history
  • Loading branch information
XuehaiPan committed Oct 5, 2023
1 parent 4ba0855 commit a1e4e29
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 19 deletions.
68 changes: 52 additions & 16 deletions optree/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,7 +701,7 @@ def tree_transpose(

leaves, treespec = tree_flatten(
tree,
is_leaf,
is_leaf=is_leaf,
none_is_leaf=outer_treespec.none_is_leaf,
namespace=outer_treespec.namespace or inner_treespec.namespace,
)
Expand Down Expand Up @@ -773,15 +773,21 @@ def tree_broadcast_prefix(
def broadcast_leaves(x: T, subtree: PyTree[S]) -> PyTree[T]:
subtreespec = tree_structure(
subtree,
is_leaf, # type: ignore[arg-type]
is_leaf=is_leaf, # type: ignore[arg-type]
none_is_leaf=none_is_leaf,
namespace=namespace,
)
return subtreespec.unflatten([x] * subtreespec.num_leaves)

# If prefix_tree is not a tree prefix of full_tree, this code can raise a ValueError;
# use prefix_errors to find disagreements and raise more precise error messages.
# prefix_errors = prefix_errors(prefix_tree, full_tree, is_leaf, none_is_leaf=none_is_leaf, namespace=namespace)
# errors = prefix_errors(
# prefix_tree,
# full_tree,
# is_leaf=is_leaf,
# none_is_leaf=none_is_leaf,
# namespace=namespace,
# )
return tree_map(
broadcast_leaves, # type: ignore[arg-type]
prefix_tree,
Expand Down Expand Up @@ -848,15 +854,21 @@ def broadcast_prefix(
def add_leaves(x: T, subtree: PyTree[S]) -> None:
subtreespec = tree_structure(
subtree,
is_leaf, # type: ignore[arg-type]
is_leaf=is_leaf, # type: ignore[arg-type]
none_is_leaf=none_is_leaf,
namespace=namespace,
)
result.extend([x] * subtreespec.num_leaves)

# If prefix_tree is not a tree prefix of full_tree, this code can raise a ValueError;
# use prefix_errors to find disagreements and raise more precise error messages.
# prefix_errors = prefix_errors(prefix_tree, full_tree, is_leaf, none_is_leaf=none_is_leaf, namespace=namespace)
# errors = prefix_errors(
# prefix_tree,
# full_tree,
# is_leaf=is_leaf,
# none_is_leaf=none_is_leaf,
# namespace=namespace,
# )
tree_map_(
add_leaves,
prefix_tree,
Expand Down Expand Up @@ -944,7 +956,7 @@ def tree_reduce(
Returns:
The result of reducing the leaves of the pytree using ``func``.
""" # pylint: disable=line-too-long
leaves = tree_leaves(tree, is_leaf, none_is_leaf=none_is_leaf, namespace=namespace)
leaves = tree_leaves(tree, is_leaf=is_leaf, none_is_leaf=none_is_leaf, namespace=namespace)
if initial is __MISSING:
return functools.reduce(func, leaves)
return functools.reduce(func, leaves, initial)
Expand Down Expand Up @@ -992,7 +1004,7 @@ def tree_sum(
Returns:
The total sum of ``start`` and leaf values in ``tree``.
"""
leaves = tree_leaves(tree, is_leaf, none_is_leaf=none_is_leaf, namespace=namespace)
leaves = tree_leaves(tree, is_leaf=is_leaf, none_is_leaf=none_is_leaf, namespace=namespace)
# sum() rejects string values for `start` parameter
if isinstance(start, str):
return ''.join([start, *leaves]) # type: ignore[list-item,return-value]
Expand Down Expand Up @@ -1078,7 +1090,7 @@ def tree_max(tree, *, default=__MISSING, key=None, is_leaf=None, none_is_leaf=Fa
Returns:
The maximum leaf value in ``tree``.
"""
leaves = tree_leaves(tree, is_leaf, none_is_leaf=none_is_leaf, namespace=namespace)
leaves = tree_leaves(tree, is_leaf=is_leaf, none_is_leaf=none_is_leaf, namespace=namespace)
if default is __MISSING:
if key is None: # special handling for Python 3.7
return max(leaves)
Expand Down Expand Up @@ -1165,7 +1177,7 @@ def tree_min(tree, *, default=__MISSING, key=None, is_leaf=None, none_is_leaf=Fa
Returns:
The minimum leaf value in ``tree``.
"""
leaves = tree_leaves(tree, is_leaf, none_is_leaf=none_is_leaf, namespace=namespace)
leaves = tree_leaves(tree, is_leaf=is_leaf, none_is_leaf=none_is_leaf, namespace=namespace)
if default is __MISSING:
if key is None: # special handling for Python 3.7
return min(leaves)
Expand Down Expand Up @@ -1216,7 +1228,14 @@ def tree_all(
:data:`True` if all leaves in ``tree`` are true, or if ``tree`` is empty.
Otherwise, :data:`False`.
"""
return all(tree_leaves(tree, is_leaf, none_is_leaf=none_is_leaf, namespace=namespace)) # type: ignore[arg-type]
return all(
tree_leaves(
tree, # type: ignore[arg-type]
is_leaf=is_leaf, # type: ignore[arg-type]
none_is_leaf=none_is_leaf,
namespace=namespace,
),
)


def tree_any(
Expand Down Expand Up @@ -1260,7 +1279,14 @@ def tree_any(
:data:`True` if any leaves in ``tree`` are true, otherwise, :data:`False`. If ``tree`` is
empty, return :data:`False`.
"""
return any(tree_leaves(tree, is_leaf, none_is_leaf=none_is_leaf, namespace=namespace)) # type: ignore[arg-type]
return any(
tree_leaves(
tree, # type: ignore[arg-type]
is_leaf=is_leaf, # type: ignore[arg-type]
none_is_leaf=none_is_leaf,
namespace=namespace,
),
)


def treespec_is_prefix(
Expand Down Expand Up @@ -1571,7 +1597,7 @@ def prefix_errors(
KeyPath(),
prefix_tree,
full_tree,
is_leaf,
is_leaf=is_leaf,
none_is_leaf=none_is_leaf,
namespace=namespace,
),
Expand Down Expand Up @@ -1717,8 +1743,18 @@ def _prefix_error(

# If the root types and numbers of children agree, there must be an error in a subtree,
# so recurse:
keys = _child_keys(prefix_tree, is_leaf, none_is_leaf=none_is_leaf, namespace=namespace)
keys_ = _child_keys(full_tree, is_leaf, none_is_leaf=none_is_leaf, namespace=namespace) # type: ignore[arg-type]
keys = _child_keys(
prefix_tree,
is_leaf=is_leaf,
none_is_leaf=none_is_leaf,
namespace=namespace,
)
keys_ = _child_keys(
full_tree,
is_leaf=is_leaf, # type: ignore[arg-type]
none_is_leaf=none_is_leaf,
namespace=namespace,
)
assert keys == keys_ or (
# Special handling for directory types already done in the keys check above
both_standard_dict
Expand All @@ -1729,7 +1765,7 @@ def _prefix_error(
key_path + k,
cast(PyTree[T], t1),
cast(PyTree[S], t2),
is_leaf,
is_leaf=is_leaf,
none_is_leaf=none_is_leaf,
namespace=namespace,
)
Expand All @@ -1742,7 +1778,7 @@ def _child_keys(
none_is_leaf: bool = False,
namespace: str = '',
) -> list[KeyPathEntry]:
treespec = tree_structure(tree, is_leaf, none_is_leaf=none_is_leaf, namespace=namespace)
treespec = tree_structure(tree, is_leaf=is_leaf, none_is_leaf=none_is_leaf, namespace=namespace)
assert not treespec_is_strict_leaf(treespec), 'treespec must be a non-leaf node'

handler = register_keypaths.get(type(tree)) # type: ignore[attr-defined]
Expand Down
16 changes: 13 additions & 3 deletions tests/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ def test_paths_with_is_leaf(tree, is_leaf, none_is_leaf, namespace):
def test_round_trip_is_leaf(tree, is_leaf, none_is_leaf, namespace):
subtrees, treespec = optree.tree_flatten(
tree,
is_leaf,
is_leaf=is_leaf,
none_is_leaf=none_is_leaf,
namespace=namespace,
)
Expand Down Expand Up @@ -430,8 +430,18 @@ def test_all_leaves_with_leaves(leaf, none_is_leaf, namespace):
namespace=['', 'undefined', 'namespace'],
)
def test_all_leaves_with_is_leaf(tree, is_leaf, none_is_leaf, namespace):
leaves = optree.tree_leaves(tree, is_leaf, none_is_leaf=none_is_leaf, namespace=namespace)
assert optree.all_leaves(leaves, is_leaf, none_is_leaf=none_is_leaf, namespace=namespace)
leaves = optree.tree_leaves(
tree,
is_leaf=is_leaf,
none_is_leaf=none_is_leaf,
namespace=namespace,
)
assert optree.all_leaves(
leaves,
is_leaf=is_leaf,
none_is_leaf=none_is_leaf,
namespace=namespace,
)


def test_tree_map():
Expand Down

0 comments on commit a1e4e29

Please sign in to comment.