diff --git a/README.md b/README.md index 155c9363..fe6350e0 100644 --- a/README.md +++ b/README.md @@ -156,8 +156,8 @@ optree.register_pytree_node( torch.Tensor, # (tensor) -> (children, metadata) flatten_func=lambda tensor: ( - (tensor.cpu().numpy(),), - dict(dtype=tensor.dtype, device=tensor.device, requires_grad=tensor.requires_grad), + (tensor.cpu().detach().numpy(),), + {'dtype': tensor.dtype, 'device': tensor.device, 'requires_grad': tensor.requires_grad}, ), # (metadata, children) -> tensor unflatten_func=lambda metadata, children: torch.tensor(children[0], **metadata), diff --git a/docs/source/ops.rst b/docs/source/ops.rst index 5cc396d2..00a7665b 100644 --- a/docs/source/ops.rst +++ b/docs/source/ops.rst @@ -34,10 +34,10 @@ Tree Manipulation Functions tree_map_ tree_map_with_path tree_map_with_path_ + tree_replace_nones tree_transpose tree_broadcast_prefix broadcast_prefix - tree_replace_nones prefix_errors .. autofunction:: tree_flatten @@ -51,10 +51,10 @@ Tree Manipulation Functions .. autofunction:: tree_map_ .. autofunction:: tree_map_with_path .. autofunction:: tree_map_with_path_ +.. autofunction:: tree_replace_nones .. autofunction:: tree_transpose .. autofunction:: tree_broadcast_prefix .. autofunction:: broadcast_prefix -.. autofunction:: tree_replace_nones .. autofunction:: prefix_errors ------ @@ -64,7 +64,6 @@ Tree Reduce Functions .. autosummary:: - tree_replace_nones tree_reduce tree_sum tree_max diff --git a/include/treespec.h b/include/treespec.h index 3a399f83..db11df74 100644 --- a/include/treespec.h +++ b/include/treespec.h @@ -249,8 +249,6 @@ class PyTreeSpec { const std::optional &leaf_predicate, const std::string ®istry_namespace); - [[nodiscard]] py::list FlattenUpToImpl(const py::handle &full_tree) const; - template static bool AllLeavesImpl(const py::iterable &iterable, const std::string ®istry_namespace); diff --git a/optree/__init__.py b/optree/__init__.py index 8144196e..443000b0 100644 --- a/optree/__init__.py +++ b/optree/__init__.py @@ -96,10 +96,10 @@ 'tree_map_', 'tree_map_with_path', 'tree_map_with_path_', + 'tree_replace_nones', 'tree_transpose', 'tree_broadcast_prefix', 'broadcast_prefix', - 'tree_replace_nones', 'tree_reduce', 'tree_sum', 'tree_max', diff --git a/optree/ops.py b/optree/ops.py index e88c5e75..c5755852 100644 --- a/optree/ops.py +++ b/optree/ops.py @@ -65,10 +65,10 @@ 'tree_map_', 'tree_map_with_path', 'tree_map_with_path_', + 'tree_replace_nones', 'tree_transpose', 'tree_broadcast_prefix', 'broadcast_prefix', - 'tree_replace_nones', 'tree_reduce', 'tree_sum', 'tree_max', @@ -525,7 +525,7 @@ def tree_map_with_path( none_is_leaf: bool = False, namespace: str = '', ) -> PyTree[U]: - """Map a multi-input function over pytree args to produce a new pytree. + """Map a multi-input function over pytree args as well as the tree paths to produce a new pytree. See also :func:`tree_map`, :func:`tree_map_`, and :func:`tree_map_with_path_`. @@ -612,6 +612,35 @@ def tree_map_with_path_( return tree +def tree_replace_nones(sentinel: Any, tree: PyTree[T] | None, namespace: str = '') -> PyTree[T]: + """Replace :data:`None` in ``tree`` with ``sentinel``. + + See also :func:`tree_flatten` and :func:`tree_map`. + + >>> tree_replace_nones(0, {'a': 1, 'b': None, 'c': (2, None)}) + {'a': 1, 'b': 0, 'c': (2, 0)} + >>> tree_replace_nones(0, None) + 0 + + Args: + sentinel (object): The value to replace :data:`None` with. + tree (pytree): A pytree to be transformed. + namespace (str, optional): The registry namespace used for custom pytree node types. + (default: :const:`''`, i.e., the global namespace) + + Returns: + A new pytree with the same structure as ``tree`` but with :data:`None` replaced. + """ + if tree is None: + return sentinel + return tree_map( + lambda x: x if x is not None else sentinel, + tree, + none_is_leaf=True, + namespace=namespace, + ) + + def tree_transpose( outer_treespec: PyTreeSpec, inner_treespec: PyTreeSpec, @@ -708,24 +737,24 @@ def tree_broadcast_prefix( from ``prefix_tree``. The number of replicas is determined by the corresponding subtree in ``full_tree``. - >>> tree_broadcast_prefix(1, [1, 2, 3]) + >>> tree_broadcast_prefix(1, [2, 3, 4]) [1, 1, 1] - >>> tree_broadcast_prefix([1, 2, 3], [1, 2, 3]) + >>> tree_broadcast_prefix([1, 2, 3], [4, 5, 6]) [1, 2, 3] - >>> tree_broadcast_prefix([1, 2, 3], [1, 2, 3, 4]) + >>> tree_broadcast_prefix([1, 2, 3], [4, 5, 6, 7]) Traceback (most recent call last): ... - ValueError: list arity mismatch; expected: 3, got: 4; list: [1, 2, 3, 4]. - >>> tree_broadcast_prefix([1, 2, 3], [1, 2, (3, 4)]) + ValueError: list arity mismatch; expected: 3, got: 4; list: [4, 5, 6, 7]. + >>> tree_broadcast_prefix([1, 2, 3], [4, 5, (6, 7)]) [1, 2, (3, 3)] - >>> tree_broadcast_prefix([1, 2, 3], [1, 2, {'a': 3, 'b': 4, 'c': (None, 5)}]) + >>> tree_broadcast_prefix([1, 2, 3], [4, 5, {'a': 6, 'b': 7, 'c': (None, 8)}]) [1, 2, {'a': 3, 'b': 3, 'c': (None, 3)}] - >>> tree_broadcast_prefix([1, 2, 3], [1, 2, {'a': 3, 'b': 4, 'c': (None, 5)}], none_is_leaf=True) + >>> tree_broadcast_prefix([1, 2, 3], [4, 5, {'a': 6, 'b': 7, 'c': (None, 8)}], none_is_leaf=True) [1, 2, {'a': 3, 'b': 3, 'c': (3, 3)}] Args: - prefix_tree (pytree): A pytree with the same structure as a prefix of ``full_tree``. - full_tree (pytree): A pytree with the same structure as a suffix of ``prefix_tree``. + prefix_tree (pytree): A pytree with the prefix structure of ``full_tree``. + full_tree (pytree): A pytree with the suffix structure of ``prefix_tree``. is_leaf (callable, optional): An optionally specified function that will be called at each flattening step. It should return a boolean, with :data:`True` stopping the traversal and the whole subtree being treated as a leaf, and :data:`False` indicating the @@ -782,24 +811,24 @@ def broadcast_prefix( replicated from ``prefix_tree``. The number of replicas is determined by the corresponding subtree in ``full_tree``. - >>> broadcast_prefix(1, [1, 2, 3]) + >>> broadcast_prefix(1, [2, 3, 4]) [1, 1, 1] - >>> broadcast_prefix([1, 2, 3], [1, 2, 3]) + >>> broadcast_prefix([1, 2, 3], [4, 5, 6]) [1, 2, 3] - >>> broadcast_prefix([1, 2, 3], [1, 2, 3, 4]) + >>> broadcast_prefix([1, 2, 3], [4, 5, 6, 7]) Traceback (most recent call last): ... - ValueError: list arity mismatch; expected: 3, got: 4; list: [1, 2, 3, 4]. - >>> broadcast_prefix([1, 2, 3], [1, 2, (3, 4)]) + ValueError: list arity mismatch; expected: 3, got: 4; list: [4, 5, 6, 7]. + >>> broadcast_prefix([1, 2, 3], [4, 5, (6, 7)]) [1, 2, 3, 3] - >>> broadcast_prefix([1, 2, 3], [1, 2, {'a': 3, 'b': 4, 'c': (None, 5)}]) + >>> broadcast_prefix([1, 2, 3], [4, 5, {'a': 6, 'b': 7, 'c': (None, 8)}]) [1, 2, 3, 3, 3] - >>> broadcast_prefix([1, 2, 3], [1, 2, {'a': 3, 'b': 4, 'c': (None, 5)}], none_is_leaf=True) + >>> broadcast_prefix([1, 2, 3], [4, 5, {'a': 6, 'b': 7, 'c': (None, 8)}], none_is_leaf=True) [1, 2, 3, 3, 3, 3] Args: - prefix_tree (pytree): A pytree with the same structure as a prefix of ``full_tree``. - full_tree (pytree): A pytree with the same structure as a suffix of ``prefix_tree``. + prefix_tree (pytree): A pytree with the prefix structure of ``full_tree``. + full_tree (pytree): A pytree with the suffix structure of ``prefix_tree``. is_leaf (callable, optional): An optionally specified function that will be called at each flattening step. It should return a boolean, with :data:`True` stopping the traversal and the whole subtree being treated as a leaf, and :data:`False` indicating the @@ -839,36 +868,13 @@ def add_leaves(x: T, subtree: PyTree[S]) -> None: return result -def tree_replace_nones(sentinel: Any, tree: PyTree[T] | None, namespace: str = '') -> PyTree[T]: - """Replace :data:`None` in ``tree`` with ``sentinel``. - - See also :func:`tree_flatten` and :func:`tree_map`. - - >>> tree_replace_nones(0, {'a': 1, 'b': None, 'c': (2, None)}) - {'a': 1, 'b': 0, 'c': (2, 0)} - >>> tree_replace_nones(0, None) - 0 - - Args: - sentinel (object): The value to replace :data:`None` with. - tree (pytree): A pytree to be transformed. - namespace (str, optional): The registry namespace used for custom pytree node types. - (default: :const:`''`, i.e., the global namespace) - - Returns: - A new pytree with the same structure as ``tree`` but with :data:`None` replaced. - """ - if tree is None: - return sentinel - return tree_map( - lambda x: x if x is not None else sentinel, - tree, - none_is_leaf=True, - namespace=namespace, - ) +class MissingSentinel: # pylint: disable=missing-class-docstring,too-few-public-methods + def __repr__(self) -> str: + return '' -__MISSING: T = object() # type: ignore[valid-type] +__MISSING: T = MissingSentinel() # type: ignore[valid-type] +del MissingSentinel @overload @@ -1523,7 +1529,7 @@ def flatten_one_level( node_type = type(tree) handler = register_pytree_node.get(node_type, namespace=namespace) # type: ignore[attr-defined] if handler: - flattened = handler.to_iterable(tree) + flattened = tuple(handler.to_iterable(tree)) if len(flattened) == 2: flattened = (*flattened, None) elif len(flattened) != 3: diff --git a/optree/registry.py b/optree/registry.py index 595e62eb..88578416 100644 --- a/optree/registry.py +++ b/optree/registry.py @@ -44,8 +44,14 @@ class PyTreeNodeRegistryEntry(NamedTuple): from_iterable: UnflattenFunc -__GLOBAL_NAMESPACE: str = object() # type: ignore[assignment] +class GlobalNamespace: # pylint: disable=missing-class-docstring,too-few-public-methods + def __repr__(self) -> str: + return '' + + +__GLOBAL_NAMESPACE: str = GlobalNamespace() # type: ignore[assignment] __REGISTRY_LOCK: Lock = Lock() +del GlobalNamespace def register_pytree_node( @@ -99,8 +105,8 @@ def register_pytree_node( >>> register_pytree_node( ... torch.Tensor, ... flatten_func=lambda tensor: ( - ... (tensor.cpu().numpy(),), - ... dict(dtype=tensor.dtype, device=tensor.device, requires_grad=tensor.requires_grad), + ... (tensor.cpu().detach().numpy(),), + ... {'dtype': tensor.dtype, 'device': tensor.device, 'requires_grad': tensor.requires_grad}, ... ), ... unflatten_func=lambda metadata, children: torch.tensor(children[0], **metadata), ... namespace='torch2numpy', diff --git a/optree/typing.py b/optree/typing.py index 9ef91a25..a7746324 100644 --- a/optree/typing.py +++ b/optree/typing.py @@ -259,7 +259,13 @@ def __deepcopy__(self, memo: dict[int, Any]) -> TypeAlias: return self -FlattenFunc = Callable[[CustomTreeNode[T]], Tuple[Children[T], MetaData]] +FlattenFunc = Callable[ + [CustomTreeNode[T]], + Union[ + Tuple[Children[T], MetaData], + Tuple[Children[T], MetaData, Optional[Iterable[Any]]], + ], +] UnflattenFunc = Callable[[MetaData, Children[T]], CustomTreeNode[T]] diff --git a/pyproject.toml b/pyproject.toml index a5b4d847..ed473194 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -105,7 +105,7 @@ test-command = """make -C "{project}" test PYTHON=python""" safe = true line-length = 100 skip-string-normalization = true -target-version = ["py37", "py38", "py39", "py310", "py311", "py312"] +target-version = ["py37"] [tool.isort] atomic = true diff --git a/src/treespec/flatten.cpp b/src/treespec/flatten.cpp index 07bd4b76..05121c9a 100644 --- a/src/treespec/flatten.cpp +++ b/src/treespec/flatten.cpp @@ -400,7 +400,7 @@ PyTreeSpec::FlattenWithPath(const py::handle& tree, } // NOLINTNEXTLINE[readability-function-cognitive-complexity] -py::list PyTreeSpec::FlattenUpToImpl(const py::handle& full_tree) const { +py::list PyTreeSpec::FlattenUpTo(const py::handle& full_tree) const { const ssize_t num_leaves = GetNumLeaves(); auto agenda = reserved_vector(4); @@ -629,10 +629,6 @@ py::list PyTreeSpec::FlattenUpToImpl(const py::handle& full_tree) const { return leaves; } -py::list PyTreeSpec::FlattenUpTo(const py::handle& full_tree) const { - return FlattenUpToImpl(full_tree); -} - template /*static*/ bool PyTreeSpec::AllLeavesImpl(const py::iterable& iterable, const std::string& registry_namespace) { diff --git a/src/treespec/treespec.cpp b/src/treespec/treespec.cpp index a337dcc1..95c22373 100644 --- a/src/treespec/treespec.cpp +++ b/src/treespec/treespec.cpp @@ -480,7 +480,7 @@ std::unique_ptr PyTreeSpec::Child(ssize_t index) const { EXPECT_EQ(py::ssize_t_cast(num_children), node.arity, "Node arity did not match."); switch (node.kind) { case PyTreeKind::Leaf: - INTERNAL_ERROR("MakeNode not implemented for leaves."); + INTERNAL_ERROR("PyTreeSpec::MakeNode() not implemented for leaves."); case PyTreeKind::None: return py::none(); diff --git a/src/treespec/unflatten.cpp b/src/treespec/unflatten.cpp index e9813569..018fcaa5 100644 --- a/src/treespec/unflatten.cpp +++ b/src/treespec/unflatten.cpp @@ -29,23 +29,20 @@ py::object PyTreeSpec::UnflattenImpl(const Span& leaves) const { py::ssize_t_cast(agenda.size()), node.arity, "Too few elements for PyTreeSpec node."); switch (node.kind) { - case PyTreeKind::None: case PyTreeKind::Leaf: { - if (node.kind == PyTreeKind::Leaf || m_none_is_leaf) [[likely]] { - if (it == leaves.end()) [[unlikely]] { - std::ostringstream oss{}; - oss << "Too few leaves for PyTreeSpec; expected: " << GetNumLeaves() - << ", got: " << leaf_count << "."; - throw py::value_error(oss.str()); - } - agenda.emplace_back(py::reinterpret_borrow(*it)); - ++it; - ++leaf_count; - break; + if (it == leaves.end()) [[unlikely]] { + std::ostringstream oss{}; + oss << "Too few leaves for PyTreeSpec; expected: " << GetNumLeaves() + << ", got: " << leaf_count << "."; + throw py::value_error(oss.str()); } - [[fallthrough]]; + agenda.emplace_back(py::reinterpret_borrow(*it)); + ++it; + ++leaf_count; + break; } + case PyTreeKind::None: case PyTreeKind::Tuple: case PyTreeKind::List: case PyTreeKind::Dict: diff --git a/tests/test_registry.py b/tests/test_registry.py index 89b2dc23..3de6f632 100644 --- a/tests/test_registry.py +++ b/tests/test_registry.py @@ -340,12 +340,12 @@ def test_pytree_node_registry_get(): handler = optree.register_pytree_node.get(list) assert handler is not None lst = [1, 2, 3] - assert handler.to_iterable(lst)[:2] == (lst, None) + assert tuple(handler.to_iterable(lst))[:2] == (lst, None) handler = optree.register_pytree_node.get(list, namespace='any') assert handler is not None lst = [1, 2, 3] - assert handler.to_iterable(lst)[:2] == (lst, None) + assert tuple(handler.to_iterable(lst))[:2] == (lst, None) handler = optree.register_pytree_node.get(set) assert handler is None