From c81e3e093873eadbb82f8b71a5c035d5a1fc4f8a Mon Sep 17 00:00:00 2001 From: Alexander Mohr Date: Wed, 22 Aug 2018 13:01:21 -0700 Subject: [PATCH] add thread helper from https://github.com/aio-libs/aiohttp/pull/3205 --- aiohttp/web_response.py | 68 +++++++++++++++++++++++++++++--------- tests/test_web_response.py | 43 +++++++++++++++++++++++- 2 files changed, 95 insertions(+), 16 deletions(-) diff --git a/aiohttp/web_response.py b/aiohttp/web_response.py index ad9ea0c2a51..9411228654b 100644 --- a/aiohttp/web_response.py +++ b/aiohttp/web_response.py @@ -1,3 +1,4 @@ +import asyncio import collections import datetime import enum @@ -16,7 +17,8 @@ from .http import RESPONSES, SERVER_SOFTWARE, HttpVersion10, HttpVersion11 -__all__ = ('ContentCoding', 'StreamResponse', 'Response', 'json_response') +__all__ = ('ContentCoding', 'StreamResponse', 'Response', 'json_response', + 'async_json_response') class ContentCoding(enum.Enum): @@ -271,7 +273,7 @@ def _generate_content_type_header(self, CONTENT_TYPE=hdrs.CONTENT_TYPE): ctype = self._content_type self.headers[CONTENT_TYPE] = ctype - def _do_start_compression(self, coding): + async def _do_start_compression(self, coding): if coding != ContentCoding.identity: self.headers[hdrs.CONTENT_ENCODING] = coding.value self._payload_writer.enable_compression(coding.value) @@ -279,15 +281,15 @@ def _do_start_compression(self, coding): # remove the header self._headers.popall(hdrs.CONTENT_LENGTH, None) - def _start_compression(self, request): + async def _start_compression(self, request): if self._compression_force: - self._do_start_compression(self._compression_force) + await self._do_start_compression(self._compression_force) else: accept_encoding = request.headers.get( hdrs.ACCEPT_ENCODING, '').lower() for coding in ContentCoding: if coding.value in accept_encoding: - self._do_start_compression(coding) + await self._do_start_compression(coding) return async def prepare(self, request): @@ -326,7 +328,7 @@ async def _start(self, request, headers.add(SET_COOKIE, value) if self._compression: - self._start_compression(request) + await self._start_compression(request) if self._chunked: if version != HttpVersion11: @@ -437,7 +439,7 @@ class Response(StreamResponse): def __init__(self, *, body=None, status=200, reason=None, text=None, headers=None, content_type=None, - charset=None): + charset=None, zlib_thread_size=None): if body is not None and text is not None: raise ValueError("body and text are not allowed together") @@ -489,6 +491,7 @@ def __init__(self, *, body=None, status=200, self.body = body self._compressed_body = None + self._zlib_thread_size = zlib_thread_size @property def body(self): @@ -575,7 +578,9 @@ def content_length(self): def content_length(self, value): raise RuntimeError("Content length is set automatically") - async def write_eof(self): + async def write_eof(self, data=b''): + assert not data + if self._eof_sent: return if self._compressed_body is not None: @@ -594,7 +599,7 @@ async def write_eof(self): else: await super().write_eof() - async def _start(self, request): + async def _start(self, request, *args, **kwargs): if not self._chunked and hdrs.CONTENT_LENGTH not in self._headers: if not self._body_payload: if self._body is not None: @@ -602,19 +607,29 @@ async def _start(self, request): else: self._headers[hdrs.CONTENT_LENGTH] = '0' - return await super()._start(request) + return await super()._start(request, *args, **kwargs) - def _do_start_compression(self, coding): + def _compress_body(self, zlib_mode): + compressobj = zlib.compressobj(wbits=zlib_mode) + self._compressed_body = \ + compressobj.compress(self._body) + compressobj.flush() + + async def _do_start_compression(self, coding): if self._body_payload or self._chunked: - return super()._do_start_compression(coding) + return await super()._do_start_compression(coding) if coding != ContentCoding.identity: # Instead of using _payload_writer.enable_compression, # compress the whole body zlib_mode = (16 + zlib.MAX_WBITS if coding.value == 'gzip' else -zlib.MAX_WBITS) - compressobj = zlib.compressobj(wbits=zlib_mode) - self._compressed_body = compressobj.compress(self._body) +\ - compressobj.flush() + + if self._zlib_thread_size is not None and \ + len(self._body) > self._zlib_thread_size: + await asyncio.get_event_loop().run_in_executor( + None, self._compress_body, zlib_mode) + else: + self._compress_body(zlib_mode) + self._headers[hdrs.CONTENT_ENCODING] = coding.value self._headers[hdrs.CONTENT_LENGTH] = \ str(len(self._compressed_body)) @@ -632,3 +647,26 @@ def json_response(data=sentinel, *, text=None, body=None, status=200, text = dumps(data) return Response(text=text, body=body, status=status, reason=reason, headers=headers, content_type=content_type) + + +async def async_json_response(data=sentinel, *, text=None, body=None, + status=200, reason=None, headers=None, + content_type='application/json', + dumps=json.dumps, executor_body_size=None): + if data is not sentinel: + if text or body: + raise ValueError( + "only one of data, text, or body should be specified" + ) + else: + if asyncio.iscoroutinefunction(dumps): + text = await dumps(data) + elif executor_body_size is not None and \ + len(data) > executor_body_size: + loop = asyncio.get_event_loop() + text = await loop.run_in_executor(None, dumps, data) + else: + text = dumps(data) + return Response(text=text, body=body, status=status, reason=reason, + headers=headers, content_type=content_type, + zlib_thread_size=executor_body_size) diff --git a/tests/test_web_response.py b/tests/test_web_response.py index 12a7f25deb2..161da97fc6f 100644 --- a/tests/test_web_response.py +++ b/tests/test_web_response.py @@ -1,5 +1,6 @@ import collections import datetime +import gzip import json import re from unittest import mock @@ -10,7 +11,8 @@ from aiohttp import HttpVersion, HttpVersion10, HttpVersion11, hdrs, signals from aiohttp.payload import BytesPayload from aiohttp.test_utils import make_mocked_coro, make_mocked_request -from aiohttp.web import ContentCoding, Response, StreamResponse, json_response +from aiohttp.web import ContentCoding, Response, StreamResponse,\ + json_response, async_json_response def make_request(method, path, headers=CIMultiDict(), @@ -408,6 +410,17 @@ async def test_force_compression_no_accept_gzip(): assert 'gzip' == resp.headers.get(hdrs.CONTENT_ENCODING) +async def test_change_content_threaded_compression_enabled(): + req = make_request('GET', '/') + body_thread_size = 1024 + body = b'answer' * body_thread_size + resp = Response(body=body, + zlib_thread_size=body_thread_size) + resp.enable_compression(ContentCoding.gzip) + await resp.prepare(req) + assert gzip.decompress(resp._compressed_body) == body + + async def test_change_content_length_if_compression_enabled(): req = make_request('GET', '/') resp = Response(body=b'answer') @@ -1095,6 +1108,34 @@ def test_response_with_content_length_header_without_body(): assert resp.content_length == 123 +async def test_async_json_small_response(): + text = 'jaysawn' + resp = await async_json_response(text=json.dumps(text)) + assert resp.text == json.dumps(text) + + resp = await async_json_response(text) + assert resp.text == json.dumps(text) + + with pytest.raises(ValueError): + await async_json_response(text, body=text) + + +async def test_async_json_large_response(): + cuttoff_length = 1024 + text = 'ja' * cuttoff_length + resp = await async_json_response(text, executor_body_size=cuttoff_length) + assert resp.text == json.dumps(text) + + +async def test_async_json_coro_response(): + async def dumps(data): + return json.dumps(data) + + text = 'jaysawn' + resp = await async_json_response(text, dumps=dumps) + assert resp.text == json.dumps(text) + + class TestJSONResponse: def test_content_type_is_application_json_by_default(self):