Skip to content

Commit

Permalink
allow to use payload objects on server side
Browse files Browse the repository at this point in the history
  • Loading branch information
fafhrd91 committed Feb 21, 2017
1 parent ca1e650 commit 7e9a381
Show file tree
Hide file tree
Showing 5 changed files with 236 additions and 18 deletions.
1 change: 1 addition & 0 deletions aiohttp/multipart.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ def content_disposition_header(disptype, quote_fields=True, **params):
if not disptype or not (TOKEN > set(disptype)):
raise ValueError('bad content disposition type {!r}'
''.format(disptype))

value = disptype
if params:
lparams = []
Expand Down
14 changes: 13 additions & 1 deletion aiohttp/payload.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from .multipart import content_disposition_header
from .streams import DEFAULT_LIMIT, DataQueue, EofStream, StreamReader

__all__ = ('PAYLOAD_REGISTRY', 'Payload',
__all__ = ('PAYLOAD_REGISTRY', 'get_payload', 'Payload',
'BytesPayload', 'StringPayload', 'StreamReaderPayload',
'IOBasePayload', 'BytesIOPayload', 'BufferedReaderPayload',
'TextIOPayload', 'StringIOPayload')
Expand All @@ -21,6 +21,10 @@ class LookupError(Exception):
pass


def get_payload(data, *args, **kwargs):
return PAYLOAD_REGISTRY.get(data, *args, **kwargs)


class PayloadRegistry:
"""Payload registry.
Expand All @@ -31,6 +35,8 @@ def __init__(self):
self._registry = []

def get(self, data, *args, **kwargs):
if isinstance(data, Payload):
return data
for ctor, type in self._registry:
if isinstance(data, type):
return ctor(data, *args, **kwargs)
Expand Down Expand Up @@ -144,6 +150,8 @@ def write(self, writer):
yield from writer.write(chunk)
chunk = self._value.read(DEFAULT_LIMIT)

self._value.close()


class StringIOPayload(IOBasePayload):

Expand All @@ -154,6 +162,8 @@ def write(self, writer):
yield from writer.write(chunk.encode(self._encoding))
chunk = self._value.read(DEFAULT_LIMIT)

self._value.close()


class TextIOPayload(Payload):

Expand All @@ -165,6 +175,8 @@ def write(self, writer):
yield from writer.write(chunk.encode(encoding))
chunk = self._value.read(DEFAULT_LIMIT)

self._value.close()


class BytesIOPayload(IOBasePayload):

Expand Down
67 changes: 51 additions & 16 deletions aiohttp/web_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from multidict import CIMultiDict, CIMultiDictProxy

from . import hdrs
from . import hdrs, payload
from .helpers import HeadersMixin, SimpleCookie, sentinel
from .http import (RESPONSES, SERVER_SOFTWARE, HttpVersion10,
HttpVersion11, PayloadWriter)
Expand Down Expand Up @@ -478,20 +478,53 @@ def __init__(self, *, body=None, status=200,
headers[hdrs.CONTENT_TYPE] = content_type

super().__init__(status=status, reason=reason, headers=headers)

if text is not None:
self.text = text
else:
self._body = body
self.body = body

@property
def body(self):
return self._body

@body.setter
def body(self, body):
assert body is None or isinstance(body, (bytes, bytearray)), \
"body argument must be bytes (%r)" % type(body)
self._body = body
def body(self, body,
CONTENT_TYPE=hdrs.CONTENT_TYPE,
CONTENT_LENGTH=hdrs.CONTENT_LENGTH):
if body is None:
self._body = None
self._body_payload = False
elif isinstance(body, (bytes, bytearray)):
self._body = body
self._body_payload = False
else:
try:
self._body = body = payload.PAYLOAD_REGISTRY.get(body)
except payload.LookupError:
raise ValueError('Unsupported body type %r' % type(body))

self._body_payload = True

headers = self._headers

# enable chunked encoding if needed
if not self._chunked and CONTENT_LENGTH not in headers:
size = body.size
if size is None:
self._chunked = True
elif CONTENT_LENGTH not in headers:
headers[CONTENT_LENGTH] = str(size)

# set content-type
if CONTENT_TYPE not in headers:
headers[CONTENT_TYPE] = body.content_type

# copy payload headers
if body.headers:
for (key, value) in body.headers.items():
if key not in headers:
headers[key] = value

@property
def text(self):
Expand Down Expand Up @@ -531,18 +564,20 @@ def content_length(self, value):
@asyncio.coroutine
def write_eof(self):
body = self._body
if (body is not None and
(self._req._method == hdrs.METH_HEAD or
self._status in [204, 304])):
body = b''

if body is None:
body = b''

yield from super().write_eof(body)
if body is not None:
if (self._req._method == hdrs.METH_HEAD or
self._status in [204, 304]):
yield from super().write_eof()
elif self._body_payload:
yield from body.write(self._payload_writer)
yield from super().write_eof()
else:
yield from super().write_eof(body)
else:
yield from super().write_eof()

def _start(self, request):
if not self._chunked:
if not self._chunked and hdrs.CONTENT_LENGTH not in self._headers:
if self._body is not None:
self._headers[hdrs.CONTENT_LENGTH] = str(len(self._body))
else:
Expand Down
170 changes: 170 additions & 0 deletions tests/test_web_functional.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import io
import json
import pathlib
import zlib
Expand All @@ -8,6 +9,7 @@
from multidict import MultiDict
from yarl import URL

import aiohttp
from aiohttp import FormData, HttpVersion10, HttpVersion11, multipart, web

try:
Expand All @@ -16,6 +18,16 @@
ssl = False


@pytest.fixture
def here():
return pathlib.Path(__file__).parent


@pytest.fixture
def fname(here):
return here / 'sample.key'


@asyncio.coroutine
def test_simple_get(loop, test_client):

Expand Down Expand Up @@ -739,6 +751,164 @@ def handler(request):
assert 200 == resp.status


@asyncio.coroutine
def test_response_with_streamer(loop, test_client, fname):

with fname.open('rb') as f:
data = f.read()

data_size = len(data)

@aiohttp.streamer
def stream(writer, f_name):
with f_name.open('rb') as f:
data = f.read(100)
while data:
yield from writer.write(data)
data = f.read(100)

@asyncio.coroutine
def handler(request):
headers = {'Content-Length': str(data_size)}
return web.Response(body=stream(fname), headers=headers)

app = web.Application(loop=loop)
app.router.add_get('/', handler)
client = yield from test_client(app)

resp = yield from client.get('/')
assert 200 == resp.status
resp_data = yield from resp.read()
assert resp_data == data
assert resp.headers.get('Content-Length') == str(len(resp_data))


@asyncio.coroutine
def test_response_with_streamer_no_params(loop, test_client, fname):

with fname.open('rb') as f:
data = f.read()

data_size = len(data)

@aiohttp.streamer
def stream(writer):
with fname.open('rb') as f:
data = f.read(100)
while data:
yield from writer.write(data)
data = f.read(100)

@asyncio.coroutine
def handler(request):
headers = {'Content-Length': str(data_size)}
return web.Response(body=stream, headers=headers)

app = web.Application(loop=loop)
app.router.add_get('/', handler)
client = yield from test_client(app)

resp = yield from client.get('/')
assert 200 == resp.status
resp_data = yield from resp.read()
assert resp_data == data
assert resp.headers.get('Content-Length') == str(len(resp_data))


@asyncio.coroutine
def test_response_with_file(loop, test_client, fname):

with fname.open('rb') as f:
data = f.read()

@asyncio.coroutine
def handler(request):
return web.Response(body=fname.open('rb'))

app = web.Application(loop=loop)
app.router.add_get('/', handler)
client = yield from test_client(app)

resp = yield from client.get('/')
assert 200 == resp.status
resp_data = yield from resp.read()
assert resp_data == data
assert resp.headers.get('Content-Type') in (
'application/octet-stream', 'application/pgp-keys')
assert resp.headers.get('Content-Length') == str(len(resp_data))
assert (resp.headers.get('Content-Disposition') ==
'attachment; filename="sample.key"; filename*=utf-8\'\'sample.key')


@asyncio.coroutine
def test_response_with_file_ctype(loop, test_client, fname):

with fname.open('rb') as f:
data = f.read()

@asyncio.coroutine
def handler(request):
return web.Response(
body=fname.open('rb'), headers={'content-type': 'text/binary'})

app = web.Application(loop=loop)
app.router.add_get('/', handler)
client = yield from test_client(app)

resp = yield from client.get('/')
assert 200 == resp.status
resp_data = yield from resp.read()
assert resp_data == data
assert resp.headers.get('Content-Type') == 'text/binary'
assert resp.headers.get('Content-Length') == str(len(resp_data))
assert (resp.headers.get('Content-Disposition') ==
'attachment; filename="sample.key"; filename*=utf-8\'\'sample.key')


@asyncio.coroutine
def test_response_with_payload_disp(loop, test_client, fname):

with fname.open('rb') as f:
data = f.read()

@asyncio.coroutine
def handler(request):
pl = aiohttp.get_payload(fname.open('rb'))
pl.set_content_disposition('inline', filename='test.txt')
return web.Response(
body=pl, headers={'content-type': 'text/binary'})

app = web.Application(loop=loop)
app.router.add_get('/', handler)
client = yield from test_client(app)

resp = yield from client.get('/')
assert 200 == resp.status
resp_data = yield from resp.read()
assert resp_data == data
assert resp.headers.get('Content-Type') == 'text/binary'
assert resp.headers.get('Content-Length') == str(len(resp_data))
assert (resp.headers.get('Content-Disposition') ==
'inline; filename="test.txt"; filename*=utf-8\'\'test.txt')


@asyncio.coroutine
def test_response_with_payload_stringio(loop, test_client, fname):

@asyncio.coroutine
def handler(request):
return web.Response(body=io.StringIO('test'))

app = web.Application(loop=loop)
app.router.add_get('/', handler)
client = yield from test_client(app)

resp = yield from client.get('/')
assert 200 == resp.status
resp_data = yield from resp.read()
assert resp_data == b'test'


@asyncio.coroutine
def test_response_with_precompressed_body_gzip(loop, test_client):

Expand Down
2 changes: 1 addition & 1 deletion tests/test_web_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -802,7 +802,7 @@ def test_ctor_both_charset_param_and_header():
def test_assign_nonbyteish_body():
resp = Response(body=b'data')

with pytest.raises(AssertionError):
with pytest.raises(ValueError):
resp.body = 123
assert b'data' == resp.body
assert 4 == resp.content_length
Expand Down

0 comments on commit 7e9a381

Please sign in to comment.