Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Don't cancel web handler on disconnection #4080

Merged
merged 16 commits into from
Oct 1, 2019
18 changes: 11 additions & 7 deletions aiohttp/http_websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

from .base_protocol import BaseProtocol
from .helpers import NO_EXTENSIONS
from .log import ws_logger
from .streams import DataQueue

__all__ = ('WS_CLOSED_MESSAGE', 'WS_CLOSING_MESSAGE', 'WS_KEY',
Expand Down Expand Up @@ -552,8 +551,8 @@ def __init__(self, protocol: BaseProtocol, transport: asyncio.Transport, *,
async def _send_frame(self, message: bytes, opcode: int,
compress: Optional[int]=None) -> None:
"""Send a frame over the websocket with message as its payload."""
if self._closing:
ws_logger.warning('websocket connection is closing.')
if self._closing and not (opcode & WSMsgType.CLOSE):
raise ConnectionResetError('Cannot write to closing transport')

rsv = 0

Expand Down Expand Up @@ -595,21 +594,26 @@ async def _send_frame(self, message: bytes, opcode: int,
mask = mask.to_bytes(4, 'big')
message = bytearray(message)
_websocket_mask(mask, message)
self.transport.write(header + mask + message)
self._write(header + mask + message)
self._output_size += len(header) + len(mask) + len(message)
else:
if len(message) > MSG_SIZE:
self.transport.write(header)
self.transport.write(message)
self._write(header)
self._write(message)
else:
self.transport.write(header + message)
self._write(header + message)

self._output_size += len(header) + len(message)

if self._output_size > self._limit:
self._output_size = 0
await self.protocol._drain_helper()

def _write(self, data: bytes) -> None:
if self.transport is None or self.transport.is_closing():
raise ConnectionResetError('Cannot write to closing transport')
self.transport.write(data)

async def pong(self, message: bytes=b'') -> None:
"""Send pong message."""
if isinstance(message, str):
Expand Down
19 changes: 15 additions & 4 deletions aiohttp/web_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,8 @@ class RequestHandler(BaseProtocol):
'_waiter', '_error_handler', '_task_handler',
'_upgrade', '_payload_parser', '_request_parser',
'_reading_paused', 'logger', 'access_log',
'access_logger', '_close', '_force_close')
'access_logger', '_close', '_force_close',
'_current_request')

def __init__(self, manager: 'Server', *,
loop: asyncio.AbstractEventLoop,
Expand All @@ -156,6 +157,7 @@ def __init__(self, manager: 'Server', *,

self._request_count = 0
self._keepalive = False
self._current_request = None # type: Optional[BaseRequest]
self._manager = manager # type: Optional[Server]
self._request_handler = manager.request_handler # type: Optional[_RequestHandler] # noqa
self._request_factory = manager.request_factory # type: Optional[_RequestFactory] # noqa
Expand Down Expand Up @@ -225,6 +227,9 @@ async def shutdown(self, timeout: Optional[float]=15.0) -> None:
not self._error_handler.done()):
await self._error_handler

if self._current_request is not None:
self._current_request._cancel(asyncio.CancelledError())

if (self._task_handler is not None and
not self._task_handler.done()):
await self._task_handler
Expand Down Expand Up @@ -264,8 +269,10 @@ def connection_lost(self, exc: Optional[BaseException]) -> None:
if self._keepalive_handle is not None:
self._keepalive_handle.cancel()

if self._task_handler is not None:
self._task_handler.cancel()
if self._current_request is not None:
if exc is None:
exc = ConnectionResetError("Connetion lost")
Dreamsorcerer marked this conversation as resolved.
Show resolved Hide resolved
self._current_request._cancel(exc)

if self._error_handler is not None:
self._error_handler.cancel()
Expand Down Expand Up @@ -402,7 +409,11 @@ async def _handle_request(self,
) -> Tuple[StreamResponse, bool]:
assert self._request_handler is not None
try:
resp = await self._request_handler(request)
try:
self._current_request = request
resp = await self._request_handler(request)
finally:
self._current_request = None
except HTTPException as exc:
resp = Response(status=exc.status,
reason=exc.reason,
Expand Down
3 changes: 3 additions & 0 deletions aiohttp/web_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -685,6 +685,9 @@ def __eq__(self, other: object) -> bool:
async def _prepare_hook(self, response: StreamResponse) -> None:
return

def _cancel(self, exc: BaseException) -> None:
self._payload.set_exception(exc)


class Request(BaseRequest):

Expand Down
4 changes: 4 additions & 0 deletions aiohttp/web_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,3 +458,7 @@ async def __anext__(self) -> WSMessage:
WSMsgType.CLOSED):
raise StopAsyncIteration # NOQA
return msg

def _cancel(self, exc: BaseException) -> None:
if self._reader is not None:
self._reader.set_exception(exc)
10 changes: 3 additions & 7 deletions tests/test_client_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import aiohttp
from aiohttp import client, hdrs
from aiohttp.http import WS_KEY
from aiohttp.log import ws_logger
from aiohttp.streams import EofStream
from aiohttp.test_utils import make_mocked_coro

Expand Down Expand Up @@ -364,7 +363,7 @@ async def test_close_exc2(loop, ws_key, key_data) -> None:
await resp.close()


async def test_send_data_after_close(ws_key, key_data, loop, mocker) -> None:
async def test_send_data_after_close(ws_key, key_data, loop) -> None:
webknjaz marked this conversation as resolved.
Show resolved Hide resolved
resp = mock.Mock()
resp.status = 101
resp.headers = {
Expand All @@ -382,16 +381,13 @@ async def test_send_data_after_close(ws_key, key_data, loop, mocker) -> None:
'http://test.org')
resp._writer._closing = True

mocker.spy(ws_logger, 'warning')

for meth, args in ((resp.ping, ()),
(resp.pong, ()),
(resp.send_str, ('s',)),
(resp.send_bytes, (b'b',)),
(resp.send_json, ({},))):
await meth(*args)
assert ws_logger.warning.called
ws_logger.warning.reset_mock()
with pytest.raises(ConnectionResetError):
await meth(*args)


async def test_send_data_type_errors(ws_key, key_data, loop) -> None:
Expand Down
6 changes: 5 additions & 1 deletion tests/test_web_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -821,7 +821,11 @@ async def test_two_data_received_without_waking_up_start_task(srv) -> None:
async def test_client_disconnect(aiohttp_server) -> None:

async def handler(request):
await request.content.read(10)
buf = b""
with pytest.raises(ConnectionError):
while len(buf) < 10:
buf += await request.content.read(10)
webknjaz marked this conversation as resolved.
Show resolved Hide resolved
# return with closed transport means premature client disconnection
return web.Response()

loop = asyncio.get_event_loop()
Expand Down
96 changes: 15 additions & 81 deletions tests/test_web_websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from multidict import CIMultiDict

from aiohttp import WSMsgType, signals
from aiohttp.log import ws_logger
from aiohttp.streams import EofStream
from aiohttp.test_utils import make_mocked_coro, make_mocked_request
from aiohttp.web import HTTPBadRequest, WebSocketResponse
Expand Down Expand Up @@ -198,52 +197,48 @@ def test_closed_after_ctor() -> None:
assert ws.close_code is None


async def test_send_str_closed(make_request, mocker) -> None:
async def test_send_str_closed(make_request) -> None:
req = make_request('GET', '/')
ws = WebSocketResponse()
await ws.prepare(req)
ws._reader.feed_data(WS_CLOSED_MESSAGE, 0)
await ws.close()

mocker.spy(ws_logger, 'warning')
await ws.send_str('string')
assert ws_logger.warning.called
with pytest.raises(ConnectionError):
await ws.send_str('string')


async def test_send_bytes_closed(make_request, mocker) -> None:
async def test_send_bytes_closed(make_request) -> None:
req = make_request('GET', '/')
ws = WebSocketResponse()
await ws.prepare(req)
ws._reader.feed_data(WS_CLOSED_MESSAGE, 0)
await ws.close()

mocker.spy(ws_logger, 'warning')
await ws.send_bytes(b'bytes')
assert ws_logger.warning.called
with pytest.raises(ConnectionError):
await ws.send_bytes(b'bytes')


async def test_send_json_closed(make_request, mocker) -> None:
async def test_send_json_closed(make_request) -> None:
req = make_request('GET', '/')
ws = WebSocketResponse()
await ws.prepare(req)
ws._reader.feed_data(WS_CLOSED_MESSAGE, 0)
await ws.close()

mocker.spy(ws_logger, 'warning')
await ws.send_json({'type': 'json'})
assert ws_logger.warning.called
with pytest.raises(ConnectionError):
await ws.send_json({'type': 'json'})


async def test_ping_closed(make_request, mocker) -> None:
async def test_ping_closed(make_request) -> None:
req = make_request('GET', '/')
ws = WebSocketResponse()
await ws.prepare(req)
ws._reader.feed_data(WS_CLOSED_MESSAGE, 0)
await ws.close()

mocker.spy(ws_logger, 'warning')
await ws.ping()
assert ws_logger.warning.called
with pytest.raises(ConnectionError):
await ws.ping()


async def test_pong_closed(make_request, mocker) -> None:
Expand All @@ -253,9 +248,8 @@ async def test_pong_closed(make_request, mocker) -> None:
ws._reader.feed_data(WS_CLOSED_MESSAGE, 0)
await ws.close()

mocker.spy(ws_logger, 'warning')
await ws.pong()
assert ws_logger.warning.called
with pytest.raises(ConnectionError):
await ws.pong()


async def test_close_idempotent(make_request) -> None:
Expand Down Expand Up @@ -326,40 +320,6 @@ async def test_receive_eofstream_in_reader(make_request, loop) -> None:
assert ws.closed


async def test_receive_exc_in_reader(make_request, loop) -> None:
req = make_request('GET', '/')
ws = WebSocketResponse()
await ws.prepare(req)

ws._reader = mock.Mock()
exc = ValueError()
res = loop.create_future()
res.set_exception(exc)
ws._reader.read = make_mocked_coro(res)
ws._payload_writer.drain = mock.Mock()
ws._payload_writer.drain.return_value = loop.create_future()
ws._payload_writer.drain.return_value.set_result(True)

msg = await ws.receive()
assert msg.type == WSMsgType.ERROR
assert msg.data is exc
assert ws.exception() is exc


async def test_receive_cancelled(make_request, loop) -> None:
req = make_request('GET', '/')
ws = WebSocketResponse()
await ws.prepare(req)

ws._reader = mock.Mock()
res = loop.create_future()
res.set_exception(asyncio.CancelledError())
ws._reader.read = make_mocked_coro(res)

with pytest.raises(asyncio.CancelledError):
await ws.receive()


async def test_receive_timeouterror(make_request, loop) -> None:
req = make_request('GET', '/')
ws = WebSocketResponse()
Expand Down Expand Up @@ -400,33 +360,7 @@ async def test_concurrent_receive(make_request) -> None:
await ws.receive()


async def test_close_exc(make_request, loop, mocker) -> None:
req = make_request('GET', '/')

ws = WebSocketResponse()
await ws.prepare(req)

ws._reader = mock.Mock()
exc = ValueError()
ws._reader.read.return_value = loop.create_future()
ws._reader.read.return_value.set_exception(exc)
ws._payload_writer.drain = mock.Mock()
ws._payload_writer.drain.return_value = loop.create_future()
ws._payload_writer.drain.return_value.set_result(True)

await ws.close()
assert ws.closed
assert ws.exception() is exc

ws._closed = False
ws._reader.read.return_value = loop.create_future()
ws._reader.read.return_value.set_exception(asyncio.CancelledError())
with pytest.raises(asyncio.CancelledError):
await ws.close()
assert ws.close_code == 1006


async def test_close_exc2(make_request) -> None:
async def test_close_exc(make_request) -> None:

req = make_request('GET', '/')
ws = WebSocketResponse()
Expand Down
26 changes: 2 additions & 24 deletions tests/test_web_websocket_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,19 +277,6 @@ async def handler(request):
await asyncio.sleep(0.08)
msg = await ws._reader.read()
assert msg.type == WSMsgType.CLOSE
await ws.send_str('hang')

# i am not sure what do we test here
# under uvloop this code raises RuntimeError
try:
await asyncio.sleep(0.08)
await ws.send_str('hang')
await asyncio.sleep(0.08)
await ws.send_str('hang')
await asyncio.sleep(0.08)
await ws.send_str('hang')
except RuntimeError:
pass

await asyncio.sleep(0.08)
assert (await aborted)
Expand Down Expand Up @@ -665,19 +652,12 @@ async def handler(request):


async def test_heartbeat_no_pong(loop, aiohttp_client, ceil) -> None:
cancelled = False

async def handler(request):
nonlocal cancelled

ws = web.WebSocketResponse(heartbeat=0.05)
await ws.prepare(request)

try:
await ws.receive()
except asyncio.CancelledError:
cancelled = True

await ws.receive()
return ws

app = web.Application()
Expand All @@ -687,9 +667,7 @@ async def handler(request):
ws = await client.ws_connect('/', autoping=False)
msg = await ws.receive()
assert msg.type == aiohttp.WSMsgType.PING
await ws.receive()

assert cancelled
await ws.close()


async def test_server_ws_async_for(loop, aiohttp_server) -> None:
Expand Down
4 changes: 3 additions & 1 deletion tests/test_websocket_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ def protocol():

@pytest.fixture
def transport():
return mock.Mock()
ret = mock.Mock()
ret.is_closing.return_value = False
return ret


@pytest.fixture
Expand Down