Skip to content

Commit

Permalink
feat: add method PyTreeSpec.paths
Browse files Browse the repository at this point in the history
  • Loading branch information
XuehaiPan committed Mar 12, 2023
1 parent fc2b0a1 commit 0676540
Show file tree
Hide file tree
Showing 11 changed files with 150 additions and 25 deletions.
2 changes: 2 additions & 0 deletions docs/source/ops.rst
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ PyTreeSpec Functions

treespec_is_prefix
treespec_is_suffix
treespec_paths
treespec_children
treespec_is_leaf
treespec_is_strict_leaf
Expand All @@ -97,6 +98,7 @@ PyTreeSpec Functions

.. autofunction:: treespec_is_prefix
.. autofunction:: treespec_is_suffix
.. autofunction:: treespec_paths
.. autofunction:: treespec_children
.. autofunction:: treespec_is_leaf
.. autofunction:: treespec_is_strict_leaf
Expand Down
43 changes: 27 additions & 16 deletions include/treespec.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ class PyTreeSpec {
public:
PyTreeSpec() = default;

// Flattens a PyTree into a list of leaves and a PyTreeSpec.
// Returns references to the flattened objects, which might be temporary objects in the case of
// Flatten a PyTree into a list of leaves and a PyTreeSpec.
// Return references to the flattened objects, which might be temporary objects in the case of
// custom PyType handlers.
static std::pair<std::vector<py::object>, std::unique_ptr<PyTreeSpec>> Flatten(
const py::handle &tree,
Expand All @@ -68,8 +68,8 @@ class PyTreeSpec {
const bool &none_is_leaf,
const std::string &registry_namespace);

// Flattens a PyTree into a list of leaves with a list of paths and a PyTreeSpec.
// Returns references to the flattened objects, which might be temporary objects in the case of
// Flatten a PyTree into a list of leaves with a list of paths and a PyTreeSpec.
// Return references to the flattened objects, which might be temporary objects in the case of
// custom PyType handlers.
static std::tuple<std::vector<py::object>, std::vector<py::object>, std::unique_ptr<PyTreeSpec>>
FlattenWithPath(const py::handle &tree,
Expand All @@ -85,7 +85,7 @@ class PyTreeSpec {
const bool &none_is_leaf,
const std::string &registry_namespace);

// Flattens a PyTree up to this PyTreeSpec. 'this' must be a tree prefix of the tree-structure
// Flatten a PyTree up to this PyTreeSpec. 'this' must be a tree prefix of the tree-structure
// of 'x'. For example, if we flatten a value [(1, (2, 3)), {"foo": 4}] with a PyTreeSpec [(*,
// *), *], the result is the list of leaves [1, (2, 3), {"foo": 4}].
[[nodiscard]] py::list FlattenUpTo(const py::handle &full_tree) const;
Expand All @@ -95,38 +95,43 @@ class PyTreeSpec {
const bool &none_is_leaf = false,
const std::string &registry_namespace = "");

// Returns an unflattened PyTree given an iterable of leaves and a PyTreeSpec.
// Return an unflattened PyTree given an iterable of leaves and a PyTreeSpec.
[[nodiscard]] py::object Unflatten(const py::iterable &leaves) const;

// Composes two PyTreeSpecs, replacing the leaves of this tree with copies of `inner`.
// Compose two PyTreeSpecs, replacing the leaves of this tree with copies of `inner`.
[[nodiscard]] std::unique_ptr<PyTreeSpec> Compose(const PyTreeSpec &inner_treespec) const;

// Maps a function over a PyTree structure, applying f_leaf to each leaf, and
// Map a function over a PyTree structure, applying f_leaf to each leaf, and
// f_node(children, node_data) to each container node.
[[nodiscard]] py::object Walk(const py::function &f_node,
const py::handle &f_leaf,
const py::iterable &leaves) const;

// Returns true if this PyTreeSpec is a prefix of `other`.
// Return true if this PyTreeSpec is a prefix of `other`.
[[nodiscard]] bool IsPrefix(const PyTreeSpec &other, const bool &strict = false) const;

// Returns true if this PyTreeSpec is a suffix of `other`.
// Return true if this PyTreeSpec is a suffix of `other`.
[[nodiscard]] inline bool IsSuffix(const PyTreeSpec &other, const bool &strict = false) const {
return other.IsPrefix(*this, strict);
}

// Return paths to all leaves in the PyTreeSpec.
[[nodiscard]] std::vector<py::tuple> Paths() const;

// Return the children of the PyTreeSpec.
[[nodiscard]] std::vector<std::unique_ptr<PyTreeSpec>> Children() const;

// Test whether this PyTreeSpec represents a leaf.
[[nodiscard]] bool IsLeaf(const bool &strict = true) const;

// Makes a Tuple PyTreeSpec out of a vector of PyTreeSpecs.
// Make a Tuple PyTreeSpec out of a vector of PyTreeSpecs.
static std::unique_ptr<PyTreeSpec> Tuple(const std::vector<PyTreeSpec> &treespecs,
const bool &none_is_leaf);

// Makes a PyTreeSpec representing a leaf node.
// Make a PyTreeSpec representing a leaf node.
static std::unique_ptr<PyTreeSpec> Leaf(const bool &none_is_leaf);

// Makes a PyTreeSpec representing a `None` node.
// Make a PyTreeSpec representing a `None` node.
static std::unique_ptr<PyTreeSpec> None(const bool &none_is_leaf);

[[nodiscard]] ssize_t GetNumLeaves() const;
Expand Down Expand Up @@ -206,11 +211,11 @@ class PyTreeSpec {

[[nodiscard]] std::string ToString() const;

// Transforms the PyTreeSpec into a picklable object.
// Transform the PyTreeSpec into a picklable object.
// Used to implement `PyTreeSpec.__getstate__`.
[[nodiscard]] py::object ToPicklable() const;

// Transforms the object returned by `ToPicklable()` back to PyTreeSpec.
// Transform the object returned by `ToPicklable()` back to PyTreeSpec.
// Used to implement `PyTreeSpec.__setstate__`.
static PyTreeSpec FromPicklable(const py::object &picklable);

Expand Down Expand Up @@ -260,7 +265,7 @@ class PyTreeSpec {
// Helper that manufactures an instance of a node given its children.
static py::object MakeNode(const Node &node, const absl::Span<py::object> &children);

// Computes the node kind of a given Python object.
// Compute the node kind of a given Python object.
template <bool NoneIsLeaf>
static PyTreeKind GetKind(const py::handle &handle,
PyTreeTypeRegistry::Registration const **custom,
Expand Down Expand Up @@ -290,6 +295,12 @@ class PyTreeSpec {
template <typename Span>
py::object UnflattenImpl(const Span &leaves) const;

template <typename Span, typename Stack>
[[nodiscard]] ssize_t PathsImpl(Span &paths, // NOLINT[runtime/references]
Stack &stack, // NOLINT[runtime/references]
const ssize_t &pos,
const ssize_t &depth) const;

static PyTreeSpec FromPicklableImpl(const py::object &picklable);
};

Expand Down
1 change: 1 addition & 0 deletions optree/_C.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ class PyTreeSpec:
other: PyTreeSpec,
strict: bool = ..., # False
) -> bool: ...
def paths(self) -> list[builtins.tuple[Any, ...]]: ...
def children(self) -> list[PyTreeSpec]: ...
def is_leaf(
self,
Expand Down
2 changes: 2 additions & 0 deletions optree/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
treespec_is_suffix,
treespec_leaf,
treespec_none,
treespec_paths,
treespec_tuple,
)
from optree.registry import (
Expand Down Expand Up @@ -105,6 +106,7 @@
'prefix_errors',
'treespec_is_prefix',
'treespec_is_suffix',
'treespec_paths',
'treespec_children',
'treespec_is_leaf',
'treespec_is_strict_leaf',
Expand Down
13 changes: 11 additions & 2 deletions optree/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
'tree_any',
'treespec_is_prefix',
'treespec_is_suffix',
'treespec_paths',
'treespec_children',
'treespec_is_leaf',
'treespec_is_strict_leaf',
Expand Down Expand Up @@ -159,7 +160,7 @@ def tree_flatten_with_path(
) -> tuple[list[tuple[Any, ...]], list[T], PyTreeSpec]:
"""Flatten a pytree and additionally record the paths.
See also :func:`tree_flatten` and :func:`tree_paths`.
See also :func:`tree_flatten`, :func:`tree_paths`, and :func:`treespec_paths`.
The flattening order (i.e., the order of elements in the output list) is deterministic,
corresponding to a left-to-right depth-first tree traversal.
Expand Down Expand Up @@ -328,7 +329,7 @@ def tree_paths(
) -> list[tuple[Any, ...]]:
"""Get the path entries to the leaves of a pytree.
See also :func:`tree_flatten` and :func:`tree_flatten_with_path`.
See also :func:`tree_flatten`, :func:`tree_flatten_with_path`, and :func:`treespec_paths`.
>>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5}
>>> tree_paths(tree)
Expand Down Expand Up @@ -1286,6 +1287,14 @@ def treespec_is_suffix(
return treespec.is_suffix(other_treespec, strict=strict)


def treespec_paths(treespec: PyTreeSpec) -> list[tuple[Any, ...]]:
"""Return a list of paths to the leaves of a treespec.
See also :func:`tree_flatten_with_path`, :func:`tree_paths`, and :meth:`PyTreeSpec.paths`.
"""
return treespec.paths()


def treespec_children(treespec: PyTreeSpec) -> list[PyTreeSpec]:
"""Return a list of treespecs for the children of a treespec."""
return treespec.children()
Expand Down
1 change: 1 addition & 0 deletions src/optree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ void BuildModule(py::module& mod) { // NOLINT[runtime/references]
"Test whether this treespec is a suffix of the given treespec.",
py::arg("other"),
py::arg("strict") = true)
.def("paths", &PyTreeSpec::Paths, "Return a list of paths to the leaves of the treespec.")
.def("children", &PyTreeSpec::Children, "Return a list of treespecs for the children.")
.def("is_leaf",
&PyTreeSpec::IsLeaf,
Expand Down
6 changes: 3 additions & 3 deletions src/treespec/flatten.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ bool PyTreeSpec::FlattenIntoImpl(const py::handle& handle,
node.kind = GetKind<NoneIsLeaf>(handle, &node.custom, registry_namespace);
// NOLINTNEXTLINE[misc-no-recursion]
auto recurse = [this, &found_custom, &leaf_predicate, &registry_namespace, &leaves, &depth](
py::handle child) {
const py::handle& child) {
found_custom |= FlattenIntoImpl<NoneIsLeaf>(
child, leaves, depth + 1, leaf_predicate, registry_namespace);
};
Expand Down Expand Up @@ -220,7 +220,7 @@ bool PyTreeSpec::FlattenIntoWithPathImpl(const py::handle& handle,
&leaves,
&paths,
&stack,
&depth](py::handle child, py::handle entry) {
&depth](const py::handle& child, const py::handle& entry) {
stack.emplace_back(entry);
found_custom |= FlattenIntoWithPathImpl<NoneIsLeaf>(
child, leaves, paths, stack, depth + 1, leaf_predicate, registry_namespace);
Expand Down Expand Up @@ -422,7 +422,7 @@ py::list PyTreeSpec::FlattenUpToImpl(const py::handle& full_tree) const {

case PyTreeKind::Leaf:
EXPECT_GE(leaf, 0, "Leaf count mismatch.");
leaves[leaf] = py::reinterpret_borrow<py::object>(object);
SET_ITEM<py::list>(leaves, leaf, object);
--leaf;
break;

Expand Down
98 changes: 94 additions & 4 deletions src/treespec/treespec.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -193,12 +193,102 @@ std::unique_ptr<PyTreeSpec> PyTreeSpec::Compose(const PyTreeSpec& inner_treespec
return treespec;
}

std::vector<std::unique_ptr<PyTreeSpec>> PyTreeSpec::Children() const {
auto children = std::vector<std::unique_ptr<PyTreeSpec>>{};
if (m_traversal.empty()) [[likely]] {
return children;
template <typename Span, typename Stack>
ssize_t PyTreeSpec::PathsImpl(Span& paths,
Stack& stack,
const ssize_t& pos,
const ssize_t& depth) const {
const Node& root = m_traversal.at(pos);
EXPECT_GE(pos + 1, root.num_nodes, "PyTreeSpec::Paths() walked off start of array.");

ssize_t cur = pos - 1;
auto recurse = [this, &paths, &stack, &depth](const ssize_t& cur, const py::handle& entry) {
stack.emplace_back(entry);
const ssize_t num_nodes = PathsImpl(paths, stack, cur, depth + 1);
stack.pop_back();
return num_nodes;
};

switch (root.kind) {
case PyTreeKind::None:
break;
case PyTreeKind::Leaf: {
py::tuple path{depth};
for (ssize_t d = 0; d < depth; ++d) {
SET_ITEM<py::tuple>(path, d, stack[d]);
}
paths.emplace_back(std::move(path));
break;
}

case PyTreeKind::Tuple:
case PyTreeKind::List:
case PyTreeKind::NamedTuple:
case PyTreeKind::Deque:
case PyTreeKind::StructSequence: {
for (ssize_t i = root.arity - 1; i >= 0; --i) {
cur -= recurse(cur, py::int_(i));
}
break;
}

case PyTreeKind::Dict:
case PyTreeKind::OrderedDict:
case PyTreeKind::DefaultDict: {
py::list keys;
if (root.kind != PyTreeKind::DefaultDict) [[likely]] {
keys = py::list(root.node_data);
} else [[unlikely]] {
keys = GET_ITEM_BORROW<py::tuple>(root.node_data, 1);
}
for (ssize_t i = root.arity - 1; i >= 0; --i) {
cur -= recurse(cur, GET_ITEM_HANDLE<py::list>(keys, i));
}
break;
}

case PyTreeKind::Custom: {
if (root.node_entries) [[likely]] {
for (ssize_t i = root.arity - 1; i >= 0; --i) {
cur -= recurse(cur, GET_ITEM_HANDLE<py::tuple>(root.node_entries, i));
}
} else [[unlikely]] {
for (ssize_t i = root.arity - 1; i >= 0; --i) {
cur -= recurse(cur, py::int_(i));
}
}
break;
}

default:
INTERNAL_ERROR();
}
return root.num_nodes;
}

std::vector<py::tuple> PyTreeSpec::Paths() const {
auto paths = std::vector<py::tuple>{};
const ssize_t num_leaves = GetNumLeaves();
if (num_leaves == 0) [[unlikely]] {
return paths;
}
const ssize_t num_nodes = GetNumNodes();
if (num_nodes == 1 && num_leaves == 1) [[likely]] {
paths.emplace_back();
return paths;
}
auto stack = std::vector<py::handle>{};
const ssize_t num_nodes_walked = PathsImpl(paths, stack, num_nodes - 1, 0);
std::reverse(paths.begin(), paths.end());
EXPECT_EQ(num_nodes_walked, num_nodes, "`pos != 0` at end of PyTreeSpec::Paths().");
EXPECT_EQ(py::ssize_t_cast(paths.size()), num_leaves, "PyTreeSpec::Paths() mismatched leaves.");
return paths;
}

std::vector<std::unique_ptr<PyTreeSpec>> PyTreeSpec::Children() const {
EXPECT_FALSE(m_traversal.empty(), "The tree node traversal is empty.");
const Node& root = m_traversal.back();
auto children = std::vector<std::unique_ptr<PyTreeSpec>>{};
children.resize(root.arity);
ssize_t pos = py::ssize_t_cast(m_traversal.size()) - 1;
for (ssize_t i = root.arity - 1; i >= 0; --i) {
Expand Down
3 changes: 3 additions & 0 deletions tests/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,10 +295,13 @@ def test_paths(data):
tree, expected_paths, none_is_leaf = data
expected_leaves, expected_treespec = optree.tree_flatten(tree, none_is_leaf=none_is_leaf)
paths, leaves, treespec = optree.tree_flatten_with_path(tree, none_is_leaf=none_is_leaf)
treespec_paths = optree.treespec_paths(treespec)
assert len(paths) == len(leaves)
assert leaves == expected_leaves
assert treespec == expected_treespec
assert paths == expected_paths
assert len(treespec_paths) == len(leaves)
assert treespec_paths == expected_paths
paths = optree.tree_paths(tree, none_is_leaf=none_is_leaf)
assert paths == expected_paths

Expand Down
1 change: 1 addition & 0 deletions tests/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,7 @@ class MyAnotherDict(MyDict):
paths, leaves, treespec = optree.tree_flatten_with_path(tree, namespace='mydict')
assert paths == [('c', 'f'), ('c', 'd'), ('b',), ('a', 0), ('a', 1)]
assert leaves == [6, 5, 4, 2, 3]
assert paths == treespec.paths()
assert (
str(treespec)
== "PyTreeSpec(CustomTreeNode(MyDict[['c', 'b', 'a']], [CustomTreeNode(MyAnotherDict[['f', 'd']], [*, *]), *, (*, *)]), namespace='mydict')"
Expand Down
Loading

0 comments on commit 0676540

Please sign in to comment.