From ab22313f98c9ab33a6fb1eed985e63b7fb78df3d Mon Sep 17 00:00:00 2001 From: Matthias Bussonnier Date: Mon, 12 Jun 2017 21:02:00 -0700 Subject: [PATCH] Clearer error message if async context manager used synchronously Otherwise it says it does not have `__enter__` which is obvious for advanced pythonista. Still a tiny bit clearer is likely better. I guess this _might_ fool static type checkers, but I'm unaware of any that would flag that. --- trio/_core/tests/test_run.py | 10 ++++++++++ trio/_util.py | 8 ++++++++ 2 files changed, 18 insertions(+) diff --git a/trio/_core/tests/test_run.py b/trio/_core/tests/test_run.py index 5f4f232b1f..1aeaaba868 100644 --- a/trio/_core/tests/test_run.py +++ b/trio/_core/tests/test_run.py @@ -89,6 +89,16 @@ async def child(x): await task.wait() assert task.result.unwrap() == 20 +async def test_nursery_warn_use_async_with(): + with pytest.raises(RuntimeError) as excinfo: + on = _core.open_nursery() + with on as nursery: + pass # pragma: no-cover + excinfo.match(r"use 'async with open_nursery\(...\)', not 'with open_nursery\(...\)'") + + # avoid unawaited coro. + async with on: + pass async def test_child_crash_basic(): exc = ValueError("uh oh") diff --git a/trio/_util.py b/trio/_util.py index 0986e02440..69b759e478 100644 --- a/trio/_util.py +++ b/trio/_util.py @@ -83,6 +83,7 @@ async def __aiter__(*args, **kwargs): # Copyright © 2001-2017 Python Software Foundation; All Rights Reserved class _AsyncGeneratorContextManager: def __init__(self, func, args, kwds): + self._func_name = func.__name__ self._agen = func(*args, **kwds).__aiter__() async def __aenter__(self): @@ -135,6 +136,13 @@ async def __aexit__(self, type, value, traceback): if sys.exc_info()[1] is not value: raise + def __enter__(self): + raise RuntimeError("use 'async with {func_name}(...)', not 'with {func_name}(...)'".format(func_name=self._func_name)) + + def __exit__(self): + assert False, """Never called, but should be defined""" # pragma: no-cover + + def acontextmanager(func): """Like @contextmanager, but async.""" if not async_generator.isasyncgenfunction(func):