Skip to content

Commit

Permalink
feat(ops): namespacing support for custom node type registry
Browse files Browse the repository at this point in the history
  • Loading branch information
XuehaiPan committed Nov 15, 2022
1 parent f2b31a4 commit 4e5116f
Show file tree
Hide file tree
Showing 15 changed files with 1,041 additions and 163 deletions.
3 changes: 3 additions & 0 deletions docs/source/spelling_wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ Unflattened
unflatten
unflattens
unflattened
unflattening
args
kwargs
functools
Expand All @@ -31,3 +32,5 @@ arity
PyTreeTypeVar
subclassing
inplace
namespace
namespaces
27 changes: 25 additions & 2 deletions include/registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ limitations under the License.
#include <pybind11/pybind11.h>

#include <memory>
#include <string>
#include <utility>

#include "include/utils.h"

Expand Down Expand Up @@ -62,11 +64,12 @@ class PyTreeTypeRegistry {
// PyTrees.
static void Register(const py::object &type,
const py::function &to_iterable,
const py::function &from_iterable);
const py::function &from_iterable,
const std::string &regnamespace = "");

// Finds the custom type registration for `type`. Returns nullptr if none exists.
template <bool NoneIsLeaf>
static const Registration *Lookup(const py::handle &type);
static const Registration *Lookup(const py::handle &type, const std::string &regnamespace);

private:
template <bool NoneIsLeaf>
Expand All @@ -90,7 +93,27 @@ class PyTreeTypeRegistry {
bool operator()(const py::object &a, const py::handle &b) const;
};

class NamedTypeHash {
public:
using is_transparent = void;
size_t operator()(const std::pair<std::string, py::object> &p) const;
size_t operator()(const std::pair<std::string, py::handle> &p) const;
};
class NamedTypeEq {
public:
using is_transparent = void;
bool operator()(const std::pair<std::string, py::object> &a,
const std::pair<std::string, py::object> &b) const;
bool operator()(const std::pair<std::string, py::object> &a,
const std::pair<std::string, py::handle> &b) const;
};

absl::flat_hash_map<py::object, std::unique_ptr<Registration>, TypeHash, TypeEq> registrations;
absl::flat_hash_map<std::pair<std::string, py::object>,
std::unique_ptr<Registration>,
NamedTypeHash,
NamedTypeEq>
namespaced_registrations;
};

} // namespace optree
62 changes: 41 additions & 21 deletions include/treespec.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ limitations under the License.

namespace optree {

// The maximum depth of a pytree.
constexpr ssize_t MAX_RECURSION_DEPTH = 10000;

// A PyTreeSpec describes the tree structure of a PyTree. A PyTree is a tree of Python values, where
// the interior nodes are tuples, lists, dictionaries, or user-defined containers, and the leaves
// are other objects.
Expand All @@ -53,45 +56,53 @@ class PyTreeSpec {
static std::pair<std::vector<py::object>, std::unique_ptr<PyTreeSpec>> Flatten(
const py::handle &tree,
const std::optional<py::function> &leaf_predicate = std::nullopt,
const bool &none_is_leaf = false);
const bool &none_is_leaf = false,
const std::string &regnamespace = "");

// Recursive helper used to implement Flatten().
void FlattenInto(const py::handle &handle,
bool FlattenInto(const py::handle &handle,
std::vector<py::object> &leaves, // NOLINT
const std::optional<py::function> &leaf_predicate = std::nullopt,
const bool &none_is_leaf = false);
void FlattenInto(const py::handle &handle,
const std::optional<py::function> &leaf_predicate,
const bool &none_is_leaf,
const std::string &regnamespace);
bool FlattenInto(const py::handle &handle,
absl::InlinedVector<py::object, 2> &leaves, // NOLINT
const std::optional<py::function> &leaf_predicate = std::nullopt,
const bool &none_is_leaf = false);
const std::optional<py::function> &leaf_predicate,
const bool &none_is_leaf,
const std::string &regnamespace);

// Flattens a PyTree into a list of leaves with a list of paths and a PyTreeSpec.
// Returns references to the flattened objects, which might be temporary objects in the case of
// custom PyType handlers.
static std::tuple<std::vector<py::object>, std::vector<py::object>, std::unique_ptr<PyTreeSpec>>
FlattenWithPath(const py::handle &tree,
const std::optional<py::function> &leaf_predicate = std::nullopt,
const bool &none_is_leaf = false);
const bool &none_is_leaf = false,
const std::string &regnamespace = "");

// Recursive helper used to implement FlattenWithPath().
void FlattenIntoWithPath(const py::handle &handle,
bool FlattenIntoWithPath(const py::handle &handle,
std::vector<py::object> &leaves, // NOLINT
std::vector<py::object> &paths, // NOLINT
const std::optional<py::function> &leaf_predicate = std::nullopt,
const bool &none_is_leaf = false);
void FlattenIntoWithPath(const py::handle &handle,
const std::optional<py::function> &leaf_predicate,
const bool &none_is_leaf,
const std::string &regnamespace);
bool FlattenIntoWithPath(const py::handle &handle,
absl::InlinedVector<py::object, 2> &leaves, // NOLINT
absl::InlinedVector<py::object, 2> &paths, // NOLINT
const std::optional<py::function> &leaf_predicate = std::nullopt,
const bool &none_is_leaf = false);
const std::optional<py::function> &leaf_predicate,
const bool &none_is_leaf,
const std::string &regnamespace);

// Flattens a PyTree up to this PyTreeSpec. 'this' must be a tree prefix of the tree-structure
// of 'x'. For example, if we flatten a value [(1, (2, 3)), {"foo": 4}] with a PyTreeSpec [(*,
// *), *], the result is the list of leaves [1, (2, 3), {"foo": 4}].
py::list FlattenUpTo(const py::handle &full_tree) const;

// Tests whether the given list is a flat list of leaves.
static bool AllLeaves(const py::iterable &iterable, const bool &none_is_leaf = false);
static bool AllLeaves(const py::iterable &iterable,
const bool &none_is_leaf = false,
const std::string &regnamespace = "");

// Returns an unflattened PyTree given an iterable of leaves and a PyTreeSpec.
py::object Unflatten(const py::iterable &leaves) const;
Expand Down Expand Up @@ -124,6 +135,8 @@ class PyTreeSpec {

bool get_none_is_leaf() const;

std::string get_regnamespace() const;

bool operator==(const PyTreeSpec &other) const;
bool operator!=(const PyTreeSpec &other) const;

Expand Down Expand Up @@ -191,31 +204,38 @@ class PyTreeSpec {
// Whether to treat `None` as a leaf. If false, `None` is a non-leaf node with arity 0.
bool none_is_leaf;

// The registry namespace used to resolve the custom pytree node types
std::string regnamespace;

// Helper that manufactures an instance of a node given its children.
static py::object MakeNode(const Node &node, const absl::Span<py::object> &children);

// Computes the node kind of a given Python object.
template <bool NoneIsLeaf>
static PyTreeKind GetKind(const py::handle &handle,
PyTreeTypeRegistry::Registration const **custom);
PyTreeTypeRegistry::Registration const **custom,
const std::string &regnamespace);

template <bool NoneIsLeaf, typename Span>
void FlattenIntoImpl(const py::handle &handle,
bool FlattenIntoImpl(const py::handle &handle,
Span &leaves, // NOLINT
const std::optional<py::function> &leaf_predicate);
const ssize_t &depth,
const std::optional<py::function> &leaf_predicate,
const std::string &regnamespace);

template <bool NoneIsLeaf, typename Span, typename Stack>
void FlattenIntoWithPathImpl(const py::handle &handle,
bool FlattenIntoWithPathImpl(const py::handle &handle,
Span &leaves, // NOLINT
Span &paths, // NOLINT
Stack &stack, // NOLINT
const ssize_t &depth,
const std::optional<py::function> &leaf_predicate);
const std::optional<py::function> &leaf_predicate,
const std::string &regnamespace);

py::list FlattenUpToImpl(const py::handle &full_tree) const;

template <bool NoneIsLeaf>
static bool AllLeavesImpl(const py::iterable &iterable);
static bool AllLeavesImpl(const py::iterable &iterable, const std::string &regnamespace);

template <typename Span>
py::object UnflattenImpl(const Span &leaves) const;
Expand Down
12 changes: 10 additions & 2 deletions optree/_C.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,25 @@ from typing import TYPE_CHECKING, Any, Callable, Iterable, List, Optional, Seque
if TYPE_CHECKING:
from optree.typing import Children, CustomTreeNode, MetaData, PyTree, T, U

version: int
MAX_RECURSION_DEPTH: int

def flatten(
tree: PyTree[T],
leaf_predicate: Optional[Callable[[T], bool]] = None,
node_is_leaf: bool = False,
namespace: str = '',
) -> Tuple[List[T], 'PyTreeSpec']: ...
def flatten_with_path(
tree: PyTree[T],
leaf_predicate: Optional[Callable[[T], bool]] = None,
node_is_leaf: bool = False,
namespace: str = '',
) -> Tuple[List[Tuple[Any, ...]], List[T], 'PyTreeSpec']: ...
def all_leaves(iterable: Iterable[T], node_is_leaf: bool = False) -> bool: ...
def all_leaves(
iterable: Iterable[T],
node_is_leaf: bool = False,
namespace: str = '',
) -> bool: ...
def leaf(node_is_leaf: bool = False) -> 'PyTreeSpec': ...
def none(node_is_leaf: bool = False) -> 'PyTreeSpec': ...
def tuple(treespecs: Sequence['PyTreeSpec'], node_is_leaf: bool = False) -> 'PyTreeSpec': ...
Expand All @@ -42,6 +48,7 @@ class PyTreeSpec:
num_nodes: int
num_leaves: int
none_is_leaf: bool
namespace: str
def unflatten(self, leaves: Iterable[T]) -> PyTree[T]: ...
def flatten_up_to(self, full_tree: PyTree[T]) -> List[PyTree[T]]: ...
def compose(self, inner_treespec: 'PyTreeSpec') -> 'PyTreeSpec': ...
Expand All @@ -61,4 +68,5 @@ def register_node(
cls: Type[CustomTreeNode[T]],
to_iterable: Callable[[CustomTreeNode[T]], Tuple[Children[T], MetaData]],
from_iterable: Callable[[MetaData, Children[T]], CustomTreeNode[T]],
namespace: str,
) -> None: ...
2 changes: 2 additions & 0 deletions optree/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""OpTree: Optimized PyTree Utilities."""

from optree.ops import (
MAX_RECURSION_DEPTH,
NONE_IS_LEAF,
NONE_IS_NODE,
all_leaves,
Expand Down Expand Up @@ -61,6 +62,7 @@
'PyTreeTypeVar',
'CustomTreeNode',
# Tree operations
'MAX_RECURSION_DEPTH',
'NONE_IS_NODE',
'NONE_IS_LEAF',
'tree_flatten',
Expand Down
Loading

0 comments on commit 4e5116f

Please sign in to comment.