Skip to content

Commit

Permalink
Restore async concurrency safety to websocket compressor
Browse files Browse the repository at this point in the history
  • Loading branch information
bdraco committed Nov 22, 2023
1 parent ccf74bb commit 6574aa6
Show file tree
Hide file tree
Showing 4 changed files with 130 additions and 36 deletions.
1 change: 1 addition & 0 deletions CHANGES/7865.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Restore async concurrency safety to websocket compressor
22 changes: 14 additions & 8 deletions aiohttp/compression_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@ def __init__(
self._executor = executor
self._max_sync_chunk_size = max_sync_chunk_size

def should_run_in_executor(self, data: bytes) -> bool:
return (
self._max_sync_chunk_size is not None
and len(data) > self._max_sync_chunk_size
)


class ZLibCompressor(ZlibBaseHandler):
def __init__(
Expand All @@ -48,7 +54,7 @@ def __init__(
strategy: int = zlib.Z_DEFAULT_STRATEGY,
executor: Optional[Executor] = None,
max_sync_chunk_size: Optional[int] = MAX_SYNC_CHUNK_SIZE,
):
) -> None:
super().__init__(
mode=encoding_to_mode(encoding, suppress_deflate_header)
if wbits is None
Expand All @@ -66,14 +72,14 @@ def __init__(
def compress_sync(self, data: bytes) -> bytes:
return self._compressor.compress(data)

async def compress_executor(self, data: bytes) -> bytes:
return await asyncio.get_event_loop().run_in_executor(
self._executor, self.compress_sync, data
)

async def compress(self, data: bytes) -> bytes:
if (
self._max_sync_chunk_size is not None
and len(data) > self._max_sync_chunk_size
):
return await asyncio.get_event_loop().run_in_executor(
self._executor, self.compress_sync, data
)
if self.should_run_in_executor(data):
return await self.compress_executor(data)
return self.compress_sync(data)

def flush(self, mode: int = zlib.Z_FINISH) -> bytes:
Expand Down
82 changes: 55 additions & 27 deletions aiohttp/http_websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,7 @@ def __init__(
self._limit = limit
self._output_size = 0
self._compressobj: Any = None # actually compressobj
self._compress_lock: Optional[asyncio.Lock] = None

async def _send_frame(
self, message: bytes, opcode: int, compress: Optional[int] = None
Expand All @@ -626,28 +627,57 @@ async def _send_frame(
if (compress or self.compress) and opcode < 8:
if compress:
# Do not set self._compress if compressing is for this frame
compressobj = ZLibCompressor(
level=zlib.Z_BEST_SPEED,
wbits=-compress,
max_sync_chunk_size=WEBSOCKET_MAX_SYNC_CHUNK_SIZE,
)
compressobj = self._make_compress_obj(compress)
else: # self.compress
if not self._compressobj:
self._compressobj = ZLibCompressor(
level=zlib.Z_BEST_SPEED,
wbits=-self.compress,
max_sync_chunk_size=WEBSOCKET_MAX_SYNC_CHUNK_SIZE,
)
self._compressobj = self._make_compress_obj(self.compress)
compressobj = self._compressobj

message = await compressobj.compress(message)
message += compressobj.flush(
zlib.Z_FULL_FLUSH if self.notakeover else zlib.Z_SYNC_FLUSH
)
if message.endswith(_WS_DEFLATE_TRAILING):
message = message[:-4]
rsv = rsv | 0x40
if not compressobj.should_run_in_executor(message):
# Compress message if it is smaller than max sync chunk size
# without awaiting to ensure that the message written before
# the next message can be compressed.
message = compressobj.compress_sync(message)
self._write_compressed_message(message, opcode, compressobj, rsv)
else:
# Since we are compressing in an executor, and the await returns
# control to the event loop we need to hold a lock to ensure that
# the compressed message is written before the next message is
# compressed to ensure that the messages are written in the correct
# order and context takeover is not violated.
if not self._compress_lock:
self._compress_lock = asyncio.Lock()
async with self._compress_lock:
message = await compressobj.compress(message)
self._write_compressed_message(message, opcode, compressobj, rsv)

else:
self._write_message(message, opcode, rsv)

if self._output_size > self._limit:
self._output_size = 0
await self.protocol._drain_helper()

def _make_compress_obj(self, compress: int) -> ZLibCompressor:
return ZLibCompressor(
level=zlib.Z_BEST_SPEED,
wbits=-compress,
max_sync_chunk_size=WEBSOCKET_MAX_SYNC_CHUNK_SIZE,
)

def _write_compressed_message(
self, message: bytes, opcode: int, compressobj: ZLibCompressor, rsv: int
) -> None:
message += compressobj.flush(
zlib.Z_FULL_FLUSH if self.notakeover else zlib.Z_SYNC_FLUSH
)
if message.endswith(_WS_DEFLATE_TRAILING):
message = message[:-4]
rsv = rsv | 0x40
self._write_message(message, opcode, rsv)

def _write_message(self, message: bytes, opcode: int, rsv: int) -> None:
"""Write a complete websocket frame to the transport."""
msg_length = len(message)

use_mask = self.use_mask
Expand All @@ -662,25 +692,23 @@ async def _send_frame(
header = PACK_LEN2(0x80 | rsv | opcode, 126 | mask_bit, msg_length)
else:
header = PACK_LEN3(0x80 | rsv | opcode, 127 | mask_bit, msg_length)

if use_mask:
mask = self.randrange(0, 0xFFFFFFFF)
mask = mask.to_bytes(4, "big")
message = bytearray(message)
_websocket_mask(mask, message)
self._write(header + mask + message)
self._output_size += len(header) + len(mask) + len(message)
else:
if len(message) > MSG_SIZE:
self._write(header)
self._write(message)
else:
self._write(header + message)
return

self._output_size += len(header) + len(message)
if len(message) > MSG_SIZE:
self._write(header)
self._write(message)
else:
self._write(header + message)

if self._output_size > self._limit:
self._output_size = 0
await self.protocol._drain_helper()
self._output_size += len(header) + len(message)

def _write(self, data: bytes) -> None:
if self.transport.is_closing():
Expand Down
61 changes: 60 additions & 1 deletion tests/test_websocket_writer.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
# type: ignore
import asyncio
import random
from typing import Any
from unittest import mock

import pytest

from aiohttp.http import WebSocketWriter
from aiohttp import DataQueue, WSMessage
from aiohttp.http import WebSocketReader, WebSocketWriter
from aiohttp.test_utils import make_mocked_coro


Expand Down Expand Up @@ -106,3 +108,60 @@ async def test_send_compress_text_per_message(protocol: Any, transport: Any) ->
writer.transport.write.assert_called_with(b"\x81\x04text")
await writer.send(b"text", compress=15)
writer.transport.write.assert_called_with(b"\xc1\x06*I\xad(\x01\x00")


@mock.patch("aiohttp.http_websocket.WEBSOCKET_MAX_SYNC_CHUNK_SIZE", 16)
async def test_concurrent_messages_with_executor(protocol: Any, transport: Any) -> None:
"""Ensure messages are compressed correctly when there are multiple concurrent writers.
This test generates messages large enough that they will
be compressed in the executor.
"""
writer = WebSocketWriter(protocol, transport, compress=15)
queue: DataQueue[WSMessage] = DataQueue(asyncio.get_running_loop())
reader = WebSocketReader(queue, 50000)
writers = []
payloads = []
msg_length = 16 + 1
for count in range(1, 64 + 1):
payload = bytes((count,)) * msg_length
payloads.append(payload)
writers.append(writer.send(payload, binary=True))
await asyncio.gather(*writers)
for call in writer.transport.write.call_args_list:
call_bytes = call[0][0]
result, _ = reader.feed_data(call_bytes)
assert result is False
msg = await queue.read()
bytes_data: bytes = msg.data
assert len(bytes_data) == msg_length
assert bytes_data == bytes_data[0:1] * msg_length


async def test_concurrent_messages_without_executor(
protocol: Any, transport: Any
) -> None:
"""Ensure messages are compressed correctly when there are multiple concurrent writers.
This test generates messages that are small enough that
they will not be compressed in the executor.
"""
writer = WebSocketWriter(protocol, transport, compress=15)
queue: DataQueue[WSMessage] = DataQueue(asyncio.get_running_loop())
reader = WebSocketReader(queue, 50000)
writers = []
payloads = []
msg_length = 16 + 1
for count in range(1, 64 + 1):
payload = bytes((count,)) * msg_length
payloads.append(payload)
writers.append(writer.send(payload, binary=True))
await asyncio.gather(*writers)
for call in writer.transport.write.call_args_list:
call_bytes = call[0][0]
result, _ = reader.feed_data(call_bytes)
assert result is False
msg = await queue.read()
bytes_data: bytes = msg.data
assert len(bytes_data) == msg_length
assert bytes_data == bytes_data[0:1] * msg_length

0 comments on commit 6574aa6

Please sign in to comment.