Skip to content

Commit

Permalink
feat(registry): raise a warning when registering subclasses of namedt…
Browse files Browse the repository at this point in the history
…uple
  • Loading branch information
XuehaiPan committed Jan 30, 2023
1 parent caea59f commit 157a12d
Show file tree
Hide file tree
Showing 8 changed files with 133 additions and 10 deletions.
3 changes: 3 additions & 0 deletions docs/source/typing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ Typing Support
PyTreeTypeVar
CustomTreeNode
is_namedtuple
is_namedtuple_class

.. autoclass:: PyTreeSpec
:members:
Expand All @@ -32,3 +33,5 @@ Typing Support
:show-inheritance:

.. autofunction:: is_namedtuple

.. autofunction:: is_namedtuple_class
6 changes: 6 additions & 0 deletions include/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,12 @@ inline bool IsNamedTuple(const py::handle& object) {
return PyTuple_Check(object.ptr()) && PyObject_HasAttrString(object.ptr(), "_fields") == 1;
}

inline bool IsNamedTupleClass(const py::handle& type) {
// We can only identify namedtuples heuristically, here by the presence of a _fields attribute.
return PyObject_IsSubclass(type.ptr(), reinterpret_cast<PyObject*>(&PyTuple_Type)) == 1 &&
PyObject_HasAttrString(type.ptr(), "_fields") == 1;
}

inline void AssertExactNamedTuple(const py::handle& object) {
if (!IsNamedTuple(object)) [[unlikely]] {
throw std::invalid_argument(
Expand Down
6 changes: 6 additions & 0 deletions optree/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,14 @@
)
from optree.typing import (
CustomTreeNode,
FlattenFunc,
PyTree,
PyTreeDef,
PyTreeSpec,
PyTreeTypeVar,
UnflattenFunc,
is_namedtuple,
is_namedtuple_class,
)
from optree.version import __version__

Expand Down Expand Up @@ -105,7 +108,10 @@
'PyTree',
'PyTreeTypeVar',
'CustomTreeNode',
'FlattenFunc',
'UnflattenFunc',
'is_namedtuple',
'is_namedtuple_class',
]

MAX_RECURSION_DEPTH: int = MAX_RECURSION_DEPTH
Expand Down
12 changes: 6 additions & 6 deletions optree/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@
)

import optree._C as _C
from optree.typing import KT, VT, Children, CustomTreeNode, DefaultDict, MetaData
from optree.typing import KT, VT, CustomTreeNode, DefaultDict, FlattenFunc
from optree.typing import OrderedDict as GenericOrderedDict
from optree.typing import PyTree, T
from optree.typing import PyTree, T, UnflattenFunc
from optree.utils import safe_zip, unzip2


Expand All @@ -51,8 +51,8 @@


class PyTreeNodeRegistryEntry(NamedTuple):
to_iterable: Callable[[CustomTreeNode[T]], Tuple[Children[T], MetaData]]
from_iterable: Callable[[MetaData, Children[T]], CustomTreeNode[T]]
to_iterable: FlattenFunc
from_iterable: UnflattenFunc


__GLOBAL_NAMESPACE: str = object() # type: ignore[assignment]
Expand All @@ -61,8 +61,8 @@ class PyTreeNodeRegistryEntry(NamedTuple):

def register_pytree_node(
cls: Type[CustomTreeNode[T]],
flatten_func: Callable[[CustomTreeNode[T]], Tuple[Children[T], MetaData]],
unflatten_func: Callable[[MetaData, Children[T]], CustomTreeNode[T]],
flatten_func: FlattenFunc,
unflatten_func: UnflattenFunc,
namespace: str,
) -> Type[CustomTreeNode[T]]:
"""Extend the set of types that are considered internal nodes in pytrees.
Expand Down
15 changes: 14 additions & 1 deletion optree/typing.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2022 MetaOPT Team. All Rights Reserved.
# Copyright 2022-2023 MetaOPT Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -18,6 +18,7 @@

from typing import (
Any,
Callable,
DefaultDict,
Deque,
Dict,
Expand Down Expand Up @@ -55,7 +56,10 @@
'CustomTreeNode',
'Children',
'MetaData',
'FlattenFunc',
'UnflattenFunc',
'is_namedtuple',
'is_namedtuple_class',
'T',
'S',
'U',
Expand Down Expand Up @@ -239,6 +243,15 @@ def __deepcopy__(self, memo):
return self


FlattenFunc = Callable[[CustomTreeNode[T]], Tuple[Children[T], MetaData]]
UnflattenFunc = Callable[[MetaData, Children[T]], CustomTreeNode[T]]


def is_namedtuple(obj: object) -> bool:
"""Return whether the object is a namedtuple."""
return isinstance(obj, tuple) and hasattr(obj, '_fields')


def is_namedtuple_class(cls: Type) -> bool:
"""Return whether the class is a subclass of namedtuple."""
return issubclass(cls, tuple) and hasattr(cls, '_fields')
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -129,3 +129,6 @@ convention = "google"

[tool.doc8]
max-line-length = 500

[tool.pytest.ini_options]
filterwarnings = ["error"]
21 changes: 21 additions & 0 deletions src/registry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,16 @@ template <bool NoneIsLeaf>
throw std::invalid_argument(absl::StrFormat(
"PyTree type %s is already registered in the global namespace.", py::repr(cls)));
}
if (IsNamedTupleClass(cls)) [[unlikely]] {
PyErr_WarnEx(
PyExc_UserWarning,
absl::StrFormat("PyTree type %s is a subclass of `collections.namedtuple`, "
"which is already registered in the global namespace. "
"Override it with custom flatten/unflatten functions.",
py::repr(cls))
.c_str(),
/*stack_level=*/2);
}
} else [[likely]] {
if (registry->m_registrations.find(cls) != registry->m_registrations.end()) [[unlikely]] {
throw std::invalid_argument(absl::StrFormat(
Expand All @@ -87,6 +97,17 @@ template <bool NoneIsLeaf>
py::repr(cls),
py::repr(py::str(registry_namespace))));
}
if (IsNamedTupleClass(cls)) [[unlikely]] {
PyErr_WarnEx(PyExc_UserWarning,
absl::StrFormat(
"PyTree type %s is a subclass of `collections.namedtuple`, "
"which is already registered in the global namespace. "
"Override it with custom flatten/unflatten functions in namespace %s.",
py::repr(cls),
py::repr(py::str(registry_namespace)))
.c_str(),
/*stack_level=*/2);
}
}
}

Expand Down
77 changes: 74 additions & 3 deletions tests/test_registry.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2022 MetaOPT Team. All Rights Reserved.
# Copyright 2022-2023 MetaOPT Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -15,7 +15,8 @@

# pylint: disable=missing-function-docstring,invalid-name

from collections import UserDict, UserList
import re
from collections import UserDict, UserList, namedtuple

import pytest

Expand Down Expand Up @@ -174,7 +175,6 @@ def test_register_pytree_node_duplicate_builtin_namespace():
lambda _, l: l,
namespace=optree.registry.__GLOBAL_NAMESPACE,
)

with pytest.raises(
ValueError,
match=r"PyTree type <class 'list'> is already registered in the global namespace.",
Expand All @@ -187,6 +187,77 @@ def test_register_pytree_node_duplicate_builtin_namespace():
)


def test_register_pytree_node_namedtuple():
mytuple1 = namedtuple('mytuple1', ['a', 'b', 'c'])
with pytest.warns(
UserWarning,
match=re.escape(
r"PyTree type <class 'test_registry.mytuple1'> is a subclass of `collections.namedtuple`, "
r'which is already registered in the global namespace. '
r'Override it with custom flatten/unflatten functions.'
),
):
optree.register_pytree_node(
mytuple1,
lambda t: (reversed(t), None, None),
lambda _, t: mytuple1(*reversed(t)),
namespace=optree.registry.__GLOBAL_NAMESPACE,
)
with pytest.raises(
ValueError,
match=re.escape(
r"PyTree type <class 'test_registry.mytuple1'> is already registered in the global namespace."
),
):
optree.register_pytree_node(
mytuple1,
lambda t: (reversed(t), None, None),
lambda _, t: mytuple1(*reversed(t)),
namespace='mytuple',
)

tree1 = mytuple1(1, 2, 3)
leaves1, treespec1 = optree.tree_flatten(tree1)
assert leaves1 == [3, 2, 1]
assert str(treespec1) == 'PyTreeSpec(CustomTreeNode(mytuple1[None], [*, *, *]))'
assert tree1 == optree.tree_unflatten(treespec1, leaves1)

mytuple2 = namedtuple('mytuple2', ['a', 'b', 'c'])
with pytest.warns(
UserWarning,
match=re.escape(
r"PyTree type <class 'test_registry.mytuple2'> is a subclass of `collections.namedtuple`, "
r'which is already registered in the global namespace. '
r"Override it with custom flatten/unflatten functions in namespace 'mytuple'."
),
):
optree.register_pytree_node(
mytuple2,
lambda t: (reversed(t), None, None),
lambda _, t: mytuple2(*reversed(t)),
namespace='mytuple',
)

tree2 = mytuple2(1, 2, 3)
leaves2, treespec2 = optree.tree_flatten(tree2)
assert leaves2 == [1, 2, 3]
assert str(treespec2) == 'PyTreeSpec(mytuple2(a=*, b=*, c=*))'
assert tree2 == optree.tree_unflatten(treespec2, leaves2)

leaves2, treespec2 = optree.tree_flatten(tree2, namespace='undefined')
assert leaves2 == [1, 2, 3]
assert str(treespec2) == 'PyTreeSpec(mytuple2(a=*, b=*, c=*))'
assert tree2 == optree.tree_unflatten(treespec2, leaves2)

leaves2, treespec2 = optree.tree_flatten(tree2, namespace='mytuple')
assert leaves2 == [3, 2, 1]
assert (
str(treespec2)
== "PyTreeSpec(CustomTreeNode(mytuple2[None], [*, *, *]), namespace='mytuple')"
)
assert tree2 == optree.tree_unflatten(treespec2, leaves2)


def test_pytree_node_registry_get():
handler = optree.register_pytree_node.get(list)
assert handler is not None
Expand Down

0 comments on commit 157a12d

Please sign in to comment.