diff --git a/.github/workflows/set_cibw_build.py b/.github/workflows/set_cibw_build.py index bcef9f77..ae0c35a6 100755 --- a/.github/workflows/set_cibw_build.py +++ b/.github/workflows/set_cibw_build.py @@ -3,11 +3,11 @@ # pylint: disable=missing-module-docstring import os -import sys +import platform -# pylint: disable-next=consider-using-f-string -CIBW_BUILD = 'CIBW_BUILD=*cp%d%d-*' % sys.version_info[:2] +MAJOR, MINOR, *_ = platform.python_version_tuple() +CIBW_BUILD = f'CIBW_BUILD=*{platform.python_implementation().lower()[0]}p{MAJOR}{MINOR}-*' print(CIBW_BUILD) with open(os.getenv('GITHUB_ENV'), mode='a', encoding='utf-8') as file: diff --git a/include/utils.h b/include/utils.h index d8d00d13..cac337e2 100644 --- a/include/utils.h +++ b/include/utils.h @@ -23,6 +23,7 @@ limitations under the License. #include // PyMemberDef #endif +#include // pybind11::exec #include #include // std::rethrow_exception, std::current_exception @@ -127,6 +128,7 @@ constexpr bool NONE_IS_NODE = false; #define Py_Get_ID(name) (Py_ID_##name()) Py_Declare_ID(optree); +Py_Declare_ID(__main__); // __main__ Py_Declare_ID(__module__); // type.__module__ Py_Declare_ID(__qualname__); // type.__qualname__ Py_Declare_ID(__name__); // type.__name__ @@ -421,7 +423,6 @@ inline bool IsStructSequenceClassImpl(const py::handle& type) { // n_fields, n_sequence_fields, n_unnamed_fields attributes. auto* type_object = reinterpret_cast(type.ptr()); if (PyType_FastSubclass(type_object, Py_TPFLAGS_TUPLE_SUBCLASS) && - !static_cast(PyType_HasFeature(type_object, Py_TPFLAGS_BASETYPE)) && type_object->tp_bases != nullptr && static_cast(PyTuple_CheckExact(type_object->tp_bases)) && PyTuple_GET_SIZE(type_object->tp_bases) == 1 && @@ -441,7 +442,16 @@ inline bool IsStructSequenceClassImpl(const py::handle& type) { return false; } } - return true; +#ifdef PYPY_VERSION + try { + py::exec("class _(cls): pass", py::dict(py::arg("cls") = type)); + } catch (py::error_already_set& ex) { + return (ex.matches(PyExc_AssertionError) || ex.matches(PyExc_TypeError)); + } + return false; +#else + return (!static_cast(PyType_HasFeature(type_object, Py_TPFLAGS_BASETYPE))); +#endif } return false; } @@ -480,6 +490,22 @@ inline void AssertExactStructSequence(const py::handle& object) { } } inline py::tuple StructSequenceGetFieldsImpl(const py::handle& type) { +#ifdef PYPY_VERSION + py::list fields{}; + py::exec( + R"py( + from _structseq import structseqfield + + indices_by_name = { + name: member.index + for name, member in vars(cls).items() + if isinstance(member, structseqfield) + } + fields.extend(sorted(indices_by_name, key=indices_by_name.get)[:cls.n_sequence_fields]) + )py", + py::dict(py::arg("cls") = type, py::arg("fields") = fields)); + return py::tuple{fields}; +#else const auto n_sequence_fields = py::cast(getattr(type, Py_Get_ID(n_sequence_fields))); auto* members = reinterpret_cast(type.ptr())->tp_members; py::tuple fields{n_sequence_fields}; @@ -488,6 +514,7 @@ inline py::tuple StructSequenceGetFieldsImpl(const py::handle& type) { SET_ITEM(fields, i, py::str(members[i].name)); } return fields; +#endif } inline py::tuple StructSequenceGetFields(const py::handle& object) { py::handle type; diff --git a/optree/typing.py b/optree/typing.py index f00027cf..4ece56b0 100644 --- a/optree/typing.py +++ b/optree/typing.py @@ -16,6 +16,7 @@ from __future__ import annotations +import platform import types from collections.abc import Hashable from typing import ( @@ -399,7 +400,7 @@ def is_structseq_instance(obj: object) -> bool: def is_structseq_class(cls: type) -> bool: """Return whether the class is a class of PyStructSequence.""" - return ( + if ( isinstance(cls, type) # Check direct inheritance from `tuple` rather than `issubclass(cls, tuple)` and cls.__bases__ == (tuple,) @@ -407,9 +408,19 @@ def is_structseq_class(cls: type) -> bool: and isinstance(getattr(cls, 'n_fields', None), int) and isinstance(getattr(cls, 'n_sequence_fields', None), int) and isinstance(getattr(cls, 'n_unnamed_fields', None), int) + ): # Check the type does not allow subclassing - and not (cls.__flags__ & Py_TPFLAGS_BASETYPE) - ) + if platform.python_implementation() == 'PyPy': + try: + # pylint: disable-next=too-few-public-methods + class _(cls): # noqa: N801 + pass + + except (AssertionError, TypeError): + return True + return False + return not bool(cls.__flags__ & Py_TPFLAGS_BASETYPE) + return False def structseq_fields(obj: tuple | type[tuple]) -> tuple[str, ...]: @@ -423,14 +434,23 @@ def structseq_fields(obj: tuple | type[tuple]) -> tuple[str, ...]: if not is_structseq_class(cls): raise TypeError(f'Expected an instance of PyStructSequence type, got {obj!r}.') - n_sequence_fields: int = cls.n_sequence_fields # type: ignore[attr-defined] - fields: list[str] = [] - for name, member in vars(cls).items(): - if len(fields) >= n_sequence_fields: - break - if isinstance(member, types.MemberDescriptorType): - fields.append(name) - return tuple(fields) + if platform.python_implementation() == 'PyPy': + # pylint: disable-next=import-error,import-outside-toplevel + from _structseq import structseqfield + + indices_by_name = { + name: member.index + for name, member in vars(cls).items() + if isinstance(member, structseqfield) + } + fields = sorted(indices_by_name, key=indices_by_name.get) # type: ignore[arg-type] + else: + fields = [ + name + for name, member in vars(cls).items() + if isinstance(member, types.MemberDescriptorType) + ] + return tuple(fields[: cls.n_sequence_fields]) # type: ignore[attr-defined] # Ensure that the behavior is consistent with C++ implementation diff --git a/pyproject.toml b/pyproject.toml index 0677cda6..b1c4a02b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,6 +33,7 @@ classifiers = [ "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", "Programming Language :: Python :: Implementation :: CPython", + "Programming Language :: Python :: Implementation :: PyPy", "Operating System :: Microsoft :: Windows", "Operating System :: POSIX :: Linux", "Operating System :: MacOS", diff --git a/src/treespec/treespec.cpp b/src/treespec/treespec.cpp index fd308cc5..cf0e0e86 100644 --- a/src/treespec/treespec.cpp +++ b/src/treespec/treespec.cpp @@ -1214,17 +1214,29 @@ std::string PyTreeSpec::ToStringImpl() const { case PyTreeKind::StructSequence: { py::object type = node.node_data; - auto* members = reinterpret_cast(type.ptr())->tp_members; - std::string kind = reinterpret_cast(type.ptr())->tp_name; - sstream << kind << "("; + auto fields = StructSequenceGetFields(type); + EXPECT_EQ(GET_SIZE(fields), + node.arity, + "Number of fields and entries does not match."); + py::object module_name = + py::getattr(type, Py_Get_ID(__module__), Py_Get_ID(__main__)); + if (!module_name.is_none()) [[likely]] { + std::string name = static_cast(py::str(module_name)); + if (!(name.empty() || name == "__main__" || name == "builtins" || + name == "__builtins__")) [[likely]] { + sstream << name << "."; + } + } + sstream << static_cast( + py::str(py::getattr(type, Py_Get_ID(__qualname__)))); + sstream << "("; bool first = true; auto child_iter = agenda.end() - node.arity; - for (ssize_t i = 0; i < node.arity; ++i) { + for (const py::handle& field : fields) { if (!first) [[likely]] { sstream << ", "; } - // NOLINTNEXTLINE[cppcoreguidelines-pro-bounds-pointer-arithmetic] - sstream << members[i].name << "=" << *child_iter; + sstream << static_cast(py::str(field)) << "=" << *child_iter; ++child_iter; first = false; } diff --git a/tests/helpers.py b/tests/helpers.py index c5acbdaa..2612564b 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -18,6 +18,7 @@ import dataclasses import gc import itertools +import platform import sys import time from collections import OrderedDict, UserDict, defaultdict, deque, namedtuple @@ -28,6 +29,13 @@ import optree +PYPY = platform.python_implementation() == 'PyPy' +skipif_pypy = pytest.mark.skipif( + PYPY, + reason='PyPy does not support weakref and refcount correctly', +) + + def gc_collect(): for _ in range(3): gc.collect() @@ -114,6 +122,10 @@ class EmptyTuple(NamedTuple): # time.struct_time(tm_year=*, tm_mon=*, tm_mday=*, tm_hour=*, tm_min=*, tm_sec=*, tm_wday=*, tm_yday=*, tm_isdst=*) TimeStructTimeType = time.struct_time +if PYPY: + SysFloatInfoType.__module__ = 'sys' + TimeStructTimeType.__module__ = 'time' + class Vector3D: def __init__(self, x, y, z): diff --git a/tests/test_ops.py b/tests/test_ops.py index f83a334a..9eaa277c 100644 --- a/tests/test_ops.py +++ b/tests/test_ops.py @@ -2916,7 +2916,7 @@ def test_tree_max(): assert optree.tree_max(None, default=0) == 0 assert optree.tree_max(None, none_is_leaf=True) is None assert optree.tree_max(None, default=0, key=operator.neg) == 0 - with pytest.raises(TypeError, match=re.escape("bad operand type for unary -: 'NoneType'")): + with pytest.raises(TypeError, match=".*operand type for unary .+: 'NoneType'"): assert optree.tree_max(None, default=0, key=operator.neg, none_is_leaf=True) is None @@ -2935,7 +2935,7 @@ def test_tree_min(): assert optree.tree_min(None, default=0) == 0 assert optree.tree_min(None, none_is_leaf=True) is None assert optree.tree_min(None, default=0, key=operator.neg) == 0 - with pytest.raises(TypeError, match=re.escape("bad operand type for unary -: 'NoneType'")): + with pytest.raises(TypeError, match=".*operand type for unary .+: 'NoneType'"): assert optree.tree_min(None, default=0, key=operator.neg, none_is_leaf=True) is None diff --git a/tests/test_registry.py b/tests/test_registry.py index 06fbd5df..dcd86ad5 100644 --- a/tests/test_registry.py +++ b/tests/test_registry.py @@ -22,7 +22,7 @@ import pytest import optree -from helpers import gc_collect +from helpers import gc_collect, skipif_pypy def test_register_pytree_node_class_with_no_namespace(): @@ -765,6 +765,7 @@ def test_unregister_pytree_node_namedtuple(): assert treespec3 != treespec4 +@skipif_pypy def test_unregister_pytree_node_memory_leak(): # noqa: C901 @optree.register_pytree_node_class(namespace=optree.registry.__GLOBAL_NAMESPACE) diff --git a/tests/test_treespec.py b/tests/test_treespec.py index 1c59b110..ed341aab 100644 --- a/tests/test_treespec.py +++ b/tests/test_treespec.py @@ -30,6 +30,7 @@ import optree from helpers import ( NAMESPACED_TREE, + PYPY, TREE_STRINGS, TREES, MyAnotherDict, @@ -225,13 +226,17 @@ def __repr__(self): hashes.add(hash(treespec)) assert hash(other) == hash(other) assert hash(treespec) == hash(other) - with pytest.raises(RecursionError): - assert treespec != other + + if not PYPY: + with pytest.raises(RecursionError): + assert treespec != other wr = weakref.ref(treespec) del treespec, key, other gc_collect() - assert wr() is None + + if not PYPY: + assert wr() is None def test_treeiter_self_referential(): @@ -260,7 +265,9 @@ def test_treeiter_self_referential(): del it, d gc_collect() - assert wr() is None + + if not PYPY: + assert wr() is None def test_treespec_with_namespace(): @@ -406,10 +413,7 @@ def test_treespec_pickle_round_trip(tree, none_is_leaf, namespace): try: pickle.loads(pickle.dumps(tree)) except pickle.PicklingError: - with pytest.raises( - pickle.PicklingError, - match="Can't pickle .*: it's not the same object as .*", - ): + with pytest.raises(pickle.PicklingError, match=r"Can't pickle .*:"): pickle.loads(pickle.dumps(expected)) else: actual = pickle.loads(pickle.dumps(expected)) diff --git a/tests/test_typing.py b/tests/test_typing.py index 7cc9289d..6907b415 100644 --- a/tests/test_typing.py +++ b/tests/test_typing.py @@ -24,7 +24,14 @@ import pytest import optree -from helpers import CustomNamedTupleSubclass, CustomTuple, Vector2D, gc_collect, getrefcount +from helpers import ( + CustomNamedTupleSubclass, + CustomTuple, + Vector2D, + gc_collect, + getrefcount, + skipif_pypy, +) class FakeNamedTuple(tuple): @@ -84,6 +91,7 @@ def test_is_namedtuple(): assert not optree.is_namedtuple_class(FakeStructSequence) +@skipif_pypy def test_is_namedtuple_cache(): Point = namedtuple('Point', ('x', 'y')) # noqa: PYI024 @@ -141,6 +149,7 @@ class Foo(metaclass=FooMeta): assert wr() is None +@skipif_pypy def test_namedtuple_fields_cache(): Point = namedtuple('Point', ('x', 'y')) # noqa: PYI024 @@ -241,6 +250,7 @@ class MyTuple(optree.typing.structseq): assert not optree.is_structseq_class(FakeStructSequence) +@skipif_pypy def test_is_structseq_cache(): Point = namedtuple('Point', ('x', 'y')) # noqa: PYI024 @@ -452,6 +462,7 @@ def test_structseq_fields(): optree.structseq_fields(FakeStructSequence) +@skipif_pypy def test_structseq_fields_cache(): Point = namedtuple('Point', ('x', 'y')) # noqa: PYI024