From 1b88af2c3f5e5f992d0015f90927fd9c3e00bef0 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Mon, 19 Aug 2024 11:33:43 -0500 Subject: [PATCH] Improve performance of WebSocketReader (#8736) --- CHANGES/8736.misc.rst | 1 + aiohttp/http_websocket.py | 195 ++++++++++++++++++++------------------ 2 files changed, 103 insertions(+), 93 deletions(-) create mode 100644 CHANGES/8736.misc.rst diff --git a/CHANGES/8736.misc.rst b/CHANGES/8736.misc.rst new file mode 100644 index 00000000000..34ed19aebba --- /dev/null +++ b/CHANGES/8736.misc.rst @@ -0,0 +1 @@ +Improved performance of the WebSocket reader -- by :user:`bdraco`. diff --git a/aiohttp/http_websocket.py b/aiohttp/http_websocket.py index acceb9bd293..10a90aa3106 100644 --- a/aiohttp/http_websocket.py +++ b/aiohttp/http_websocket.py @@ -85,6 +85,12 @@ class WSMsgType(IntEnum): ERROR = 0x102 +MESSAGE_TYPES_WITH_CONTENT: Final = ( + WSMsgType.BINARY, + WSMsgType.TEXT, + WSMsgType.CONTINUATION, +) + WS_KEY: Final[bytes] = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11" @@ -304,17 +310,99 @@ def feed_data(self, data: bytes) -> Tuple[bool, bytes]: return True, data try: - return self._feed_data(data) + self._feed_data(data) except Exception as exc: self._exc = exc set_exception(self.queue, exc) return True, b"" - def _feed_data(self, data: bytes) -> Tuple[bool, bytes]: + return False, b"" + + def _feed_data(self, data: bytes) -> None: for fin, opcode, payload, compressed in self.parse_frame(data): - if compressed and not self._decompressobj: - self._decompressobj = ZLibDecompressor(suppress_deflate_header=True) - if opcode == WSMsgType.CLOSE: + if opcode in MESSAGE_TYPES_WITH_CONTENT: + # load text/binary + is_continuation = opcode == WSMsgType.CONTINUATION + if not fin: + # got partial frame payload + if not is_continuation: + self._opcode = opcode + self._partial += payload + if self._max_msg_size and len(self._partial) >= self._max_msg_size: + raise WebSocketError( + WSCloseCode.MESSAGE_TOO_BIG, + "Message size {} exceeds limit {}".format( + len(self._partial), self._max_msg_size + ), + ) + continue + + has_partial = bool(self._partial) + if is_continuation: + if self._opcode is None: + raise WebSocketError( + WSCloseCode.PROTOCOL_ERROR, + "Continuation frame for non started message", + ) + opcode = self._opcode + self._opcode = None + # previous frame was non finished + # we should get continuation opcode + elif has_partial: + raise WebSocketError( + WSCloseCode.PROTOCOL_ERROR, + "The opcode in non-fin frame is expected " + "to be zero, got {!r}".format(opcode), + ) + + if has_partial: + assembled_payload = self._partial + payload + self._partial.clear() + else: + assembled_payload = payload + + if self._max_msg_size and len(assembled_payload) >= self._max_msg_size: + raise WebSocketError( + WSCloseCode.MESSAGE_TOO_BIG, + "Message size {} exceeds limit {}".format( + len(assembled_payload), self._max_msg_size + ), + ) + + # Decompress process must to be done after all packets + # received. + if compressed: + if not self._decompressobj: + self._decompressobj = ZLibDecompressor( + suppress_deflate_header=True + ) + payload_merged = self._decompressobj.decompress_sync( + assembled_payload + _WS_DEFLATE_TRAILING, self._max_msg_size + ) + if self._decompressobj.unconsumed_tail: + left = len(self._decompressobj.unconsumed_tail) + raise WebSocketError( + WSCloseCode.MESSAGE_TOO_BIG, + "Decompressed message size {} exceeds limit {}".format( + self._max_msg_size + left, self._max_msg_size + ), + ) + else: + payload_merged = bytes(assembled_payload) + + if opcode == WSMsgType.TEXT: + try: + text = payload_merged.decode("utf-8") + except UnicodeDecodeError as exc: + raise WebSocketError( + WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message" + ) from exc + + self.queue.feed_data(WSMessage(WSMsgType.TEXT, text, "")) + continue + + self.queue.feed_data(WSMessage(WSMsgType.BINARY, payload_merged, "")) + elif opcode == WSMsgType.CLOSE: if len(payload) >= 2: close_code = UNPACK_CLOSE_CODE(payload[:2])[0] if close_code < 3000 and close_code not in ALLOWED_CLOSE_CODES: @@ -345,87 +433,10 @@ def _feed_data(self, data: bytes) -> Tuple[bool, bytes]: elif opcode == WSMsgType.PONG: self.queue.feed_data(WSMessage(WSMsgType.PONG, payload, "")) - elif ( - opcode not in (WSMsgType.TEXT, WSMsgType.BINARY) - and self._opcode is None - ): + else: raise WebSocketError( WSCloseCode.PROTOCOL_ERROR, f"Unexpected opcode={opcode!r}" ) - else: - # load text/binary - if not fin: - # got partial frame payload - if opcode != WSMsgType.CONTINUATION: - self._opcode = opcode - self._partial.extend(payload) - if self._max_msg_size and len(self._partial) >= self._max_msg_size: - raise WebSocketError( - WSCloseCode.MESSAGE_TOO_BIG, - "Message size {} exceeds limit {}".format( - len(self._partial), self._max_msg_size - ), - ) - else: - # previous frame was non finished - # we should get continuation opcode - if self._partial: - if opcode != WSMsgType.CONTINUATION: - raise WebSocketError( - WSCloseCode.PROTOCOL_ERROR, - "The opcode in non-fin frame is expected " - "to be zero, got {!r}".format(opcode), - ) - - if opcode == WSMsgType.CONTINUATION: - assert self._opcode is not None - opcode = self._opcode - self._opcode = None - - self._partial.extend(payload) - if self._max_msg_size and len(self._partial) >= self._max_msg_size: - raise WebSocketError( - WSCloseCode.MESSAGE_TOO_BIG, - "Message size {} exceeds limit {}".format( - len(self._partial), self._max_msg_size - ), - ) - - # Decompress process must to be done after all packets - # received. - if compressed: - assert self._decompressobj is not None - self._partial.extend(_WS_DEFLATE_TRAILING) - payload_merged = self._decompressobj.decompress_sync( - self._partial, self._max_msg_size - ) - if self._decompressobj.unconsumed_tail: - left = len(self._decompressobj.unconsumed_tail) - raise WebSocketError( - WSCloseCode.MESSAGE_TOO_BIG, - "Decompressed message size {} exceeds limit {}".format( - self._max_msg_size + left, self._max_msg_size - ), - ) - else: - payload_merged = bytes(self._partial) - - self._partial.clear() - - if opcode == WSMsgType.TEXT: - try: - text = payload_merged.decode("utf-8") - self.queue.feed_data(WSMessage(WSMsgType.TEXT, text, "")) - except UnicodeDecodeError as exc: - raise WebSocketError( - WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message" - ) from exc - else: - self.queue.feed_data( - WSMessage(WSMsgType.BINARY, payload_merged, ""), - ) - - return False, b"" def parse_frame( self, buf: bytes @@ -505,23 +516,21 @@ def parse_frame( # read payload length if self._state is WSParserState.READ_PAYLOAD_LENGTH: - length = self._payload_length_flag - if length == 126: + length_flag = self._payload_length_flag + if length_flag == 126: if buf_length - start_pos < 2: break data = buf[start_pos : start_pos + 2] start_pos += 2 - length = UNPACK_LEN2(data)[0] - self._payload_length = length - elif length > 126: + self._payload_length = UNPACK_LEN2(data)[0] + elif length_flag > 126: if buf_length - start_pos < 8: break data = buf[start_pos : start_pos + 8] start_pos += 8 - length = UNPACK_LEN3(data)[0] - self._payload_length = length + self._payload_length = UNPACK_LEN3(data)[0] else: - self._payload_length = length + self._payload_length = length_flag self._state = ( WSParserState.READ_PAYLOAD_MASK @@ -544,11 +553,11 @@ def parse_frame( chunk_len = buf_length - start_pos if length >= chunk_len: self._payload_length = length - chunk_len - payload.extend(buf[start_pos:]) + payload += buf[start_pos:] start_pos = buf_length else: self._payload_length = 0 - payload.extend(buf[start_pos : start_pos + length]) + payload += buf[start_pos : start_pos + length] start_pos = start_pos + length if self._payload_length != 0: