diff --git a/CHANGES/3333.removal b/CHANGES/3333.removal new file mode 100644 index 00000000000..0e8b5b2f3a8 --- /dev/null +++ b/CHANGES/3333.removal @@ -0,0 +1 @@ +Replace asyncio.get_event_loop() with get_running_loop() for encourage creation of aiohttp public objects inside a coroutine diff --git a/CONTRIBUTORS.txt b/CONTRIBUTORS.txt index d3db40b4a57..a84a65abef4 100644 --- a/CONTRIBUTORS.txt +++ b/CONTRIBUTORS.txt @@ -230,6 +230,7 @@ Yannick PĂ©roux Ye Cao Yegor Roganov Young-Ho Cha +Yunhao Zhang Yuriy Shatrov Yury Selivanov Yusuke Tsutsumi diff --git a/aiohttp/base_protocol.py b/aiohttp/base_protocol.py index b93d7ce2250..d1a60d46ebb 100644 --- a/aiohttp/base_protocol.py +++ b/aiohttp/base_protocol.py @@ -1,6 +1,7 @@ import asyncio from typing import Optional, cast +from .helpers import get_running_loop from .log import internal_logger @@ -9,10 +10,7 @@ class BaseProtocol(asyncio.Protocol): '_connection_lost', 'transport') def __init__(self, loop: Optional[asyncio.AbstractEventLoop]=None) -> None: - if loop is None: - self._loop = asyncio.get_event_loop() - else: - self._loop = loop + self._loop = get_running_loop(loop) self._paused = False self._drain_waiter = None # type: Optional[asyncio.Future[None]] self._connection_lost = False diff --git a/aiohttp/client.py b/aiohttp/client.py index 540faad6b1e..5dd2d166134 100644 --- a/aiohttp/client.py +++ b/aiohttp/client.py @@ -29,7 +29,8 @@ from .connector import BaseConnector, TCPConnector from .cookiejar import CookieJar from .helpers import (DEBUG, PY_36, CeilTimeout, TimeoutHandle, - proxies_from_env, sentinel, strip_auth_from_url) + get_running_loop, proxies_from_env, sentinel, + strip_auth_from_url) from .http import WS_KEY, WebSocketReader, WebSocketWriter from .http_websocket import WSHandshakeError, ws_ext_gen, ws_ext_parse from .streams import FlowControlDataQueue @@ -107,7 +108,7 @@ def __init__(self, *, connector=None, loop=None, cookies=None, loop = connector._loop else: implicit_loop = True - loop = asyncio.get_event_loop() + loop = get_running_loop() if connector is None: connector = TCPConnector(loop=loop) diff --git a/aiohttp/client_reqrep.py b/aiohttp/client_reqrep.py index 67ba098788c..31f4e239347 100644 --- a/aiohttp/client_reqrep.py +++ b/aiohttp/client_reqrep.py @@ -19,7 +19,8 @@ ClientResponseError, ContentTypeError, InvalidURL, ServerFingerprintMismatch) from .formdata import FormData -from .helpers import PY_36, HeadersMixin, TimerNoop, noop, reify, set_result +from .helpers import (PY_36, HeadersMixin, TimerNoop, get_running_loop, noop, + reify, set_result) from .http import SERVER_SOFTWARE, HttpVersion10, HttpVersion11, StreamWriter from .log import client_logger from .streams import StreamReader # noqa @@ -196,8 +197,7 @@ def __init__(self, method, url, *, proxy_headers=None, traces=None): - if loop is None: - loop = asyncio.get_event_loop() + loop = get_running_loop(loop) assert isinstance(url, URL), url assert isinstance(proxy, (URL, type(None))), proxy diff --git a/aiohttp/connector.py b/aiohttp/connector.py index aa4dfd7a467..a9909b2eacf 100644 --- a/aiohttp/connector.py +++ b/aiohttp/connector.py @@ -22,7 +22,8 @@ ssl_errors) from .client_proto import ResponseHandler from .client_reqrep import ClientRequest, Fingerprint, _merge_ssl_params -from .helpers import PY_36, CeilTimeout, is_ip_address, noop, sentinel +from .helpers import (PY_36, CeilTimeout, get_running_loop, is_ip_address, + noop, sentinel) from .locks import EventResultOrError from .resolver import DefaultResolver @@ -170,8 +171,7 @@ def __init__(self, *, keepalive_timeout=sentinel, if keepalive_timeout is sentinel: keepalive_timeout = 15.0 - if loop is None: - loop = asyncio.get_event_loop() + loop = get_running_loop(loop) self._closed = False if loop.get_debug(): diff --git a/aiohttp/helpers.py b/aiohttp/helpers.py index 995b3af0732..315a22c1ed5 100644 --- a/aiohttp/helpers.py +++ b/aiohttp/helpers.py @@ -13,6 +13,7 @@ import re import sys import time +import warnings import weakref from collections import namedtuple from contextlib import suppress @@ -230,6 +231,15 @@ def current_task(loop: Optional[asyncio.AbstractEventLoop]=None) -> asyncio.Task return asyncio.Task.current_task(loop=loop) # type: ignore +def get_running_loop(loop: Optional[asyncio.AbstractEventLoop] = None) -> asyncio.AbstractEventLoop: # type: ignore # noqa + if loop is None: + loop = asyncio.get_event_loop() + if loop.is_running(): + warnings.warn("The object should be created from async function", + DeprecationWarning, stacklevel=3) + return loop + + def isasyncgenfunction(obj: Any) -> bool: func = getattr(inspect, 'isasyncgenfunction', None) if func is not None: diff --git a/aiohttp/resolver.py b/aiohttp/resolver.py index d39bc9fcbbd..50d6cb6fee7 100644 --- a/aiohttp/resolver.py +++ b/aiohttp/resolver.py @@ -3,6 +3,7 @@ from typing import Any, Dict, List, Optional from .abc import AbstractResolver +from .helpers import get_running_loop __all__ = ('ThreadedResolver', 'AsyncResolver', 'DefaultResolver') @@ -22,8 +23,7 @@ class ThreadedResolver(AbstractResolver): """ def __init__(self, loop: Optional[asyncio.AbstractEventLoop]=None) -> None: - if loop is None: - loop = asyncio.get_event_loop() + loop = get_running_loop(loop) self._loop = loop async def resolve(self, host: str, port: int=0, @@ -50,8 +50,7 @@ class AsyncResolver(AbstractResolver): def __init__(self, loop: Optional[asyncio.AbstractEventLoop]=None, *args: Any, **kwargs: Any) -> None: - if loop is None: - loop = asyncio.get_event_loop() + loop = get_running_loop(loop) if aiodns is None: raise RuntimeError("Resolver requires aiodns library") diff --git a/aiohttp/streams.py b/aiohttp/streams.py index 7cdb97abc61..24c38d62e47 100644 --- a/aiohttp/streams.py +++ b/aiohttp/streams.py @@ -4,7 +4,8 @@ from typing import Awaitable, Callable, Optional, Tuple from .base_protocol import BaseProtocol -from .helpers import BaseTimerContext, set_exception, set_result +from .helpers import (BaseTimerContext, get_running_loop, set_exception, + set_result) from .log import internal_logger @@ -111,9 +112,7 @@ def __init__(self, protocol: BaseProtocol, self._protocol = protocol self._low_water = limit self._high_water = limit * 2 - if loop is None: - loop = asyncio.get_event_loop() - self._loop = loop + self._loop = get_running_loop(loop) self._size = 0 self._cursor = 0 self._http_chunk_splits = None # type: Optional[List[int]] diff --git a/aiohttp/web_app.py b/aiohttp/web_app.py index 2160a161509..e0a08077ce9 100644 --- a/aiohttp/web_app.py +++ b/aiohttp/web_app.py @@ -9,7 +9,7 @@ from . import hdrs from .abc import AbstractAccessLogger, AbstractMatchInfo, AbstractRouter from .frozenlist import FrozenList -from .helpers import DEBUG, AccessLogger +from .helpers import DEBUG, AccessLogger, get_running_loop from .log import web_logger from .signals import Signal from .web_middlewares import _fix_request_current_app @@ -150,8 +150,7 @@ def loop(self): return self._loop def _set_loop(self, loop): - if loop is None: - loop = asyncio.get_event_loop() + loop = get_running_loop(loop) if self._loop is not None and self._loop is not loop: raise RuntimeError( "web.Application instance initialized with different loop") diff --git a/aiohttp/web_runner.py b/aiohttp/web_runner.py index 18ee2e8adb9..8ff3d7ba959 100644 --- a/aiohttp/web_runner.py +++ b/aiohttp/web_runner.py @@ -1,10 +1,10 @@ -import asyncio import signal import socket from abc import ABC, abstractmethod from yarl import URL +from .helpers import get_running_loop from .web_app import Application @@ -81,7 +81,7 @@ def name(self): async def start(self): await super().start() - loop = asyncio.get_event_loop() + loop = get_running_loop() self._server = await loop.create_server( self._runner.server, self._host, self._port, ssl=self._ssl_context, backlog=self._backlog, @@ -106,7 +106,7 @@ def name(self): async def start(self): await super().start() - loop = asyncio.get_event_loop() + loop = get_running_loop() self._server = await loop.create_unix_server( self._runner.server, self._path, ssl=self._ssl_context, backlog=self._backlog) @@ -135,7 +135,7 @@ def name(self): async def start(self): await super().start() - loop = asyncio.get_event_loop() + loop = get_running_loop() self._server = await loop.create_server( self._runner.server, sock=self._sock, ssl=self._ssl_context, backlog=self._backlog) @@ -165,7 +165,7 @@ def sites(self): return set(self._sites) async def setup(self): - loop = asyncio.get_event_loop() + loop = get_running_loop() if self._handle_signals: try: @@ -182,7 +182,7 @@ async def shutdown(self): pass # pragma: no cover async def cleanup(self): - loop = asyncio.get_event_loop() + loop = get_running_loop() if self._server is None: # no started yet, do nothing @@ -269,7 +269,7 @@ async def shutdown(self): await self._app.shutdown() async def _make_server(self): - loop = asyncio.get_event_loop() + loop = get_running_loop() self._app._set_loop(loop) self._app.on_startup.freeze() await self._app.startup() diff --git a/aiohttp/web_server.py b/aiohttp/web_server.py index 3620bd11547..360d33c4480 100644 --- a/aiohttp/web_server.py +++ b/aiohttp/web_server.py @@ -1,6 +1,7 @@ """Low level HTTP server.""" import asyncio +from .helpers import get_running_loop from .web_protocol import RequestHandler from .web_request import BaseRequest @@ -11,9 +12,7 @@ class Server: def __init__(self, handler, *, request_factory=None, loop=None, **kwargs): - if loop is None: - loop = asyncio.get_event_loop() - self._loop = loop + self._loop = get_running_loop(loop) self._connections = {} self._kwargs = kwargs self.requests_count = 0 diff --git a/aiohttp/worker.py b/aiohttp/worker.py index a4b32c48752..16d67f449a0 100644 --- a/aiohttp/worker.py +++ b/aiohttp/worker.py @@ -12,7 +12,7 @@ from aiohttp import web -from .helpers import AccessLogger, set_result +from .helpers import AccessLogger, get_running_loop, set_result try: @@ -41,7 +41,7 @@ def __init__(self, *args, **kw): # pragma: no cover def init_process(self): # create new event_loop after fork - asyncio.get_event_loop().close() + get_running_loop().close() self.loop = asyncio.new_event_loop() asyncio.set_event_loop(self.loop) @@ -198,10 +198,10 @@ def init_process(self): # Close any existing event loop before setting a # new policy. - asyncio.get_event_loop().close() + get_running_loop().close() # Setup uvloop policy, so that every - # asyncio.get_event_loop() will create an instance + # get_running_loop() will create an instance # of uvloop event loop. asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) @@ -215,10 +215,10 @@ def init_process(self): # pragma: no cover # Close any existing event loop before setting a # new policy. - asyncio.get_event_loop().close() + get_running_loop().close() # Setup tokio policy, so that every - # asyncio.get_event_loop() will create an instance + # get_running_loop() will create an instance # of tokio event loop. asyncio.set_event_loop_policy(tokio.EventLoopPolicy())