Skip to content

Commit

Permalink
feat: add PyPy support (#145)
Browse files Browse the repository at this point in the history
  • Loading branch information
XuehaiPan authored Jun 23, 2024
1 parent 8145df0 commit 789f93d
Show file tree
Hide file tree
Showing 14 changed files with 148 additions and 55 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,8 @@ jobs:
strategy:
matrix:
os: [ubuntu-latest, windows-latest, macos-latest]
python-version: ["3.7", "3.8", "3.9", "3.10", "3.11", "3.12"]
python-version:
["3.7", "3.8", "3.9", "3.10", "3.11", "3.12", "pypy3.9", "pypy3.10"]
include:
- os: macos-13
python-version: "3.7"
Expand Down
6 changes: 3 additions & 3 deletions .github/workflows/set_cibw_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
# pylint: disable=missing-module-docstring

import os
import sys
import platform


# pylint: disable-next=consider-using-f-string
CIBW_BUILD = 'CIBW_BUILD=*cp%d%d-*' % sys.version_info[:2]
MAJOR, MINOR, *_ = platform.python_version_tuple()
CIBW_BUILD = f'CIBW_BUILD=*{platform.python_implementation().lower()[0]}p{MAJOR}{MINOR}-*'

print(CIBW_BUILD)
with open(os.getenv('GITHUB_ENV'), mode='a', encoding='utf-8') as file:
Expand Down
3 changes: 2 additions & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ jobs:
strategy:
matrix:
os: [ubuntu-latest, windows-latest, macos-latest]
python-version: ["3.7", "3.8", "3.9", "3.10", "3.11", "3.12"]
python-version:
["3.7", "3.8", "3.9", "3.10", "3.11", "3.12", "pypy3.9", "pypy3.10"]
include:
- os: macos-13
python-version: "3.7"
Expand Down
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- Add PyPy support by [@XuehaiPan](https://github.com/XuehaiPan) in [#145](https://github.com/metaopt/optree/pull/145).
- Add 32-bit wheels for Linux and Windows by [@XuehaiPan](https://github.com/XuehaiPan) in [#141](https://github.com/metaopt/optree/pull/141).
- Add Linux ppc64le and s390x wheels by [@XuehaiPan](https://github.com/XuehaiPan) in [#138](https://github.com/metaopt/optree/pull/138).
- Add accessor APIs `tree_flatten_with_accessor` and `PyTreeSpec.accessors` by [@XuehaiPan](https://github.com/XuehaiPan) in [#108](https://github.com/metaopt/optree/pull/108).
Expand Down
4 changes: 4 additions & 0 deletions include/treespec.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,11 @@ using ssize_t = py::ssize_t;
#ifndef Py_C_RECURSION_LIMIT
#define Py_C_RECURSION_LIMIT 1000
#endif
#ifndef PYPY_VERSION
constexpr ssize_t MAX_RECURSION_DEPTH = std::min(1000, Py_C_RECURSION_LIMIT);
#else
constexpr ssize_t MAX_RECURSION_DEPTH = std::min(500, Py_C_RECURSION_LIMIT);
#endif

// Test whether the given object is a leaf node.
bool IsLeaf(const py::object &object,
Expand Down
68 changes: 47 additions & 21 deletions include/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ limitations under the License.
#include <structmember.h> // PyMemberDef
#endif

#include <pybind11/eval.h> // pybind11::exec
#include <pybind11/pybind11.h>

#include <exception> // std::rethrow_exception, std::current_exception
Expand All @@ -34,24 +35,22 @@ limitations under the License.
#include <vector> // std::vector

namespace py = pybind11;
using size_t = py::size_t;
using ssize_t = py::ssize_t;

// The maximum size of the type cache.
constexpr ssize_t MAX_TYPE_CACHE_SIZE = 4096;
constexpr py::ssize_t MAX_TYPE_CACHE_SIZE = 4096;

// boost::hash_combine
template <class T>
inline void HashCombine(py::size_t& seed, const T& v) { // NOLINT[runtime/references]
std::hash<T> hasher{};
// NOLINTNEXTLINE[cppcoreguidelines-avoid-magic-numbers]
seed ^= (hasher(v) + 0x9e3779b9 + (seed << 6) + (seed >> 2));
seed ^= (hasher(v) + 0x9E3779B9 + (seed << 6) + (seed >> 2));
}
template <class T>
inline void HashCombine(py::ssize_t& seed, const T& v) { // NOLINT[runtime/references]
std::hash<T> hasher{};
// NOLINTNEXTLINE[cppcoreguidelines-avoid-magic-numbers]
seed ^= (hasher(v) + 0x9e3779b9 + (seed << 6) + (seed >> 2));
seed ^= (hasher(v) + 0x9E3779B9 + (seed << 6) + (seed >> 2));
}

class TypeHash {
Expand Down Expand Up @@ -127,6 +126,7 @@ constexpr bool NONE_IS_NODE = false;
#define Py_Get_ID(name) (Py_ID_##name())

Py_Declare_ID(optree);
Py_Declare_ID(__main__); // __main__
Py_Declare_ID(__module__); // type.__module__
Py_Declare_ID(__qualname__); // type.__qualname__
Py_Declare_ID(__name__); // type.__name__
Expand Down Expand Up @@ -180,39 +180,39 @@ inline std::vector<T> reserved_vector(const py::size_t& size) {
}

template <typename Sized = py::object>
inline ssize_t GetSize(const py::handle& sized) {
inline py::ssize_t GetSize(const py::handle& sized) {
return py::ssize_t_cast(py::len(sized));
}
template <>
inline ssize_t GetSize<py::tuple>(const py::handle& sized) {
inline py::ssize_t GetSize<py::tuple>(const py::handle& sized) {
return PyTuple_Size(sized.ptr());
}
template <>
inline ssize_t GetSize<py::list>(const py::handle& sized) {
inline py::ssize_t GetSize<py::list>(const py::handle& sized) {
return PyList_Size(sized.ptr());
}
template <>
inline ssize_t GetSize<py::dict>(const py::handle& sized) {
inline py::ssize_t GetSize<py::dict>(const py::handle& sized) {
return PyDict_Size(sized.ptr());
}

template <typename Sized = py::object>
inline ssize_t GET_SIZE(const py::handle& sized) {
inline py::ssize_t GET_SIZE(const py::handle& sized) {
return py::ssize_t_cast(py::len(sized));
}
template <>
inline ssize_t GET_SIZE<py::tuple>(const py::handle& sized) {
inline py::ssize_t GET_SIZE<py::tuple>(const py::handle& sized) {
return PyTuple_GET_SIZE(sized.ptr());
}
template <>
inline ssize_t GET_SIZE<py::list>(const py::handle& sized) {
inline py::ssize_t GET_SIZE<py::list>(const py::handle& sized) {
return PyList_GET_SIZE(sized.ptr());
}
#ifndef PyDict_GET_SIZE
#define PyDict_GET_SIZE PyDict_GetSize
#define PyDict_GET_SIZE PyDict_Size
#endif
template <>
inline ssize_t GET_SIZE<py::dict>(const py::handle& sized) {
inline py::ssize_t GET_SIZE<py::dict>(const py::handle& sized) {
return PyDict_GET_SIZE(sized.ptr());
}

Expand Down Expand Up @@ -421,7 +421,6 @@ inline bool IsStructSequenceClassImpl(const py::handle& type) {
// n_fields, n_sequence_fields, n_unnamed_fields attributes.
auto* type_object = reinterpret_cast<PyTypeObject*>(type.ptr());
if (PyType_FastSubclass(type_object, Py_TPFLAGS_TUPLE_SUBCLASS) &&
!static_cast<bool>(PyType_HasFeature(type_object, Py_TPFLAGS_BASETYPE)) &&
type_object->tp_bases != nullptr &&
static_cast<bool>(PyTuple_CheckExact(type_object->tp_bases)) &&
PyTuple_GET_SIZE(type_object->tp_bases) == 1 &&
Expand All @@ -441,7 +440,16 @@ inline bool IsStructSequenceClassImpl(const py::handle& type) {
return false;
}
}
return true;
#ifdef PYPY_VERSION
try {
py::exec("class _(cls): pass", py::dict(py::arg("cls") = type));
} catch (py::error_already_set& ex) {
return (ex.matches(PyExc_AssertionError) || ex.matches(PyExc_TypeError));
}
return false;
#else
return (!static_cast<bool>(PyType_HasFeature(type_object, Py_TPFLAGS_BASETYPE)));
#endif
}
return false;
}
Expand Down Expand Up @@ -480,14 +488,32 @@ inline void AssertExactStructSequence(const py::handle& object) {
}
}
inline py::tuple StructSequenceGetFieldsImpl(const py::handle& type) {
const auto n_sequence_fields = py::cast<ssize_t>(getattr(type, Py_Get_ID(n_sequence_fields)));
#ifdef PYPY_VERSION
py::list fields{};
py::exec(
R"py(
from _structseq import structseqfield
indices_by_name = {
name: member.index
for name, member in vars(cls).items()
if isinstance(member, structseqfield)
}
fields.extend(sorted(indices_by_name, key=indices_by_name.get)[:cls.n_sequence_fields])
)py",
py::dict(py::arg("cls") = type, py::arg("fields") = fields));
return py::tuple{fields};
#else
const auto n_sequence_fields =
py::cast<py::ssize_t>(getattr(type, Py_Get_ID(n_sequence_fields)));
auto* members = reinterpret_cast<PyTypeObject*>(type.ptr())->tp_members;
py::tuple fields{n_sequence_fields};
for (ssize_t i = 0; i < n_sequence_fields; ++i) {
for (py::ssize_t i = 0; i < n_sequence_fields; ++i) {
// NOLINTNEXTLINE[cppcoreguidelines-pro-bounds-pointer-arithmetic]
SET_ITEM<py::tuple>(fields, i, py::str(members[i].name));
}
return fields;
#endif
}
inline py::tuple StructSequenceGetFields(const py::handle& object) {
py::handle type;
Expand Down Expand Up @@ -573,12 +599,12 @@ inline py::list SortedDictKeys(const py::dict& dict) {
}

inline bool DictKeysEqual(const py::list& /*unique*/ keys, const py::dict& dict) {
ssize_t list_len = GET_SIZE<py::list>(keys);
ssize_t dict_len = GET_SIZE<py::dict>(dict);
py::ssize_t list_len = GET_SIZE<py::list>(keys);
py::ssize_t dict_len = GET_SIZE<py::dict>(dict);
if (list_len != dict_len) [[likely]] { // assumes keys are unique
return false;
}
for (ssize_t i = 0; i < list_len; ++i) {
for (py::ssize_t i = 0; i < list_len; ++i) {
py::object key = GET_ITEM_BORROW<py::list>(keys, i);
int result = PyDict_Contains(dict.ptr(), key.ptr());
if (result == -1) [[unlikely]] {
Expand Down
42 changes: 31 additions & 11 deletions optree/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from __future__ import annotations

import platform
import types
from collections.abc import Hashable
from typing import (
Expand Down Expand Up @@ -399,17 +400,27 @@ def is_structseq_instance(obj: object) -> bool:

def is_structseq_class(cls: type) -> bool:
"""Return whether the class is a class of PyStructSequence."""
return (
if (
isinstance(cls, type)
# Check direct inheritance from `tuple` rather than `issubclass(cls, tuple)`
and cls.__bases__ == (tuple,)
# Check PyStructSequence members
and isinstance(getattr(cls, 'n_fields', None), int)
and isinstance(getattr(cls, 'n_sequence_fields', None), int)
and isinstance(getattr(cls, 'n_unnamed_fields', None), int)
):
# Check the type does not allow subclassing
and not (cls.__flags__ & Py_TPFLAGS_BASETYPE)
)
if platform.python_implementation() == 'PyPy':
try:
# pylint: disable-next=too-few-public-methods
class _(cls): # noqa: N801
pass

except (AssertionError, TypeError):
return True
return False
return not bool(cls.__flags__ & Py_TPFLAGS_BASETYPE)
return False


def structseq_fields(obj: tuple | type[tuple]) -> tuple[str, ...]:
Expand All @@ -423,14 +434,23 @@ def structseq_fields(obj: tuple | type[tuple]) -> tuple[str, ...]:
if not is_structseq_class(cls):
raise TypeError(f'Expected an instance of PyStructSequence type, got {obj!r}.')

n_sequence_fields: int = cls.n_sequence_fields # type: ignore[attr-defined]
fields: list[str] = []
for name, member in vars(cls).items():
if len(fields) >= n_sequence_fields:
break
if isinstance(member, types.MemberDescriptorType):
fields.append(name)
return tuple(fields)
if platform.python_implementation() == 'PyPy':
# pylint: disable-next=import-error,import-outside-toplevel
from _structseq import structseqfield

indices_by_name = {
name: member.index
for name, member in vars(cls).items()
if isinstance(member, structseqfield)
}
fields = sorted(indices_by_name, key=indices_by_name.get) # type: ignore[arg-type]
else:
fields = [
name
for name, member in vars(cls).items()
if isinstance(member, types.MemberDescriptorType)
]
return tuple(fields[: cls.n_sequence_fields]) # type: ignore[attr-defined]


# Ensure that the behavior is consistent with C++ implementation
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ classifiers = [
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: Implementation :: CPython",
"Programming Language :: Python :: Implementation :: PyPy",
"Operating System :: Microsoft :: Windows",
"Operating System :: POSIX :: Linux",
"Operating System :: MacOS",
Expand Down
23 changes: 17 additions & 6 deletions src/treespec/treespec.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1214,17 +1214,28 @@ std::string PyTreeSpec::ToStringImpl() const {

case PyTreeKind::StructSequence: {
py::object type = node.node_data;
auto* members = reinterpret_cast<PyTypeObject*>(type.ptr())->tp_members;
std::string kind = reinterpret_cast<PyTypeObject*>(type.ptr())->tp_name;
sstream << kind << "(";
auto fields = StructSequenceGetFields(type);
EXPECT_EQ(GET_SIZE<py::tuple>(fields),
node.arity,
"Number of fields and entries does not match.");
py::object module_name =
py::getattr(type, Py_Get_ID(__module__), Py_Get_ID(__main__));
if (!module_name.is_none()) [[likely]] {
std::string name = static_cast<std::string>(py::str(module_name));
if (!(name.empty() || name == "__main__" || name == "builtins" ||
name == "__builtins__")) [[likely]] {
sstream << name << ".";
}
}
py::object qualname = py::getattr(type, Py_Get_ID(__qualname__));
sstream << static_cast<std::string>(py::str(qualname)) << "(";
bool first = true;
auto child_iter = agenda.end() - node.arity;
for (ssize_t i = 0; i < node.arity; ++i) {
for (const py::handle& field : fields) {
if (!first) [[likely]] {
sstream << ", ";
}
// NOLINTNEXTLINE[cppcoreguidelines-pro-bounds-pointer-arithmetic]
sstream << members[i].name << "=" << *child_iter;
sstream << static_cast<std::string>(py::str(field)) << "=" << *child_iter;
++child_iter;
first = false;
}
Expand Down
12 changes: 12 additions & 0 deletions tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import dataclasses
import gc
import itertools
import platform
import sys
import time
from collections import OrderedDict, UserDict, defaultdict, deque, namedtuple
Expand All @@ -28,6 +29,13 @@
import optree


PYPY = platform.python_implementation() == 'PyPy'
skipif_pypy = pytest.mark.skipif(
PYPY,
reason='PyPy does not support weakref and refcount correctly',
)


def gc_collect():
for _ in range(3):
gc.collect()
Expand Down Expand Up @@ -114,6 +122,10 @@ class EmptyTuple(NamedTuple):
# time.struct_time(tm_year=*, tm_mon=*, tm_mday=*, tm_hour=*, tm_min=*, tm_sec=*, tm_wday=*, tm_yday=*, tm_isdst=*)
TimeStructTimeType = time.struct_time

if PYPY:
SysFloatInfoType.__module__ = 'sys'
TimeStructTimeType.__module__ = 'time'


class Vector3D:
def __init__(self, x, y, z):
Expand Down
4 changes: 2 additions & 2 deletions tests/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2916,7 +2916,7 @@ def test_tree_max():
assert optree.tree_max(None, default=0) == 0
assert optree.tree_max(None, none_is_leaf=True) is None
assert optree.tree_max(None, default=0, key=operator.neg) == 0
with pytest.raises(TypeError, match=re.escape("bad operand type for unary -: 'NoneType'")):
with pytest.raises(TypeError, match=".*operand type for unary .+: 'NoneType'"):
assert optree.tree_max(None, default=0, key=operator.neg, none_is_leaf=True) is None


Expand All @@ -2935,7 +2935,7 @@ def test_tree_min():
assert optree.tree_min(None, default=0) == 0
assert optree.tree_min(None, none_is_leaf=True) is None
assert optree.tree_min(None, default=0, key=operator.neg) == 0
with pytest.raises(TypeError, match=re.escape("bad operand type for unary -: 'NoneType'")):
with pytest.raises(TypeError, match=".*operand type for unary .+: 'NoneType'"):
assert optree.tree_min(None, default=0, key=operator.neg, none_is_leaf=True) is None


Expand Down
Loading

0 comments on commit 789f93d

Please sign in to comment.