Skip to content

Commit

Permalink
feat: support PyPy
Browse files Browse the repository at this point in the history
  • Loading branch information
XuehaiPan committed Jun 23, 2024
1 parent 8145df0 commit 5c6105e
Show file tree
Hide file tree
Showing 10 changed files with 122 additions and 34 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/set_cibw_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
# pylint: disable=missing-module-docstring

import os
import sys
import platform


# pylint: disable-next=consider-using-f-string
CIBW_BUILD = 'CIBW_BUILD=*cp%d%d-*' % sys.version_info[:2]
MAJOR, MINOR, *_ = platform.python_version_tuple()
CIBW_BUILD = f'CIBW_BUILD=*{platform.python_implementation().lower()[0]}p{MAJOR}{MINOR}-*'

print(CIBW_BUILD)
with open(os.getenv('GITHUB_ENV'), mode='a', encoding='utf-8') as file:
Expand Down
31 changes: 29 additions & 2 deletions include/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ limitations under the License.
#include <structmember.h> // PyMemberDef
#endif

#include <pybind11/eval.h> // pybind11::exec
#include <pybind11/pybind11.h>

#include <exception> // std::rethrow_exception, std::current_exception
Expand Down Expand Up @@ -127,6 +128,7 @@ constexpr bool NONE_IS_NODE = false;
#define Py_Get_ID(name) (Py_ID_##name())

Py_Declare_ID(optree);
Py_Declare_ID(__main__); // __main__
Py_Declare_ID(__module__); // type.__module__
Py_Declare_ID(__qualname__); // type.__qualname__
Py_Declare_ID(__name__); // type.__name__
Expand Down Expand Up @@ -421,7 +423,6 @@ inline bool IsStructSequenceClassImpl(const py::handle& type) {
// n_fields, n_sequence_fields, n_unnamed_fields attributes.
auto* type_object = reinterpret_cast<PyTypeObject*>(type.ptr());
if (PyType_FastSubclass(type_object, Py_TPFLAGS_TUPLE_SUBCLASS) &&
!static_cast<bool>(PyType_HasFeature(type_object, Py_TPFLAGS_BASETYPE)) &&
type_object->tp_bases != nullptr &&
static_cast<bool>(PyTuple_CheckExact(type_object->tp_bases)) &&
PyTuple_GET_SIZE(type_object->tp_bases) == 1 &&
Expand All @@ -441,7 +442,16 @@ inline bool IsStructSequenceClassImpl(const py::handle& type) {
return false;
}
}
return true;
#ifdef PYPY_VERSION
try {
py::exec("class _(cls): pass", py::dict(py::arg("cls") = type));
} catch (py::error_already_set& ex) {
return (ex.matches(PyExc_AssertionError) || ex.matches(PyExc_TypeError));
}
return false;
#else
return (!static_cast<bool>(PyType_HasFeature(type_object, Py_TPFLAGS_BASETYPE)));
#endif
}
return false;
}
Expand Down Expand Up @@ -480,6 +490,22 @@ inline void AssertExactStructSequence(const py::handle& object) {
}
}
inline py::tuple StructSequenceGetFieldsImpl(const py::handle& type) {
#ifdef PYPY_VERSION
py::list fields{};
py::exec(
R"py(
from _structseq import structseqfield
indices_by_name = {
name: member.index
for name, member in vars(cls).items()
if isinstance(member, structseqfield)
}
fields.extend(sorted(indices_by_name, key=indices_by_name.get)[:cls.n_sequence_fields])
)py",
py::dict(py::arg("cls") = type, py::arg("fields") = fields));
return py::tuple{fields};
#else
const auto n_sequence_fields = py::cast<ssize_t>(getattr(type, Py_Get_ID(n_sequence_fields)));
auto* members = reinterpret_cast<PyTypeObject*>(type.ptr())->tp_members;
py::tuple fields{n_sequence_fields};
Expand All @@ -488,6 +514,7 @@ inline py::tuple StructSequenceGetFieldsImpl(const py::handle& type) {
SET_ITEM<py::tuple>(fields, i, py::str(members[i].name));
}
return fields;
#endif
}
inline py::tuple StructSequenceGetFields(const py::handle& object) {
py::handle type;
Expand Down
42 changes: 31 additions & 11 deletions optree/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from __future__ import annotations

import platform
import types
from collections.abc import Hashable
from typing import (
Expand Down Expand Up @@ -399,17 +400,27 @@ def is_structseq_instance(obj: object) -> bool:

def is_structseq_class(cls: type) -> bool:
"""Return whether the class is a class of PyStructSequence."""
return (
if (
isinstance(cls, type)
# Check direct inheritance from `tuple` rather than `issubclass(cls, tuple)`
and cls.__bases__ == (tuple,)
# Check PyStructSequence members
and isinstance(getattr(cls, 'n_fields', None), int)
and isinstance(getattr(cls, 'n_sequence_fields', None), int)
and isinstance(getattr(cls, 'n_unnamed_fields', None), int)
):
# Check the type does not allow subclassing
and not (cls.__flags__ & Py_TPFLAGS_BASETYPE)
)
if platform.python_implementation() == 'PyPy':
try:
# pylint: disable-next=too-few-public-methods
class _(cls): # noqa: N801
pass

except (AssertionError, TypeError):
return True
return False
return not bool(cls.__flags__ & Py_TPFLAGS_BASETYPE)
return False


def structseq_fields(obj: tuple | type[tuple]) -> tuple[str, ...]:
Expand All @@ -423,14 +434,23 @@ def structseq_fields(obj: tuple | type[tuple]) -> tuple[str, ...]:
if not is_structseq_class(cls):
raise TypeError(f'Expected an instance of PyStructSequence type, got {obj!r}.')

n_sequence_fields: int = cls.n_sequence_fields # type: ignore[attr-defined]
fields: list[str] = []
for name, member in vars(cls).items():
if len(fields) >= n_sequence_fields:
break
if isinstance(member, types.MemberDescriptorType):
fields.append(name)
return tuple(fields)
if platform.python_implementation() == 'PyPy':
# pylint: disable-next=import-error,import-outside-toplevel
from _structseq import structseqfield

indices_by_name = {
name: member.index
for name, member in vars(cls).items()
if isinstance(member, structseqfield)
}
fields = sorted(indices_by_name, key=indices_by_name.get) # type: ignore[arg-type]
else:
fields = [
name
for name, member in vars(cls).items()
if isinstance(member, types.MemberDescriptorType)
]
return tuple(fields[: cls.n_sequence_fields]) # type: ignore[attr-defined]


# Ensure that the behavior is consistent with C++ implementation
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ classifiers = [
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: Implementation :: CPython",
"Programming Language :: Python :: Implementation :: PyPy",
"Operating System :: Microsoft :: Windows",
"Operating System :: POSIX :: Linux",
"Operating System :: MacOS",
Expand Down
24 changes: 18 additions & 6 deletions src/treespec/treespec.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1214,17 +1214,29 @@ std::string PyTreeSpec::ToStringImpl() const {

case PyTreeKind::StructSequence: {
py::object type = node.node_data;
auto* members = reinterpret_cast<PyTypeObject*>(type.ptr())->tp_members;
std::string kind = reinterpret_cast<PyTypeObject*>(type.ptr())->tp_name;
sstream << kind << "(";
auto fields = StructSequenceGetFields(type);
EXPECT_EQ(GET_SIZE<py::tuple>(fields),
node.arity,
"Number of fields and entries does not match.");
py::object module_name =
py::getattr(type, Py_Get_ID(__module__), Py_Get_ID(__main__));
if (!module_name.is_none()) [[likely]] {
std::string name = static_cast<std::string>(py::str(module_name));
if (!(name.empty() || name == "__main__" || name == "builtins" ||
name == "__builtins__")) [[likely]] {
sstream << name << ".";
}
}
sstream << static_cast<std::string>(
py::str(py::getattr(type, Py_Get_ID(__qualname__))));
sstream << "(";
bool first = true;
auto child_iter = agenda.end() - node.arity;
for (ssize_t i = 0; i < node.arity; ++i) {
for (const py::handle& field : fields) {
if (!first) [[likely]] {
sstream << ", ";
}
// NOLINTNEXTLINE[cppcoreguidelines-pro-bounds-pointer-arithmetic]
sstream << members[i].name << "=" << *child_iter;
sstream << static_cast<std::string>(py::str(field)) << "=" << *child_iter;
++child_iter;
first = false;
}
Expand Down
12 changes: 12 additions & 0 deletions tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import dataclasses
import gc
import itertools
import platform
import sys
import time
from collections import OrderedDict, UserDict, defaultdict, deque, namedtuple
Expand All @@ -28,6 +29,13 @@
import optree


PYPY = platform.python_implementation() == 'PyPy'
skipif_pypy = pytest.mark.skipif(
PYPY,
reason='PyPy does not support weakref and refcount correctly',
)


def gc_collect():
for _ in range(3):
gc.collect()
Expand Down Expand Up @@ -114,6 +122,10 @@ class EmptyTuple(NamedTuple):
# time.struct_time(tm_year=*, tm_mon=*, tm_mday=*, tm_hour=*, tm_min=*, tm_sec=*, tm_wday=*, tm_yday=*, tm_isdst=*)
TimeStructTimeType = time.struct_time

if PYPY:
SysFloatInfoType.__module__ = 'sys'
TimeStructTimeType.__module__ = 'time'


class Vector3D:
def __init__(self, x, y, z):
Expand Down
4 changes: 2 additions & 2 deletions tests/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2916,7 +2916,7 @@ def test_tree_max():
assert optree.tree_max(None, default=0) == 0
assert optree.tree_max(None, none_is_leaf=True) is None
assert optree.tree_max(None, default=0, key=operator.neg) == 0
with pytest.raises(TypeError, match=re.escape("bad operand type for unary -: 'NoneType'")):
with pytest.raises(TypeError, match=".*operand type for unary .+: 'NoneType'"):
assert optree.tree_max(None, default=0, key=operator.neg, none_is_leaf=True) is None


Expand All @@ -2935,7 +2935,7 @@ def test_tree_min():
assert optree.tree_min(None, default=0) == 0
assert optree.tree_min(None, none_is_leaf=True) is None
assert optree.tree_min(None, default=0, key=operator.neg) == 0
with pytest.raises(TypeError, match=re.escape("bad operand type for unary -: 'NoneType'")):
with pytest.raises(TypeError, match=".*operand type for unary .+: 'NoneType'"):
assert optree.tree_min(None, default=0, key=operator.neg, none_is_leaf=True) is None


Expand Down
3 changes: 2 additions & 1 deletion tests/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import pytest

import optree
from helpers import gc_collect
from helpers import gc_collect, skipif_pypy


def test_register_pytree_node_class_with_no_namespace():
Expand Down Expand Up @@ -765,6 +765,7 @@ def test_unregister_pytree_node_namedtuple():
assert treespec3 != treespec4


@skipif_pypy
def test_unregister_pytree_node_memory_leak(): # noqa: C901

@optree.register_pytree_node_class(namespace=optree.registry.__GLOBAL_NAMESPACE)
Expand Down
20 changes: 12 additions & 8 deletions tests/test_treespec.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import optree
from helpers import (
NAMESPACED_TREE,
PYPY,
TREE_STRINGS,
TREES,
MyAnotherDict,
Expand Down Expand Up @@ -225,13 +226,17 @@ def __repr__(self):
hashes.add(hash(treespec))
assert hash(other) == hash(other)
assert hash(treespec) == hash(other)
with pytest.raises(RecursionError):
assert treespec != other

if not PYPY:
with pytest.raises(RecursionError):
assert treespec != other

wr = weakref.ref(treespec)
del treespec, key, other
gc_collect()
assert wr() is None

if not PYPY:
assert wr() is None


def test_treeiter_self_referential():
Expand Down Expand Up @@ -260,7 +265,9 @@ def test_treeiter_self_referential():

del it, d
gc_collect()
assert wr() is None

if not PYPY:
assert wr() is None


def test_treespec_with_namespace():
Expand Down Expand Up @@ -406,10 +413,7 @@ def test_treespec_pickle_round_trip(tree, none_is_leaf, namespace):
try:
pickle.loads(pickle.dumps(tree))
except pickle.PicklingError:
with pytest.raises(
pickle.PicklingError,
match="Can't pickle .*: it's not the same object as .*",
):
with pytest.raises(pickle.PicklingError, match=r"Can't pickle .*:"):
pickle.loads(pickle.dumps(expected))
else:
actual = pickle.loads(pickle.dumps(expected))
Expand Down
13 changes: 12 additions & 1 deletion tests/test_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,14 @@
import pytest

import optree
from helpers import CustomNamedTupleSubclass, CustomTuple, Vector2D, gc_collect, getrefcount
from helpers import (
CustomNamedTupleSubclass,
CustomTuple,
Vector2D,
gc_collect,
getrefcount,
skipif_pypy,
)


class FakeNamedTuple(tuple):
Expand Down Expand Up @@ -84,6 +91,7 @@ def test_is_namedtuple():
assert not optree.is_namedtuple_class(FakeStructSequence)


@skipif_pypy
def test_is_namedtuple_cache():
Point = namedtuple('Point', ('x', 'y')) # noqa: PYI024

Expand Down Expand Up @@ -141,6 +149,7 @@ class Foo(metaclass=FooMeta):
assert wr() is None


@skipif_pypy
def test_namedtuple_fields_cache():
Point = namedtuple('Point', ('x', 'y')) # noqa: PYI024

Expand Down Expand Up @@ -241,6 +250,7 @@ class MyTuple(optree.typing.structseq):
assert not optree.is_structseq_class(FakeStructSequence)


@skipif_pypy
def test_is_structseq_cache():
Point = namedtuple('Point', ('x', 'y')) # noqa: PYI024

Expand Down Expand Up @@ -452,6 +462,7 @@ def test_structseq_fields():
optree.structseq_fields(FakeStructSequence)


@skipif_pypy
def test_structseq_fields_cache():
Point = namedtuple('Point', ('x', 'y')) # noqa: PYI024

Expand Down

0 comments on commit 5c6105e

Please sign in to comment.