Skip to content

Commit

Permalink
remove asyncio from contextlib async tests
Browse files Browse the repository at this point in the history
  • Loading branch information
graingert committed Dec 25, 2024
1 parent d9ed42b commit 0461429
Showing 1 changed file with 62 additions and 33 deletions.
95 changes: 62 additions & 33 deletions Lib/test/test_contextlib_async.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import asyncio
import functools
from contextlib import (
asynccontextmanager, AbstractAsyncContextManager,
AsyncExitStack, nullcontext, aclosing, contextmanager)
Expand All @@ -8,14 +8,32 @@

from test.test_contextlib import TestBaseExitStack

support.requires_working_socket(module=True)

def tearDownModule():
asyncio._set_event_loop_policy(None)
def _run_async_fn(async_fn, /, *args, **kwargs):
coro = async_fn(*args, **kwargs)
gen = type(coro).__await__(coro)
try:
gen.send(None)
except StopIteration as e:
return e.value
else:
raise AssertionError("coroutine did not stop")
finally:
gen.close()


class TestAbstractAsyncContextManager(unittest.IsolatedAsyncioTestCase):
def _async_test(async_fn):
"""Decorator to turn an async function into a test case."""
@functools.wraps(async_fn)
def wrapper(*args, **kwargs):
return _run_async_fn(async_fn, *args, **kwargs)

return wrapper


class TestAbstractAsyncContextManager(unittest.TestCase):

@_async_test
async def test_enter(self):
class DefaultEnter(AbstractAsyncContextManager):
async def __aexit__(self, *args):
Expand All @@ -27,6 +45,7 @@ async def __aexit__(self, *args):
async with manager as context:
self.assertIs(manager, context)

@_async_test
async def test_slots(self):
class DefaultAsyncContextManager(AbstractAsyncContextManager):
__slots__ = ()
Expand All @@ -38,6 +57,7 @@ async def __aexit__(self, *args):
manager = DefaultAsyncContextManager()
manager.var = 42

@_async_test
async def test_async_gen_propagates_generator_exit(self):
# A regression test for https://bugs.python.org/issue33786.

Expand Down Expand Up @@ -88,8 +108,9 @@ class NoneAexit(ManagerFromScratch):
self.assertFalse(issubclass(NoneAexit, AbstractAsyncContextManager))


class AsyncContextManagerTestCase(unittest.IsolatedAsyncioTestCase):
class AsyncContextManagerTestCase(unittest.TestCase):

@_async_test
async def test_contextmanager_plain(self):
state = []
@asynccontextmanager
Expand All @@ -103,6 +124,7 @@ async def woohoo():
state.append(x)
self.assertEqual(state, [1, 42, 999])

@_async_test
async def test_contextmanager_finally(self):
state = []
@asynccontextmanager
Expand All @@ -120,6 +142,7 @@ async def woohoo():
raise ZeroDivisionError()
self.assertEqual(state, [1, 42, 999])

@_async_test
async def test_contextmanager_traceback(self):
@asynccontextmanager
async def f():
Expand Down Expand Up @@ -175,6 +198,7 @@ class StopAsyncIterationSubclass(StopAsyncIteration):
self.assertEqual(frames[0].name, 'test_contextmanager_traceback')
self.assertEqual(frames[0].line, 'raise stop_exc')

@_async_test
async def test_contextmanager_no_reraise(self):
@asynccontextmanager
async def whee():
Expand All @@ -184,6 +208,7 @@ async def whee():
# Calling __aexit__ should not result in an exception
self.assertFalse(await ctx.__aexit__(TypeError, TypeError("foo"), None))

@_async_test
async def test_contextmanager_trap_yield_after_throw(self):
@asynccontextmanager
async def whoo():
Expand All @@ -199,6 +224,7 @@ async def whoo():
# The "gen" attribute is an implementation detail.
self.assertFalse(ctx.gen.ag_suspended)

@_async_test
async def test_contextmanager_trap_no_yield(self):
@asynccontextmanager
async def whoo():
Expand All @@ -208,6 +234,7 @@ async def whoo():
with self.assertRaises(RuntimeError):
await ctx.__aenter__()

@_async_test
async def test_contextmanager_trap_second_yield(self):
@asynccontextmanager
async def whoo():
Expand All @@ -221,6 +248,7 @@ async def whoo():
# The "gen" attribute is an implementation detail.
self.assertFalse(ctx.gen.ag_suspended)

@_async_test
async def test_contextmanager_non_normalised(self):
@asynccontextmanager
async def whoo():
Expand All @@ -234,6 +262,7 @@ async def whoo():
with self.assertRaises(SyntaxError):
await ctx.__aexit__(RuntimeError, None, None)

@_async_test
async def test_contextmanager_except(self):
state = []
@asynccontextmanager
Expand All @@ -251,6 +280,7 @@ async def woohoo():
raise ZeroDivisionError(999)
self.assertEqual(state, [1, 42, 999])

@_async_test
async def test_contextmanager_except_stopiter(self):
@asynccontextmanager
async def woohoo():
Expand All @@ -277,6 +307,7 @@ class StopAsyncIterationSubclass(StopAsyncIteration):
else:
self.fail(f'{stop_exc} was suppressed')

@_async_test
async def test_contextmanager_wrap_runtimeerror(self):
@asynccontextmanager
async def woohoo():
Expand Down Expand Up @@ -321,12 +352,14 @@ def test_contextmanager_doc_attrib(self):
self.assertEqual(baz.__doc__, "Whee!")

@support.requires_docstrings
@_async_test
async def test_instance_docstring_given_cm_docstring(self):
baz = self._create_contextmanager_attribs()(None)
self.assertEqual(baz.__doc__, "Whee!")
async with baz:
pass # suppress warning

@_async_test
async def test_keywords(self):
# Ensure no keyword arguments are inhibited
@asynccontextmanager
Expand All @@ -335,6 +368,7 @@ async def woohoo(self, func, args, kwds):
async with woohoo(self=11, func=22, args=33, kwds=44) as target:
self.assertEqual(target, (11, 22, 33, 44))

@_async_test
async def test_recursive(self):
depth = 0
ncols = 0
Expand All @@ -361,6 +395,7 @@ async def recursive():
self.assertEqual(ncols, 10)
self.assertEqual(depth, 0)

@_async_test
async def test_decorator(self):
entered = False

Expand All @@ -379,6 +414,7 @@ async def test():
await test()
self.assertFalse(entered)

@_async_test
async def test_decorator_with_exception(self):
entered = False

Expand All @@ -401,6 +437,7 @@ async def test():
await test()
self.assertFalse(entered)

@_async_test
async def test_decorating_method(self):

@asynccontextmanager
Expand Down Expand Up @@ -435,14 +472,15 @@ async def method(self, a, b, c=None):
self.assertEqual(test.b, 2)


class AclosingTestCase(unittest.IsolatedAsyncioTestCase):
class AclosingTestCase(unittest.TestCase):

@support.requires_docstrings
def test_instance_docs(self):
cm_docstring = aclosing.__doc__
obj = aclosing(None)
self.assertEqual(obj.__doc__, cm_docstring)

@_async_test
async def test_aclosing(self):
state = []
class C:
Expand All @@ -454,6 +492,7 @@ async def aclose(self):
self.assertEqual(x, y)
self.assertEqual(state, [1])

@_async_test
async def test_aclosing_error(self):
state = []
class C:
Expand All @@ -467,6 +506,7 @@ async def aclose(self):
1 / 0
self.assertEqual(state, [1])

@_async_test
async def test_aclosing_bpo41229(self):
state = []

Expand All @@ -492,45 +532,27 @@ async def agenfunc():
self.assertEqual(state, [1])


class TestAsyncExitStack(TestBaseExitStack, unittest.IsolatedAsyncioTestCase):
class TestAsyncExitStack(TestBaseExitStack, unittest.TestCase):
class SyncAsyncExitStack(AsyncExitStack):
@staticmethod
def run_coroutine(coro):
loop = asyncio.new_event_loop()
t = loop.create_task(coro)
t.add_done_callback(lambda f: loop.stop())
loop.run_forever()

exc = t.exception()
if not exc:
return t.result()
else:
context = exc.__context__

try:
raise exc
except:
exc.__context__ = context
raise exc

def close(self):
return self.run_coroutine(self.aclose())
return _run_async_fn(self.aclose)

def __enter__(self):
return self.run_coroutine(self.__aenter__())
return _run_async_fn(self.__aenter__)

def __exit__(self, *exc_details):
return self.run_coroutine(self.__aexit__(*exc_details))
return _run_async_fn(self.__aexit__, *exc_details)

exit_stack = SyncAsyncExitStack
callback_error_internal_frames = [
('__exit__', 'return self.run_coroutine(self.__aexit__(*exc_details))'),
('run_coroutine', 'raise exc'),
('run_coroutine', 'raise exc'),
('__exit__', 'return _run_async_fn(self.__aexit__, *exc_details)'),
('_run_async_fn', 'gen.send(None)'),
('__aexit__', 'raise exc'),
('__aexit__', 'cb_suppress = cb(*exc_details)'),
]

@_async_test
async def test_async_callback(self):
expected = [
((), {}),
Expand Down Expand Up @@ -573,6 +595,7 @@ async def _exit(*args, **kwds):
stack.push_async_callback(callback=_exit, arg=3)
self.assertEqual(result, [])

@_async_test
async def test_async_push(self):
exc_raised = ZeroDivisionError
async def _expect_exc(exc_type, exc, exc_tb):
Expand Down Expand Up @@ -608,6 +631,7 @@ async def __aexit__(self, *exc_details):
self.assertIs(stack._exit_callbacks[-1][1], _expect_exc)
1/0

@_async_test
async def test_enter_async_context(self):
class TestCM(object):
async def __aenter__(self):
Expand All @@ -629,6 +653,7 @@ async def _exit():

self.assertEqual(result, [1, 2, 3, 4])

@_async_test
async def test_enter_async_context_errors(self):
class LacksEnterAndExit:
pass
Expand All @@ -648,6 +673,7 @@ async def __aenter__(self):
await stack.enter_async_context(LacksExit())
self.assertFalse(stack._exit_callbacks)

@_async_test
async def test_async_exit_exception_chaining(self):
# Ensure exception chaining matches the reference behaviour
async def raise_exc(exc):
Expand Down Expand Up @@ -679,6 +705,7 @@ async def suppress_exc(*exc_details):
self.assertIsInstance(inner_exc, ValueError)
self.assertIsInstance(inner_exc.__context__, ZeroDivisionError)

@_async_test
async def test_async_exit_exception_explicit_none_context(self):
# Ensure AsyncExitStack chaining matches actual nested `with` statements
# regarding explicit __context__ = None.
Expand Down Expand Up @@ -713,6 +740,7 @@ async def my_cm_with_exit_stack():
else:
self.fail("Expected IndexError, but no exception was raised")

@_async_test
async def test_instance_bypass_async(self):
class Example(object): pass
cm = Example()
Expand All @@ -725,7 +753,8 @@ class Example(object): pass
self.assertIs(stack._exit_callbacks[-1][1], cm)


class TestAsyncNullcontext(unittest.IsolatedAsyncioTestCase):
class TestAsyncNullcontext(unittest.TestCase):
@_async_test
async def test_async_nullcontext(self):
class C:
pass
Expand Down

0 comments on commit 0461429

Please sign in to comment.