diff --git a/docs/source/ops.rst b/docs/source/ops.rst index d85eacb2..b3c5dc2d 100644 --- a/docs/source/ops.rst +++ b/docs/source/ops.rst @@ -88,6 +88,7 @@ PyTreeSpec Functions treespec_is_prefix treespec_is_suffix + treespec_paths treespec_children treespec_is_leaf treespec_is_strict_leaf @@ -97,6 +98,7 @@ PyTreeSpec Functions .. autofunction:: treespec_is_prefix .. autofunction:: treespec_is_suffix +.. autofunction:: treespec_paths .. autofunction:: treespec_children .. autofunction:: treespec_is_leaf .. autofunction:: treespec_is_strict_leaf diff --git a/include/treespec.h b/include/treespec.h index eb9fcd4d..6eeb1a6b 100644 --- a/include/treespec.h +++ b/include/treespec.h @@ -52,8 +52,8 @@ class PyTreeSpec { public: PyTreeSpec() = default; - // Flattens a PyTree into a list of leaves and a PyTreeSpec. - // Returns references to the flattened objects, which might be temporary objects in the case of + // Flatten a PyTree into a list of leaves and a PyTreeSpec. + // Return references to the flattened objects, which might be temporary objects in the case of // custom PyType handlers. static std::pair, std::unique_ptr> Flatten( const py::handle &tree, @@ -68,8 +68,8 @@ class PyTreeSpec { const bool &none_is_leaf, const std::string ®istry_namespace); - // Flattens a PyTree into a list of leaves with a list of paths and a PyTreeSpec. - // Returns references to the flattened objects, which might be temporary objects in the case of + // Flatten a PyTree into a list of leaves with a list of paths and a PyTreeSpec. + // Return references to the flattened objects, which might be temporary objects in the case of // custom PyType handlers. static std::tuple, std::vector, std::unique_ptr> FlattenWithPath(const py::handle &tree, @@ -85,7 +85,7 @@ class PyTreeSpec { const bool &none_is_leaf, const std::string ®istry_namespace); - // Flattens a PyTree up to this PyTreeSpec. 'this' must be a tree prefix of the tree-structure + // Flatten a PyTree up to this PyTreeSpec. 'this' must be a tree prefix of the tree-structure // of 'x'. For example, if we flatten a value [(1, (2, 3)), {"foo": 4}] with a PyTreeSpec [(*, // *), *], the result is the list of leaves [1, (2, 3), {"foo": 4}]. [[nodiscard]] py::list FlattenUpTo(const py::handle &full_tree) const; @@ -95,38 +95,43 @@ class PyTreeSpec { const bool &none_is_leaf = false, const std::string ®istry_namespace = ""); - // Returns an unflattened PyTree given an iterable of leaves and a PyTreeSpec. + // Return an unflattened PyTree given an iterable of leaves and a PyTreeSpec. [[nodiscard]] py::object Unflatten(const py::iterable &leaves) const; - // Composes two PyTreeSpecs, replacing the leaves of this tree with copies of `inner`. + // Compose two PyTreeSpecs, replacing the leaves of this tree with copies of `inner`. [[nodiscard]] std::unique_ptr Compose(const PyTreeSpec &inner_treespec) const; - // Maps a function over a PyTree structure, applying f_leaf to each leaf, and + // Map a function over a PyTree structure, applying f_leaf to each leaf, and // f_node(children, node_data) to each container node. [[nodiscard]] py::object Walk(const py::function &f_node, const py::handle &f_leaf, const py::iterable &leaves) const; - // Returns true if this PyTreeSpec is a prefix of `other`. + // Return true if this PyTreeSpec is a prefix of `other`. [[nodiscard]] bool IsPrefix(const PyTreeSpec &other, const bool &strict = false) const; - // Returns true if this PyTreeSpec is a suffix of `other`. + // Return true if this PyTreeSpec is a suffix of `other`. [[nodiscard]] inline bool IsSuffix(const PyTreeSpec &other, const bool &strict = false) const { return other.IsPrefix(*this, strict); } + // Return paths to all leaves in the PyTreeSpec. + [[nodiscard]] std::vector Paths() const; + + // Return the children of the PyTreeSpec. [[nodiscard]] std::vector> Children() const; + // Test whether this PyTreeSpec represents a leaf. [[nodiscard]] bool IsLeaf(const bool &strict = true) const; - // Makes a Tuple PyTreeSpec out of a vector of PyTreeSpecs. + // Make a Tuple PyTreeSpec out of a vector of PyTreeSpecs. static std::unique_ptr Tuple(const std::vector &treespecs, const bool &none_is_leaf); - // Makes a PyTreeSpec representing a leaf node. + // Make a PyTreeSpec representing a leaf node. static std::unique_ptr Leaf(const bool &none_is_leaf); - // Makes a PyTreeSpec representing a `None` node. + // Make a PyTreeSpec representing a `None` node. static std::unique_ptr None(const bool &none_is_leaf); [[nodiscard]] ssize_t GetNumLeaves() const; @@ -206,11 +211,11 @@ class PyTreeSpec { [[nodiscard]] std::string ToString() const; - // Transforms the PyTreeSpec into a picklable object. + // Transform the PyTreeSpec into a picklable object. // Used to implement `PyTreeSpec.__getstate__`. [[nodiscard]] py::object ToPicklable() const; - // Transforms the object returned by `ToPicklable()` back to PyTreeSpec. + // Transform the object returned by `ToPicklable()` back to PyTreeSpec. // Used to implement `PyTreeSpec.__setstate__`. static PyTreeSpec FromPicklable(const py::object &picklable); @@ -260,7 +265,7 @@ class PyTreeSpec { // Helper that manufactures an instance of a node given its children. static py::object MakeNode(const Node &node, const absl::Span &children); - // Computes the node kind of a given Python object. + // Compute the node kind of a given Python object. template static PyTreeKind GetKind(const py::handle &handle, PyTreeTypeRegistry::Registration const **custom, @@ -290,6 +295,12 @@ class PyTreeSpec { template py::object UnflattenImpl(const Span &leaves) const; + template + [[nodiscard]] ssize_t PathsImpl(Span &paths, // NOLINT[runtime/references] + Stack &stack, // NOLINT[runtime/references] + const ssize_t &pos, + const ssize_t &depth) const; + static PyTreeSpec FromPicklableImpl(const py::object &picklable); }; diff --git a/optree/_C.pyi b/optree/_C.pyi index 55f06a9c..abb4c546 100644 --- a/optree/_C.pyi +++ b/optree/_C.pyi @@ -88,6 +88,7 @@ class PyTreeSpec: other: PyTreeSpec, strict: bool = ..., # False ) -> bool: ... + def paths(self) -> list[builtins.tuple[Any, ...]]: ... def children(self) -> list[PyTreeSpec]: ... def is_leaf( self, diff --git a/optree/__init__.py b/optree/__init__.py index f5b2a978..cf614029 100644 --- a/optree/__init__.py +++ b/optree/__init__.py @@ -48,6 +48,7 @@ treespec_is_suffix, treespec_leaf, treespec_none, + treespec_paths, treespec_tuple, ) from optree.registry import ( @@ -105,6 +106,7 @@ 'prefix_errors', 'treespec_is_prefix', 'treespec_is_suffix', + 'treespec_paths', 'treespec_children', 'treespec_is_leaf', 'treespec_is_strict_leaf', diff --git a/optree/ops.py b/optree/ops.py index bc0f244b..28c206fc 100644 --- a/optree/ops.py +++ b/optree/ops.py @@ -77,6 +77,7 @@ 'tree_any', 'treespec_is_prefix', 'treespec_is_suffix', + 'treespec_paths', 'treespec_children', 'treespec_is_leaf', 'treespec_is_strict_leaf', @@ -159,7 +160,7 @@ def tree_flatten_with_path( ) -> tuple[list[tuple[Any, ...]], list[T], PyTreeSpec]: """Flatten a pytree and additionally record the paths. - See also :func:`tree_flatten` and :func:`tree_paths`. + See also :func:`tree_flatten`, :func:`tree_paths`, and :func:`treespec_paths`. The flattening order (i.e., the order of elements in the output list) is deterministic, corresponding to a left-to-right depth-first tree traversal. @@ -328,7 +329,7 @@ def tree_paths( ) -> list[tuple[Any, ...]]: """Get the path entries to the leaves of a pytree. - See also :func:`tree_flatten` and :func:`tree_flatten_with_path`. + See also :func:`tree_flatten`, :func:`tree_flatten_with_path`, and :func:`treespec_paths`. >>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5} >>> tree_paths(tree) @@ -1286,6 +1287,14 @@ def treespec_is_suffix( return treespec.is_suffix(other_treespec, strict=strict) +def treespec_paths(treespec: PyTreeSpec) -> list[tuple[Any, ...]]: + """Return a list of paths to the leaves of a treespec. + + See also :func:`tree_flatten_with_path`, :func:`tree_paths`, and :meth:`PyTreeSpec.paths`. + """ + return treespec.paths() + + def treespec_children(treespec: PyTreeSpec) -> list[PyTreeSpec]: """Return a list of treespecs for the children of a treespec.""" return treespec.children() diff --git a/src/optree.cpp b/src/optree.cpp index 8769e075..f269ae27 100644 --- a/src/optree.cpp +++ b/src/optree.cpp @@ -158,6 +158,7 @@ void BuildModule(py::module& mod) { // NOLINT[runtime/references] "Test whether this treespec is a suffix of the given treespec.", py::arg("other"), py::arg("strict") = true) + .def("paths", &PyTreeSpec::Paths, "Return a list of paths to the leaves of the treespec.") .def("children", &PyTreeSpec::Children, "Return a list of treespecs for the children.") .def("is_leaf", &PyTreeSpec::IsLeaf, diff --git a/src/treespec/flatten.cpp b/src/treespec/flatten.cpp index e143e4b7..9ce49174 100644 --- a/src/treespec/flatten.cpp +++ b/src/treespec/flatten.cpp @@ -42,7 +42,7 @@ bool PyTreeSpec::FlattenIntoImpl(const py::handle& handle, node.kind = GetKind(handle, &node.custom, registry_namespace); // NOLINTNEXTLINE[misc-no-recursion] auto recurse = [this, &found_custom, &leaf_predicate, ®istry_namespace, &leaves, &depth]( - py::handle child) { + const py::handle& child) { found_custom |= FlattenIntoImpl( child, leaves, depth + 1, leaf_predicate, registry_namespace); }; @@ -220,7 +220,7 @@ bool PyTreeSpec::FlattenIntoWithPathImpl(const py::handle& handle, &leaves, &paths, &stack, - &depth](py::handle child, py::handle entry) { + &depth](const py::handle& child, const py::handle& entry) { stack.emplace_back(entry); found_custom |= FlattenIntoWithPathImpl( child, leaves, paths, stack, depth + 1, leaf_predicate, registry_namespace); @@ -422,7 +422,7 @@ py::list PyTreeSpec::FlattenUpToImpl(const py::handle& full_tree) const { case PyTreeKind::Leaf: EXPECT_GE(leaf, 0, "Leaf count mismatch."); - leaves[leaf] = py::reinterpret_borrow(object); + SET_ITEM(leaves, leaf, object); --leaf; break; diff --git a/src/treespec/treespec.cpp b/src/treespec/treespec.cpp index 82bb7e04..5a164de9 100644 --- a/src/treespec/treespec.cpp +++ b/src/treespec/treespec.cpp @@ -193,12 +193,102 @@ std::unique_ptr PyTreeSpec::Compose(const PyTreeSpec& inner_treespec return treespec; } -std::vector> PyTreeSpec::Children() const { - auto children = std::vector>{}; - if (m_traversal.empty()) [[likely]] { - return children; +template +ssize_t PyTreeSpec::PathsImpl(Span& paths, + Stack& stack, + const ssize_t& pos, + const ssize_t& depth) const { + const Node& root = m_traversal.at(pos); + EXPECT_GE(pos + 1, root.num_nodes, "PyTreeSpec::Paths() walked off start of array."); + + ssize_t cur = pos - 1; + auto recurse = [this, &paths, &stack, &depth](const ssize_t& cur, const py::handle& entry) { + stack.emplace_back(entry); + const ssize_t num_nodes = PathsImpl(paths, stack, cur, depth + 1); + stack.pop_back(); + return num_nodes; + }; + + switch (root.kind) { + case PyTreeKind::None: + break; + case PyTreeKind::Leaf: { + py::tuple path{depth}; + for (ssize_t d = 0; d < depth; ++d) { + SET_ITEM(path, d, stack[d]); + } + paths.emplace_back(std::move(path)); + break; + } + + case PyTreeKind::Tuple: + case PyTreeKind::List: + case PyTreeKind::NamedTuple: + case PyTreeKind::Deque: + case PyTreeKind::StructSequence: { + for (ssize_t i = root.arity - 1; i >= 0; --i) { + cur -= recurse(cur, py::int_(i)); + } + break; + } + + case PyTreeKind::Dict: + case PyTreeKind::OrderedDict: + case PyTreeKind::DefaultDict: { + py::list keys; + if (root.kind != PyTreeKind::DefaultDict) [[likely]] { + keys = py::list(root.node_data); + } else [[unlikely]] { + keys = GET_ITEM_BORROW(root.node_data, 1); + } + for (ssize_t i = root.arity - 1; i >= 0; --i) { + cur -= recurse(cur, GET_ITEM_HANDLE(keys, i)); + } + break; + } + + case PyTreeKind::Custom: { + if (root.node_entries) [[likely]] { + for (ssize_t i = root.arity - 1; i >= 0; --i) { + cur -= recurse(cur, GET_ITEM_HANDLE(root.node_entries, i)); + } + } else [[unlikely]] { + for (ssize_t i = root.arity - 1; i >= 0; --i) { + cur -= recurse(cur, py::int_(i)); + } + } + break; + } + + default: + INTERNAL_ERROR(); } + return root.num_nodes; +} + +std::vector PyTreeSpec::Paths() const { + auto paths = std::vector{}; + const ssize_t num_leaves = GetNumLeaves(); + if (num_leaves == 0) [[unlikely]] { + return paths; + } + const ssize_t num_nodes = GetNumNodes(); + if (num_nodes == 1 && num_leaves == 1) [[likely]] { + paths.emplace_back(); + return paths; + } + auto stack = std::vector{}; + const ssize_t num_nodes_walked = PathsImpl(paths, stack, num_nodes - 1, 0); + std::reverse(paths.begin(), paths.end()); + EXPECT_EQ(num_nodes_walked, num_nodes, "`pos != 0` at end of PyTreeSpec::Paths()."); + EXPECT_EQ(py::ssize_t_cast(paths.size()), num_leaves, "PyTreeSpec::Paths() mismatched leaves."); + return paths; +} + +std::vector> PyTreeSpec::Children() const { + EXPECT_FALSE(m_traversal.empty(), "The tree node traversal is empty."); const Node& root = m_traversal.back(); + auto children = std::vector>{}; children.resize(root.arity); ssize_t pos = py::ssize_t_cast(m_traversal.size()) - 1; for (ssize_t i = root.arity - 1; i >= 0; --i) { diff --git a/tests/test_ops.py b/tests/test_ops.py index d9c47378..85ff22d9 100644 --- a/tests/test_ops.py +++ b/tests/test_ops.py @@ -295,10 +295,13 @@ def test_paths(data): tree, expected_paths, none_is_leaf = data expected_leaves, expected_treespec = optree.tree_flatten(tree, none_is_leaf=none_is_leaf) paths, leaves, treespec = optree.tree_flatten_with_path(tree, none_is_leaf=none_is_leaf) + treespec_paths = optree.treespec_paths(treespec) assert len(paths) == len(leaves) assert leaves == expected_leaves assert treespec == expected_treespec assert paths == expected_paths + assert len(treespec_paths) == len(leaves) + assert treespec_paths == expected_paths paths = optree.tree_paths(tree, none_is_leaf=none_is_leaf) assert paths == expected_paths diff --git a/tests/test_registry.py b/tests/test_registry.py index ac0056f9..787ca887 100644 --- a/tests/test_registry.py +++ b/tests/test_registry.py @@ -388,6 +388,7 @@ class MyAnotherDict(MyDict): paths, leaves, treespec = optree.tree_flatten_with_path(tree, namespace='mydict') assert paths == [('c', 'f'), ('c', 'd'), ('b',), ('a', 0), ('a', 1)] assert leaves == [6, 5, 4, 2, 3] + assert paths == treespec.paths() assert ( str(treespec) == "PyTreeSpec(CustomTreeNode(MyDict[['c', 'b', 'a']], [CustomTreeNode(MyAnotherDict[['f', 'd']], [*, *]), *, (*, *)]), namespace='mydict')" diff --git a/tests/test_treespec.py b/tests/test_treespec.py index cd39085e..846425ca 100644 --- a/tests/test_treespec.py +++ b/tests/test_treespec.py @@ -125,6 +125,7 @@ def test_with_namespace(): ) assert paths == [()] assert leaves == [tree] + assert paths == treespec.paths() assert str(treespec) == ('PyTreeSpec(*)') for namespace in ('', 'undefined'): leaves, treespec = optree.tree_flatten(tree, none_is_leaf=True, namespace=namespace) @@ -135,6 +136,7 @@ def test_with_namespace(): ) assert paths == [()] assert leaves == [tree] + assert paths == treespec.paths() assert str(treespec) == ('PyTreeSpec(*, NoneIsLeaf)') expected_string = "PyTreeSpec(CustomTreeNode(MyAnotherDict[['foo', 'baz']], [CustomTreeNode(MyAnotherDict[['c', 'b', 'a']], [None, *, *]), *]), namespace='namespace')" @@ -146,6 +148,7 @@ def test_with_namespace(): ) assert paths == [('foo', 'b'), ('foo', 'a'), ('baz',)] assert leaves == [2, 1, 101] + assert paths == treespec.paths() assert str(treespec) == expected_string expected_string = "PyTreeSpec(CustomTreeNode(MyAnotherDict[['foo', 'baz']], [CustomTreeNode(MyAnotherDict[['c', 'b', 'a']], [*, *, *]), *]), NoneIsLeaf, namespace='namespace')" @@ -157,6 +160,7 @@ def test_with_namespace(): ) assert paths == [('foo', 'c'), ('foo', 'b'), ('foo', 'a'), ('baz',)] assert leaves == [None, 2, 1, 101] + assert paths == treespec.paths() assert str(treespec) == expected_string @@ -328,6 +332,7 @@ def test_treespec_num_leaves(tree, none_is_leaf): leaves, treespec = optree.tree_flatten(tree, none_is_leaf=none_is_leaf) assert treespec.num_leaves == len(leaves) assert treespec.num_leaves == len(treespec) + assert treespec.num_leaves == len(treespec.paths()) @parametrize(