Skip to content

Commit

Permalink
Add more tests
Browse files Browse the repository at this point in the history
Client deflate support should now more complete
  • Loading branch information
fanthos committed Sep 17, 2017
1 parent 3c6029e commit c55c5c0
Show file tree
Hide file tree
Showing 5 changed files with 123 additions and 30 deletions.
9 changes: 6 additions & 3 deletions aiohttp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,9 @@ def _ws_connect(self, url, *,
extstr = ws_ext_gen(compress=compress)
if extstr:
headers[hdrs.SEC_WEBSOCKET_EXTENSIONS] = extstr
else:
raise ValueError('Compress level must between 8 and 15')


# send request
resp = yield from self.get(url, headers=headers,
Expand Down Expand Up @@ -474,19 +477,19 @@ def _ws_connect(self, url, *,
# websocket compress
notakeover = False
if compress:
compress, notakeover = ws_ext_parse(
compress, _, notakeover = ws_ext_parse(
resp.headers[hdrs.SEC_WEBSOCKET_EXTENSIONS]
)
if compress == 0:
pass
elif compress < 0:
elif compress == -1:
raise WSServerHandshakeError(
resp.request_info,
resp.history,
message='Invalid deflate extension',
code=resp.status,
headers=resp.headers)
elif compress < 8 or compress > 15:
elif compress == -2:
raise WSServerHandshakeError(
resp.request_info,
resp.history,
Expand Down
45 changes: 22 additions & 23 deletions aiohttp/http_websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,35 +152,37 @@ def _websocket_mask_python(mask, data):

def extensions_parse(extstr):
if not extstr:
return 0, False
return 0, False, False

extensions = [[s.strip() for s in s1.split(';')]
for s1 in extstr.split(',')]
extensions = [s.strip() for s in extstr.split(',')]
compress = 0
compress_notakeover = False
server_notakeover = False
client_notakeover = False
for ext in extensions:
if ext[0] == 'permessage-deflate':
if ext.startswith('permessage-deflate'):
compress = 15
for param in ext[1:]:
for param in [s.strip() for s in ext.split(';')][1:]:
if param.startswith('server_max_window_bits'):
compress = int(param.split('=')[1])
elif param == 'server_no_context_takeover':
compress_notakeover = True
# Ignore Client Takeover
elif param not in ('client_no_context_takeover',
'client_max_window_bits'):
return -1, False
if compress > 15:
raise HttpBadRequest(
message='Handshake error: PMCE window > 15') from None
server_notakeover = True
elif param == 'client_no_context_takeover':
client_notakeover = True
# Ignore Client window bits
elif param != 'client_max_window_bits':
return -1, False, False
# compress wbit 8 does not support in zlib
if compress > 15 or compress < 9:
return -2, False, False
break

return compress, compress_notakeover
return compress, server_notakeover, client_notakeover


def extensions_gen(compress=0, server_notakeover=False,
client_notakeover=False):
if compress < 8 or compress > 15:
# compress wbit 8 does not support in zlib
if compress < 9 or compress > 15:
return False
enabledext = 'permessage-deflate'
if compress < 15:
Expand Down Expand Up @@ -648,18 +650,15 @@ def do_handshake(method, headers, stream,
(hdrs.SEC_WEBSOCKET_ACCEPT, base64.b64encode(
hashlib.sha1(key.encode() + WS_KEY).digest()).decode())]

compress = 0
compress_notakeover = False

extensions = headers.get(hdrs.SEC_WEBSOCKET_EXTENSIONS)
compress, compress_notakeover = extensions_parse(extensions)
compress, compress_notakeover, _ = extensions_parse(extensions)
if compress:
if compress < 0:
if compress == -1:
raise HttpBadRequest(
message='Handshake error: PMCE bad extensions') from None
if compress > 15:
if compress == -2:
raise HttpBadRequest(
message='Handshake error: PMCE window > 15') from None
message='Handshake error: PMCE window not in range') from None

enabledext = extensions_gen(compress=compress,
server_notakeover=compress_notakeover)
Expand Down
48 changes: 48 additions & 0 deletions tests/test_client_ws_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,3 +644,51 @@ def handler(request):

yield from resp.close()
assert resp.get_extra_info('socket') is None


@asyncio.coroutine
def test_send_recv_compress(loop, test_client):

@asyncio.coroutine
def handler(request):
ws = web.WebSocketResponse()
yield from ws.prepare(request)

msg = yield from ws.receive_str()
yield from ws.send_str(msg+'/answer')
yield from ws.close()
return ws

app = web.Application()
app.router.add_route('GET', '/', handler)
client = yield from test_client(app)
resp = yield from client.ws_connect('/', compress=9)
yield from resp.send_str('ask')

assert resp.compress == 9

data = yield from resp.receive_str()
assert data == 'ask/answer'

yield from resp.close()
assert resp.get_extra_info('socket') is None


@asyncio.coroutine
def test_send_recv_compress_error(loop, test_client):

@asyncio.coroutine
def handler(request):
ws = web.WebSocketResponse()
yield from ws.prepare(request)

msg = yield from ws.receive_bytes()
yield from ws.send_bytes(msg+b'/answer')
yield from ws.close()
return ws

app = web.Application()
app.router.add_route('GET', '/', handler)
client = yield from test_client(app)
with pytest.raises(ValueError):
resp = yield from client.ws_connect('/', compress=1)
40 changes: 36 additions & 4 deletions tests/test_websocket_handshake.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ def message():
True, None, True, False, URL('/path'))


def gen_ws_headers(protocols='', compress=0, compress_notakeover=False):
def gen_ws_headers(protocols='', compress=0,
server_notakeover=False, client_notakeover=False):
key = base64.b64encode(os.urandom(16)).decode()
hdrs = [('Upgrade', 'websocket'),
('Connection', 'upgrade'),
Expand All @@ -38,8 +39,10 @@ def gen_ws_headers(protocols='', compress=0, compress_notakeover=False):
params = 'permessage-deflate'
if compress < 15:
params += '; server_max_window_bits=' + str(compress)
if compress_notakeover:
if server_notakeover:
params += '; server_no_context_takeover'
if client_notakeover:
params += '; client_no_context_takeover'
hdrs += [('Sec-Websocket-Extensions', params)]
return hdrs, key

Expand Down Expand Up @@ -172,8 +175,8 @@ def test_handshake_compress(message, transport):
assert writer.compress == 15


def test_handshake_compress_notakeover(message, transport):
hdrs, sec_key = gen_ws_headers(compress=15, compress_notakeover=True)
def test_handshake_compress_server_notakeover(message, transport):
hdrs, sec_key = gen_ws_headers(compress=15, server_notakeover=True)

message.headers.extend(hdrs)
status, headers, parser, writer, protocol = do_handshake(
Expand All @@ -186,3 +189,32 @@ def test_handshake_compress_notakeover(message, transport):

assert writer.compress == 15
assert writer.notakeover is True


def test_handshake_compress_client_notakeover(message, transport):
hdrs, sec_key = gen_ws_headers(compress=15, client_notakeover=True)

message.headers.extend(hdrs)
status, headers, parser, writer, protocol = do_handshake(
message.method, message.headers, transport)

headers = dict(headers)
assert 'Sec-Websocket-Extensions' in headers
assert headers['Sec-Websocket-Extensions'] == (
'permessage-deflate'), hdrs

assert writer.compress == 15
assert writer.notakeover is True


def test_handshake_compress_level(message, transport):
hdrs, sec_key = gen_ws_headers(compress=9)

message.headers.extend(hdrs)
status, headers, parser, writer, protocol = do_handshake(
message.method, message.headers, transport)

headers = dict(headers)
assert 'Sec-Websocket-Extensions' in headers
assert headers['Sec-Websocket-Extensions'] == (
'permessage-deflate; ' + 'server_max_window_bits=9')
11 changes: 11 additions & 0 deletions tests/test_websocket_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,3 +431,14 @@ def test_parse_compress_frame_multi(parser):
res = parser.parse_frame(b'1234')
fin, opcode, payload, compress = res[0]
assert (1, 1, b'1234', False) == (fin, opcode, payload, not not compress)


def test_parse_compress_error_frame(parser):
parser.parse_frame(struct.pack('!BB', 0b01000001, 0b00000001))
parser.parse_frame(b'1')

with pytest.raises(WebSocketError) as ctx:
parser.parse_frame(struct.pack('!BB', 0b11000001, 0b00000001))
parser.parse_frame(b'1')

assert ctx.value.code == WSCloseCode.PROTOCOL_ERROR

0 comments on commit c55c5c0

Please sign in to comment.