From 7e9a381243a26987343344694d8a7cfc28df5cbd Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Mon, 20 Feb 2017 15:26:51 -0800 Subject: [PATCH] allow to use payload objects on server side --- aiohttp/multipart.py | 1 + aiohttp/payload.py | 14 ++- aiohttp/web_response.py | 67 ++++++++++---- tests/test_web_functional.py | 170 +++++++++++++++++++++++++++++++++++ tests/test_web_response.py | 2 +- 5 files changed, 236 insertions(+), 18 deletions(-) diff --git a/aiohttp/multipart.py b/aiohttp/multipart.py index ed89dd2470e..933f47e6686 100644 --- a/aiohttp/multipart.py +++ b/aiohttp/multipart.py @@ -163,6 +163,7 @@ def content_disposition_header(disptype, quote_fields=True, **params): if not disptype or not (TOKEN > set(disptype)): raise ValueError('bad content disposition type {!r}' ''.format(disptype)) + value = disptype if params: lparams = [] diff --git a/aiohttp/payload.py b/aiohttp/payload.py index 5a7fba37d95..2247b327ca3 100644 --- a/aiohttp/payload.py +++ b/aiohttp/payload.py @@ -11,7 +11,7 @@ from .multipart import content_disposition_header from .streams import DEFAULT_LIMIT, DataQueue, EofStream, StreamReader -__all__ = ('PAYLOAD_REGISTRY', 'Payload', +__all__ = ('PAYLOAD_REGISTRY', 'get_payload', 'Payload', 'BytesPayload', 'StringPayload', 'StreamReaderPayload', 'IOBasePayload', 'BytesIOPayload', 'BufferedReaderPayload', 'TextIOPayload', 'StringIOPayload') @@ -21,6 +21,10 @@ class LookupError(Exception): pass +def get_payload(data, *args, **kwargs): + return PAYLOAD_REGISTRY.get(data, *args, **kwargs) + + class PayloadRegistry: """Payload registry. @@ -31,6 +35,8 @@ def __init__(self): self._registry = [] def get(self, data, *args, **kwargs): + if isinstance(data, Payload): + return data for ctor, type in self._registry: if isinstance(data, type): return ctor(data, *args, **kwargs) @@ -144,6 +150,8 @@ def write(self, writer): yield from writer.write(chunk) chunk = self._value.read(DEFAULT_LIMIT) + self._value.close() + class StringIOPayload(IOBasePayload): @@ -154,6 +162,8 @@ def write(self, writer): yield from writer.write(chunk.encode(self._encoding)) chunk = self._value.read(DEFAULT_LIMIT) + self._value.close() + class TextIOPayload(Payload): @@ -165,6 +175,8 @@ def write(self, writer): yield from writer.write(chunk.encode(encoding)) chunk = self._value.read(DEFAULT_LIMIT) + self._value.close() + class BytesIOPayload(IOBasePayload): diff --git a/aiohttp/web_response.py b/aiohttp/web_response.py index 405dee1f4d7..eff07f3df37 100644 --- a/aiohttp/web_response.py +++ b/aiohttp/web_response.py @@ -9,7 +9,7 @@ from multidict import CIMultiDict, CIMultiDictProxy -from . import hdrs +from . import hdrs, payload from .helpers import HeadersMixin, SimpleCookie, sentinel from .http import (RESPONSES, SERVER_SOFTWARE, HttpVersion10, HttpVersion11, PayloadWriter) @@ -478,20 +478,53 @@ def __init__(self, *, body=None, status=200, headers[hdrs.CONTENT_TYPE] = content_type super().__init__(status=status, reason=reason, headers=headers) + if text is not None: self.text = text else: - self._body = body + self.body = body @property def body(self): return self._body @body.setter - def body(self, body): - assert body is None or isinstance(body, (bytes, bytearray)), \ - "body argument must be bytes (%r)" % type(body) - self._body = body + def body(self, body, + CONTENT_TYPE=hdrs.CONTENT_TYPE, + CONTENT_LENGTH=hdrs.CONTENT_LENGTH): + if body is None: + self._body = None + self._body_payload = False + elif isinstance(body, (bytes, bytearray)): + self._body = body + self._body_payload = False + else: + try: + self._body = body = payload.PAYLOAD_REGISTRY.get(body) + except payload.LookupError: + raise ValueError('Unsupported body type %r' % type(body)) + + self._body_payload = True + + headers = self._headers + + # enable chunked encoding if needed + if not self._chunked and CONTENT_LENGTH not in headers: + size = body.size + if size is None: + self._chunked = True + elif CONTENT_LENGTH not in headers: + headers[CONTENT_LENGTH] = str(size) + + # set content-type + if CONTENT_TYPE not in headers: + headers[CONTENT_TYPE] = body.content_type + + # copy payload headers + if body.headers: + for (key, value) in body.headers.items(): + if key not in headers: + headers[key] = value @property def text(self): @@ -531,18 +564,20 @@ def content_length(self, value): @asyncio.coroutine def write_eof(self): body = self._body - if (body is not None and - (self._req._method == hdrs.METH_HEAD or - self._status in [204, 304])): - body = b'' - - if body is None: - body = b'' - - yield from super().write_eof(body) + if body is not None: + if (self._req._method == hdrs.METH_HEAD or + self._status in [204, 304]): + yield from super().write_eof() + elif self._body_payload: + yield from body.write(self._payload_writer) + yield from super().write_eof() + else: + yield from super().write_eof(body) + else: + yield from super().write_eof() def _start(self, request): - if not self._chunked: + if not self._chunked and hdrs.CONTENT_LENGTH not in self._headers: if self._body is not None: self._headers[hdrs.CONTENT_LENGTH] = str(len(self._body)) else: diff --git a/tests/test_web_functional.py b/tests/test_web_functional.py index ae2fa2f4fba..67b274cb7f1 100644 --- a/tests/test_web_functional.py +++ b/tests/test_web_functional.py @@ -1,4 +1,5 @@ import asyncio +import io import json import pathlib import zlib @@ -8,6 +9,7 @@ from multidict import MultiDict from yarl import URL +import aiohttp from aiohttp import FormData, HttpVersion10, HttpVersion11, multipart, web try: @@ -16,6 +18,16 @@ ssl = False +@pytest.fixture +def here(): + return pathlib.Path(__file__).parent + + +@pytest.fixture +def fname(here): + return here / 'sample.key' + + @asyncio.coroutine def test_simple_get(loop, test_client): @@ -739,6 +751,164 @@ def handler(request): assert 200 == resp.status +@asyncio.coroutine +def test_response_with_streamer(loop, test_client, fname): + + with fname.open('rb') as f: + data = f.read() + + data_size = len(data) + + @aiohttp.streamer + def stream(writer, f_name): + with f_name.open('rb') as f: + data = f.read(100) + while data: + yield from writer.write(data) + data = f.read(100) + + @asyncio.coroutine + def handler(request): + headers = {'Content-Length': str(data_size)} + return web.Response(body=stream(fname), headers=headers) + + app = web.Application(loop=loop) + app.router.add_get('/', handler) + client = yield from test_client(app) + + resp = yield from client.get('/') + assert 200 == resp.status + resp_data = yield from resp.read() + assert resp_data == data + assert resp.headers.get('Content-Length') == str(len(resp_data)) + + +@asyncio.coroutine +def test_response_with_streamer_no_params(loop, test_client, fname): + + with fname.open('rb') as f: + data = f.read() + + data_size = len(data) + + @aiohttp.streamer + def stream(writer): + with fname.open('rb') as f: + data = f.read(100) + while data: + yield from writer.write(data) + data = f.read(100) + + @asyncio.coroutine + def handler(request): + headers = {'Content-Length': str(data_size)} + return web.Response(body=stream, headers=headers) + + app = web.Application(loop=loop) + app.router.add_get('/', handler) + client = yield from test_client(app) + + resp = yield from client.get('/') + assert 200 == resp.status + resp_data = yield from resp.read() + assert resp_data == data + assert resp.headers.get('Content-Length') == str(len(resp_data)) + + +@asyncio.coroutine +def test_response_with_file(loop, test_client, fname): + + with fname.open('rb') as f: + data = f.read() + + @asyncio.coroutine + def handler(request): + return web.Response(body=fname.open('rb')) + + app = web.Application(loop=loop) + app.router.add_get('/', handler) + client = yield from test_client(app) + + resp = yield from client.get('/') + assert 200 == resp.status + resp_data = yield from resp.read() + assert resp_data == data + assert resp.headers.get('Content-Type') in ( + 'application/octet-stream', 'application/pgp-keys') + assert resp.headers.get('Content-Length') == str(len(resp_data)) + assert (resp.headers.get('Content-Disposition') == + 'attachment; filename="sample.key"; filename*=utf-8\'\'sample.key') + + +@asyncio.coroutine +def test_response_with_file_ctype(loop, test_client, fname): + + with fname.open('rb') as f: + data = f.read() + + @asyncio.coroutine + def handler(request): + return web.Response( + body=fname.open('rb'), headers={'content-type': 'text/binary'}) + + app = web.Application(loop=loop) + app.router.add_get('/', handler) + client = yield from test_client(app) + + resp = yield from client.get('/') + assert 200 == resp.status + resp_data = yield from resp.read() + assert resp_data == data + assert resp.headers.get('Content-Type') == 'text/binary' + assert resp.headers.get('Content-Length') == str(len(resp_data)) + assert (resp.headers.get('Content-Disposition') == + 'attachment; filename="sample.key"; filename*=utf-8\'\'sample.key') + + +@asyncio.coroutine +def test_response_with_payload_disp(loop, test_client, fname): + + with fname.open('rb') as f: + data = f.read() + + @asyncio.coroutine + def handler(request): + pl = aiohttp.get_payload(fname.open('rb')) + pl.set_content_disposition('inline', filename='test.txt') + return web.Response( + body=pl, headers={'content-type': 'text/binary'}) + + app = web.Application(loop=loop) + app.router.add_get('/', handler) + client = yield from test_client(app) + + resp = yield from client.get('/') + assert 200 == resp.status + resp_data = yield from resp.read() + assert resp_data == data + assert resp.headers.get('Content-Type') == 'text/binary' + assert resp.headers.get('Content-Length') == str(len(resp_data)) + assert (resp.headers.get('Content-Disposition') == + 'inline; filename="test.txt"; filename*=utf-8\'\'test.txt') + + +@asyncio.coroutine +def test_response_with_payload_stringio(loop, test_client, fname): + + @asyncio.coroutine + def handler(request): + return web.Response(body=io.StringIO('test')) + + app = web.Application(loop=loop) + app.router.add_get('/', handler) + client = yield from test_client(app) + + resp = yield from client.get('/') + assert 200 == resp.status + resp_data = yield from resp.read() + assert resp_data == b'test' + + @asyncio.coroutine def test_response_with_precompressed_body_gzip(loop, test_client): diff --git a/tests/test_web_response.py b/tests/test_web_response.py index 5f14f5cc8ba..222edcdee5e 100644 --- a/tests/test_web_response.py +++ b/tests/test_web_response.py @@ -802,7 +802,7 @@ def test_ctor_both_charset_param_and_header(): def test_assign_nonbyteish_body(): resp = Response(body=b'data') - with pytest.raises(AssertionError): + with pytest.raises(ValueError): resp.body = 123 assert b'data' == resp.body assert 4 == resp.content_length