Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Websockets refactoring #2836

Merged
merged 10 commits into from
Mar 15, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGES/2836.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Websockets refactoring, all websocket writer methods are converted
into coroutines.
9 changes: 6 additions & 3 deletions aiohttp/client_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(self, reader, writer, protocol,
self._heartbeat = heartbeat
self._heartbeat_cb = None
if heartbeat is not None:
self._pong_heartbeat = heartbeat/2.0
self._pong_heartbeat = heartbeat / 2.0
self._pong_response_cb = None
self._loop = loop
self._waiting = None
Expand Down Expand Up @@ -61,7 +61,10 @@ def _reset_heartbeat(self):

def _send_heartbeat(self):
if self._heartbeat is not None and not self._closed:
self._writer.ping()
# fire-and-forget a task is not perfect but maybe ok for
# sending ping. Otherwise we need a long-living heartbeat
# task in the class.
self._loop.create_task(self._writer.ping())

if self._pong_response_cb is not None:
self._pong_response_cb.cancel()
Expand Down Expand Up @@ -137,7 +140,7 @@ async def close(self, *, code=1000, message=b''):
self._cancel_heartbeat()
self._closed = True
try:
self._writer.close(code, message)
await self._writer.close(code, message)
except asyncio.CancelledError:
self._close_code = 1006
self._response.close()
Expand Down
26 changes: 12 additions & 14 deletions aiohttp/http_websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from enum import IntEnum
from struct import Struct

from .helpers import NO_EXTENSIONS, noop
from .helpers import NO_EXTENSIONS
from .log import ws_logger


Expand Down Expand Up @@ -527,7 +527,7 @@ def __init__(self, protocol, transport, *,
self._output_size = 0
self._compressobj = None

def _send_frame(self, message, opcode, compress=None):
async def _send_frame(self, message, opcode, compress=None):
"""Send a frame over the websocket with message as its payload."""
if self._closing:
ws_logger.warning('websocket connection is closing.')
Expand Down Expand Up @@ -585,37 +585,35 @@ def _send_frame(self, message, opcode, compress=None):

if self._output_size > self._limit:
self._output_size = 0
return self.protocol._drain_helper()
await self.protocol._drain_helper()

return noop()

def pong(self, message=b''):
async def pong(self, message=b''):
"""Send pong message."""
if isinstance(message, str):
message = message.encode('utf-8')
return self._send_frame(message, WSMsgType.PONG)
return await self._send_frame(message, WSMsgType.PONG)

def ping(self, message=b''):
async def ping(self, message=b''):
"""Send ping message."""
if isinstance(message, str):
message = message.encode('utf-8')
return self._send_frame(message, WSMsgType.PING)
return await self._send_frame(message, WSMsgType.PING)

def send(self, message, binary=False, compress=None):
async def send(self, message, binary=False, compress=None):
"""Send a frame over the websocket with message as its payload."""
if isinstance(message, str):
message = message.encode('utf-8')
if binary:
return self._send_frame(message, WSMsgType.BINARY, compress)
return await self._send_frame(message, WSMsgType.BINARY, compress)
else:
return self._send_frame(message, WSMsgType.TEXT, compress)
return await self._send_frame(message, WSMsgType.TEXT, compress)

def close(self, code=1000, message=b''):
async def close(self, code=1000, message=b''):
"""Close the websocket, sending the specified code and message."""
if isinstance(message, str):
message = message.encode('utf-8')
try:
return self._send_frame(
return await self._send_frame(
PACK_CLOSE_CODE(code) + message, opcode=WSMsgType.CLOSE)
finally:
self._closing = True
9 changes: 6 additions & 3 deletions aiohttp/web_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def __init__(self, *,
self._heartbeat = heartbeat
self._heartbeat_cb = None
if heartbeat is not None:
self._pong_heartbeat = heartbeat/2.0
self._pong_heartbeat = heartbeat / 2.0
self._pong_response_cb = None
self._compress = compress

Expand All @@ -80,7 +80,10 @@ def _reset_heartbeat(self):

def _send_heartbeat(self):
if self._heartbeat is not None and not self._closed:
self._writer.ping()
# fire-and-forget a task is not perfect but maybe ok for
# sending ping. Otherwise we need a long-living heartbeat
# task in the class.
self._loop.create_task(self._writer.ping())

if self._pong_response_cb is not None:
self._pong_response_cb.cancel()
Expand Down Expand Up @@ -286,7 +289,7 @@ async def close(self, *, code=1000, message=b''):
if not self._closed:
self._closed = True
try:
self._writer.close(code, message)
await self._writer.close(code, message)
await self._payload_writer.drain()
except (asyncio.CancelledError, asyncio.TimeoutError):
self._close_code = 1006
Expand Down
13 changes: 10 additions & 3 deletions tests/test_client_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,9 @@ async def test_close(loop, ws_key, key_data):
m_os.urandom.return_value = key_data
m_req.return_value = loop.create_future()
m_req.return_value.set_result(resp)
writer = WebSocketWriter.return_value = mock.Mock()
writer = mock.Mock()
WebSocketWriter.return_value = writer
writer.close = make_mocked_coro()

session = aiohttp.ClientSession(loop=loop)
resp = await session.ws_connect(
Expand Down Expand Up @@ -280,7 +282,9 @@ async def test_close_exc(loop, ws_key, key_data):
m_os.urandom.return_value = key_data
m_req.return_value = loop.create_future()
m_req.return_value.set_result(resp)
WebSocketWriter.return_value = mock.Mock()
writer = mock.Mock()
WebSocketWriter.return_value = writer
writer.close = make_mocked_coro()

session = aiohttp.ClientSession(loop=loop)
resp = await session.ws_connect('http://test.org')
Expand Down Expand Up @@ -400,7 +404,10 @@ async def test_reader_read_exception(ws_key, key_data, loop):
m_os.urandom.return_value = key_data
m_req.return_value = loop.create_future()
m_req.return_value.set_result(hresp)
WebSocketWriter.return_value = mock.Mock()

writer = mock.Mock()
WebSocketWriter.return_value = writer
writer.close = make_mocked_coro()

session = aiohttp.ClientSession(loop=loop)
resp = await session.ws_connect('http://test.org')
Expand Down
2 changes: 1 addition & 1 deletion tests/test_client_ws_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,7 +544,7 @@ async def handler(request):

client = await aiohttp_client(app)
resp = await client.ws_connect('/', heartbeat=0.01)

await asyncio.sleep(0.1)
await resp.receive()
await resp.close()

Expand Down
61 changes: 32 additions & 29 deletions tests/test_websocket_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,14 @@
import pytest

from aiohttp.http import WebSocketWriter
from aiohttp.test_utils import make_mocked_coro


@pytest.fixture
def protocol():
return mock.Mock()
ret = mock.Mock()
ret._drain_helper = make_mocked_coro()
return ret


@pytest.fixture
Expand All @@ -21,83 +24,83 @@ def writer(protocol, transport):
return WebSocketWriter(protocol, transport, use_mask=False)


def test_pong(writer):
writer.pong()
async def test_pong(writer):
await writer.pong()
writer.transport.write.assert_called_with(b'\x8a\x00')


def test_ping(writer):
writer.ping()
async def test_ping(writer):
await writer.ping()
writer.transport.write.assert_called_with(b'\x89\x00')


def test_send_text(writer):
writer.send(b'text')
async def test_send_text(writer):
await writer.send(b'text')
writer.transport.write.assert_called_with(b'\x81\x04text')


def test_send_binary(writer):
writer.send('binary', True)
async def test_send_binary(writer):
await writer.send('binary', True)
writer.transport.write.assert_called_with(b'\x82\x06binary')


def test_send_binary_long(writer):
writer.send(b'b' * 127, True)
async def test_send_binary_long(writer):
await writer.send(b'b' * 127, True)
assert writer.transport.write.call_args[0][0].startswith(b'\x82~\x00\x7fb')


def test_send_binary_very_long(writer):
writer.send(b'b' * 65537, True)
async def test_send_binary_very_long(writer):
await writer.send(b'b' * 65537, True)
assert (writer.transport.write.call_args_list[0][0][0] ==
b'\x82\x7f\x00\x00\x00\x00\x00\x01\x00\x01')
assert writer.transport.write.call_args_list[1][0][0] == b'b' * 65537


def test_close(writer):
writer.close(1001, 'msg')
async def test_close(writer):
await writer.close(1001, 'msg')
writer.transport.write.assert_called_with(b'\x88\x05\x03\xe9msg')

writer.close(1001, b'msg')
await writer.close(1001, b'msg')
writer.transport.write.assert_called_with(b'\x88\x05\x03\xe9msg')

# Test that Service Restart close code is also supported
writer.close(1012, b'msg')
await writer.close(1012, b'msg')
writer.transport.write.assert_called_with(b'\x88\x05\x03\xf4msg')


def test_send_text_masked(protocol, transport):
async def test_send_text_masked(protocol, transport):
writer = WebSocketWriter(protocol,
transport,
use_mask=True,
random=random.Random(123))
writer.send(b'text')
await writer.send(b'text')
writer.transport.write.assert_called_with(b'\x81\x84\rg\xb3fy\x02\xcb\x12')


def test_send_compress_text(protocol, transport):
async def test_send_compress_text(protocol, transport):
writer = WebSocketWriter(protocol, transport, compress=15)
writer.send(b'text')
await writer.send(b'text')
writer.transport.write.assert_called_with(b'\xc1\x06*I\xad(\x01\x00')
writer.send(b'text')
await writer.send(b'text')
writer.transport.write.assert_called_with(b'\xc1\x05*\x01b\x00\x00')


def test_send_compress_text_notakeover(protocol, transport):
async def test_send_compress_text_notakeover(protocol, transport):
writer = WebSocketWriter(protocol,
transport,
compress=15,
notakeover=True)
writer.send(b'text')
await writer.send(b'text')
writer.transport.write.assert_called_with(b'\xc1\x06*I\xad(\x01\x00')
writer.send(b'text')
await writer.send(b'text')
writer.transport.write.assert_called_with(b'\xc1\x06*I\xad(\x01\x00')


def test_send_compress_text_per_message(protocol, transport):
async def test_send_compress_text_per_message(protocol, transport):
writer = WebSocketWriter(protocol, transport)
writer.send(b'text', compress=15)
await writer.send(b'text', compress=15)
writer.transport.write.assert_called_with(b'\xc1\x06*I\xad(\x01\x00')
writer.send(b'text')
await writer.send(b'text')
writer.transport.write.assert_called_with(b'\x81\x04text')
writer.send(b'text', compress=15)
await writer.send(b'text', compress=15)
writer.transport.write.assert_called_with(b'\xc1\x06*I\xad(\x01\x00')