Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Encourage creation of aiohttp public objects inside a coroutine #3333

Closed
wants to merge 17 commits into from
Closed
1 change: 1 addition & 0 deletions CHANGES/3333.removal
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Replace asyncio.get_event_loop() with get_running_loop() for encourage creation of aiohttp public objects inside a coroutine
1 change: 1 addition & 0 deletions CONTRIBUTORS.txt
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ Yannick Péroux
Ye Cao
Yegor Roganov
Young-Ho Cha
Yunhao Zhang
Yuriy Shatrov
Yury Selivanov
Yusuke Tsutsumi
Expand Down
3 changes: 2 additions & 1 deletion aiohttp/base_protocol.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
from typing import Optional, cast

from .helpers import get_running_loop
from .log import internal_logger


Expand All @@ -10,7 +11,7 @@ class BaseProtocol(asyncio.Protocol):

def __init__(self, loop: Optional[asyncio.AbstractEventLoop]=None) -> None:
if loop is None:
self._loop = asyncio.get_event_loop()
self._loop = get_running_loop()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The idea is replacing the entire if loop is None block with self._loop = get_running_loop(loop).
The change doesn't only remove several lines but always check a loop (explicit or implicit) for running.

Please apply it everywhere.

Copy link
Contributor Author

@zhangyunhao116 zhangyunhao116 Oct 9, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got you.I will do this.
BTW,is there any plan to use asyncio.get_running_loop() which is a new api in python3.7 ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Eventually yes, but we should support Python 3.5 and 3.6 for a long time

else:
self._loop = loop
self._paused = False
Expand Down
5 changes: 3 additions & 2 deletions aiohttp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@
from .connector import BaseConnector, TCPConnector
from .cookiejar import CookieJar
from .helpers import (DEBUG, PY_36, CeilTimeout, TimeoutHandle,
proxies_from_env, sentinel, strip_auth_from_url)
get_running_loop, proxies_from_env, sentinel,
strip_auth_from_url)
from .http import WS_KEY, WebSocketReader, WebSocketWriter
from .http_websocket import WSHandshakeError, ws_ext_gen, ws_ext_parse
from .streams import FlowControlDataQueue
Expand Down Expand Up @@ -107,7 +108,7 @@ def __init__(self, *, connector=None, loop=None, cookies=None,
loop = connector._loop
else:
implicit_loop = True
loop = asyncio.get_event_loop()
loop = get_running_loop()

if connector is None:
connector = TCPConnector(loop=loop)
Expand Down
5 changes: 3 additions & 2 deletions aiohttp/client_reqrep.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
ClientResponseError, ContentTypeError,
InvalidURL, ServerFingerprintMismatch)
from .formdata import FormData
from .helpers import PY_36, HeadersMixin, TimerNoop, noop, reify, set_result
from .helpers import (PY_36, HeadersMixin, TimerNoop, get_running_loop, noop,
reify, set_result)
from .http import SERVER_SOFTWARE, HttpVersion10, HttpVersion11, StreamWriter
from .log import client_logger
from .streams import StreamReader # noqa
Expand Down Expand Up @@ -197,7 +198,7 @@ def __init__(self, method, url, *,
traces=None):

if loop is None:
loop = asyncio.get_event_loop()
loop = get_running_loop()

assert isinstance(url, URL), url
assert isinstance(proxy, (URL, type(None))), proxy
Expand Down
5 changes: 3 additions & 2 deletions aiohttp/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
ssl_errors)
from .client_proto import ResponseHandler
from .client_reqrep import ClientRequest, Fingerprint, _merge_ssl_params
from .helpers import PY_36, CeilTimeout, is_ip_address, noop, sentinel
from .helpers import (PY_36, CeilTimeout, get_running_loop, is_ip_address,
noop, sentinel)
from .locks import EventResultOrError
from .resolver import DefaultResolver

Expand Down Expand Up @@ -171,7 +172,7 @@ def __init__(self, *, keepalive_timeout=sentinel,
keepalive_timeout = 15.0

if loop is None:
loop = asyncio.get_event_loop()
loop = get_running_loop()

self._closed = False
if loop.get_debug():
Expand Down
10 changes: 10 additions & 0 deletions aiohttp/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import re
import sys
import time
import warnings
import weakref
from collections import namedtuple
from contextlib import suppress
Expand Down Expand Up @@ -230,6 +231,15 @@ def current_task(loop: Optional[asyncio.AbstractEventLoop]=None) -> asyncio.Task
return asyncio.Task.current_task(loop=loop) # type: ignore


def get_running_loop(loop: Optional[asyncio.AbstractEventLoop]=None) -> asyncio.AbstractEventLoop: # type: ignore # noqa
if loop is None:
loop = asyncio.get_event_loop()
if not loop.is_running():
warnings.warn("The object should be created from async function",
DeprecationWarning, stacklevel=3)
return loop


def isasyncgenfunction(obj: Any) -> bool:
func = getattr(inspect, 'isasyncgenfunction', None)
if func is not None:
Expand Down
5 changes: 3 additions & 2 deletions aiohttp/resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Any, Dict, List, Optional

from .abc import AbstractResolver
from .helpers import get_running_loop


__all__ = ('ThreadedResolver', 'AsyncResolver', 'DefaultResolver')
Expand All @@ -23,7 +24,7 @@ class ThreadedResolver(AbstractResolver):

def __init__(self, loop: Optional[asyncio.AbstractEventLoop]=None) -> None:
if loop is None:
loop = asyncio.get_event_loop()
loop = get_running_loop()
self._loop = loop

async def resolve(self, host: str, port: int=0,
Expand Down Expand Up @@ -51,7 +52,7 @@ class AsyncResolver(AbstractResolver):
def __init__(self, loop: Optional[asyncio.AbstractEventLoop]=None,
*args: Any, **kwargs: Any) -> None:
if loop is None:
loop = asyncio.get_event_loop()
loop = get_running_loop()

if aiodns is None:
raise RuntimeError("Resolver requires aiodns library")
Expand Down
5 changes: 3 additions & 2 deletions aiohttp/streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from typing import Awaitable, Callable, Optional, Tuple

from .base_protocol import BaseProtocol
from .helpers import BaseTimerContext, set_exception, set_result
from .helpers import (BaseTimerContext, get_running_loop, set_exception,
set_result)
from .log import internal_logger


Expand Down Expand Up @@ -112,7 +113,7 @@ def __init__(self, protocol: BaseProtocol,
self._low_water = limit
self._high_water = limit * 2
if loop is None:
loop = asyncio.get_event_loop()
loop = get_running_loop()
self._loop = loop
self._size = 0
self._cursor = 0
Expand Down
4 changes: 2 additions & 2 deletions aiohttp/web_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from . import hdrs
from .abc import AbstractAccessLogger, AbstractMatchInfo, AbstractRouter
from .frozenlist import FrozenList
from .helpers import DEBUG, AccessLogger
from .helpers import DEBUG, AccessLogger, get_running_loop
from .log import web_logger
from .signals import Signal
from .web_middlewares import _fix_request_current_app
Expand Down Expand Up @@ -151,7 +151,7 @@ def loop(self):

def _set_loop(self, loop):
if loop is None:
loop = asyncio.get_event_loop()
loop = get_running_loop()
if self._loop is not None and self._loop is not loop:
raise RuntimeError(
"web.Application instance initialized with different loop")
Expand Down
14 changes: 7 additions & 7 deletions aiohttp/web_runner.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import asyncio
import signal
import socket
from abc import ABC, abstractmethod

from yarl import URL

from .helpers import get_running_loop
from .web_app import Application


Expand Down Expand Up @@ -81,7 +81,7 @@ def name(self):

async def start(self):
await super().start()
loop = asyncio.get_event_loop()
loop = get_running_loop()
self._server = await loop.create_server(
self._runner.server, self._host, self._port,
ssl=self._ssl_context, backlog=self._backlog,
Expand All @@ -106,7 +106,7 @@ def name(self):

async def start(self):
await super().start()
loop = asyncio.get_event_loop()
loop = get_running_loop()
self._server = await loop.create_unix_server(
self._runner.server, self._path,
ssl=self._ssl_context, backlog=self._backlog)
Expand Down Expand Up @@ -135,7 +135,7 @@ def name(self):

async def start(self):
await super().start()
loop = asyncio.get_event_loop()
loop = get_running_loop()
self._server = await loop.create_server(
self._runner.server, sock=self._sock,
ssl=self._ssl_context, backlog=self._backlog)
Expand Down Expand Up @@ -165,7 +165,7 @@ def sites(self):
return set(self._sites)

async def setup(self):
loop = asyncio.get_event_loop()
loop = get_running_loop()

if self._handle_signals:
try:
Expand All @@ -182,7 +182,7 @@ async def shutdown(self):
pass # pragma: no cover

async def cleanup(self):
loop = asyncio.get_event_loop()
loop = get_running_loop()

if self._server is None:
# no started yet, do nothing
Expand Down Expand Up @@ -269,7 +269,7 @@ async def shutdown(self):
await self._app.shutdown()

async def _make_server(self):
loop = asyncio.get_event_loop()
loop = get_running_loop()
self._app._set_loop(loop)
self._app.on_startup.freeze()
await self._app.startup()
Expand Down
3 changes: 2 additions & 1 deletion aiohttp/web_server.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Low level HTTP server."""
import asyncio

from .helpers import get_running_loop
from .web_protocol import RequestHandler
from .web_request import BaseRequest

Expand All @@ -12,7 +13,7 @@ class Server:

def __init__(self, handler, *, request_factory=None, loop=None, **kwargs):
if loop is None:
loop = asyncio.get_event_loop()
loop = get_running_loop()
self._loop = loop
self._connections = {}
self._kwargs = kwargs
Expand Down
12 changes: 6 additions & 6 deletions aiohttp/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from aiohttp import web

from .helpers import AccessLogger, set_result
from .helpers import AccessLogger, get_running_loop, set_result


try:
Expand Down Expand Up @@ -41,7 +41,7 @@ def __init__(self, *args, **kw): # pragma: no cover

def init_process(self):
# create new event_loop after fork
asyncio.get_event_loop().close()
get_running_loop().close()

self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
Expand Down Expand Up @@ -198,10 +198,10 @@ def init_process(self):

# Close any existing event loop before setting a
# new policy.
asyncio.get_event_loop().close()
get_running_loop().close()

# Setup uvloop policy, so that every
# asyncio.get_event_loop() will create an instance
# get_running_loop() will create an instance
# of uvloop event loop.
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())

Expand All @@ -215,10 +215,10 @@ def init_process(self): # pragma: no cover

# Close any existing event loop before setting a
# new policy.
asyncio.get_event_loop().close()
get_running_loop().close()

# Setup tokio policy, so that every
# asyncio.get_event_loop() will create an instance
# get_running_loop() will create an instance
# of tokio event loop.
asyncio.set_event_loop_policy(tokio.EventLoopPolicy())

Expand Down
8 changes: 6 additions & 2 deletions tests/test_client_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,8 +269,12 @@ def test_host_header_ipv6_with_port(make_request) -> None:

def test_default_loop(loop) -> None:
asyncio.set_event_loop(loop)
req = ClientRequest('get', URL('http://python.org/'))
assert req.loop is loop
with pytest.warns(DeprecationWarning) as warning_checker:
req = ClientRequest('get', URL('http://python.org/'))
assert req.loop is loop
assert len(warning_checker) == 1
msg = str(warning_checker.list[0].message)
assert msg == "The object should be created from async function"


def test_default_headers_useragent(make_request) -> None:
Expand Down
18 changes: 13 additions & 5 deletions tests/test_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,10 +238,14 @@ def test_context_manager(loop) -> None:

def test_ctor_loop() -> None:
with mock.patch('aiohttp.connector.asyncio') as m_asyncio:
session = aiohttp.BaseConnector()

with pytest.warns(DeprecationWarning) as warning_checker:
session = aiohttp.BaseConnector()
assert session._loop is m_asyncio.get_event_loop.return_value

assert len(warning_checker) == 1
msg = str(warning_checker.list[0].message)
assert msg == "The object should be created from async function"


def test_close(loop) -> None:
proto = mock.Mock()
Expand Down Expand Up @@ -1363,9 +1367,13 @@ def test_close_cancels_cleanup_closed_handle(loop) -> None:
def test_ctor_with_default_loop() -> None:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
conn = aiohttp.BaseConnector()
assert loop is conn._loop
loop.close()
with pytest.warns(DeprecationWarning) as warning_checker:
conn = aiohttp.BaseConnector()
assert loop is conn._loop
loop.close()
assert len(warning_checker) == 1
msg = str(warning_checker.list[0].message)
assert msg == "The object should be created from async function"


async def test_connect_with_limit(loop, key) -> None:
Expand Down
8 changes: 6 additions & 2 deletions tests/test_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,12 @@ async def test_close_for_async_resolver(loop) -> None:

def test_default_loop_for_threaded_resolver(loop) -> None:
asyncio.set_event_loop(loop)
resolver = ThreadedResolver()
assert resolver._loop is loop
with pytest.warns(DeprecationWarning) as warning_checker:
resolver = ThreadedResolver()
assert resolver._loop is loop
assert len(warning_checker) == 1
msg = str(warning_checker.list[0].message)
assert msg == "The object should be created from async function"


@pytest.mark.skipif(aiodns is None, reason="aiodns required")
Expand Down
9 changes: 6 additions & 3 deletions tests/test_streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,12 @@ async def test_create_waiter(self) -> None:
def test_ctor_global_loop(self) -> None:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
stream = streams.StreamReader(mock.Mock(_reading_paused=False))

assert stream._loop is loop
with pytest.warns(DeprecationWarning) as warning_checker:
stream = streams.StreamReader(mock.Mock(_reading_paused=False))
assert stream._loop is loop
assert len(warning_checker) == 1
msg = str(warning_checker.list[0].message)
assert msg == "The object should be created from async function"

async def test_at_eof(self) -> None:
stream = self._make_one()
Expand Down
Loading