Skip to content

Commit

Permalink
Fix interoperability with native async generators
Browse files Browse the repository at this point in the history
The missing pieces were:
- don't incref when unwrapping (my unconfirmed belief is that ctypes increfs
  automatically when you extract a field of py_object type)
- call PyObject_GC_Track when wrapping

The behavior is still different than in the old version that calls
_PyAsyncGenValueWrapperNew directly; with that simpler call (the old approach
that doesn't work on Windows due to symbol visibility), refcounts track obvious
references without needing to call gc.collect(), while here it's possible for
the refcount to seem over-inflated and then fix itself when a GC occurs.
I haven't looked into this enough to be sure, but suspect the difference
boils down to our inability to allocate from the freelist that the C allocation
and deallocation functions use.
  • Loading branch information
oremanj committed Feb 11, 2018
1 parent 17521f8 commit 95b5d14
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 116 deletions.
152 changes: 75 additions & 77 deletions async_generator/_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,87 +8,85 @@
import collections.abc


class YieldWrapper:
def __init__(self, payload):
self.payload = payload


def _wrap(value):
return YieldWrapper(value)


def _is_wrapped(box):
return isinstance(box, YieldWrapper)


def _unwrap(box):
return box.payload


# This is the magic code that lets you use yield_ and yield_from_ with native
# generators.
#
# The old version worked great on Linux and MacOS, but not on Windows, because
# it depended on _PyAsyncGenValueWrapperNew. The new version segfaults
# everywhere, and I'm not sure why -- probably my lack of understanding
# of ctypes and refcounts.
#
# There are also some commented out tests that should be re-enabled if this is
# fixed:
#
# if sys.version_info >= (3, 6):
# # Use the same box type that the interpreter uses internally. This allows
# # yield_ and (more importantly!) yield_from_ to work in built-in
# # generators.
# import ctypes # mua ha ha.
#
# # We used to call _PyAsyncGenValueWrapperNew to create and set up new
# # wrapper objects, but that symbol isn't available on Windows:
# #
# # https://github.com/python-trio/async_generator/issues/5
# #
# # Fortunately, the type object is available, but it means we have to do
# # this the hard way.
#
# # We don't actually need to access this, but we need to make a ctypes
# # structure so we can call addressof.
# class _ctypes_PyTypeObject(ctypes.Structure):
# pass
# _PyAsyncGenWrappedValue_Type_ptr = ctypes.addressof(
# _ctypes_PyTypeObject.in_dll(
# ctypes.pythonapi, "_PyAsyncGenWrappedValue_Type"))
# _PyObject_GC_New = ctypes.pythonapi._PyObject_GC_New
# _PyObject_GC_New.restype = ctypes.py_object
# _PyObject_GC_New.argtypes = (ctypes.c_void_p,)
#
# _Py_IncRef = ctypes.pythonapi.Py_IncRef
# _Py_IncRef.restype = None
# _Py_IncRef.argtypes = (ctypes.py_object,)
#
# class _ctypes_PyAsyncGenWrappedValue(ctypes.Structure):
# _fields_ = [
# ('PyObject_HEAD', ctypes.c_byte * object().__sizeof__()),
# ('agw_val', ctypes.py_object),
# ]
# def _wrap(value):
# box = _PyObject_GC_New(_PyAsyncGenWrappedValue_Type_ptr)
# raw = ctypes.cast(ctypes.c_void_p(id(box)),
# ctypes.POINTER(_ctypes_PyAsyncGenWrappedValue))
# raw.contents.agw_val = value
# _Py_IncRef(value)
# return box
#
# def _unwrap(box):
# assert _is_wrapped(box)
# raw = ctypes.cast(ctypes.c_void_p(id(box)),
# ctypes.POINTER(_ctypes_PyAsyncGenWrappedValue))
# value = raw.contents.agw_val
# _Py_IncRef(value)
# return value
#
# _PyAsyncGenWrappedValue_Type = type(_wrap(1))
# def _is_wrapped(box):
# return isinstance(box, _PyAsyncGenWrappedValue_Type)
# it depended on _PyAsyncGenValueWrapperNew. The new version gets around this
# by inlining most of the code of _PyAsyncGenValueWrapperNew (skipping the freelist
# maintenance, but that's OK since it's just an optimization).

if sys.version_info >= (3, 6):
# Use the same box type that the interpreter uses internally. This allows
# yield_ and (more importantly!) yield_from_ to work in built-in
# generators.
import ctypes # mua ha ha.

# We used to call _PyAsyncGenValueWrapperNew to create and set up new
# wrapper objects, but that symbol isn't available on Windows:
#
# https://github.com/python-trio/async_generator/issues/5
#
# Fortunately, the type object is available, but it means we have to do
# this the hard way.

# We don't actually need to access this, but we need to make a ctypes
# structure so we can call addressof.
class _ctypes_PyTypeObject(ctypes.Structure):
pass
_PyAsyncGenWrappedValue_Type_ptr = ctypes.addressof(
_ctypes_PyTypeObject.in_dll(
ctypes.pythonapi, "_PyAsyncGenWrappedValue_Type"))
_PyObject_GC_New = ctypes.pythonapi._PyObject_GC_New
_PyObject_GC_New.restype = ctypes.py_object
_PyObject_GC_New.argtypes = (ctypes.c_void_p,)

_PyObject_GC_Track = ctypes.pythonapi.PyObject_GC_Track
_PyObject_GC_Track.restype = None
_PyObject_GC_Track.argtypes = (ctypes.py_object,)

_Py_IncRef = ctypes.pythonapi.Py_IncRef
_Py_IncRef.restype = None
_Py_IncRef.argtypes = (ctypes.py_object,)

class _ctypes_PyAsyncGenWrappedValue(ctypes.Structure):
_fields_ = [
('PyObject_HEAD', ctypes.c_byte * object().__sizeof__()),
('agw_val', ctypes.py_object),
]
def _wrap(value):
box = _PyObject_GC_New(_PyAsyncGenWrappedValue_Type_ptr)
raw = ctypes.cast(ctypes.c_void_p(id(box)),
ctypes.POINTER(_ctypes_PyAsyncGenWrappedValue))
raw.contents.agw_val = value
_Py_IncRef(value)
_PyObject_GC_Track(box)
return box

def _unwrap(box):
assert _is_wrapped(box)
raw = ctypes.cast(ctypes.c_void_p(id(box)),
ctypes.POINTER(_ctypes_PyAsyncGenWrappedValue))
value = raw.contents.agw_val
return value

_PyAsyncGenWrappedValue_Type = type(_wrap(1))
def _is_wrapped(box):
return isinstance(box, _PyAsyncGenWrappedValue_Type)

else:
class YieldWrapper:
def __init__(self, payload):
self.payload = payload

def _wrap(value):
return YieldWrapper(value)

def _is_wrapped(box):
return isinstance(box, YieldWrapper)

def _unwrap(box):
return box.payload


# The magic @coroutine decorator is how you write the bottom level of
Expand Down
87 changes: 48 additions & 39 deletions async_generator/_tests/test_async_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,13 +352,12 @@ async def native_async_range(count):
for i in range(count):
yield i
# XX uncomment if/when we re-enable the ctypes hacks:
# async def native_async_range_twice(count):
# # make sure yield_from_ works inside a native async generator
# await yield_from_(async_range(count))
# yield None
# # make sure we can yield_from_ a native async generator
# await yield_from_(native_async_range(count))
async def native_async_range_twice(count):
# make sure yield_from_ works inside a native async generator
await yield_from_(async_range(count))
yield None
# make sure we can yield_from_ a native async generator
await yield_from_(native_async_range(count))
"""
)

Expand All @@ -382,11 +381,10 @@ async def yield_from_native():

assert await collect(yield_from_native()) == [0, 1, 2]

# XX uncomment if/when we re-enable the ctypes hacks:
# if sys.version_info >= (3, 6):
# assert await collect(native_async_range_twice(3)) == [
# 0, 1, 2, None, 0, 1, 2,
# ]
if sys.version_info >= (3, 6):
assert await collect(native_async_range_twice(3)) == [
0, 1, 2, None, 0, 1, 2,
]


@async_generator
Expand Down Expand Up @@ -706,33 +704,44 @@ async def f():

@pytest.mark.skipif(not hasattr(sys, "getrefcount"), reason="CPython only")
def test_refcnt():
x = object()
print(sys.getrefcount(x))
print(sys.getrefcount(x))
print(sys.getrefcount(x))
print(sys.getrefcount(x))
base_count = sys.getrefcount(x)
l = [_impl._wrap(x) for _ in range(100)]
print(sys.getrefcount(x))
print(sys.getrefcount(x))
print(sys.getrefcount(x))
assert sys.getrefcount(x) >= base_count + 100
l2 = [_impl._unwrap(box) for box in l]
assert sys.getrefcount(x) >= base_count + 200
print(sys.getrefcount(x))
print(sys.getrefcount(x))
print(sys.getrefcount(x))
print(sys.getrefcount(x))
del l
print(sys.getrefcount(x))
print(sys.getrefcount(x))
print(sys.getrefcount(x))
del l2
print(sys.getrefcount(x))
print(sys.getrefcount(x))
print(sys.getrefcount(x))
assert sys.getrefcount(x) == base_count
print(sys.getrefcount(x))
def test(collection_point):
import sys, gc
x = object()
base_count = sys.getrefcount(x)
l = [_impl._wrap(x) for _ in range(100)]
box = l[0]
assert sys.getrefcount(x) >= base_count + 100
assert sys.getrefcount(box) == 3
if collection_point == 1:
gc.collect()
assert sys.getrefcount(x) == base_count + 100
l2 = [_impl._unwrap(box) for box in l]
assert sys.getrefcount(x) >= base_count + 200
assert sys.getrefcount(box) == 3
if collection_point == 2:
gc.collect()
assert sys.getrefcount(x) == base_count + 200
assert all(id(unwrapped) == id(x) for unwrapped in l2)
del l
assert sys.getrefcount(box) == 2
del l2
assert sys.getrefcount(box) == 2
del box
if collection_point == 3:
gc.collect()
assert sys.getrefcount(x) == base_count

for collection_point in (0, 1, 2, 3):
for _ in range(20):
test(collection_point)

for _ in range(20):
for collection_point in (0, 1, 2, 3):
test(collection_point)

import random
for _ in range(100):
test(random.randrange(4))


################################################################
Expand Down

0 comments on commit 95b5d14

Please sign in to comment.