Skip to content

Commit

Permalink
add thread helper
Browse files Browse the repository at this point in the history
  • Loading branch information
thehesiod committed Aug 22, 2018
1 parent d203426 commit c81e3e0
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 16 deletions.
68 changes: 53 additions & 15 deletions aiohttp/web_response.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import collections
import datetime
import enum
Expand All @@ -16,7 +17,8 @@
from .http import RESPONSES, SERVER_SOFTWARE, HttpVersion10, HttpVersion11


__all__ = ('ContentCoding', 'StreamResponse', 'Response', 'json_response')
__all__ = ('ContentCoding', 'StreamResponse', 'Response', 'json_response',
'async_json_response')


class ContentCoding(enum.Enum):
Expand Down Expand Up @@ -271,23 +273,23 @@ def _generate_content_type_header(self, CONTENT_TYPE=hdrs.CONTENT_TYPE):
ctype = self._content_type
self.headers[CONTENT_TYPE] = ctype

def _do_start_compression(self, coding):
async def _do_start_compression(self, coding):
if coding != ContentCoding.identity:
self.headers[hdrs.CONTENT_ENCODING] = coding.value
self._payload_writer.enable_compression(coding.value)
# Compressed payload may have different content length,
# remove the header
self._headers.popall(hdrs.CONTENT_LENGTH, None)

def _start_compression(self, request):
async def _start_compression(self, request):
if self._compression_force:
self._do_start_compression(self._compression_force)
await self._do_start_compression(self._compression_force)
else:
accept_encoding = request.headers.get(
hdrs.ACCEPT_ENCODING, '').lower()
for coding in ContentCoding:
if coding.value in accept_encoding:
self._do_start_compression(coding)
await self._do_start_compression(coding)
return

async def prepare(self, request):
Expand Down Expand Up @@ -326,7 +328,7 @@ async def _start(self, request,
headers.add(SET_COOKIE, value)

if self._compression:
self._start_compression(request)
await self._start_compression(request)

if self._chunked:
if version != HttpVersion11:
Expand Down Expand Up @@ -437,7 +439,7 @@ class Response(StreamResponse):

def __init__(self, *, body=None, status=200,
reason=None, text=None, headers=None, content_type=None,
charset=None):
charset=None, zlib_thread_size=None):
if body is not None and text is not None:
raise ValueError("body and text are not allowed together")

Expand Down Expand Up @@ -489,6 +491,7 @@ def __init__(self, *, body=None, status=200,
self.body = body

self._compressed_body = None
self._zlib_thread_size = zlib_thread_size

@property
def body(self):
Expand Down Expand Up @@ -575,7 +578,9 @@ def content_length(self):
def content_length(self, value):
raise RuntimeError("Content length is set automatically")

async def write_eof(self):
async def write_eof(self, data=b''):
assert not data

if self._eof_sent:
return
if self._compressed_body is not None:
Expand All @@ -594,27 +599,37 @@ async def write_eof(self):
else:
await super().write_eof()

async def _start(self, request):
async def _start(self, request, *args, **kwargs):
if not self._chunked and hdrs.CONTENT_LENGTH not in self._headers:
if not self._body_payload:
if self._body is not None:
self._headers[hdrs.CONTENT_LENGTH] = str(len(self._body))
else:
self._headers[hdrs.CONTENT_LENGTH] = '0'

return await super()._start(request)
return await super()._start(request, *args, **kwargs)

def _do_start_compression(self, coding):
def _compress_body(self, zlib_mode):
compressobj = zlib.compressobj(wbits=zlib_mode)
self._compressed_body = \
compressobj.compress(self._body) + compressobj.flush()

async def _do_start_compression(self, coding):
if self._body_payload or self._chunked:
return super()._do_start_compression(coding)
return await super()._do_start_compression(coding)
if coding != ContentCoding.identity:
# Instead of using _payload_writer.enable_compression,
# compress the whole body
zlib_mode = (16 + zlib.MAX_WBITS
if coding.value == 'gzip' else -zlib.MAX_WBITS)
compressobj = zlib.compressobj(wbits=zlib_mode)
self._compressed_body = compressobj.compress(self._body) +\
compressobj.flush()

if self._zlib_thread_size is not None and \
len(self._body) > self._zlib_thread_size:
await asyncio.get_event_loop().run_in_executor(
None, self._compress_body, zlib_mode)
else:
self._compress_body(zlib_mode)

self._headers[hdrs.CONTENT_ENCODING] = coding.value
self._headers[hdrs.CONTENT_LENGTH] = \
str(len(self._compressed_body))
Expand All @@ -632,3 +647,26 @@ def json_response(data=sentinel, *, text=None, body=None, status=200,
text = dumps(data)
return Response(text=text, body=body, status=status, reason=reason,
headers=headers, content_type=content_type)


async def async_json_response(data=sentinel, *, text=None, body=None,
status=200, reason=None, headers=None,
content_type='application/json',
dumps=json.dumps, executor_body_size=None):
if data is not sentinel:
if text or body:
raise ValueError(
"only one of data, text, or body should be specified"
)
else:
if asyncio.iscoroutinefunction(dumps):
text = await dumps(data)
elif executor_body_size is not None and \
len(data) > executor_body_size:
loop = asyncio.get_event_loop()
text = await loop.run_in_executor(None, dumps, data)
else:
text = dumps(data)
return Response(text=text, body=body, status=status, reason=reason,
headers=headers, content_type=content_type,
zlib_thread_size=executor_body_size)
43 changes: 42 additions & 1 deletion tests/test_web_response.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import collections
import datetime
import gzip
import json
import re
from unittest import mock
Expand All @@ -10,7 +11,8 @@
from aiohttp import HttpVersion, HttpVersion10, HttpVersion11, hdrs, signals
from aiohttp.payload import BytesPayload
from aiohttp.test_utils import make_mocked_coro, make_mocked_request
from aiohttp.web import ContentCoding, Response, StreamResponse, json_response
from aiohttp.web import ContentCoding, Response, StreamResponse,\
json_response, async_json_response


def make_request(method, path, headers=CIMultiDict(),
Expand Down Expand Up @@ -408,6 +410,17 @@ async def test_force_compression_no_accept_gzip():
assert 'gzip' == resp.headers.get(hdrs.CONTENT_ENCODING)


async def test_change_content_threaded_compression_enabled():
req = make_request('GET', '/')
body_thread_size = 1024
body = b'answer' * body_thread_size
resp = Response(body=body,
zlib_thread_size=body_thread_size)
resp.enable_compression(ContentCoding.gzip)
await resp.prepare(req)
assert gzip.decompress(resp._compressed_body) == body


async def test_change_content_length_if_compression_enabled():
req = make_request('GET', '/')
resp = Response(body=b'answer')
Expand Down Expand Up @@ -1095,6 +1108,34 @@ def test_response_with_content_length_header_without_body():
assert resp.content_length == 123


async def test_async_json_small_response():
text = 'jaysawn'
resp = await async_json_response(text=json.dumps(text))
assert resp.text == json.dumps(text)

resp = await async_json_response(text)
assert resp.text == json.dumps(text)

with pytest.raises(ValueError):
await async_json_response(text, body=text)


async def test_async_json_large_response():
cuttoff_length = 1024
text = 'ja' * cuttoff_length
resp = await async_json_response(text, executor_body_size=cuttoff_length)
assert resp.text == json.dumps(text)


async def test_async_json_coro_response():
async def dumps(data):
return json.dumps(data)

text = 'jaysawn'
resp = await async_json_response(text, dumps=dumps)
assert resp.text == json.dumps(text)


class TestJSONResponse:

def test_content_type_is_application_json_by_default(self):
Expand Down

0 comments on commit c81e3e0

Please sign in to comment.