diff --git a/aiohttp/helpers.py b/aiohttp/helpers.py index 5ddc141d8fb..9f9f8fb7808 100644 --- a/aiohttp/helpers.py +++ b/aiohttp/helpers.py @@ -1,4 +1,5 @@ """Various helper functions""" +import asyncio import base64 import datetime import functools @@ -11,7 +12,7 @@ from . import hdrs, multidict from .errors import InvalidURL -__all__ = ('BasicAuth', 'FormData', 'parse_mimetype') +__all__ = ('BasicAuth', 'FormData', 'parse_mimetype', 'Timeout') class BasicAuth(namedtuple('BasicAuth', ['login', 'password', 'encoding'])): @@ -460,3 +461,43 @@ def requote_uri(uri): # there may be unquoted '%'s in the URI. We need to make sure they're # properly quoted so they do not cause issues elsewhere. return quote(uri, safe=safe_without_percent) + + +class Timeout: + """Timeout context manager. + + Useful in cases when you want to apply timeout logic around block + of code or in cases when asyncio.wait_for is not suitable. For example: + + >>> with aiohttp.Timeout(0.001): + >>> async with aiohttp.get('https://github.com') as r: + >>> await r.text() + + + :param timeout: timeout value in seconds + :param loop: asyncio compatible event loop + """ + def __init__(self, timeout, *, loop=None): + self._timeout = timeout + if loop is None: + loop = asyncio.get_event_loop() + self._loop = loop + self._task = None + self._cancelled = False + self._cancel_handler = None + + def __enter__(self): + self._task = asyncio.Task.current_task(loop=self._loop) + self._cancel_handler = self._loop.call_later( + self._timeout, self._cancel_task) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if exc_type is asyncio.CancelledError and self._cancelled: + self._task = None + raise asyncio.TimeoutError + self._cancel_handler.cancel() + self._task = None + + def _cancel_task(self): + self._cancelled = self._task.cancel() diff --git a/docs/client.rst b/docs/client.rst index 461df90e858..89d6af30228 100644 --- a/docs/client.rst +++ b/docs/client.rst @@ -539,6 +539,11 @@ time to wait for a response from a server:: File "", line 1, in asyncio.TimeoutError() +Or wrap your client call in :class:`Timeout` context manager:: + + with aiohttp.Timeout(0.001): + async with aiohttp.get('https://github.com') as r: + await r.text() .. warning:: diff --git a/tests/test_timeout.py b/tests/test_timeout.py new file mode 100644 index 00000000000..426e9a97461 --- /dev/null +++ b/tests/test_timeout.py @@ -0,0 +1,122 @@ +import asyncio +import time + +import pytest +from aiohttp.helpers import Timeout + + +def test_timeout(loop): + canceled_raised = False + + @asyncio.coroutine + def long_running_task(): + try: + yield from asyncio.sleep(10, loop=loop) + except asyncio.CancelledError: + nonlocal canceled_raised + canceled_raised = True + raise + + @asyncio.coroutine + def run(): + with pytest.raises(asyncio.TimeoutError): + with Timeout(0.01, loop=loop) as t: + yield from long_running_task() + assert t._loop is loop + assert canceled_raised, 'CancelledError was not raised' + + loop.run_until_complete(run()) + + +def test_timeout_finish_in_time(loop): + @asyncio.coroutine + def long_running_task(): + yield from asyncio.sleep(0.01, loop=loop) + return 'done' + + @asyncio.coroutine + def run(): + with Timeout(0.1, loop=loop): + resp = yield from long_running_task() + assert resp == 'done' + + loop.run_until_complete(run()) + + +def test_timeout_gloabal_loop(loop): + asyncio.set_event_loop(loop) + + @asyncio.coroutine + def run(): + with Timeout(0.1) as t: + yield from asyncio.sleep(0.01) + assert t._loop is loop + + loop.run_until_complete(run()) + + +def test_timeout_not_relevant_exception(loop): + @asyncio.coroutine + def run(): + with pytest.raises(KeyError): + with Timeout(0.1, loop=loop): + raise KeyError + + loop.run_until_complete(run()) + + +def test_timeout_canceled_error_is_converted_to_timeout(loop): + @asyncio.coroutine + def run(): + with pytest.raises(asyncio.CancelledError): + with Timeout(0.001, loop=loop): + raise asyncio.CancelledError + + loop.run_until_complete(run()) + + +def test_timeout_blocking_loop(loop): + @asyncio.coroutine + def long_running_task(): + time.sleep(0.1) + return 'done' + + @asyncio.coroutine + def run(): + with Timeout(0.01, loop=loop): + result = yield from long_running_task() + assert result == 'done' + + loop.run_until_complete(run()) + + +def test_for_race_conditions(loop): + @asyncio.coroutine + def run(): + fut = asyncio.Future(loop=loop) + loop.call_later(0.1, fut.set_result('done')) + with Timeout(0.2, loop=loop): + resp = yield from fut + assert resp == 'done' + + loop.run_until_complete(run()) + + +def test_timeout_time(loop): + @asyncio.coroutine + def go(): + foo_running = None + + start = loop.time() + with pytest.raises(asyncio.TimeoutError): + with Timeout(0.1, loop=loop): + foo_running = True + try: + yield from asyncio.sleep(0.2, loop=loop) + finally: + foo_running = False + + assert abs(0.1 - (loop.time() - start)) < 0.01 + assert not foo_running + + loop.run_until_complete(go())