diff --git a/aiohttp/web.py b/aiohttp/web.py index fd37dca23b4..64b33a5ab1d 100644 --- a/aiohttp/web.py +++ b/aiohttp/web.py @@ -171,6 +171,7 @@ def __init__(self, *, logger=web_logger, loop=None, self._on_pre_signal = PreSignal() self._on_post_signal = PostSignal() self._on_response_prepare = Signal(self) + self._on_startup = Signal(self) self._on_shutdown = Signal(self) self._on_cleanup = Signal(self) @@ -190,6 +191,10 @@ def on_pre_signal(self): def on_post_signal(self): return self._on_post_signal + @property + def on_startup(self): + return self._on_startup + @property def on_shutdown(self): return self._on_shutdown @@ -214,6 +219,14 @@ def make_handler(self, **kwargs): return self._handler_factory( self, self.router, loop=self.loop, **kwargs) + @asyncio.coroutine + def startup(self): + """Causes on_startup signal + + Should be called in the event loop along with the request handler. + """ + yield from self.on_startup.send(self) + @asyncio.coroutine def shutdown(self): """Causes on_shutdown signal @@ -267,9 +280,11 @@ def run_app(app, *, host='0.0.0.0', port=None, loop = app.loop handler = app.make_handler() - srv = loop.run_until_complete(loop.create_server(handler, host, port, - ssl=ssl_context, - backlog=backlog)) + server = loop.create_server(handler, host, port, ssl=ssl_context, + backlog=backlog) + srv, startup_res = loop.run_until_complete(asyncio.gather(server, + app.startup(), + loop=loop)) scheme = 'https' if ssl_context else 'http' print("======== Running on {scheme}://{host}:{port}/ ========\n" diff --git a/aiohttp/worker.py b/aiohttp/worker.py index 607f783daef..7acad99d553 100644 --- a/aiohttp/worker.py +++ b/aiohttp/worker.py @@ -34,7 +34,9 @@ def run(self): self._runner = ensure_future(self._run(), loop=self.loop) try: - self.loop.run_until_complete(self._runner) + self.loop.run_until_complete(asyncio.gather(self._runner, + self.wsgi.startup(), + loop=self.loop)) finally: self.loop.close() diff --git a/tests/test_worker.py b/tests/test_worker.py deleted file mode 100644 index 9f05264cfc6..00000000000 --- a/tests/test_worker.py +++ /dev/null @@ -1,183 +0,0 @@ -"""Tests for aiohttp/worker.py""" -import asyncio -import sys -from unittest import mock - -import pytest - -from aiohttp import helpers - -base_worker = pytest.importorskip('aiohttp.worker') -pytest.importorskip('uvloop') - - -class BaseTestWorker: - - def __init__(self): - self.servers = [] - self.exit_code = 0 - self.cfg = mock.Mock() - self.cfg.graceful_timeout = 100 - - -class AsyncioWorker(BaseTestWorker, base_worker.GunicornWebWorker): - pass - - -class UvloopWorker(BaseTestWorker, base_worker.GunicornUVLoopWebWorker): - - def __init__(self): - if sys.version_info < (3, 5) \ - or sys.platform in ('win32', 'cygwin', 'cli'): - raise pytest.skip("uvloop requires Python 3.5 and *nix.") - - super().__init__() - - -@pytest.fixture(params=[AsyncioWorker, UvloopWorker]) -def worker(request): - return request.param() - - -def test_init_process(worker): - with mock.patch('aiohttp.worker.asyncio') as m_asyncio: - try: - worker.init_process() - except TypeError: - pass - - assert m_asyncio.get_event_loop.return_value.close.called - assert m_asyncio.new_event_loop.called - assert m_asyncio.set_event_loop.called - - -def test_run(worker, loop): - worker.loop = loop - worker._run = mock.Mock( - wraps=asyncio.coroutine(lambda: None)) - with pytest.raises(SystemExit): - worker.run() - - assert worker._run.called - assert loop.is_closed() - - -def test_handle_quit(worker): - worker.handle_quit(object(), object()) - assert not worker.alive - assert worker.exit_code == 0 - - -def test_handle_abort(worker): - worker.handle_abort(object(), object()) - assert not worker.alive - assert worker.exit_code == 1 - - -def test_init_signals(worker): - worker.loop = mock.Mock() - worker.init_signals() - assert worker.loop.add_signal_handler.called - - -def test_make_handler(worker): - worker.wsgi = mock.Mock() - worker.loop = mock.Mock() - worker.log = mock.Mock() - worker.cfg = mock.Mock() - - f = worker.make_handler(worker.wsgi) - assert f is worker.wsgi.make_handler.return_value - - -def test__run_ok(worker, loop): - worker.ppid = 1 - worker.alive = True - worker.servers = {} - sock = mock.Mock() - sock.cfg_addr = ('localhost', 8080) - worker.sockets = [sock] - worker.wsgi = mock.Mock() - worker.close = mock.Mock() - worker.close.return_value = helpers.create_future(loop) - worker.close.return_value.set_result(()) - worker.log = mock.Mock() - worker.notify = mock.Mock() - worker.loop = loop - ret = helpers.create_future(loop) - loop.create_server = mock.Mock( - wraps=asyncio.coroutine(lambda *a, **kw: ret)) - ret.set_result(sock) - worker.wsgi.make_handler.return_value.num_connections = 1 - worker.cfg.max_requests = 100 - worker.cfg.is_ssl = True - - ssl_context = mock.Mock() - with mock.patch('ssl.SSLContext', return_value=ssl_context): - with mock.patch('aiohttp.worker.asyncio') as m_asyncio: - m_asyncio.sleep = mock.Mock( - wraps=asyncio.coroutine(lambda *a, **kw: None)) - loop.run_until_complete(worker._run()) - - assert worker.notify.called - assert worker.log.info.called - - args, kwargs = loop.create_server.call_args - assert 'ssl' in kwargs - ctx = kwargs['ssl'] - assert ctx is ssl_context - - -def test__run_exc(worker, loop): - with mock.patch('aiohttp.worker.os') as m_os: - m_os.getpid.return_value = 1 - m_os.getppid.return_value = 1 - - worker.servers = [mock.Mock()] - worker.ppid = 1 - worker.alive = True - worker.sockets = [] - worker.log = mock.Mock() - worker.loop = mock.Mock() - worker.notify = mock.Mock() - worker.cfg.is_ssl = False - - with mock.patch('aiohttp.worker.asyncio.sleep') as m_sleep: - slp = helpers.create_future(loop) - slp.set_exception(KeyboardInterrupt) - m_sleep.return_value = slp - - worker.close = mock.Mock() - worker.close.return_value = helpers.create_future(loop) - worker.close.return_value.set_result(1) - - loop.run_until_complete(worker._run()) - - assert m_sleep.called - assert worker.close.called - - -def test_close(worker, loop): - srv = mock.Mock() - handler = mock.Mock() - worker.servers = {srv: handler} - worker.log = mock.Mock() - worker.loop = loop - app = worker.wsgi = mock.Mock() - app.finish.return_value = helpers.create_future(loop) - app.finish.return_value.set_result(1) - handler.connections = [object()] - handler.finish_connections.return_value = helpers.create_future(loop) - handler.finish_connections.return_value.set_result(1) - - app.shutdown.return_value = helpers.create_future(loop) - app.shutdown.return_value.set_result(None) - - loop.run_until_complete(worker.close()) - app.shutdown.assert_called_with() - app.finish.assert_called_with() - handler.finish_connections.assert_called_with(timeout=95.0) - srv.close.assert_called_with() - assert worker.servers is None - - loop.run_until_complete(worker.close())