Skip to content

Commit

Permalink
feat(ops): add functions tree_broadcast_map and `tree_broadcast_map…
Browse files Browse the repository at this point in the history
…_with_path`
  • Loading branch information
XuehaiPan committed Oct 6, 2023
1 parent d28e9f9 commit a23aa76
Show file tree
Hide file tree
Showing 3 changed files with 188 additions and 2 deletions.
4 changes: 4 additions & 0 deletions docs/source/ops.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ Tree Manipulation Functions
broadcast_prefix
tree_broadcast_common
broadcast_common
tree_broadcast_map
tree_broadcast_map_with_path
prefix_errors

.. autofunction:: tree_flatten
Expand All @@ -61,6 +63,8 @@ Tree Manipulation Functions
.. autofunction:: broadcast_prefix
.. autofunction:: tree_broadcast_common
.. autofunction:: broadcast_common
.. autofunction:: tree_broadcast_map
.. autofunction:: tree_broadcast_map_with_path
.. autofunction:: prefix_errors

------
Expand Down
4 changes: 4 additions & 0 deletions optree/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
tree_all,
tree_any,
tree_broadcast_common,
tree_broadcast_map,
tree_broadcast_map_with_path,
tree_broadcast_prefix,
tree_flatten,
tree_flatten_with_path,
Expand Down Expand Up @@ -106,6 +108,8 @@
'broadcast_prefix',
'tree_broadcast_common',
'broadcast_common',
'tree_broadcast_map',
'tree_broadcast_map_with_path',
'tree_reduce',
'tree_sum',
'tree_max',
Expand Down
182 changes: 180 additions & 2 deletions optree/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@
'broadcast_prefix',
'tree_broadcast_common',
'broadcast_common',
'tree_broadcast_map',
'tree_broadcast_map_with_path',
'tree_reduce',
'tree_sum',
'tree_max',
Expand Down Expand Up @@ -468,7 +470,8 @@ def tree_map(
) -> PyTree[U]:
"""Map a multi-input function over pytree args to produce a new pytree.
See also :func:`tree_map_`, :func:`tree_map_with_path`, and :func:`tree_map_with_path_`.
See also :func:`tree_map_`, :func:`tree_map_with_path`, :func:`tree_map_with_path_`,
and :func:`tree_broadcast_map`.
>>> tree_map(lambda x: x + 1, {'x': 7, 'y': (42, 64)})
{'x': 8, 'y': (43, 65)}
Expand Down Expand Up @@ -566,7 +569,8 @@ def tree_map_with_path(
) -> PyTree[U]:
"""Map a multi-input function over pytree args as well as the tree paths to produce a new pytree.
See also :func:`tree_map`, :func:`tree_map_`, and :func:`tree_map_with_path_`.
See also :func:`tree_map`, :func:`tree_map_`, :func:`tree_map_with_path_`,
and :func:`tree_broadcast_map_with_path`.
>>> tree_map_with_path(lambda p, x: (len(p), x), {'x': 7, 'y': (42, 64)})
{'x': (1, 7), 'y': ((2, 42), (2, 64))}
Expand Down Expand Up @@ -1095,6 +1099,180 @@ def add_leaves(x: T, y: T) -> None:
return broadcasted_leaves, other_broadcasted_leaves


# pylint: disable-next=too-many-locals
def tree_broadcast_map(
func: Callable[..., U],
tree: PyTree[T],
*rests: PyTree[T],
is_leaf: Callable[[T], bool] | None = None,
none_is_leaf: bool = False,
namespace: str = '',
) -> PyTree[U]:
"""Map a multi-input function over pytree args to produce a new pytree.
See also :func:`tree_broadcast_map_with_path`, :func:`tree_map`, :func:`tree_map_`,
and :func:`tree_map_with_path`.
If only one input is provided, this function is the same as :func:`tree_map`:
>>> tree_broadcast_map(lambda x: x + 1, {'x': 7, 'y': (42, 64)})
{'x': 8, 'y': (43, 65)}
>>> tree_broadcast_map(lambda x: x + 1, {'x': 7, 'y': (42, 64), 'z': None})
{'x': 8, 'y': (43, 65), 'z': None}
>>> tree_broadcast_map(lambda x: x is None, {'x': 7, 'y': (42, 64), 'z': None})
{'x': False, 'y': (False, False), 'z': None}
>>> tree_broadcast_map(lambda x: x is None, {'x': 7, 'y': (42, 64), 'z': None}, none_is_leaf=True)
{'x': False, 'y': (False, False), 'z': True}
If multiple inputs are given, all input trees will be broadcasted to the common suffix structure
of all inputs:
>>> tree_broadcast_map(lambda x, y: x * y, [5, 6, (3, 4)], [{'a': 7, 'b': 9}, [1, 2], 8])
[{'a': 35, 'b': 45}, [6, 12], (24, 32)]
Args:
func (callable): A function that takes ``1 + len(rests)`` arguments, to be applied at the
corresponding leaves of the pytrees.
tree (pytree): A pytree to be mapped over, with each leaf providing the first positional
argument to function ``func``.
rests (tuple of pytree): A tuple of pytrees, they should have a common suffix structure with
each other and with ``tree``.
is_leaf (callable, optional): An optionally specified function that will be called at each
flattening step. It should return a boolean, with :data:`True` stopping the traversal
and the whole subtree being treated as a leaf, and :data:`False` indicating the
flattening should traverse the current object.
none_is_leaf (bool, optional): Whether to treat :data:`None` as a leaf. If :data:`False`,
:data:`None` is a non-leaf node with arity 0. Thus :data:`None` is contained in the
treespec rather than in the leaves list and :data:`None` will be remain in the result
pytree. (default: :data:`False`)
namespace (str, optional): The registry namespace used for custom pytree node types.
(default: :const:`''`, i.e., the global namespace)
Returns:
A new pytree with the structure as the common suffix structure of ``tree`` and ``rests`` but
with the value at each leaf given by ``func(x, *xs)`` where ``x`` is the value at the
corresponding leaf (may be broadcasted) in ``tree`` and ``xs`` is the tuple of values at
corresponding leaves (may be broadcasted) in ``rests``.
"""
if not rests:
return tree_map(
func,
tree,
is_leaf=is_leaf,
none_is_leaf=none_is_leaf,
namespace=namespace,
)

broadcasted_tree = tree
broadcasted_rests = list(rests)
for _ in range(2):
for i, rest in enumerate(rests):
broadcasted_tree, broadcasted_rests[i] = tree_broadcast_common(
broadcasted_tree,
rest,
is_leaf=is_leaf,
none_is_leaf=none_is_leaf,
namespace=namespace,
)

return tree_map(
func,
broadcasted_tree,
*broadcasted_rests,
is_leaf=is_leaf,
none_is_leaf=none_is_leaf,
namespace=namespace,
)


# pylint: disable-next=too-many-locals
def tree_broadcast_map_with_path(
func: Callable[..., U],
tree: PyTree[T],
*rests: PyTree[T],
is_leaf: Callable[[T], bool] | None = None,
none_is_leaf: bool = False,
namespace: str = '',
) -> PyTree[U]:
"""Map a multi-input function over pytree args as well as the tree paths to produce a new pytree.
See also :func:`tree_broadcast_map`, :func:`tree_map`, :func:`tree_map_`,
and :func:`tree_map_with_path`.
If only one input is provided, this function is the same as :func:`tree_map`:
>>> tree_broadcast_map_with_path(lambda p, x: (len(p), x), {'x': 7, 'y': (42, 64)})
{'x': (1, 7), 'y': ((2, 42), (2, 64))}
>>> tree_broadcast_map_with_path(lambda p, x: x + len(p), {'x': 7, 'y': (42, 64), 'z': None})
{'x': 8, 'y': (44, 66), 'z': None}
>>> tree_broadcast_map_with_path(lambda p, x: p, {'x': 7, 'y': (42, 64), 'z': {1.5: None}})
{'x': ('x',), 'y': (('y', 0), ('y', 1)), 'z': {1.5: None}}
>>> tree_broadcast_map_with_path(lambda p, x: p, {'x': 7, 'y': (42, 64), 'z': {1.5: None}}, none_is_leaf=True)
{'x': ('x',), 'y': (('y', 0), ('y', 1)), 'z': {1.5: ('z', 1.5)}}
If multiple inputs are given, all input trees will be broadcasted to the common suffix structure
of all inputs:
>>> tree_broadcast_map_with_path(lambda p, x, y: (p, x * y), [5, 6, (3, 4)], [{'a': 7, 'b': 9}, [1, 2], 8])
[{'a': ((0, 'a'), 35), 'b': ((0, 'b'), 45)},
[((1, 0), 6), ((1, 1), 12)],
(((2, 0), 24), ((2, 1), 32))]
Args:
func (callable): A function that takes ``2 + len(rests)`` arguments, to be applied at the
corresponding leaves of the pytrees with extra paths.
tree (pytree): A pytree to be mapped over, with each leaf providing the first positional
argument to function ``func``.
rests (tuple of pytree): A tuple of pytrees, they should have a common suffix structure with
each other and with ``tree``.
is_leaf (callable, optional): An optionally specified function that will be called at each
flattening step. It should return a boolean, with :data:`True` stopping the traversal
and the whole subtree being treated as a leaf, and :data:`False` indicating the
flattening should traverse the current object.
none_is_leaf (bool, optional): Whether to treat :data:`None` as a leaf. If :data:`False`,
:data:`None` is a non-leaf node with arity 0. Thus :data:`None` is contained in the
treespec rather than in the leaves list and :data:`None` will be remain in the result
pytree. (default: :data:`False`)
namespace (str, optional): The registry namespace used for custom pytree node types.
(default: :const:`''`, i.e., the global namespace)
Returns:
A new pytree with the structure as the common suffix structure of ``tree`` and ``rests`` but
with the value at each leaf given by ``func(p, x, *xs)`` where ``(p, x)`` are the path and
value at the corresponding leaf (may be broadcasted) in and ``xs`` is the tuple of values at
corresponding leaves (may be broadcasted) in ``rests``.
"""
if not rests:
return tree_map_with_path(
func,
tree,
is_leaf=is_leaf,
none_is_leaf=none_is_leaf,
namespace=namespace,
)

broadcasted_tree = tree
broadcasted_rests = list(rests)
for _ in range(2):
for i, rest in enumerate(rests):
broadcasted_tree, broadcasted_rests[i] = tree_broadcast_common(
broadcasted_tree,
rest,
is_leaf=is_leaf,
none_is_leaf=none_is_leaf,
namespace=namespace,
)

return tree_map_with_path(
func,
broadcasted_tree,
*broadcasted_rests,
is_leaf=is_leaf,
none_is_leaf=none_is_leaf,
namespace=namespace,
)


# pylint: disable-next=missing-class-docstring,too-few-public-methods
class MissingSentinel: # pragma: no cover
def __repr__(self) -> str:
Expand Down

0 comments on commit a23aa76

Please sign in to comment.