Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(registry): raise a warning when registering subclasses of namedtuple #24

Merged
merged 1 commit into from
Jan 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/source/typing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ Typing Support
PyTreeTypeVar
CustomTreeNode
is_namedtuple
is_namedtuple_class

.. autoclass:: PyTreeSpec
:members:
Expand All @@ -32,3 +33,5 @@ Typing Support
:show-inheritance:

.. autofunction:: is_namedtuple

.. autofunction:: is_namedtuple_class
6 changes: 6 additions & 0 deletions include/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,12 @@ inline bool IsNamedTuple(const py::handle& object) {
return PyTuple_Check(object.ptr()) && PyObject_HasAttrString(object.ptr(), "_fields") == 1;
}

inline bool IsNamedTupleClass(const py::handle& type) {
// We can only identify namedtuples heuristically, here by the presence of a _fields attribute.
return PyObject_IsSubclass(type.ptr(), reinterpret_cast<PyObject*>(&PyTuple_Type)) == 1 &&
PyObject_HasAttrString(type.ptr(), "_fields") == 1;
}

inline void AssertExactNamedTuple(const py::handle& object) {
if (!IsNamedTuple(object)) [[unlikely]] {
throw std::invalid_argument(
Expand Down
6 changes: 6 additions & 0 deletions optree/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,14 @@
)
from optree.typing import (
CustomTreeNode,
FlattenFunc,
PyTree,
PyTreeDef,
PyTreeSpec,
PyTreeTypeVar,
UnflattenFunc,
is_namedtuple,
is_namedtuple_class,
)
from optree.version import __version__

Expand Down Expand Up @@ -105,7 +108,10 @@
'PyTree',
'PyTreeTypeVar',
'CustomTreeNode',
'FlattenFunc',
'UnflattenFunc',
'is_namedtuple',
'is_namedtuple_class',
]

MAX_RECURSION_DEPTH: int = MAX_RECURSION_DEPTH
Expand Down
14 changes: 7 additions & 7 deletions optree/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@
)

import optree._C as _C
from optree.typing import KT, VT, Children, CustomTreeNode, DefaultDict, MetaData
from optree.typing import KT, VT, CustomTreeNode, DefaultDict, FlattenFunc
from optree.typing import OrderedDict as GenericOrderedDict
from optree.typing import PyTree, T
from optree.typing import PyTree, T, UnflattenFunc
from optree.utils import safe_zip, unzip2


Expand All @@ -51,8 +51,8 @@


class PyTreeNodeRegistryEntry(NamedTuple):
to_iterable: Callable[[CustomTreeNode[T]], Tuple[Children[T], MetaData]]
from_iterable: Callable[[MetaData, Children[T]], CustomTreeNode[T]]
to_iterable: FlattenFunc
from_iterable: UnflattenFunc


__GLOBAL_NAMESPACE: str = object() # type: ignore[assignment]
Expand All @@ -61,8 +61,8 @@ class PyTreeNodeRegistryEntry(NamedTuple):

def register_pytree_node(
cls: Type[CustomTreeNode[T]],
flatten_func: Callable[[CustomTreeNode[T]], Tuple[Children[T], MetaData]],
unflatten_func: Callable[[MetaData, Children[T]], CustomTreeNode[T]],
flatten_func: FlattenFunc,
unflatten_func: UnflattenFunc,
namespace: str,
) -> Type[CustomTreeNode[T]]:
"""Extend the set of types that are considered internal nodes in pytrees.
Expand Down Expand Up @@ -189,7 +189,7 @@ def register_pytree_node(
with __REGISTRY_LOCK:
_C.register_node(cls, flatten_func, unflatten_func, namespace)
CustomTreeNode.register(cls) # pylint: disable=no-member
_nodetype_registry[registration_key] = PyTreeNodeRegistryEntry(flatten_func, unflatten_func) # type: ignore[arg-type]
_nodetype_registry[registration_key] = PyTreeNodeRegistryEntry(flatten_func, unflatten_func)
return cls


Expand Down
15 changes: 14 additions & 1 deletion optree/typing.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2022 MetaOPT Team. All Rights Reserved.
# Copyright 2022-2023 MetaOPT Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -18,6 +18,7 @@

from typing import (
Any,
Callable,
DefaultDict,
Deque,
Dict,
Expand Down Expand Up @@ -55,7 +56,10 @@
'CustomTreeNode',
'Children',
'MetaData',
'FlattenFunc',
'UnflattenFunc',
'is_namedtuple',
'is_namedtuple_class',
'T',
'S',
'U',
Expand Down Expand Up @@ -239,6 +243,15 @@ def __deepcopy__(self, memo):
return self


FlattenFunc = Callable[[CustomTreeNode[T]], Tuple[Children[T], MetaData]]
UnflattenFunc = Callable[[MetaData, Children[T]], CustomTreeNode[T]]


def is_namedtuple(obj: object) -> bool:
"""Return whether the object is a namedtuple."""
return isinstance(obj, tuple) and hasattr(obj, '_fields')


def is_namedtuple_class(cls: Type) -> bool:
"""Return whether the class is a subclass of namedtuple."""
return issubclass(cls, tuple) and hasattr(cls, '_fields')
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -129,3 +129,6 @@ convention = "google"

[tool.doc8]
max-line-length = 500

[tool.pytest.ini_options]
filterwarnings = ["error"]
21 changes: 21 additions & 0 deletions src/registry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,16 @@ template <bool NoneIsLeaf>
throw std::invalid_argument(absl::StrFormat(
"PyTree type %s is already registered in the global namespace.", py::repr(cls)));
}
if (IsNamedTupleClass(cls)) [[unlikely]] {
PyErr_WarnEx(
PyExc_UserWarning,
absl::StrFormat("PyTree type %s is a subclass of `collections.namedtuple`, "
"which is already registered in the global namespace. "
"Override it with custom flatten/unflatten functions.",
py::repr(cls))
.c_str(),
/*stack_level=*/2);
}
} else [[likely]] {
if (registry->m_registrations.find(cls) != registry->m_registrations.end()) [[unlikely]] {
throw std::invalid_argument(absl::StrFormat(
Expand All @@ -87,6 +97,17 @@ template <bool NoneIsLeaf>
py::repr(cls),
py::repr(py::str(registry_namespace))));
}
if (IsNamedTupleClass(cls)) [[unlikely]] {
PyErr_WarnEx(PyExc_UserWarning,
absl::StrFormat(
"PyTree type %s is a subclass of `collections.namedtuple`, "
"which is already registered in the global namespace. "
"Override it with custom flatten/unflatten functions in namespace %s.",
py::repr(cls),
py::repr(py::str(registry_namespace)))
.c_str(),
/*stack_level=*/2);
}
}
}

Expand Down
77 changes: 74 additions & 3 deletions tests/test_registry.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2022 MetaOPT Team. All Rights Reserved.
# Copyright 2022-2023 MetaOPT Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -15,7 +15,8 @@

# pylint: disable=missing-function-docstring,invalid-name

from collections import UserDict, UserList
import re
from collections import UserDict, UserList, namedtuple

import pytest

Expand Down Expand Up @@ -174,7 +175,6 @@ def test_register_pytree_node_duplicate_builtin_namespace():
lambda _, l: l,
namespace=optree.registry.__GLOBAL_NAMESPACE,
)

with pytest.raises(
ValueError,
match=r"PyTree type <class 'list'> is already registered in the global namespace.",
Expand All @@ -187,6 +187,77 @@ def test_register_pytree_node_duplicate_builtin_namespace():
)


def test_register_pytree_node_namedtuple():
mytuple1 = namedtuple('mytuple1', ['a', 'b', 'c'])
with pytest.warns(
UserWarning,
match=re.escape(
r"PyTree type <class 'test_registry.mytuple1'> is a subclass of `collections.namedtuple`, "
r'which is already registered in the global namespace. '
r'Override it with custom flatten/unflatten functions.'
),
):
optree.register_pytree_node(
mytuple1,
lambda t: (reversed(t), None, None),
lambda _, t: mytuple1(*reversed(t)),
namespace=optree.registry.__GLOBAL_NAMESPACE,
)
with pytest.raises(
ValueError,
match=re.escape(
r"PyTree type <class 'test_registry.mytuple1'> is already registered in the global namespace."
),
):
optree.register_pytree_node(
mytuple1,
lambda t: (reversed(t), None, None),
lambda _, t: mytuple1(*reversed(t)),
namespace='mytuple',
)

tree1 = mytuple1(1, 2, 3)
leaves1, treespec1 = optree.tree_flatten(tree1)
assert leaves1 == [3, 2, 1]
assert str(treespec1) == 'PyTreeSpec(CustomTreeNode(mytuple1[None], [*, *, *]))'
assert tree1 == optree.tree_unflatten(treespec1, leaves1)

mytuple2 = namedtuple('mytuple2', ['a', 'b', 'c'])
with pytest.warns(
UserWarning,
match=re.escape(
r"PyTree type <class 'test_registry.mytuple2'> is a subclass of `collections.namedtuple`, "
r'which is already registered in the global namespace. '
r"Override it with custom flatten/unflatten functions in namespace 'mytuple'."
),
):
optree.register_pytree_node(
mytuple2,
lambda t: (reversed(t), None, None),
lambda _, t: mytuple2(*reversed(t)),
namespace='mytuple',
)

tree2 = mytuple2(1, 2, 3)
leaves2, treespec2 = optree.tree_flatten(tree2)
assert leaves2 == [1, 2, 3]
assert str(treespec2) == 'PyTreeSpec(mytuple2(a=*, b=*, c=*))'
assert tree2 == optree.tree_unflatten(treespec2, leaves2)

leaves2, treespec2 = optree.tree_flatten(tree2, namespace='undefined')
assert leaves2 == [1, 2, 3]
assert str(treespec2) == 'PyTreeSpec(mytuple2(a=*, b=*, c=*))'
assert tree2 == optree.tree_unflatten(treespec2, leaves2)

leaves2, treespec2 = optree.tree_flatten(tree2, namespace='mytuple')
assert leaves2 == [3, 2, 1]
assert (
str(treespec2)
== "PyTreeSpec(CustomTreeNode(mytuple2[None], [*, *, *]), namespace='mytuple')"
)
assert tree2 == optree.tree_unflatten(treespec2, leaves2)


def test_pytree_node_registry_get():
handler = optree.register_pytree_node.get(list)
assert handler is not None
Expand Down