diff --git a/async_generator/_impl.py b/async_generator/_impl.py index cf2c783..e8dd021 100644 --- a/async_generator/_impl.py +++ b/async_generator/_impl.py @@ -8,87 +8,85 @@ import collections.abc -class YieldWrapper: - def __init__(self, payload): - self.payload = payload - - -def _wrap(value): - return YieldWrapper(value) - - -def _is_wrapped(box): - return isinstance(box, YieldWrapper) - - -def _unwrap(box): - return box.payload - - # This is the magic code that lets you use yield_ and yield_from_ with native # generators. # # The old version worked great on Linux and MacOS, but not on Windows, because -# it depended on _PyAsyncGenValueWrapperNew. The new version segfaults -# everywhere, and I'm not sure why -- probably my lack of understanding -# of ctypes and refcounts. -# -# There are also some commented out tests that should be re-enabled if this is -# fixed: -# -# if sys.version_info >= (3, 6): -# # Use the same box type that the interpreter uses internally. This allows -# # yield_ and (more importantly!) yield_from_ to work in built-in -# # generators. -# import ctypes # mua ha ha. -# -# # We used to call _PyAsyncGenValueWrapperNew to create and set up new -# # wrapper objects, but that symbol isn't available on Windows: -# # -# # https://github.com/python-trio/async_generator/issues/5 -# # -# # Fortunately, the type object is available, but it means we have to do -# # this the hard way. -# -# # We don't actually need to access this, but we need to make a ctypes -# # structure so we can call addressof. -# class _ctypes_PyTypeObject(ctypes.Structure): -# pass -# _PyAsyncGenWrappedValue_Type_ptr = ctypes.addressof( -# _ctypes_PyTypeObject.in_dll( -# ctypes.pythonapi, "_PyAsyncGenWrappedValue_Type")) -# _PyObject_GC_New = ctypes.pythonapi._PyObject_GC_New -# _PyObject_GC_New.restype = ctypes.py_object -# _PyObject_GC_New.argtypes = (ctypes.c_void_p,) -# -# _Py_IncRef = ctypes.pythonapi.Py_IncRef -# _Py_IncRef.restype = None -# _Py_IncRef.argtypes = (ctypes.py_object,) -# -# class _ctypes_PyAsyncGenWrappedValue(ctypes.Structure): -# _fields_ = [ -# ('PyObject_HEAD', ctypes.c_byte * object().__sizeof__()), -# ('agw_val', ctypes.py_object), -# ] -# def _wrap(value): -# box = _PyObject_GC_New(_PyAsyncGenWrappedValue_Type_ptr) -# raw = ctypes.cast(ctypes.c_void_p(id(box)), -# ctypes.POINTER(_ctypes_PyAsyncGenWrappedValue)) -# raw.contents.agw_val = value -# _Py_IncRef(value) -# return box -# -# def _unwrap(box): -# assert _is_wrapped(box) -# raw = ctypes.cast(ctypes.c_void_p(id(box)), -# ctypes.POINTER(_ctypes_PyAsyncGenWrappedValue)) -# value = raw.contents.agw_val -# _Py_IncRef(value) -# return value -# -# _PyAsyncGenWrappedValue_Type = type(_wrap(1)) -# def _is_wrapped(box): -# return isinstance(box, _PyAsyncGenWrappedValue_Type) +# it depended on _PyAsyncGenValueWrapperNew. The new version gets around this +# by inlining most of the code of _PyAsyncGenValueWrapperNew (skipping the freelist +# maintenance, but that's OK since it's just an optimization). + +if sys.version_info >= (3, 6): + # Use the same box type that the interpreter uses internally. This allows + # yield_ and (more importantly!) yield_from_ to work in built-in + # generators. + import ctypes # mua ha ha. + + # We used to call _PyAsyncGenValueWrapperNew to create and set up new + # wrapper objects, but that symbol isn't available on Windows: + # + # https://github.com/python-trio/async_generator/issues/5 + # + # Fortunately, the type object is available, but it means we have to do + # this the hard way. + + # We don't actually need to access this, but we need to make a ctypes + # structure so we can call addressof. + class _ctypes_PyTypeObject(ctypes.Structure): + pass + _PyAsyncGenWrappedValue_Type_ptr = ctypes.addressof( + _ctypes_PyTypeObject.in_dll( + ctypes.pythonapi, "_PyAsyncGenWrappedValue_Type")) + _PyObject_GC_New = ctypes.pythonapi._PyObject_GC_New + _PyObject_GC_New.restype = ctypes.py_object + _PyObject_GC_New.argtypes = (ctypes.c_void_p,) + + _PyObject_GC_Track = ctypes.pythonapi.PyObject_GC_Track + _PyObject_GC_Track.restype = None + _PyObject_GC_Track.argtypes = (ctypes.py_object,) + + _Py_IncRef = ctypes.pythonapi.Py_IncRef + _Py_IncRef.restype = None + _Py_IncRef.argtypes = (ctypes.py_object,) + + class _ctypes_PyAsyncGenWrappedValue(ctypes.Structure): + _fields_ = [ + ('PyObject_HEAD', ctypes.c_byte * object().__sizeof__()), + ('agw_val', ctypes.py_object), + ] + def _wrap(value): + box = _PyObject_GC_New(_PyAsyncGenWrappedValue_Type_ptr) + raw = ctypes.cast(ctypes.c_void_p(id(box)), + ctypes.POINTER(_ctypes_PyAsyncGenWrappedValue)) + raw.contents.agw_val = value + _Py_IncRef(value) + _PyObject_GC_Track(box) + return box + + def _unwrap(box): + assert _is_wrapped(box) + raw = ctypes.cast(ctypes.c_void_p(id(box)), + ctypes.POINTER(_ctypes_PyAsyncGenWrappedValue)) + value = raw.contents.agw_val + return value + + _PyAsyncGenWrappedValue_Type = type(_wrap(1)) + def _is_wrapped(box): + return isinstance(box, _PyAsyncGenWrappedValue_Type) + +else: + class YieldWrapper: + def __init__(self, payload): + self.payload = payload + + def _wrap(value): + return YieldWrapper(value) + + def _is_wrapped(box): + return isinstance(box, YieldWrapper) + + def _unwrap(box): + return box.payload # The magic @coroutine decorator is how you write the bottom level of diff --git a/async_generator/_tests/test_async_generator.py b/async_generator/_tests/test_async_generator.py index eaf7c92..fc2024c 100644 --- a/async_generator/_tests/test_async_generator.py +++ b/async_generator/_tests/test_async_generator.py @@ -352,13 +352,12 @@ async def native_async_range(count): for i in range(count): yield i -# XX uncomment if/when we re-enable the ctypes hacks: -# async def native_async_range_twice(count): -# # make sure yield_from_ works inside a native async generator -# await yield_from_(async_range(count)) -# yield None -# # make sure we can yield_from_ a native async generator -# await yield_from_(native_async_range(count)) +async def native_async_range_twice(count): + # make sure yield_from_ works inside a native async generator + await yield_from_(async_range(count)) + yield None + # make sure we can yield_from_ a native async generator + await yield_from_(native_async_range(count)) """ ) @@ -382,11 +381,10 @@ async def yield_from_native(): assert await collect(yield_from_native()) == [0, 1, 2] - # XX uncomment if/when we re-enable the ctypes hacks: - # if sys.version_info >= (3, 6): - # assert await collect(native_async_range_twice(3)) == [ - # 0, 1, 2, None, 0, 1, 2, - # ] + if sys.version_info >= (3, 6): + assert await collect(native_async_range_twice(3)) == [ + 0, 1, 2, None, 0, 1, 2, + ] @async_generator @@ -706,33 +704,44 @@ async def f(): @pytest.mark.skipif(not hasattr(sys, "getrefcount"), reason="CPython only") def test_refcnt(): - x = object() - print(sys.getrefcount(x)) - print(sys.getrefcount(x)) - print(sys.getrefcount(x)) - print(sys.getrefcount(x)) - base_count = sys.getrefcount(x) - l = [_impl._wrap(x) for _ in range(100)] - print(sys.getrefcount(x)) - print(sys.getrefcount(x)) - print(sys.getrefcount(x)) - assert sys.getrefcount(x) >= base_count + 100 - l2 = [_impl._unwrap(box) for box in l] - assert sys.getrefcount(x) >= base_count + 200 - print(sys.getrefcount(x)) - print(sys.getrefcount(x)) - print(sys.getrefcount(x)) - print(sys.getrefcount(x)) - del l - print(sys.getrefcount(x)) - print(sys.getrefcount(x)) - print(sys.getrefcount(x)) - del l2 - print(sys.getrefcount(x)) - print(sys.getrefcount(x)) - print(sys.getrefcount(x)) - assert sys.getrefcount(x) == base_count - print(sys.getrefcount(x)) + def test(collection_point): + import sys, gc + x = object() + base_count = sys.getrefcount(x) + l = [_impl._wrap(x) for _ in range(100)] + box = l[0] + assert sys.getrefcount(x) >= base_count + 100 + assert sys.getrefcount(box) == 3 + if collection_point == 1: + gc.collect() + assert sys.getrefcount(x) == base_count + 100 + l2 = [_impl._unwrap(box) for box in l] + assert sys.getrefcount(x) >= base_count + 200 + assert sys.getrefcount(box) == 3 + if collection_point == 2: + gc.collect() + assert sys.getrefcount(x) == base_count + 200 + assert all(id(unwrapped) == id(x) for unwrapped in l2) + del l + assert sys.getrefcount(box) == 2 + del l2 + assert sys.getrefcount(box) == 2 + del box + if collection_point == 3: + gc.collect() + assert sys.getrefcount(x) == base_count + + for collection_point in (0, 1, 2, 3): + for _ in range(20): + test(collection_point) + + for _ in range(20): + for collection_point in (0, 1, 2, 3): + test(collection_point) + + import random + for _ in range(100): + test(random.randrange(4)) ################################################################