diff --git a/CONTRIBUTORS.txt b/CONTRIBUTORS.txt index da7848c5a98..edafc7fe7f4 100644 --- a/CONTRIBUTORS.txt +++ b/CONTRIBUTORS.txt @@ -35,6 +35,7 @@ Arthur Darcet Ben Bader Benedikt Reinartz Boris Feld +Boyi Chen Brett Cannon Brian C. Lane Brian Muller diff --git a/aiohttp/client.py b/aiohttp/client.py index 57cccc9919e..fc5ad9fb983 100644 --- a/aiohttp/client.py +++ b/aiohttp/client.py @@ -27,6 +27,7 @@ from .helpers import (PY_35, CeilTimeout, ProxyInfo, TimeoutHandle, _BaseCoroMixin, deprecated_noop, sentinel) from .http import WS_KEY, WebSocketReader, WebSocketWriter +from .http_websocket import WSHandshakeError, ws_ext_gen, ws_ext_parse from .streams import FlowControlDataQueue @@ -370,7 +371,8 @@ def ws_connect(self, url, *, origin=None, headers=None, proxy=None, - proxy_auth=None): + proxy_auth=None, + compress=0): """Initiate websocket connection.""" return _WSRequestContextManager( self._ws_connect(url, @@ -384,7 +386,8 @@ def ws_connect(self, url, *, origin=origin, headers=headers, proxy=proxy, - proxy_auth=proxy_auth)) + proxy_auth=proxy_auth, + compress=compress)) @asyncio.coroutine def _ws_connect(self, url, *, @@ -398,7 +401,8 @@ def _ws_connect(self, url, *, origin=None, headers=None, proxy=None, - proxy_auth=None): + proxy_auth=None, + compress=0): if headers is None: headers = CIMultiDict() @@ -420,6 +424,9 @@ def _ws_connect(self, url, *, headers[hdrs.SEC_WEBSOCKET_PROTOCOL] = ','.join(protocols) if origin is not None: headers[hdrs.ORIGIN] = origin + if compress: + extstr = ws_ext_gen(compress=compress) + headers[hdrs.SEC_WEBSOCKET_EXTENSIONS] = extstr # send request resp = yield from self.get(url, headers=headers, @@ -478,12 +485,32 @@ def _ws_connect(self, url, *, protocol = proto break + # websocket compress + notakeover = False + if compress: + compress_hdrs = resp.headers.get(hdrs.SEC_WEBSOCKET_EXTENSIONS) + if compress_hdrs: + try: + compress, notakeover = ws_ext_parse(compress_hdrs) + except WSHandshakeError as exc: + raise WSServerHandshakeError( + resp.request_info, + resp.history, + message=exc.args[0], + code=resp.status, + headers=resp.headers) + else: + compress = 0 + notakeover = False + proto = resp.connection.protocol reader = FlowControlDataQueue( proto, limit=2 ** 16, loop=self._loop) proto.set_parser(WebSocketReader(reader), reader) resp.connection.writer.set_tcp_nodelay(True) - writer = WebSocketWriter(resp.connection.writer, use_mask=True) + writer = WebSocketWriter( + resp.connection.writer, use_mask=True, + compress=compress, notakeover=notakeover) except Exception: resp.close() raise @@ -497,7 +524,9 @@ def _ws_connect(self, url, *, autoping, self._loop, receive_timeout=receive_timeout, - heartbeat=heartbeat) + heartbeat=heartbeat, + compress=compress, + client_notakeover=notakeover) def _prepare_headers(self, headers): """ Add default headers and transform it to CIMultiDict diff --git a/aiohttp/client_ws.py b/aiohttp/client_ws.py index 05ef2272580..7d026fa1f7e 100644 --- a/aiohttp/client_ws.py +++ b/aiohttp/client_ws.py @@ -13,7 +13,8 @@ class ClientWebSocketResponse: def __init__(self, reader, writer, protocol, response, timeout, autoclose, autoping, loop, *, - receive_timeout=None, heartbeat=None): + receive_timeout=None, heartbeat=None, + compress=0, client_notakeover=False): self._response = response self._conn = response.connection @@ -35,6 +36,8 @@ def __init__(self, reader, writer, protocol, self._loop = loop self._waiting = None self._exception = None + self._compress = compress + self._client_notakeover = client_notakeover self._reset_heartbeat() @@ -82,6 +85,14 @@ def close_code(self): def protocol(self): return self._protocol + @property + def compress(self): + return self._compress + + @property + def client_notakeover(self): + return self._client_notakeover + def get_extra_info(self, name, default=None): """extra info from connection transport""" try: diff --git a/aiohttp/hdrs.py b/aiohttp/hdrs.py index e14fcfa43df..b49d79b1043 100644 --- a/aiohttp/hdrs.py +++ b/aiohttp/hdrs.py @@ -75,6 +75,7 @@ SEC_WEBSOCKET_ACCEPT = istr('SEC-WEBSOCKET-ACCEPT') SEC_WEBSOCKET_VERSION = istr('SEC-WEBSOCKET-VERSION') SEC_WEBSOCKET_PROTOCOL = istr('SEC-WEBSOCKET-PROTOCOL') +SEC_WEBSOCKET_EXTENSIONS = istr('SEC-WEBSOCKET-EXTENSIONS') SEC_WEBSOCKET_KEY = istr('SEC-WEBSOCKET-KEY') SEC_WEBSOCKET_KEY1 = istr('SEC-WEBSOCKET-KEY1') SERVER = istr('SERVER') diff --git a/aiohttp/http_websocket.py b/aiohttp/http_websocket.py index f853c0d1a86..37f20787039 100644 --- a/aiohttp/http_websocket.py +++ b/aiohttp/http_websocket.py @@ -6,7 +6,9 @@ import hashlib import json import random +import re import sys +import zlib from enum import IntEnum from struct import Struct @@ -106,6 +108,10 @@ def __init__(self, code, message): super().__init__(message) +class WSHandshakeError(Exception): + """WebSocket protocol handshake error.""" + + native_byteorder = sys.byteorder @@ -146,6 +152,89 @@ def _websocket_mask_python(mask, data): except ImportError: # pragma: no cover _websocket_mask = _websocket_mask_python +_WS_DEFLATE_TRAILING = bytes([0x00, 0x00, 0xff, 0xff]) + + +_WS_EXT_RE = re.compile(r'^(?:;\s*(?:' + r'(server_no_context_takeover)|' + r'(client_no_context_takeover)|' + r'(server_max_window_bits(?:=(\d+))?)|' + r'(client_max_window_bits(?:=(\d+))?)))*$') + +_WS_EXT_RE_SPLIT = re.compile(r'permessage-deflate([^,]+)?') + + +def ws_ext_parse(extstr, isserver=False): + if not extstr: + return 0, False + + compress = 0 + notakeover = False + for ext in _WS_EXT_RE_SPLIT.finditer(extstr): + defext = ext.group(1) + # Return compress = 15 when get `permessage-deflate` + if not defext: + compress = 15 + break + match = _WS_EXT_RE.match(defext) + if match: + compress = 15 + if isserver: + # Server never fail to detect compress handshake. + # Server does not need to send max wbit to client + if match.group(4): + compress = int(match.group(4)) + # Group3 must match if group4 matches + # Compress wbit 8 does not support in zlib + # If compress level not support, + # CONTINUE to next extension + if compress > 15 or compress < 9: + compress = 0 + continue + if match.group(1): + notakeover = True + # Ignore regex group 5 & 6 for client_max_window_bits + break + else: + if match.group(6): + compress = int(match.group(6)) + # Group5 must match if group6 matches + # Compress wbit 8 does not support in zlib + # If compress level not support, + # FAIL the parse progress + if compress > 15 or compress < 9: + raise WSHandshakeError('Invalid window size') + if match.group(2): + notakeover = True + # Ignore regex group 5 & 6 for client_max_window_bits + break + # Return Fail if client side and not match + elif not isserver: + raise WSHandshakeError('Extension for deflate not supported' + + ext.group(1)) + + return compress, notakeover + + +def ws_ext_gen(compress=15, isserver=False, + server_notakeover=False): + # client_notakeover=False not used for server + # compress wbit 8 does not support in zlib + if compress < 9 or compress > 15: + raise ValueError('Compress wbits must between 9 and 15, ' + 'zlib does not support wbits=8') + enabledext = ['permessage-deflate'] + if not isserver: + enabledext.append('client_max_window_bits') + + if compress < 15: + enabledext.append('server_max_window_bits=' + str(compress)) + if server_notakeover: + enabledext.append('server_no_context_takeover') + # if client_notakeover: + # enabledext.append('client_no_context_takeover') + return '; '.join(enabledext) + class WSParserState(IntEnum): READ_HEADER = 1 @@ -156,7 +245,7 @@ class WSParserState(IntEnum): class WebSocketReader: - def __init__(self, queue): + def __init__(self, queue, compress=True): self.queue = queue self._exc = None @@ -173,6 +262,9 @@ def __init__(self, queue): self._frame_mask = None self._payload_length = 0 self._payload_length_flag = 0 + self._compressed = None + self._decompressobj = None + self._compress = compress def feed_eof(self): self.queue.feed_eof() @@ -189,7 +281,9 @@ def feed_data(self, data): return True, b'' def _feed_data(self, data): - for fin, opcode, payload in self.parse_frame(data): + for fin, opcode, payload, compressed in self.parse_frame(data): + if compressed and not self._decompressobj: + self._decompressobj = zlib.decompressobj(wbits=-zlib.MAX_WBITS) if opcode == WSMsgType.CLOSE: if len(payload) >= 2: close_code = UNPACK_CLOSE_CODE(payload[:2])[0] @@ -234,7 +328,6 @@ def _feed_data(self, data): # got partial frame payload if opcode != WSMsgType.CONTINUATION: self._opcode = opcode - self._partial.append(payload) else: # previous frame was non finished @@ -250,7 +343,16 @@ def _feed_data(self, data): opcode = self._opcode self._opcode = None - payload_merged = b''.join(self._partial) + payload + self._partial.append(payload) + + payload_merged = b''.join(self._partial) + + # Decompress process must to be done after all packets + # received. + if compressed: + payload_merged = self._decompressobj.decompress( + payload_merged + _WS_DEFLATE_TRAILING) + self._partial.clear() if opcode == WSMsgType.TEXT: @@ -300,7 +402,9 @@ def parse_frame(self, buf): # 1 bit, MUST be 0 unless negotiated otherwise # frame-rsv3 = %x0 ; # 1 bit, MUST be 0 unless negotiated otherwise - if rsv1 or rsv2 or rsv3: + # + # Remove rsv1 from this test for deflate development + if rsv2 or rsv3 or (rsv1 and not self._compress): raise WebSocketError( WSCloseCode.PROTOCOL_ERROR, 'Received frame with non-zero reserved bits') @@ -321,6 +425,16 @@ def parse_frame(self, buf): 'Control frame payload cannot be ' 'larger than 125 bytes') + # Set compress status if last package is FIN + # OR set compress status if this is first fragment + # Raise error if not first fragment with rsv1 = 0x1 + if self._frame_fin or self._compressed is None: + self._compressed = True if rsv1 else False + elif rsv1: + raise WebSocketError( + WSCloseCode.PROTOCOL_ERROR, + 'Received frame with non-zero reserved bits') + self._frame_fin = fin self._frame_opcode = opcode self._has_mask = has_mask @@ -390,8 +504,11 @@ def parse_frame(self, buf): if self._has_mask: _websocket_mask(self._frame_mask, payload) - frames.append( - (self._frame_fin, self._frame_opcode, payload)) + frames.append(( + self._frame_fin, + self._frame_opcode, + payload, + self._compressed)) self._frame_payload = bytearray() self._state = WSParserState.READ_HEADER @@ -406,20 +523,40 @@ def parse_frame(self, buf): class WebSocketWriter: def __init__(self, stream, *, - use_mask=False, limit=DEFAULT_LIMIT, random=random.Random()): + use_mask=False, limit=DEFAULT_LIMIT, random=random.Random(), + compress=0, notakeover=False): self.stream = stream self.writer = stream.transport self.use_mask = use_mask self.randrange = random.randrange + self.compress = compress + self.notakeover = notakeover self._closing = False self._limit = limit self._output_size = 0 + self._compressobj = None def _send_frame(self, message, opcode): """Send a frame over the websocket with message as its payload.""" if self._closing: ws_logger.warning('websocket connection is closing.') + rsv = 0 + + # Only compress larger packets (disabled) + # Does small packet needs to be compressed? + # if self.compress and opcode < 8 and len(message) > 124: + if self.compress and opcode < 8: + if not self._compressobj: + self._compressobj = zlib.compressobj(wbits=-self.compress) + + message = self._compressobj.compress(message) + message = message + self._compressobj.flush( + zlib.Z_FULL_FLUSH if self.notakeover else zlib.Z_SYNC_FLUSH) + if message.endswith(_WS_DEFLATE_TRAILING): + message = message[:-4] + rsv = rsv | 0x40 + msg_length = len(message) use_mask = self.use_mask @@ -429,11 +566,11 @@ def _send_frame(self, message, opcode): mask_bit = 0 if msg_length < 126: - header = PACK_LEN1(0x80 | opcode, msg_length | mask_bit) + header = PACK_LEN1(0x80 | rsv | opcode, msg_length | mask_bit) elif msg_length < (1 << 16): - header = PACK_LEN2(0x80 | opcode, 126 | mask_bit, msg_length) + header = PACK_LEN2(0x80 | rsv | opcode, 126 | mask_bit, msg_length) else: - header = PACK_LEN3(0x80 | opcode, 127 | mask_bit, msg_length) + header = PACK_LEN3(0x80 | rsv | opcode, 127 | mask_bit, msg_length) if use_mask: mask = self.randrange(0, 0xffffffff) mask = mask.to_bytes(4, 'big') @@ -488,8 +625,8 @@ def close(self, code=1000, message=b''): self._closing = True -def do_handshake(method, headers, stream, - protocols=(), write_buffer_size=DEFAULT_LIMIT): +def do_handshake(method, headers, stream, protocols=(), + write_buffer_size=DEFAULT_LIMIT, compress=True): """Prepare WebSocket handshake. It return HTTP response code, response headers, websocket parser, @@ -500,6 +637,8 @@ def do_handshake(method, headers, stream, which the server also knows. `write_buffer_size` max size of write buffer before `drain()` get called. + + `compress` enable or disable server side deflate extension support. """ # WebSocket accepts only GET if method.upper() != hdrs.METH_GET: @@ -556,6 +695,18 @@ def do_handshake(method, headers, stream, (hdrs.SEC_WEBSOCKET_ACCEPT, base64.b64encode( hashlib.sha1(key.encode() + WS_KEY).digest()).decode())] + notakeover = False + if compress: + extensions = headers.get(hdrs.SEC_WEBSOCKET_EXTENSIONS) + # Server side always get return with no exception. + # If something happened, just drop compress extension + compress, notakeover = ws_ext_parse(extensions, isserver=True) + if compress: + enabledext = ws_ext_gen(compress=compress, isserver=True, + server_notakeover=notakeover) + response_headers.append((hdrs.SEC_WEBSOCKET_EXTENSIONS, + enabledext)) + if protocol: response_headers.append((hdrs.SEC_WEBSOCKET_PROTOCOL, protocol)) @@ -563,5 +714,8 @@ def do_handshake(method, headers, stream, return (101, response_headers, None, - WebSocketWriter(stream, limit=write_buffer_size), - protocol) + WebSocketWriter( + stream, limit=write_buffer_size, + compress=compress, notakeover=notakeover), + protocol, + compress) diff --git a/aiohttp/web_ws.py b/aiohttp/web_ws.py index b5f30a2083d..1f3b19472f1 100644 --- a/aiohttp/web_ws.py +++ b/aiohttp/web_ws.py @@ -32,7 +32,7 @@ class WebSocketResponse(StreamResponse): def __init__(self, *, timeout=10.0, receive_timeout=None, autoclose=True, autoping=True, heartbeat=None, - protocols=()): + protocols=(), compress=True): super().__init__(status=101) self._protocols = protocols self._ws_protocol = None @@ -54,6 +54,7 @@ def __init__(self, *, if heartbeat is not None: self._pong_heartbeat = heartbeat/2.0 self._pong_response_cb = None + self._compress = compress def _cancel_heartbeat(self): if self._pong_response_cb is not None: @@ -103,9 +104,9 @@ def _pre_start(self, request): self._loop = request.loop try: - status, headers, _, writer, protocol = do_handshake( + status, headers, _, writer, protocol, compress = do_handshake( request.method, request.headers, request._protocol.writer, - self._protocols) + self._protocols, compress=self._compress) except HttpProcessingError as err: if err.code == 405: raise HTTPMethodNotAllowed( @@ -122,6 +123,7 @@ def _pre_start(self, request): for k, v in headers: self.headers[k] = v self.force_close() + self._compress = compress return protocol, writer def _post_start(self, request, protocol, writer): @@ -129,13 +131,14 @@ def _post_start(self, request, protocol, writer): self._writer = writer self._reader = FlowControlDataQueue( request._protocol, limit=2 ** 16, loop=self._loop) - request.protocol.set_parser(WebSocketReader(self._reader)) + request.protocol.set_parser(WebSocketReader( + self._reader, compress=self._compress)) def can_prepare(self, request): if self._writer is not None: raise RuntimeError('Already started') try: - _, _, _, _, protocol = do_handshake( + _, _, _, _, protocol, _ = do_handshake( request.method, request.headers, request._protocol.writer, self._protocols) except HttpProcessingError: @@ -155,6 +158,10 @@ def close_code(self): def ws_protocol(self): return self._ws_protocol + @property + def compress(self): + return self._compress + def exception(self): return self._exception diff --git a/changes/2273.feature b/changes/2273.feature new file mode 100644 index 00000000000..a2da4f5f422 --- /dev/null +++ b/changes/2273.feature @@ -0,0 +1,5 @@ +Add server support for WebSockets Per-Message Deflate. + +Add client option to add deflate compress header in WebSockets request header. +If calling ClientSession.ws_connect() with `compress=15` the client will +support deflate compress negotiation. diff --git a/docs/client_reference.rst b/docs/client_reference.rst index 746d4393343..37788462a33 100644 --- a/docs/client_reference.rst +++ b/docs/client_reference.rst @@ -524,6 +524,10 @@ The client session supports the context manager protocol for self closing. :param aiohttp.BasicAuth proxy_auth: an object that represents proxy HTTP Basic Authorization (optional) + :param int compress: Enable Per-Message Compress Extension support. + 0 for disable, 9 to 15 for window bit support. + Default value is 0. + .. versionadded:: 0.16 Add :meth:`ws_connect`. diff --git a/docs/web_reference.rst b/docs/web_reference.rst index e7cb203c711..f9e14b4f90a 100644 --- a/docs/web_reference.rst +++ b/docs/web_reference.rst @@ -862,8 +862,8 @@ WebSocketResponse ^^^^^^^^^^^^^^^^^ .. class:: WebSocketResponse(*, timeout=10.0, receive_timeout=None, \ - autoclose=True, \ - autoping=True, heartbeat=None, protocols=()) + autoclose=True, autoping=True, heartbeat=None, \ + protocols=(), compress=True) Class for handling server-side websockets, inherited from :class:`StreamResponse`. @@ -901,6 +901,9 @@ WebSocketResponse operations. Default value is None (no timeout for receive operation) + :param float compress: Enable per-message deflate extension support. + False for disabled, default value is True. + .. versionadded:: 0.19 The class supports ``async for`` statement for iterating over diff --git a/tests/test_client_ws.py b/tests/test_client_ws.py index 1f85d1bd89d..e2f2716d5b9 100644 --- a/tests/test_client_ws.py +++ b/tests/test_client_ws.py @@ -507,3 +507,139 @@ def test_ws_connect_non_overlapped_protocols_2(ws_key, loop, key_data): assert res.protocol is None del res + + +@asyncio.coroutine +def test_ws_connect_deflate(loop, ws_key, key_data): + resp = mock.Mock() + resp.status = 101 + resp.headers = { + hdrs.UPGRADE: hdrs.WEBSOCKET, + hdrs.CONNECTION: hdrs.UPGRADE, + hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, + hdrs.SEC_WEBSOCKET_EXTENSIONS: 'permessage-deflate', + } + with mock.patch('aiohttp.client.os') as m_os: + with mock.patch('aiohttp.client.ClientSession.get') as m_req: + m_os.urandom.return_value = key_data + m_req.return_value = helpers.create_future(loop) + m_req.return_value.set_result(resp) + + res = yield from aiohttp.ClientSession(loop=loop).ws_connect( + 'http://test.org', compress=15) + + assert res.compress == 15 + assert res.client_notakeover is False + + +@asyncio.coroutine +def test_ws_connect_deflate_server_not_support(loop, ws_key, key_data): + resp = mock.Mock() + resp.status = 101 + resp.headers = { + hdrs.UPGRADE: hdrs.WEBSOCKET, + hdrs.CONNECTION: hdrs.UPGRADE, + hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, + } + with mock.patch('aiohttp.client.os') as m_os: + with mock.patch('aiohttp.client.ClientSession.get') as m_req: + m_os.urandom.return_value = key_data + m_req.return_value = helpers.create_future(loop) + m_req.return_value.set_result(resp) + + res = yield from aiohttp.ClientSession(loop=loop).ws_connect( + 'http://test.org', compress=15) + + assert res.compress == 0 + assert res.client_notakeover is False + + +@asyncio.coroutine +def test_ws_connect_deflate_notakeover(loop, ws_key, key_data): + resp = mock.Mock() + resp.status = 101 + resp.headers = { + hdrs.UPGRADE: hdrs.WEBSOCKET, + hdrs.CONNECTION: hdrs.UPGRADE, + hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, + hdrs.SEC_WEBSOCKET_EXTENSIONS: 'permessage-deflate; ' + 'client_no_context_takeover', + } + with mock.patch('aiohttp.client.os') as m_os: + with mock.patch('aiohttp.client.ClientSession.get') as m_req: + m_os.urandom.return_value = key_data + m_req.return_value = helpers.create_future(loop) + m_req.return_value.set_result(resp) + + res = yield from aiohttp.ClientSession(loop=loop).ws_connect( + 'http://test.org', compress=15) + + assert res.compress == 15 + assert res.client_notakeover is True + + +@asyncio.coroutine +def test_ws_connect_deflate_client_wbits(loop, ws_key, key_data): + resp = mock.Mock() + resp.status = 101 + resp.headers = { + hdrs.UPGRADE: hdrs.WEBSOCKET, + hdrs.CONNECTION: hdrs.UPGRADE, + hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, + hdrs.SEC_WEBSOCKET_EXTENSIONS: 'permessage-deflate; ' + 'client_max_window_bits=10', + } + with mock.patch('aiohttp.client.os') as m_os: + with mock.patch('aiohttp.client.ClientSession.get') as m_req: + m_os.urandom.return_value = key_data + m_req.return_value = helpers.create_future(loop) + m_req.return_value.set_result(resp) + + res = yield from aiohttp.ClientSession(loop=loop).ws_connect( + 'http://test.org', compress=15) + + assert res.compress == 10 + assert res.client_notakeover is False + + +@asyncio.coroutine +def test_ws_connect_deflate_client_wbits_bad(loop, ws_key, key_data): + resp = mock.Mock() + resp.status = 101 + resp.headers = { + hdrs.UPGRADE: hdrs.WEBSOCKET, + hdrs.CONNECTION: hdrs.UPGRADE, + hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, + hdrs.SEC_WEBSOCKET_EXTENSIONS: 'permessage-deflate; ' + 'client_max_window_bits=6', + } + with mock.patch('aiohttp.client.os') as m_os: + with mock.patch('aiohttp.client.ClientSession.get') as m_req: + m_os.urandom.return_value = key_data + m_req.return_value = helpers.create_future(loop) + m_req.return_value.set_result(resp) + + with pytest.raises(client.WSServerHandshakeError): + yield from aiohttp.ClientSession(loop=loop).ws_connect( + 'http://test.org', compress=15) + + +@asyncio.coroutine +def test_ws_connect_deflate_server_ext_bad(loop, ws_key, key_data): + resp = mock.Mock() + resp.status = 101 + resp.headers = { + hdrs.UPGRADE: hdrs.WEBSOCKET, + hdrs.CONNECTION: hdrs.UPGRADE, + hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, + hdrs.SEC_WEBSOCKET_EXTENSIONS: 'permessage-deflate; bad', + } + with mock.patch('aiohttp.client.os') as m_os: + with mock.patch('aiohttp.client.ClientSession.get') as m_req: + m_os.urandom.return_value = key_data + m_req.return_value = helpers.create_future(loop) + m_req.return_value.set_result(resp) + + with pytest.raises(client.WSServerHandshakeError): + yield from aiohttp.ClientSession(loop=loop).ws_connect( + 'http://test.org', compress=15) diff --git a/tests/test_client_ws_functional.py b/tests/test_client_ws_functional.py index 95a9fc8754d..e7aa069fc2e 100644 --- a/tests/test_client_ws_functional.py +++ b/tests/test_client_ws_functional.py @@ -616,3 +616,81 @@ def handler(request): yield from resp.receive() assert ping_received + + +@asyncio.coroutine +def test_send_recv_compress(loop, test_client): + + @asyncio.coroutine + def handler(request): + ws = web.WebSocketResponse() + yield from ws.prepare(request) + + msg = yield from ws.receive_str() + yield from ws.send_str(msg+'/answer') + yield from ws.close() + return ws + + app = web.Application() + app.router.add_route('GET', '/', handler) + client = yield from test_client(app) + resp = yield from client.ws_connect('/', compress=15) + yield from resp.send_str('ask') + + assert resp.compress == 15 + + data = yield from resp.receive_str() + assert data == 'ask/answer' + + yield from resp.close() + assert resp.get_extra_info('socket') is None + + +@asyncio.coroutine +def test_send_recv_compress_wbits(loop, test_client): + + @asyncio.coroutine + def handler(request): + ws = web.WebSocketResponse() + yield from ws.prepare(request) + + msg = yield from ws.receive_str() + yield from ws.send_str(msg+'/answer') + yield from ws.close() + return ws + + app = web.Application() + app.router.add_route('GET', '/', handler) + client = yield from test_client(app) + resp = yield from client.ws_connect('/', compress=9) + yield from resp.send_str('ask') + + # Client indicates supports wbits 15 + # Server supports wbit 15 for decode + assert resp.compress == 15 + + data = yield from resp.receive_str() + assert data == 'ask/answer' + + yield from resp.close() + assert resp.get_extra_info('socket') is None + + +@asyncio.coroutine +def test_send_recv_compress_wbit_error(loop, test_client): + + @asyncio.coroutine + def handler(request): + ws = web.WebSocketResponse() + yield from ws.prepare(request) + + msg = yield from ws.receive_bytes() + yield from ws.send_bytes(msg+b'/answer') + yield from ws.close() + return ws + + app = web.Application() + app.router.add_route('GET', '/', handler) + client = yield from test_client(app) + with pytest.raises(ValueError): + yield from client.ws_connect('/', compress=1) diff --git a/tests/test_websocket_handshake.py b/tests/test_websocket_handshake.py index 15ec918cfe6..df45a7292ca 100644 --- a/tests/test_websocket_handshake.py +++ b/tests/test_websocket_handshake.py @@ -26,7 +26,8 @@ def message(): True, None, True, False, URL('/path')) -def gen_ws_headers(protocols=''): +def gen_ws_headers(protocols='', compress=0, extension_text='', + server_notakeover=False, client_notakeover=False): key = base64.b64encode(os.urandom(16)).decode() hdrs = [('Upgrade', 'websocket'), ('Connection', 'upgrade'), @@ -34,6 +35,17 @@ def gen_ws_headers(protocols=''): ('Sec-Websocket-Key', key)] if protocols: hdrs += [('Sec-Websocket-Protocol', protocols)] + if compress: + params = 'permessage-deflate' + if compress < 15: + params += '; server_max_window_bits=' + str(compress) + if server_notakeover: + params += '; server_no_context_takeover' + if client_notakeover: + params += '; client_no_context_takeover' + if extension_text: + params += '; ' + extension_text + hdrs += [('Sec-Websocket-Extensions', params)] return hdrs, key @@ -95,7 +107,7 @@ def test_handshake(message, transport): hdrs, sec_key = gen_ws_headers() message.headers.extend(hdrs) - status, headers, parser, writer, protocol = do_handshake( + status, headers, parser, writer, protocol, _ = do_handshake( message.method, message.headers, transport) assert status == 101 assert protocol is None @@ -111,7 +123,7 @@ def test_handshake_protocol(message, transport): proto = 'chat' message.headers.extend(gen_ws_headers(proto)[0]) - _, resp_headers, _, _, protocol = do_handshake( + _, resp_headers, _, _, protocol, _ = do_handshake( message.method, message.headers, transport, protocols=[proto]) @@ -129,7 +141,7 @@ def test_handshake_protocol_agreement(message, transport): server_protos = 'worse_proto,chat' message.headers.extend(gen_ws_headers(server_protos)[0]) - _, resp_headers, _, _, protocol = do_handshake( + _, resp_headers, _, _, protocol, _ = do_handshake( message.method, message.headers, transport, protocols=wanted_protos) @@ -142,10 +154,124 @@ def test_handshake_protocol_unsupported(log, message, transport): message.headers.extend(gen_ws_headers('test')[0]) with log('aiohttp.websocket') as ctx: - _, _, _, _, protocol = do_handshake( + _, _, _, _, protocol, _ = do_handshake( message.method, message.headers, transport, protocols=[proto]) assert protocol is None assert (ctx.records[-1].msg == 'Client protocols %r don’t overlap server-known ones %r') + + +def test_handshake_compress(message, transport): + hdrs, sec_key = gen_ws_headers(compress=15) + + message.headers.extend(hdrs) + status, headers, parser, writer, protocol, compress = do_handshake( + message.method, message.headers, transport) + + headers = dict(headers) + assert 'Sec-Websocket-Extensions' in headers + assert headers['Sec-Websocket-Extensions'] == 'permessage-deflate' + + assert compress == 15 + + +def test_handshake_compress_server_notakeover(message, transport): + hdrs, sec_key = gen_ws_headers(compress=15, server_notakeover=True) + + message.headers.extend(hdrs) + status, headers, parser, writer, protocol, compress = do_handshake( + message.method, message.headers, transport) + + headers = dict(headers) + assert 'Sec-Websocket-Extensions' in headers + assert headers['Sec-Websocket-Extensions'] == ( + 'permessage-deflate; server_no_context_takeover') + + assert compress == 15 + assert writer.notakeover is True + + +def test_handshake_compress_client_notakeover(message, transport): + hdrs, sec_key = gen_ws_headers(compress=15, client_notakeover=True) + + message.headers.extend(hdrs) + status, headers, parser, writer, protocol, compress = do_handshake( + message.method, message.headers, transport) + + headers = dict(headers) + assert 'Sec-Websocket-Extensions' in headers + assert headers['Sec-Websocket-Extensions'] == ( + 'permessage-deflate'), hdrs + + assert compress == 15 + + +def test_handshake_compress_wbits(message, transport): + hdrs, sec_key = gen_ws_headers(compress=9) + + message.headers.extend(hdrs) + status, headers, parser, writer, protocol, compress = do_handshake( + message.method, message.headers, transport) + + headers = dict(headers) + assert 'Sec-Websocket-Extensions' in headers + assert headers['Sec-Websocket-Extensions'] == ( + 'permessage-deflate; server_max_window_bits=9') + assert compress == 9 + + +def test_handshake_compress_wbits_error(message, transport): + hdrs, sec_key = gen_ws_headers(compress=6) + + message.headers.extend(hdrs) + + status, headers, parser, writer, protocol, compress = do_handshake( + message.method, message.headers, transport) + + headers = dict(headers) + assert 'Sec-Websocket-Extensions' not in headers + assert compress == 0 + + +def test_handshake_compress_bad_ext(message, transport): + hdrs, sec_key = gen_ws_headers(compress=15, extension_text='bad') + + message.headers.extend(hdrs) + + status, headers, parser, writer, protocol, compress = do_handshake( + message.method, message.headers, transport) + + headers = dict(headers) + assert 'Sec-Websocket-Extensions' not in headers + assert compress == 0 + + +def test_handshake_compress_multi_ext_bad(message, transport): + hdrs, sec_key = gen_ws_headers(compress=15, + extension_text='bad, permessage-deflate') + + message.headers.extend(hdrs) + + status, headers, parser, writer, protocol, compress = do_handshake( + message.method, message.headers, transport) + + headers = dict(headers) + assert 'Sec-Websocket-Extensions' in headers + assert headers['Sec-Websocket-Extensions'] == 'permessage-deflate' + + +def test_handshake_compress_multi_ext_wbits(message, transport): + hdrs, sec_key = gen_ws_headers(compress=6, + extension_text=', permessage-deflate') + + message.headers.extend(hdrs) + + status, headers, parser, writer, protocol, compress = do_handshake( + message.method, message.headers, transport) + + headers = dict(headers) + assert 'Sec-Websocket-Extensions' in headers + assert headers['Sec-Websocket-Extensions'] == 'permessage-deflate' + assert compress == 15 diff --git a/tests/test_websocket_parser.py b/tests/test_websocket_parser.py index 89b7a7c6ef8..5ef9bd6ad76 100644 --- a/tests/test_websocket_parser.py +++ b/tests/test_websocket_parser.py @@ -73,41 +73,41 @@ def parser(out): def test_parse_frame(parser): parser.parse_frame(struct.pack('!BB', 0b00000001, 0b00000001)) res = parser.parse_frame(b'1') - fin, opcode, payload = res[0] + fin, opcode, payload, compress = res[0] - assert (0, 1, b'1') == (fin, opcode, payload) + assert (0, 1, b'1', False) == (fin, opcode, payload, not not compress) def test_parse_frame_length0(parser): - fin, opcode, payload = parser.parse_frame( + fin, opcode, payload, compress = parser.parse_frame( struct.pack('!BB', 0b00000001, 0b00000000))[0] - assert (0, 1, b'') == (fin, opcode, payload) + assert (0, 1, b'', False) == (fin, opcode, payload, not not compress) def test_parse_frame_length2(parser): parser.parse_frame(struct.pack('!BB', 0b00000001, 126)) parser.parse_frame(struct.pack('!H', 4)) res = parser.parse_frame(b'1234') - fin, opcode, payload = res[0] + fin, opcode, payload, compress = res[0] - assert (0, 1, b'1234') == (fin, opcode, payload) + assert (0, 1, b'1234', False) == (fin, opcode, payload, not not compress) def test_parse_frame_length4(parser): parser.parse_frame(struct.pack('!BB', 0b00000001, 127)) parser.parse_frame(struct.pack('!Q', 4)) - fin, opcode, payload = parser.parse_frame(b'1234')[0] + fin, opcode, payload, compress = parser.parse_frame(b'1234')[0] - assert (0, 1, b'1234') == (fin, opcode, payload) + assert (0, 1, b'1234', False) == (fin, opcode, payload, not not compress) def test_parse_frame_mask(parser): parser.parse_frame(struct.pack('!BB', 0b00000001, 0b10000001)) parser.parse_frame(b'0001') - fin, opcode, payload = parser.parse_frame(b'1')[0] + fin, opcode, payload, compress = parser.parse_frame(b'1')[0] - assert (0, 1, b'\x01') == (fin, opcode, payload) + assert (0, 1, b'\x01', False) == (fin, opcode, payload, not not compress) def test_parse_frame_header_reversed_bits(out, parser): @@ -136,7 +136,7 @@ def test_parse_frame_header_payload_size(out, parser): def test_ping_frame(out, parser): parser.parse_frame = mock.Mock() - parser.parse_frame.return_value = [(1, WSMsgType.PING, b'data')] + parser.parse_frame.return_value = [(1, WSMsgType.PING, b'data', False)] parser.feed_data(b'') res = out._buffer[0] @@ -145,7 +145,7 @@ def test_ping_frame(out, parser): def test_pong_frame(out, parser): parser.parse_frame = mock.Mock() - parser.parse_frame.return_value = [(1, WSMsgType.PONG, b'data')] + parser.parse_frame.return_value = [(1, WSMsgType.PONG, b'data', False)] parser.feed_data(b'') res = out._buffer[0] @@ -154,7 +154,7 @@ def test_pong_frame(out, parser): def test_close_frame(out, parser): parser.parse_frame = mock.Mock() - parser.parse_frame.return_value = [(1, WSMsgType.CLOSE, b'')] + parser.parse_frame.return_value = [(1, WSMsgType.CLOSE, b'', False)] parser.feed_data(b'') res = out._buffer[0] @@ -163,7 +163,7 @@ def test_close_frame(out, parser): def test_close_frame_info(out, parser): parser.parse_frame = mock.Mock() - parser.parse_frame.return_value = [(1, WSMsgType.CLOSE, b'0112345')] + parser.parse_frame.return_value = [(1, WSMsgType.CLOSE, b'0112345', False)] parser.feed_data(b'') res = out._buffer[0] @@ -172,7 +172,7 @@ def test_close_frame_info(out, parser): def test_close_frame_invalid(out, parser): parser.parse_frame = mock.Mock() - parser.parse_frame.return_value = [(1, WSMsgType.CLOSE, b'1')] + parser.parse_frame.return_value = [(1, WSMsgType.CLOSE, b'1', False)] parser.feed_data(b'') assert isinstance(out.exception(), WebSocketError) @@ -200,7 +200,7 @@ def test_close_frame_unicode_err(parser): def test_unknown_frame(out, parser): parser.parse_frame = mock.Mock() - parser.parse_frame.return_value = [(1, WSMsgType.CONTINUATION, b'')] + parser.parse_frame.return_value = [(1, WSMsgType.CONTINUATION, b'', False)] with pytest.raises(WebSocketError): parser.feed_data(b'') @@ -225,7 +225,7 @@ def test_simple_text_unicode_err(parser): def test_simple_binary(out, parser): parser.parse_frame = mock.Mock() - parser.parse_frame.return_value = [(1, WSMsgType.BINARY, b'binary')] + parser.parse_frame.return_value = [(1, WSMsgType.BINARY, b'binary', False)] parser.feed_data(b'') res = out._buffer[0] @@ -255,9 +255,9 @@ def test_continuation(out, parser): def test_continuation_with_ping(out, parser): parser.parse_frame = mock.Mock() parser.parse_frame.return_value = [ - (0, WSMsgType.TEXT, b'line1'), - (0, WSMsgType.PING, b''), - (1, WSMsgType.CONTINUATION, b'line2'), + (0, WSMsgType.TEXT, b'line1', False), + (0, WSMsgType.PING, b'', False), + (1, WSMsgType.CONTINUATION, b'line2', False), ] data1 = build_frame(b'line1', WSMsgType.TEXT, is_fin=False) @@ -278,8 +278,8 @@ def test_continuation_with_ping(out, parser): def test_continuation_err(out, parser): parser.parse_frame = mock.Mock() parser.parse_frame.return_value = [ - (0, WSMsgType.TEXT, b'line1'), - (1, WSMsgType.TEXT, b'line2')] + (0, WSMsgType.TEXT, b'line1', False), + (1, WSMsgType.TEXT, b'line2', False)] with pytest.raises(WebSocketError): parser._feed_data(b'') @@ -288,10 +288,10 @@ def test_continuation_err(out, parser): def test_continuation_with_close(out, parser): parser.parse_frame = mock.Mock() parser.parse_frame.return_value = [ - (0, WSMsgType.TEXT, b'line1'), + (0, WSMsgType.TEXT, b'line1', False), (0, WSMsgType.CLOSE, - build_close_frame(1002, b'test', noheader=True)), - (1, WSMsgType.CONTINUATION, b'line2'), + build_close_frame(1002, b'test', noheader=True), False), + (1, WSMsgType.CONTINUATION, b'line2', False), ] parser.feed_data(b'') @@ -304,10 +304,10 @@ def test_continuation_with_close(out, parser): def test_continuation_with_close_unicode_err(out, parser): parser.parse_frame = mock.Mock() parser.parse_frame.return_value = [ - (0, WSMsgType.TEXT, b'line1'), + (0, WSMsgType.TEXT, b'line1', False), (0, WSMsgType.CLOSE, - build_close_frame(1000, b'\xf4\x90\x80\x80', noheader=True)), - (1, WSMsgType.CONTINUATION, b'line2')] + build_close_frame(1000, b'\xf4\x90\x80\x80', noheader=True), False), + (1, WSMsgType.CONTINUATION, b'line2', False)] with pytest.raises(WebSocketError) as ctx: parser._feed_data(b'') @@ -318,10 +318,10 @@ def test_continuation_with_close_unicode_err(out, parser): def test_continuation_with_close_bad_code(out, parser): parser.parse_frame = mock.Mock() parser.parse_frame.return_value = [ - (0, WSMsgType.TEXT, b'line1'), + (0, WSMsgType.TEXT, b'line1', False), (0, WSMsgType.CLOSE, - build_close_frame(1, b'test', noheader=True)), - (1, WSMsgType.CONTINUATION, b'line2')] + build_close_frame(1, b'test', noheader=True), False), + (1, WSMsgType.CONTINUATION, b'line2', False)] with pytest.raises(WebSocketError) as ctx: parser._feed_data(b'') @@ -332,9 +332,9 @@ def test_continuation_with_close_bad_code(out, parser): def test_continuation_with_close_bad_payload(out, parser): parser.parse_frame = mock.Mock() parser.parse_frame.return_value = [ - (0, WSMsgType.TEXT, b'line1'), - (0, WSMsgType.CLOSE, b'1'), - (1, WSMsgType.CONTINUATION, b'line2')] + (0, WSMsgType.TEXT, b'line1', False), + (0, WSMsgType.CLOSE, b'1', False), + (1, WSMsgType.CONTINUATION, b'line2', False)] with pytest.raises(WebSocketError) as ctx: parser._feed_data(b'') @@ -345,9 +345,9 @@ def test_continuation_with_close_bad_payload(out, parser): def test_continuation_with_close_empty(out, parser): parser.parse_frame = mock.Mock() parser.parse_frame.return_value = [ - (0, WSMsgType.TEXT, b'line1'), - (0, WSMsgType.CLOSE, b''), - (1, WSMsgType.CONTINUATION, b'line2'), + (0, WSMsgType.TEXT, b'line1', False), + (0, WSMsgType.CLOSE, b'', False), + (1, WSMsgType.CONTINUATION, b'line2', False), ] parser.feed_data(b'') @@ -403,3 +403,57 @@ def test_msgtype_aliases(): assert aiohttp.WSMsgType.CLOSE == aiohttp.WSMsgType.close assert aiohttp.WSMsgType.CLOSED == aiohttp.WSMsgType.closed assert aiohttp.WSMsgType.ERROR == aiohttp.WSMsgType.error + + +def test_parse_compress_frame_single(parser): + parser.parse_frame(struct.pack('!BB', 0b11000001, 0b00000001)) + res = parser.parse_frame(b'1') + fin, opcode, payload, compress = res[0] + + assert (1, 1, b'1', True) == (fin, opcode, payload, not not compress) + + +def test_parse_compress_frame_multi(parser): + parser.parse_frame(struct.pack('!BB', 0b01000001, 126)) + parser.parse_frame(struct.pack('!H', 4)) + res = parser.parse_frame(b'1234') + fin, opcode, payload, compress = res[0] + assert (0, 1, b'1234', True) == (fin, opcode, payload, not not compress) + + parser.parse_frame(struct.pack('!BB', 0b10000001, 126)) + parser.parse_frame(struct.pack('!H', 4)) + res = parser.parse_frame(b'1234') + fin, opcode, payload, compress = res[0] + assert (1, 1, b'1234', True) == (fin, opcode, payload, not not compress) + + parser.parse_frame(struct.pack('!BB', 0b10000001, 126)) + parser.parse_frame(struct.pack('!H', 4)) + res = parser.parse_frame(b'1234') + fin, opcode, payload, compress = res[0] + assert (1, 1, b'1234', False) == (fin, opcode, payload, not not compress) + + +def test_parse_compress_error_frame(parser): + parser.parse_frame(struct.pack('!BB', 0b01000001, 0b00000001)) + parser.parse_frame(b'1') + + with pytest.raises(WebSocketError) as ctx: + parser.parse_frame(struct.pack('!BB', 0b11000001, 0b00000001)) + parser.parse_frame(b'1') + + assert ctx.value.code == WSCloseCode.PROTOCOL_ERROR + + +@pytest.fixture() +def parser_no_compress(out): + return WebSocketReader(out, compress=False) + + +def test_parse_no_compress_frame_single(parser_no_compress): + + with pytest.raises(WebSocketError) as ctx: + parser_no_compress.parse_frame(struct.pack( + '!BB', 0b11000001, 0b00000001)) + parser_no_compress.parse_frame(b'1') + + assert ctx.value.code == WSCloseCode.PROTOCOL_ERROR diff --git a/tests/test_websocket_writer.py b/tests/test_websocket_writer.py index 59bc9734fbf..342586b4659 100644 --- a/tests/test_websocket_writer.py +++ b/tests/test_websocket_writer.py @@ -66,3 +66,19 @@ def test_send_text_masked(stream, writer): random=random.Random(123)) writer.send(b'text') stream.transport.write.assert_called_with(b'\x81\x84\rg\xb3fy\x02\xcb\x12') + + +def test_send_compress_text(stream, writer): + writer = WebSocketWriter(stream, compress=15) + writer.send(b'text') + stream.transport.write.assert_called_with(b'\xc1\x06*I\xad(\x01\x00') + writer.send(b'text') + stream.transport.write.assert_called_with(b'\xc1\x05*\x01b\x00\x00') + + +def test_send_compress_text_notakeover(stream, writer): + writer = WebSocketWriter(stream, compress=15, notakeover=True) + writer.send(b'text') + stream.transport.write.assert_called_with(b'\xc1\x06*I\xad(\x01\x00') + writer.send(b'text') + stream.transport.write.assert_called_with(b'\xc1\x06*I\xad(\x01\x00')