From 6574aa6810259e367069270bf2a554e44500ce50 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Wed, 22 Nov 2023 12:41:52 +0100 Subject: [PATCH] Restore async concurrency safety to websocket compressor --- CHANGES/7865.bugfix | 1 + aiohttp/compression_utils.py | 22 +++++---- aiohttp/http_websocket.py | 82 +++++++++++++++++++++++----------- tests/test_websocket_writer.py | 61 ++++++++++++++++++++++++- 4 files changed, 130 insertions(+), 36 deletions(-) create mode 100644 CHANGES/7865.bugfix diff --git a/CHANGES/7865.bugfix b/CHANGES/7865.bugfix new file mode 100644 index 00000000000..9a46e124486 --- /dev/null +++ b/CHANGES/7865.bugfix @@ -0,0 +1 @@ +Restore async concurrency safety to websocket compressor diff --git a/aiohttp/compression_utils.py b/aiohttp/compression_utils.py index 52791fe5015..d75a83913d8 100644 --- a/aiohttp/compression_utils.py +++ b/aiohttp/compression_utils.py @@ -37,6 +37,12 @@ def __init__( self._executor = executor self._max_sync_chunk_size = max_sync_chunk_size + def should_run_in_executor(self, data: bytes) -> bool: + return ( + self._max_sync_chunk_size is not None + and len(data) > self._max_sync_chunk_size + ) + class ZLibCompressor(ZlibBaseHandler): def __init__( @@ -48,7 +54,7 @@ def __init__( strategy: int = zlib.Z_DEFAULT_STRATEGY, executor: Optional[Executor] = None, max_sync_chunk_size: Optional[int] = MAX_SYNC_CHUNK_SIZE, - ): + ) -> None: super().__init__( mode=encoding_to_mode(encoding, suppress_deflate_header) if wbits is None @@ -66,14 +72,14 @@ def __init__( def compress_sync(self, data: bytes) -> bytes: return self._compressor.compress(data) + async def compress_executor(self, data: bytes) -> bytes: + return await asyncio.get_event_loop().run_in_executor( + self._executor, self.compress_sync, data + ) + async def compress(self, data: bytes) -> bytes: - if ( - self._max_sync_chunk_size is not None - and len(data) > self._max_sync_chunk_size - ): - return await asyncio.get_event_loop().run_in_executor( - self._executor, self.compress_sync, data - ) + if self.should_run_in_executor(data): + return await self.compress_executor(data) return self.compress_sync(data) def flush(self, mode: int = zlib.Z_FINISH) -> bytes: diff --git a/aiohttp/http_websocket.py b/aiohttp/http_websocket.py index 3f124e177c1..ce3e1f01600 100644 --- a/aiohttp/http_websocket.py +++ b/aiohttp/http_websocket.py @@ -610,6 +610,7 @@ def __init__( self._limit = limit self._output_size = 0 self._compressobj: Any = None # actually compressobj + self._compress_lock: Optional[asyncio.Lock] = None async def _send_frame( self, message: bytes, opcode: int, compress: Optional[int] = None @@ -626,28 +627,57 @@ async def _send_frame( if (compress or self.compress) and opcode < 8: if compress: # Do not set self._compress if compressing is for this frame - compressobj = ZLibCompressor( - level=zlib.Z_BEST_SPEED, - wbits=-compress, - max_sync_chunk_size=WEBSOCKET_MAX_SYNC_CHUNK_SIZE, - ) + compressobj = self._make_compress_obj(compress) else: # self.compress if not self._compressobj: - self._compressobj = ZLibCompressor( - level=zlib.Z_BEST_SPEED, - wbits=-self.compress, - max_sync_chunk_size=WEBSOCKET_MAX_SYNC_CHUNK_SIZE, - ) + self._compressobj = self._make_compress_obj(self.compress) compressobj = self._compressobj - message = await compressobj.compress(message) - message += compressobj.flush( - zlib.Z_FULL_FLUSH if self.notakeover else zlib.Z_SYNC_FLUSH - ) - if message.endswith(_WS_DEFLATE_TRAILING): - message = message[:-4] - rsv = rsv | 0x40 + if not compressobj.should_run_in_executor(message): + # Compress message if it is smaller than max sync chunk size + # without awaiting to ensure that the message written before + # the next message can be compressed. + message = compressobj.compress_sync(message) + self._write_compressed_message(message, opcode, compressobj, rsv) + else: + # Since we are compressing in an executor, and the await returns + # control to the event loop we need to hold a lock to ensure that + # the compressed message is written before the next message is + # compressed to ensure that the messages are written in the correct + # order and context takeover is not violated. + if not self._compress_lock: + self._compress_lock = asyncio.Lock() + async with self._compress_lock: + message = await compressobj.compress(message) + self._write_compressed_message(message, opcode, compressobj, rsv) + + else: + self._write_message(message, opcode, rsv) + + if self._output_size > self._limit: + self._output_size = 0 + await self.protocol._drain_helper() + + def _make_compress_obj(self, compress: int) -> ZLibCompressor: + return ZLibCompressor( + level=zlib.Z_BEST_SPEED, + wbits=-compress, + max_sync_chunk_size=WEBSOCKET_MAX_SYNC_CHUNK_SIZE, + ) + + def _write_compressed_message( + self, message: bytes, opcode: int, compressobj: ZLibCompressor, rsv: int + ) -> None: + message += compressobj.flush( + zlib.Z_FULL_FLUSH if self.notakeover else zlib.Z_SYNC_FLUSH + ) + if message.endswith(_WS_DEFLATE_TRAILING): + message = message[:-4] + rsv = rsv | 0x40 + self._write_message(message, opcode, rsv) + def _write_message(self, message: bytes, opcode: int, rsv: int) -> None: + """Write a complete websocket frame to the transport.""" msg_length = len(message) use_mask = self.use_mask @@ -662,6 +692,7 @@ async def _send_frame( header = PACK_LEN2(0x80 | rsv | opcode, 126 | mask_bit, msg_length) else: header = PACK_LEN3(0x80 | rsv | opcode, 127 | mask_bit, msg_length) + if use_mask: mask = self.randrange(0, 0xFFFFFFFF) mask = mask.to_bytes(4, "big") @@ -669,18 +700,15 @@ async def _send_frame( _websocket_mask(mask, message) self._write(header + mask + message) self._output_size += len(header) + len(mask) + len(message) - else: - if len(message) > MSG_SIZE: - self._write(header) - self._write(message) - else: - self._write(header + message) + return - self._output_size += len(header) + len(message) + if len(message) > MSG_SIZE: + self._write(header) + self._write(message) + else: + self._write(header + message) - if self._output_size > self._limit: - self._output_size = 0 - await self.protocol._drain_helper() + self._output_size += len(header) + len(message) def _write(self, data: bytes) -> None: if self.transport.is_closing(): diff --git a/tests/test_websocket_writer.py b/tests/test_websocket_writer.py index 82bed546170..57ff74cab02 100644 --- a/tests/test_websocket_writer.py +++ b/tests/test_websocket_writer.py @@ -1,11 +1,13 @@ # type: ignore +import asyncio import random from typing import Any from unittest import mock import pytest -from aiohttp.http import WebSocketWriter +from aiohttp import DataQueue, WSMessage +from aiohttp.http import WebSocketReader, WebSocketWriter from aiohttp.test_utils import make_mocked_coro @@ -106,3 +108,60 @@ async def test_send_compress_text_per_message(protocol: Any, transport: Any) -> writer.transport.write.assert_called_with(b"\x81\x04text") await writer.send(b"text", compress=15) writer.transport.write.assert_called_with(b"\xc1\x06*I\xad(\x01\x00") + + +@mock.patch("aiohttp.http_websocket.WEBSOCKET_MAX_SYNC_CHUNK_SIZE", 16) +async def test_concurrent_messages_with_executor(protocol: Any, transport: Any) -> None: + """Ensure messages are compressed correctly when there are multiple concurrent writers. + + This test generates messages large enough that they will + be compressed in the executor. + """ + writer = WebSocketWriter(protocol, transport, compress=15) + queue: DataQueue[WSMessage] = DataQueue(asyncio.get_running_loop()) + reader = WebSocketReader(queue, 50000) + writers = [] + payloads = [] + msg_length = 16 + 1 + for count in range(1, 64 + 1): + payload = bytes((count,)) * msg_length + payloads.append(payload) + writers.append(writer.send(payload, binary=True)) + await asyncio.gather(*writers) + for call in writer.transport.write.call_args_list: + call_bytes = call[0][0] + result, _ = reader.feed_data(call_bytes) + assert result is False + msg = await queue.read() + bytes_data: bytes = msg.data + assert len(bytes_data) == msg_length + assert bytes_data == bytes_data[0:1] * msg_length + + +async def test_concurrent_messages_without_executor( + protocol: Any, transport: Any +) -> None: + """Ensure messages are compressed correctly when there are multiple concurrent writers. + + This test generates messages that are small enough that + they will not be compressed in the executor. + """ + writer = WebSocketWriter(protocol, transport, compress=15) + queue: DataQueue[WSMessage] = DataQueue(asyncio.get_running_loop()) + reader = WebSocketReader(queue, 50000) + writers = [] + payloads = [] + msg_length = 16 + 1 + for count in range(1, 64 + 1): + payload = bytes((count,)) * msg_length + payloads.append(payload) + writers.append(writer.send(payload, binary=True)) + await asyncio.gather(*writers) + for call in writer.transport.write.call_args_list: + call_bytes = call[0][0] + result, _ = reader.feed_data(call_bytes) + assert result is False + msg = await queue.read() + bytes_data: bytes = msg.data + assert len(bytes_data) == msg_length + assert bytes_data == bytes_data[0:1] * msg_length