diff --git a/aiohttp/client.py b/aiohttp/client.py index 0268d0f43a8..1f0d6ec2ab8 100644 --- a/aiohttp/client.py +++ b/aiohttp/client.py @@ -413,6 +413,9 @@ def _ws_connect(self, url, *, extstr = ws_ext_gen(compress=compress) if extstr: headers[hdrs.SEC_WEBSOCKET_EXTENSIONS] = extstr + else: + raise ValueError('Compress level must between 8 and 15') + # send request resp = yield from self.get(url, headers=headers, @@ -474,19 +477,19 @@ def _ws_connect(self, url, *, # websocket compress notakeover = False if compress: - compress, notakeover = ws_ext_parse( + compress, _, notakeover = ws_ext_parse( resp.headers[hdrs.SEC_WEBSOCKET_EXTENSIONS] ) if compress == 0: pass - elif compress < 0: + elif compress == -1: raise WSServerHandshakeError( resp.request_info, resp.history, message='Invalid deflate extension', code=resp.status, headers=resp.headers) - elif compress < 8 or compress > 15: + elif compress == -2: raise WSServerHandshakeError( resp.request_info, resp.history, diff --git a/aiohttp/http_websocket.py b/aiohttp/http_websocket.py index f5367d30b46..eb0ae23fea2 100644 --- a/aiohttp/http_websocket.py +++ b/aiohttp/http_websocket.py @@ -152,35 +152,37 @@ def _websocket_mask_python(mask, data): def extensions_parse(extstr): if not extstr: - return 0, False + return 0, False, False - extensions = [[s.strip() for s in s1.split(';')] - for s1 in extstr.split(',')] + extensions = [s.strip() for s in extstr.split(',')] compress = 0 - compress_notakeover = False + server_notakeover = False + client_notakeover = False for ext in extensions: - if ext[0] == 'permessage-deflate': + if ext.startswith('permessage-deflate'): compress = 15 - for param in ext[1:]: + for param in [s.strip() for s in ext.split(';')][1:]: if param.startswith('server_max_window_bits'): compress = int(param.split('=')[1]) elif param == 'server_no_context_takeover': - compress_notakeover = True - # Ignore Client Takeover - elif param not in ('client_no_context_takeover', - 'client_max_window_bits'): - return -1, False - if compress > 15: - raise HttpBadRequest( - message='Handshake error: PMCE window > 15') from None + server_notakeover = True + elif param == 'client_no_context_takeover': + client_notakeover = True + # Ignore Client window bits + elif param != 'client_max_window_bits': + return -1, False, False + # compress wbit 8 does not support in zlib + if compress > 15 or compress < 9: + return -2, False, False break - return compress, compress_notakeover + return compress, server_notakeover, client_notakeover def extensions_gen(compress=0, server_notakeover=False, client_notakeover=False): - if compress < 8 or compress > 15: + # compress wbit 8 does not support in zlib + if compress < 9 or compress > 15: return False enabledext = 'permessage-deflate' if compress < 15: @@ -648,18 +650,15 @@ def do_handshake(method, headers, stream, (hdrs.SEC_WEBSOCKET_ACCEPT, base64.b64encode( hashlib.sha1(key.encode() + WS_KEY).digest()).decode())] - compress = 0 - compress_notakeover = False - extensions = headers.get(hdrs.SEC_WEBSOCKET_EXTENSIONS) - compress, compress_notakeover = extensions_parse(extensions) + compress, compress_notakeover, _ = extensions_parse(extensions) if compress: - if compress < 0: + if compress == -1: raise HttpBadRequest( message='Handshake error: PMCE bad extensions') from None - if compress > 15: + if compress == -2: raise HttpBadRequest( - message='Handshake error: PMCE window > 15') from None + message='Handshake error: PMCE window not in range') from None enabledext = extensions_gen(compress=compress, server_notakeover=compress_notakeover) diff --git a/tests/test_client_ws_functional.py b/tests/test_client_ws_functional.py index 82d143345c0..6f813067ace 100644 --- a/tests/test_client_ws_functional.py +++ b/tests/test_client_ws_functional.py @@ -644,3 +644,51 @@ def handler(request): yield from resp.close() assert resp.get_extra_info('socket') is None + + +@asyncio.coroutine +def test_send_recv_compress(loop, test_client): + + @asyncio.coroutine + def handler(request): + ws = web.WebSocketResponse() + yield from ws.prepare(request) + + msg = yield from ws.receive_str() + yield from ws.send_str(msg+'/answer') + yield from ws.close() + return ws + + app = web.Application() + app.router.add_route('GET', '/', handler) + client = yield from test_client(app) + resp = yield from client.ws_connect('/', compress=9) + yield from resp.send_str('ask') + + assert resp.compress == 9 + + data = yield from resp.receive_str() + assert data == 'ask/answer' + + yield from resp.close() + assert resp.get_extra_info('socket') is None + + +@asyncio.coroutine +def test_send_recv_compress_error(loop, test_client): + + @asyncio.coroutine + def handler(request): + ws = web.WebSocketResponse() + yield from ws.prepare(request) + + msg = yield from ws.receive_bytes() + yield from ws.send_bytes(msg+b'/answer') + yield from ws.close() + return ws + + app = web.Application() + app.router.add_route('GET', '/', handler) + client = yield from test_client(app) + with pytest.raises(ValueError): + resp = yield from client.ws_connect('/', compress=1) diff --git a/tests/test_websocket_handshake.py b/tests/test_websocket_handshake.py index 6e5ea311629..84eb142e263 100644 --- a/tests/test_websocket_handshake.py +++ b/tests/test_websocket_handshake.py @@ -26,7 +26,8 @@ def message(): True, None, True, False, URL('/path')) -def gen_ws_headers(protocols='', compress=0, compress_notakeover=False): +def gen_ws_headers(protocols='', compress=0, + server_notakeover=False, client_notakeover=False): key = base64.b64encode(os.urandom(16)).decode() hdrs = [('Upgrade', 'websocket'), ('Connection', 'upgrade'), @@ -38,8 +39,10 @@ def gen_ws_headers(protocols='', compress=0, compress_notakeover=False): params = 'permessage-deflate' if compress < 15: params += '; server_max_window_bits=' + str(compress) - if compress_notakeover: + if server_notakeover: params += '; server_no_context_takeover' + if client_notakeover: + params += '; client_no_context_takeover' hdrs += [('Sec-Websocket-Extensions', params)] return hdrs, key @@ -172,8 +175,8 @@ def test_handshake_compress(message, transport): assert writer.compress == 15 -def test_handshake_compress_notakeover(message, transport): - hdrs, sec_key = gen_ws_headers(compress=15, compress_notakeover=True) +def test_handshake_compress_server_notakeover(message, transport): + hdrs, sec_key = gen_ws_headers(compress=15, server_notakeover=True) message.headers.extend(hdrs) status, headers, parser, writer, protocol = do_handshake( @@ -186,3 +189,32 @@ def test_handshake_compress_notakeover(message, transport): assert writer.compress == 15 assert writer.notakeover is True + + +def test_handshake_compress_client_notakeover(message, transport): + hdrs, sec_key = gen_ws_headers(compress=15, client_notakeover=True) + + message.headers.extend(hdrs) + status, headers, parser, writer, protocol = do_handshake( + message.method, message.headers, transport) + + headers = dict(headers) + assert 'Sec-Websocket-Extensions' in headers + assert headers['Sec-Websocket-Extensions'] == ( + 'permessage-deflate'), hdrs + + assert writer.compress == 15 + assert writer.notakeover is True + + +def test_handshake_compress_level(message, transport): + hdrs, sec_key = gen_ws_headers(compress=9) + + message.headers.extend(hdrs) + status, headers, parser, writer, protocol = do_handshake( + message.method, message.headers, transport) + + headers = dict(headers) + assert 'Sec-Websocket-Extensions' in headers + assert headers['Sec-Websocket-Extensions'] == ( + 'permessage-deflate; ' + 'server_max_window_bits=9') diff --git a/tests/test_websocket_parser.py b/tests/test_websocket_parser.py index 6ec5550c045..0daca60021c 100644 --- a/tests/test_websocket_parser.py +++ b/tests/test_websocket_parser.py @@ -431,3 +431,14 @@ def test_parse_compress_frame_multi(parser): res = parser.parse_frame(b'1234') fin, opcode, payload, compress = res[0] assert (1, 1, b'1234', False) == (fin, opcode, payload, not not compress) + + +def test_parse_compress_error_frame(parser): + parser.parse_frame(struct.pack('!BB', 0b01000001, 0b00000001)) + parser.parse_frame(b'1') + + with pytest.raises(WebSocketError) as ctx: + parser.parse_frame(struct.pack('!BB', 0b11000001, 0b00000001)) + parser.parse_frame(b'1') + + assert ctx.value.code == WSCloseCode.PROTOCOL_ERROR