Skip to content

Commit

Permalink
style: miscellaneous style housekeeping
Browse files Browse the repository at this point in the history
  • Loading branch information
XuehaiPan committed Oct 5, 2023
1 parent c1da9d0 commit cdad2f4
Show file tree
Hide file tree
Showing 12 changed files with 91 additions and 83 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,8 @@ optree.register_pytree_node(
torch.Tensor,
# (tensor) -> (children, metadata)
flatten_func=lambda tensor: (
(tensor.cpu().numpy(),),
dict(dtype=tensor.dtype, device=tensor.device, requires_grad=tensor.requires_grad),
(tensor.cpu().detach().numpy(),),
{'dtype': tensor.dtype, 'device': tensor.device, 'requires_grad': tensor.requires_grad},
),
# (metadata, children) -> tensor
unflatten_func=lambda metadata, children: torch.tensor(children[0], **metadata),
Expand Down
5 changes: 2 additions & 3 deletions docs/source/ops.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,10 @@ Tree Manipulation Functions
tree_map_
tree_map_with_path
tree_map_with_path_
tree_replace_nones
tree_transpose
tree_broadcast_prefix
broadcast_prefix
tree_replace_nones
prefix_errors

.. autofunction:: tree_flatten
Expand All @@ -51,10 +51,10 @@ Tree Manipulation Functions
.. autofunction:: tree_map_
.. autofunction:: tree_map_with_path
.. autofunction:: tree_map_with_path_
.. autofunction:: tree_replace_nones
.. autofunction:: tree_transpose
.. autofunction:: tree_broadcast_prefix
.. autofunction:: broadcast_prefix
.. autofunction:: tree_replace_nones
.. autofunction:: prefix_errors

------
Expand All @@ -64,7 +64,6 @@ Tree Reduce Functions

.. autosummary::

tree_replace_nones
tree_reduce
tree_sum
tree_max
Expand Down
2 changes: 0 additions & 2 deletions include/treespec.h
Original file line number Diff line number Diff line change
Expand Up @@ -249,8 +249,6 @@ class PyTreeSpec {
const std::optional<py::function> &leaf_predicate,
const std::string &registry_namespace);

[[nodiscard]] py::list FlattenUpToImpl(const py::handle &full_tree) const;

template <bool NoneIsLeaf>
static bool AllLeavesImpl(const py::iterable &iterable, const std::string &registry_namespace);

Expand Down
2 changes: 1 addition & 1 deletion optree/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,10 @@
'tree_map_',
'tree_map_with_path',
'tree_map_with_path_',
'tree_replace_nones',
'tree_transpose',
'tree_broadcast_prefix',
'broadcast_prefix',
'tree_replace_nones',
'tree_reduce',
'tree_sum',
'tree_max',
Expand Down
104 changes: 55 additions & 49 deletions optree/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,10 @@
'tree_map_',
'tree_map_with_path',
'tree_map_with_path_',
'tree_replace_nones',
'tree_transpose',
'tree_broadcast_prefix',
'broadcast_prefix',
'tree_replace_nones',
'tree_reduce',
'tree_sum',
'tree_max',
Expand Down Expand Up @@ -525,7 +525,7 @@ def tree_map_with_path(
none_is_leaf: bool = False,
namespace: str = '',
) -> PyTree[U]:
"""Map a multi-input function over pytree args to produce a new pytree.
"""Map a multi-input function over pytree args as well as the tree paths to produce a new pytree.
See also :func:`tree_map`, :func:`tree_map_`, and :func:`tree_map_with_path_`.
Expand Down Expand Up @@ -612,6 +612,35 @@ def tree_map_with_path_(
return tree


def tree_replace_nones(sentinel: Any, tree: PyTree[T] | None, namespace: str = '') -> PyTree[T]:
"""Replace :data:`None` in ``tree`` with ``sentinel``.
See also :func:`tree_flatten` and :func:`tree_map`.
>>> tree_replace_nones(0, {'a': 1, 'b': None, 'c': (2, None)})
{'a': 1, 'b': 0, 'c': (2, 0)}
>>> tree_replace_nones(0, None)
0
Args:
sentinel (object): The value to replace :data:`None` with.
tree (pytree): A pytree to be transformed.
namespace (str, optional): The registry namespace used for custom pytree node types.
(default: :const:`''`, i.e., the global namespace)
Returns:
A new pytree with the same structure as ``tree`` but with :data:`None` replaced.
"""
if tree is None:
return sentinel
return tree_map(
lambda x: x if x is not None else sentinel,
tree,
none_is_leaf=True,
namespace=namespace,
)


def tree_transpose(
outer_treespec: PyTreeSpec,
inner_treespec: PyTreeSpec,
Expand Down Expand Up @@ -708,24 +737,24 @@ def tree_broadcast_prefix(
from ``prefix_tree``. The number of replicas is determined by the corresponding subtree in
``full_tree``.
>>> tree_broadcast_prefix(1, [1, 2, 3])
>>> tree_broadcast_prefix(1, [2, 3, 4])
[1, 1, 1]
>>> tree_broadcast_prefix([1, 2, 3], [1, 2, 3])
>>> tree_broadcast_prefix([1, 2, 3], [4, 5, 6])
[1, 2, 3]
>>> tree_broadcast_prefix([1, 2, 3], [1, 2, 3, 4])
>>> tree_broadcast_prefix([1, 2, 3], [4, 5, 6, 7])
Traceback (most recent call last):
...
ValueError: list arity mismatch; expected: 3, got: 4; list: [1, 2, 3, 4].
>>> tree_broadcast_prefix([1, 2, 3], [1, 2, (3, 4)])
ValueError: list arity mismatch; expected: 3, got: 4; list: [4, 5, 6, 7].
>>> tree_broadcast_prefix([1, 2, 3], [4, 5, (6, 7)])
[1, 2, (3, 3)]
>>> tree_broadcast_prefix([1, 2, 3], [1, 2, {'a': 3, 'b': 4, 'c': (None, 5)}])
>>> tree_broadcast_prefix([1, 2, 3], [4, 5, {'a': 6, 'b': 7, 'c': (None, 8)}])
[1, 2, {'a': 3, 'b': 3, 'c': (None, 3)}]
>>> tree_broadcast_prefix([1, 2, 3], [1, 2, {'a': 3, 'b': 4, 'c': (None, 5)}], none_is_leaf=True)
>>> tree_broadcast_prefix([1, 2, 3], [4, 5, {'a': 6, 'b': 7, 'c': (None, 8)}], none_is_leaf=True)
[1, 2, {'a': 3, 'b': 3, 'c': (3, 3)}]
Args:
prefix_tree (pytree): A pytree with the same structure as a prefix of ``full_tree``.
full_tree (pytree): A pytree with the same structure as a suffix of ``prefix_tree``.
prefix_tree (pytree): A pytree with the prefix structure of ``full_tree``.
full_tree (pytree): A pytree with the suffix structure of ``prefix_tree``.
is_leaf (callable, optional): An optionally specified function that will be called at each
flattening step. It should return a boolean, with :data:`True` stopping the traversal
and the whole subtree being treated as a leaf, and :data:`False` indicating the
Expand Down Expand Up @@ -782,24 +811,24 @@ def broadcast_prefix(
replicated from ``prefix_tree``. The number of replicas is determined by the corresponding
subtree in ``full_tree``.
>>> broadcast_prefix(1, [1, 2, 3])
>>> broadcast_prefix(1, [2, 3, 4])
[1, 1, 1]
>>> broadcast_prefix([1, 2, 3], [1, 2, 3])
>>> broadcast_prefix([1, 2, 3], [4, 5, 6])
[1, 2, 3]
>>> broadcast_prefix([1, 2, 3], [1, 2, 3, 4])
>>> broadcast_prefix([1, 2, 3], [4, 5, 6, 7])
Traceback (most recent call last):
...
ValueError: list arity mismatch; expected: 3, got: 4; list: [1, 2, 3, 4].
>>> broadcast_prefix([1, 2, 3], [1, 2, (3, 4)])
ValueError: list arity mismatch; expected: 3, got: 4; list: [4, 5, 6, 7].
>>> broadcast_prefix([1, 2, 3], [4, 5, (6, 7)])
[1, 2, 3, 3]
>>> broadcast_prefix([1, 2, 3], [1, 2, {'a': 3, 'b': 4, 'c': (None, 5)}])
>>> broadcast_prefix([1, 2, 3], [4, 5, {'a': 6, 'b': 7, 'c': (None, 8)}])
[1, 2, 3, 3, 3]
>>> broadcast_prefix([1, 2, 3], [1, 2, {'a': 3, 'b': 4, 'c': (None, 5)}], none_is_leaf=True)
>>> broadcast_prefix([1, 2, 3], [4, 5, {'a': 6, 'b': 7, 'c': (None, 8)}], none_is_leaf=True)
[1, 2, 3, 3, 3, 3]
Args:
prefix_tree (pytree): A pytree with the same structure as a prefix of ``full_tree``.
full_tree (pytree): A pytree with the same structure as a suffix of ``prefix_tree``.
prefix_tree (pytree): A pytree with the prefix structure of ``full_tree``.
full_tree (pytree): A pytree with the suffix structure of ``prefix_tree``.
is_leaf (callable, optional): An optionally specified function that will be called at each
flattening step. It should return a boolean, with :data:`True` stopping the traversal
and the whole subtree being treated as a leaf, and :data:`False` indicating the
Expand Down Expand Up @@ -839,36 +868,13 @@ def add_leaves(x: T, subtree: PyTree[S]) -> None:
return result


def tree_replace_nones(sentinel: Any, tree: PyTree[T] | None, namespace: str = '') -> PyTree[T]:
"""Replace :data:`None` in ``tree`` with ``sentinel``.
See also :func:`tree_flatten` and :func:`tree_map`.
>>> tree_replace_nones(0, {'a': 1, 'b': None, 'c': (2, None)})
{'a': 1, 'b': 0, 'c': (2, 0)}
>>> tree_replace_nones(0, None)
0
Args:
sentinel (object): The value to replace :data:`None` with.
tree (pytree): A pytree to be transformed.
namespace (str, optional): The registry namespace used for custom pytree node types.
(default: :const:`''`, i.e., the global namespace)
Returns:
A new pytree with the same structure as ``tree`` but with :data:`None` replaced.
"""
if tree is None:
return sentinel
return tree_map(
lambda x: x if x is not None else sentinel,
tree,
none_is_leaf=True,
namespace=namespace,
)
class MissingSentinel: # pylint: disable=missing-class-docstring,too-few-public-methods
def __repr__(self) -> str:
return '<MISSING>'


__MISSING: T = object() # type: ignore[valid-type]
__MISSING: T = MissingSentinel() # type: ignore[valid-type]
del MissingSentinel


@overload
Expand Down Expand Up @@ -1523,7 +1529,7 @@ def flatten_one_level(
node_type = type(tree)
handler = register_pytree_node.get(node_type, namespace=namespace) # type: ignore[attr-defined]
if handler:
flattened = handler.to_iterable(tree)
flattened = tuple(handler.to_iterable(tree))
if len(flattened) == 2:
flattened = (*flattened, None)
elif len(flattened) != 3:
Expand Down
12 changes: 9 additions & 3 deletions optree/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,14 @@ class PyTreeNodeRegistryEntry(NamedTuple):
from_iterable: UnflattenFunc


__GLOBAL_NAMESPACE: str = object() # type: ignore[assignment]
class GlobalNamespace: # pylint: disable=missing-class-docstring,too-few-public-methods
def __repr__(self) -> str:
return '<GLOBAL NAMESPACE>'


__GLOBAL_NAMESPACE: str = GlobalNamespace() # type: ignore[assignment]
__REGISTRY_LOCK: Lock = Lock()
del GlobalNamespace


def register_pytree_node(
Expand Down Expand Up @@ -99,8 +105,8 @@ def register_pytree_node(
>>> register_pytree_node(
... torch.Tensor,
... flatten_func=lambda tensor: (
... (tensor.cpu().numpy(),),
... dict(dtype=tensor.dtype, device=tensor.device, requires_grad=tensor.requires_grad),
... (tensor.cpu().detach().numpy(),),
... {'dtype': tensor.dtype, 'device': tensor.device, 'requires_grad': tensor.requires_grad},
... ),
... unflatten_func=lambda metadata, children: torch.tensor(children[0], **metadata),
... namespace='torch2numpy',
Expand Down
8 changes: 7 additions & 1 deletion optree/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,13 @@ def __deepcopy__(self, memo: dict[int, Any]) -> TypeAlias:
return self


FlattenFunc = Callable[[CustomTreeNode[T]], Tuple[Children[T], MetaData]]
FlattenFunc = Callable[
[CustomTreeNode[T]],
Union[
Tuple[Children[T], MetaData],
Tuple[Children[T], MetaData, Optional[Iterable[Any]]],
],
]
UnflattenFunc = Callable[[MetaData, Children[T]], CustomTreeNode[T]]


Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ test-command = """make -C "{project}" test PYTHON=python"""
safe = true
line-length = 100
skip-string-normalization = true
target-version = ["py37", "py38", "py39", "py310", "py311", "py312"]
target-version = ["py37"]

[tool.isort]
atomic = true
Expand Down
6 changes: 1 addition & 5 deletions src/treespec/flatten.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,7 @@ PyTreeSpec::FlattenWithPath(const py::handle& tree,
}

// NOLINTNEXTLINE[readability-function-cognitive-complexity]
py::list PyTreeSpec::FlattenUpToImpl(const py::handle& full_tree) const {
py::list PyTreeSpec::FlattenUpTo(const py::handle& full_tree) const {
const ssize_t num_leaves = GetNumLeaves();

auto agenda = reserved_vector<py::object>(4);
Expand Down Expand Up @@ -629,10 +629,6 @@ py::list PyTreeSpec::FlattenUpToImpl(const py::handle& full_tree) const {
return leaves;
}

py::list PyTreeSpec::FlattenUpTo(const py::handle& full_tree) const {
return FlattenUpToImpl(full_tree);
}

template <bool NoneIsLeaf>
/*static*/ bool PyTreeSpec::AllLeavesImpl(const py::iterable& iterable,
const std::string& registry_namespace) {
Expand Down
2 changes: 1 addition & 1 deletion src/treespec/treespec.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,7 @@ std::unique_ptr<PyTreeSpec> PyTreeSpec::Child(ssize_t index) const {
EXPECT_EQ(py::ssize_t_cast(num_children), node.arity, "Node arity did not match.");
switch (node.kind) {
case PyTreeKind::Leaf:
INTERNAL_ERROR("MakeNode not implemented for leaves.");
INTERNAL_ERROR("PyTreeSpec::MakeNode() not implemented for leaves.");

case PyTreeKind::None:
return py::none();
Expand Down
23 changes: 10 additions & 13 deletions src/treespec/unflatten.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,23 +29,20 @@ py::object PyTreeSpec::UnflattenImpl(const Span& leaves) const {
py::ssize_t_cast(agenda.size()), node.arity, "Too few elements for PyTreeSpec node.");

switch (node.kind) {
case PyTreeKind::None:
case PyTreeKind::Leaf: {
if (node.kind == PyTreeKind::Leaf || m_none_is_leaf) [[likely]] {
if (it == leaves.end()) [[unlikely]] {
std::ostringstream oss{};
oss << "Too few leaves for PyTreeSpec; expected: " << GetNumLeaves()
<< ", got: " << leaf_count << ".";
throw py::value_error(oss.str());
}
agenda.emplace_back(py::reinterpret_borrow<py::object>(*it));
++it;
++leaf_count;
break;
if (it == leaves.end()) [[unlikely]] {
std::ostringstream oss{};
oss << "Too few leaves for PyTreeSpec; expected: " << GetNumLeaves()
<< ", got: " << leaf_count << ".";
throw py::value_error(oss.str());
}
[[fallthrough]];
agenda.emplace_back(py::reinterpret_borrow<py::object>(*it));
++it;
++leaf_count;
break;
}

case PyTreeKind::None:
case PyTreeKind::Tuple:
case PyTreeKind::List:
case PyTreeKind::Dict:
Expand Down
4 changes: 2 additions & 2 deletions tests/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,12 +340,12 @@ def test_pytree_node_registry_get():
handler = optree.register_pytree_node.get(list)
assert handler is not None
lst = [1, 2, 3]
assert handler.to_iterable(lst)[:2] == (lst, None)
assert tuple(handler.to_iterable(lst))[:2] == (lst, None)

handler = optree.register_pytree_node.get(list, namespace='any')
assert handler is not None
lst = [1, 2, 3]
assert handler.to_iterable(lst)[:2] == (lst, None)
assert tuple(handler.to_iterable(lst))[:2] == (lst, None)

handler = optree.register_pytree_node.get(set)
assert handler is None
Expand Down

0 comments on commit cdad2f4

Please sign in to comment.