diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index fbf0afcf..5d598f0c 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -90,7 +90,8 @@ jobs: strategy: matrix: os: [ubuntu-latest, windows-latest, macos-latest] - python-version: ["3.7", "3.8", "3.9", "3.10", "3.11", "3.12"] + python-version: + ["3.7", "3.8", "3.9", "3.10", "3.11", "3.12", "pypy3.9", "pypy3.10"] include: - os: macos-13 python-version: "3.7" diff --git a/.github/workflows/set_cibw_build.py b/.github/workflows/set_cibw_build.py index bcef9f77..ae0c35a6 100755 --- a/.github/workflows/set_cibw_build.py +++ b/.github/workflows/set_cibw_build.py @@ -3,11 +3,11 @@ # pylint: disable=missing-module-docstring import os -import sys +import platform -# pylint: disable-next=consider-using-f-string -CIBW_BUILD = 'CIBW_BUILD=*cp%d%d-*' % sys.version_info[:2] +MAJOR, MINOR, *_ = platform.python_version_tuple() +CIBW_BUILD = f'CIBW_BUILD=*{platform.python_implementation().lower()[0]}p{MAJOR}{MINOR}-*' print(CIBW_BUILD) with open(os.getenv('GITHUB_ENV'), mode='a', encoding='utf-8') as file: diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 417d9cf6..db9482b8 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -36,7 +36,8 @@ jobs: strategy: matrix: os: [ubuntu-latest, windows-latest, macos-latest] - python-version: ["3.7", "3.8", "3.9", "3.10", "3.11", "3.12"] + python-version: + ["3.7", "3.8", "3.9", "3.10", "3.11", "3.12", "pypy3.9", "pypy3.10"] include: - os: macos-13 python-version: "3.7" diff --git a/CHANGELOG.md b/CHANGELOG.md index abb905be..2b66395e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- Add PyPy support by [@XuehaiPan](https://github.com/XuehaiPan) in [#145](https://github.com/metaopt/optree/pull/145). - Add 32-bit wheels for Linux and Windows by [@XuehaiPan](https://github.com/XuehaiPan) in [#141](https://github.com/metaopt/optree/pull/141). - Add Linux ppc64le and s390x wheels by [@XuehaiPan](https://github.com/XuehaiPan) in [#138](https://github.com/metaopt/optree/pull/138). - Add accessor APIs `tree_flatten_with_accessor` and `PyTreeSpec.accessors` by [@XuehaiPan](https://github.com/XuehaiPan) in [#108](https://github.com/metaopt/optree/pull/108). diff --git a/include/treespec.h b/include/treespec.h index 9d53f0b5..939a702b 100644 --- a/include/treespec.h +++ b/include/treespec.h @@ -42,7 +42,11 @@ using ssize_t = py::ssize_t; #ifndef Py_C_RECURSION_LIMIT #define Py_C_RECURSION_LIMIT 1000 #endif +#ifndef PYPY_VERSION constexpr ssize_t MAX_RECURSION_DEPTH = std::min(1000, Py_C_RECURSION_LIMIT); +#else +constexpr ssize_t MAX_RECURSION_DEPTH = std::min(500, Py_C_RECURSION_LIMIT); +#endif // Test whether the given object is a leaf node. bool IsLeaf(const py::object &object, diff --git a/include/utils.h b/include/utils.h index d8d00d13..d4b68292 100644 --- a/include/utils.h +++ b/include/utils.h @@ -23,6 +23,7 @@ limitations under the License. #include // PyMemberDef #endif +#include // pybind11::exec #include #include // std::rethrow_exception, std::current_exception @@ -34,24 +35,22 @@ limitations under the License. #include // std::vector namespace py = pybind11; -using size_t = py::size_t; -using ssize_t = py::ssize_t; // The maximum size of the type cache. -constexpr ssize_t MAX_TYPE_CACHE_SIZE = 4096; +constexpr py::ssize_t MAX_TYPE_CACHE_SIZE = 4096; // boost::hash_combine template inline void HashCombine(py::size_t& seed, const T& v) { // NOLINT[runtime/references] std::hash hasher{}; // NOLINTNEXTLINE[cppcoreguidelines-avoid-magic-numbers] - seed ^= (hasher(v) + 0x9e3779b9 + (seed << 6) + (seed >> 2)); + seed ^= (hasher(v) + 0x9E3779B9 + (seed << 6) + (seed >> 2)); } template inline void HashCombine(py::ssize_t& seed, const T& v) { // NOLINT[runtime/references] std::hash hasher{}; // NOLINTNEXTLINE[cppcoreguidelines-avoid-magic-numbers] - seed ^= (hasher(v) + 0x9e3779b9 + (seed << 6) + (seed >> 2)); + seed ^= (hasher(v) + 0x9E3779B9 + (seed << 6) + (seed >> 2)); } class TypeHash { @@ -127,6 +126,7 @@ constexpr bool NONE_IS_NODE = false; #define Py_Get_ID(name) (Py_ID_##name()) Py_Declare_ID(optree); +Py_Declare_ID(__main__); // __main__ Py_Declare_ID(__module__); // type.__module__ Py_Declare_ID(__qualname__); // type.__qualname__ Py_Declare_ID(__name__); // type.__name__ @@ -180,39 +180,39 @@ inline std::vector reserved_vector(const py::size_t& size) { } template -inline ssize_t GetSize(const py::handle& sized) { +inline py::ssize_t GetSize(const py::handle& sized) { return py::ssize_t_cast(py::len(sized)); } template <> -inline ssize_t GetSize(const py::handle& sized) { +inline py::ssize_t GetSize(const py::handle& sized) { return PyTuple_Size(sized.ptr()); } template <> -inline ssize_t GetSize(const py::handle& sized) { +inline py::ssize_t GetSize(const py::handle& sized) { return PyList_Size(sized.ptr()); } template <> -inline ssize_t GetSize(const py::handle& sized) { +inline py::ssize_t GetSize(const py::handle& sized) { return PyDict_Size(sized.ptr()); } template -inline ssize_t GET_SIZE(const py::handle& sized) { +inline py::ssize_t GET_SIZE(const py::handle& sized) { return py::ssize_t_cast(py::len(sized)); } template <> -inline ssize_t GET_SIZE(const py::handle& sized) { +inline py::ssize_t GET_SIZE(const py::handle& sized) { return PyTuple_GET_SIZE(sized.ptr()); } template <> -inline ssize_t GET_SIZE(const py::handle& sized) { +inline py::ssize_t GET_SIZE(const py::handle& sized) { return PyList_GET_SIZE(sized.ptr()); } #ifndef PyDict_GET_SIZE -#define PyDict_GET_SIZE PyDict_GetSize +#define PyDict_GET_SIZE PyDict_Size #endif template <> -inline ssize_t GET_SIZE(const py::handle& sized) { +inline py::ssize_t GET_SIZE(const py::handle& sized) { return PyDict_GET_SIZE(sized.ptr()); } @@ -421,7 +421,6 @@ inline bool IsStructSequenceClassImpl(const py::handle& type) { // n_fields, n_sequence_fields, n_unnamed_fields attributes. auto* type_object = reinterpret_cast(type.ptr()); if (PyType_FastSubclass(type_object, Py_TPFLAGS_TUPLE_SUBCLASS) && - !static_cast(PyType_HasFeature(type_object, Py_TPFLAGS_BASETYPE)) && type_object->tp_bases != nullptr && static_cast(PyTuple_CheckExact(type_object->tp_bases)) && PyTuple_GET_SIZE(type_object->tp_bases) == 1 && @@ -441,7 +440,16 @@ inline bool IsStructSequenceClassImpl(const py::handle& type) { return false; } } - return true; +#ifdef PYPY_VERSION + try { + py::exec("class _(cls): pass", py::dict(py::arg("cls") = type)); + } catch (py::error_already_set& ex) { + return (ex.matches(PyExc_AssertionError) || ex.matches(PyExc_TypeError)); + } + return false; +#else + return (!static_cast(PyType_HasFeature(type_object, Py_TPFLAGS_BASETYPE))); +#endif } return false; } @@ -480,14 +488,32 @@ inline void AssertExactStructSequence(const py::handle& object) { } } inline py::tuple StructSequenceGetFieldsImpl(const py::handle& type) { - const auto n_sequence_fields = py::cast(getattr(type, Py_Get_ID(n_sequence_fields))); +#ifdef PYPY_VERSION + py::list fields{}; + py::exec( + R"py( + from _structseq import structseqfield + + indices_by_name = { + name: member.index + for name, member in vars(cls).items() + if isinstance(member, structseqfield) + } + fields.extend(sorted(indices_by_name, key=indices_by_name.get)[:cls.n_sequence_fields]) + )py", + py::dict(py::arg("cls") = type, py::arg("fields") = fields)); + return py::tuple{fields}; +#else + const auto n_sequence_fields = + py::cast(getattr(type, Py_Get_ID(n_sequence_fields))); auto* members = reinterpret_cast(type.ptr())->tp_members; py::tuple fields{n_sequence_fields}; - for (ssize_t i = 0; i < n_sequence_fields; ++i) { + for (py::ssize_t i = 0; i < n_sequence_fields; ++i) { // NOLINTNEXTLINE[cppcoreguidelines-pro-bounds-pointer-arithmetic] SET_ITEM(fields, i, py::str(members[i].name)); } return fields; +#endif } inline py::tuple StructSequenceGetFields(const py::handle& object) { py::handle type; @@ -573,12 +599,12 @@ inline py::list SortedDictKeys(const py::dict& dict) { } inline bool DictKeysEqual(const py::list& /*unique*/ keys, const py::dict& dict) { - ssize_t list_len = GET_SIZE(keys); - ssize_t dict_len = GET_SIZE(dict); + py::ssize_t list_len = GET_SIZE(keys); + py::ssize_t dict_len = GET_SIZE(dict); if (list_len != dict_len) [[likely]] { // assumes keys are unique return false; } - for (ssize_t i = 0; i < list_len; ++i) { + for (py::ssize_t i = 0; i < list_len; ++i) { py::object key = GET_ITEM_BORROW(keys, i); int result = PyDict_Contains(dict.ptr(), key.ptr()); if (result == -1) [[unlikely]] { diff --git a/optree/typing.py b/optree/typing.py index f00027cf..4ece56b0 100644 --- a/optree/typing.py +++ b/optree/typing.py @@ -16,6 +16,7 @@ from __future__ import annotations +import platform import types from collections.abc import Hashable from typing import ( @@ -399,7 +400,7 @@ def is_structseq_instance(obj: object) -> bool: def is_structseq_class(cls: type) -> bool: """Return whether the class is a class of PyStructSequence.""" - return ( + if ( isinstance(cls, type) # Check direct inheritance from `tuple` rather than `issubclass(cls, tuple)` and cls.__bases__ == (tuple,) @@ -407,9 +408,19 @@ def is_structseq_class(cls: type) -> bool: and isinstance(getattr(cls, 'n_fields', None), int) and isinstance(getattr(cls, 'n_sequence_fields', None), int) and isinstance(getattr(cls, 'n_unnamed_fields', None), int) + ): # Check the type does not allow subclassing - and not (cls.__flags__ & Py_TPFLAGS_BASETYPE) - ) + if platform.python_implementation() == 'PyPy': + try: + # pylint: disable-next=too-few-public-methods + class _(cls): # noqa: N801 + pass + + except (AssertionError, TypeError): + return True + return False + return not bool(cls.__flags__ & Py_TPFLAGS_BASETYPE) + return False def structseq_fields(obj: tuple | type[tuple]) -> tuple[str, ...]: @@ -423,14 +434,23 @@ def structseq_fields(obj: tuple | type[tuple]) -> tuple[str, ...]: if not is_structseq_class(cls): raise TypeError(f'Expected an instance of PyStructSequence type, got {obj!r}.') - n_sequence_fields: int = cls.n_sequence_fields # type: ignore[attr-defined] - fields: list[str] = [] - for name, member in vars(cls).items(): - if len(fields) >= n_sequence_fields: - break - if isinstance(member, types.MemberDescriptorType): - fields.append(name) - return tuple(fields) + if platform.python_implementation() == 'PyPy': + # pylint: disable-next=import-error,import-outside-toplevel + from _structseq import structseqfield + + indices_by_name = { + name: member.index + for name, member in vars(cls).items() + if isinstance(member, structseqfield) + } + fields = sorted(indices_by_name, key=indices_by_name.get) # type: ignore[arg-type] + else: + fields = [ + name + for name, member in vars(cls).items() + if isinstance(member, types.MemberDescriptorType) + ] + return tuple(fields[: cls.n_sequence_fields]) # type: ignore[attr-defined] # Ensure that the behavior is consistent with C++ implementation diff --git a/pyproject.toml b/pyproject.toml index 0677cda6..b1c4a02b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,6 +33,7 @@ classifiers = [ "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", "Programming Language :: Python :: Implementation :: CPython", + "Programming Language :: Python :: Implementation :: PyPy", "Operating System :: Microsoft :: Windows", "Operating System :: POSIX :: Linux", "Operating System :: MacOS", diff --git a/src/treespec/treespec.cpp b/src/treespec/treespec.cpp index fd308cc5..f10da9ab 100644 --- a/src/treespec/treespec.cpp +++ b/src/treespec/treespec.cpp @@ -1214,17 +1214,28 @@ std::string PyTreeSpec::ToStringImpl() const { case PyTreeKind::StructSequence: { py::object type = node.node_data; - auto* members = reinterpret_cast(type.ptr())->tp_members; - std::string kind = reinterpret_cast(type.ptr())->tp_name; - sstream << kind << "("; + auto fields = StructSequenceGetFields(type); + EXPECT_EQ(GET_SIZE(fields), + node.arity, + "Number of fields and entries does not match."); + py::object module_name = + py::getattr(type, Py_Get_ID(__module__), Py_Get_ID(__main__)); + if (!module_name.is_none()) [[likely]] { + std::string name = static_cast(py::str(module_name)); + if (!(name.empty() || name == "__main__" || name == "builtins" || + name == "__builtins__")) [[likely]] { + sstream << name << "."; + } + } + py::object qualname = py::getattr(type, Py_Get_ID(__qualname__)); + sstream << static_cast(py::str(qualname)) << "("; bool first = true; auto child_iter = agenda.end() - node.arity; - for (ssize_t i = 0; i < node.arity; ++i) { + for (const py::handle& field : fields) { if (!first) [[likely]] { sstream << ", "; } - // NOLINTNEXTLINE[cppcoreguidelines-pro-bounds-pointer-arithmetic] - sstream << members[i].name << "=" << *child_iter; + sstream << static_cast(py::str(field)) << "=" << *child_iter; ++child_iter; first = false; } diff --git a/tests/helpers.py b/tests/helpers.py index c5acbdaa..2612564b 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -18,6 +18,7 @@ import dataclasses import gc import itertools +import platform import sys import time from collections import OrderedDict, UserDict, defaultdict, deque, namedtuple @@ -28,6 +29,13 @@ import optree +PYPY = platform.python_implementation() == 'PyPy' +skipif_pypy = pytest.mark.skipif( + PYPY, + reason='PyPy does not support weakref and refcount correctly', +) + + def gc_collect(): for _ in range(3): gc.collect() @@ -114,6 +122,10 @@ class EmptyTuple(NamedTuple): # time.struct_time(tm_year=*, tm_mon=*, tm_mday=*, tm_hour=*, tm_min=*, tm_sec=*, tm_wday=*, tm_yday=*, tm_isdst=*) TimeStructTimeType = time.struct_time +if PYPY: + SysFloatInfoType.__module__ = 'sys' + TimeStructTimeType.__module__ = 'time' + class Vector3D: def __init__(self, x, y, z): diff --git a/tests/test_ops.py b/tests/test_ops.py index f83a334a..9eaa277c 100644 --- a/tests/test_ops.py +++ b/tests/test_ops.py @@ -2916,7 +2916,7 @@ def test_tree_max(): assert optree.tree_max(None, default=0) == 0 assert optree.tree_max(None, none_is_leaf=True) is None assert optree.tree_max(None, default=0, key=operator.neg) == 0 - with pytest.raises(TypeError, match=re.escape("bad operand type for unary -: 'NoneType'")): + with pytest.raises(TypeError, match=".*operand type for unary .+: 'NoneType'"): assert optree.tree_max(None, default=0, key=operator.neg, none_is_leaf=True) is None @@ -2935,7 +2935,7 @@ def test_tree_min(): assert optree.tree_min(None, default=0) == 0 assert optree.tree_min(None, none_is_leaf=True) is None assert optree.tree_min(None, default=0, key=operator.neg) == 0 - with pytest.raises(TypeError, match=re.escape("bad operand type for unary -: 'NoneType'")): + with pytest.raises(TypeError, match=".*operand type for unary .+: 'NoneType'"): assert optree.tree_min(None, default=0, key=operator.neg, none_is_leaf=True) is None diff --git a/tests/test_registry.py b/tests/test_registry.py index 06fbd5df..dcd86ad5 100644 --- a/tests/test_registry.py +++ b/tests/test_registry.py @@ -22,7 +22,7 @@ import pytest import optree -from helpers import gc_collect +from helpers import gc_collect, skipif_pypy def test_register_pytree_node_class_with_no_namespace(): @@ -765,6 +765,7 @@ def test_unregister_pytree_node_namedtuple(): assert treespec3 != treespec4 +@skipif_pypy def test_unregister_pytree_node_memory_leak(): # noqa: C901 @optree.register_pytree_node_class(namespace=optree.registry.__GLOBAL_NAMESPACE) diff --git a/tests/test_treespec.py b/tests/test_treespec.py index 1c59b110..ed341aab 100644 --- a/tests/test_treespec.py +++ b/tests/test_treespec.py @@ -30,6 +30,7 @@ import optree from helpers import ( NAMESPACED_TREE, + PYPY, TREE_STRINGS, TREES, MyAnotherDict, @@ -225,13 +226,17 @@ def __repr__(self): hashes.add(hash(treespec)) assert hash(other) == hash(other) assert hash(treespec) == hash(other) - with pytest.raises(RecursionError): - assert treespec != other + + if not PYPY: + with pytest.raises(RecursionError): + assert treespec != other wr = weakref.ref(treespec) del treespec, key, other gc_collect() - assert wr() is None + + if not PYPY: + assert wr() is None def test_treeiter_self_referential(): @@ -260,7 +265,9 @@ def test_treeiter_self_referential(): del it, d gc_collect() - assert wr() is None + + if not PYPY: + assert wr() is None def test_treespec_with_namespace(): @@ -406,10 +413,7 @@ def test_treespec_pickle_round_trip(tree, none_is_leaf, namespace): try: pickle.loads(pickle.dumps(tree)) except pickle.PicklingError: - with pytest.raises( - pickle.PicklingError, - match="Can't pickle .*: it's not the same object as .*", - ): + with pytest.raises(pickle.PicklingError, match=r"Can't pickle .*:"): pickle.loads(pickle.dumps(expected)) else: actual = pickle.loads(pickle.dumps(expected)) diff --git a/tests/test_typing.py b/tests/test_typing.py index 7cc9289d..6907b415 100644 --- a/tests/test_typing.py +++ b/tests/test_typing.py @@ -24,7 +24,14 @@ import pytest import optree -from helpers import CustomNamedTupleSubclass, CustomTuple, Vector2D, gc_collect, getrefcount +from helpers import ( + CustomNamedTupleSubclass, + CustomTuple, + Vector2D, + gc_collect, + getrefcount, + skipif_pypy, +) class FakeNamedTuple(tuple): @@ -84,6 +91,7 @@ def test_is_namedtuple(): assert not optree.is_namedtuple_class(FakeStructSequence) +@skipif_pypy def test_is_namedtuple_cache(): Point = namedtuple('Point', ('x', 'y')) # noqa: PYI024 @@ -141,6 +149,7 @@ class Foo(metaclass=FooMeta): assert wr() is None +@skipif_pypy def test_namedtuple_fields_cache(): Point = namedtuple('Point', ('x', 'y')) # noqa: PYI024 @@ -241,6 +250,7 @@ class MyTuple(optree.typing.structseq): assert not optree.is_structseq_class(FakeStructSequence) +@skipif_pypy def test_is_structseq_cache(): Point = namedtuple('Point', ('x', 'y')) # noqa: PYI024 @@ -452,6 +462,7 @@ def test_structseq_fields(): optree.structseq_fields(FakeStructSequence) +@skipif_pypy def test_structseq_fields_cache(): Point = namedtuple('Point', ('x', 'y')) # noqa: PYI024