Skip to content

Commit

Permalink
support coroutine as data provider
Browse files Browse the repository at this point in the history
  • Loading branch information
fafhrd91 committed Feb 21, 2017
1 parent d69838a commit 22462a4
Show file tree
Hide file tree
Showing 4 changed files with 147 additions and 3 deletions.
2 changes: 2 additions & 0 deletions aiohttp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .file_sender import FileSender # noqa
from .cookiejar import CookieJar # noqa
from .payload import * # noqa
from .payload_streamer import * # noqa
from .resolver import * # noqa

# deprecated #1657
Expand All @@ -27,6 +28,7 @@
helpers.__all__ + # noqa
streams.__all__ + # noqa
payload.__all__ + # noqa
payload_streamer.__all__ + # noqa
multipart.__all__ + # noqa
('hdrs', 'FileSender',
'HttpVersion', 'HttpVersion10', 'HttpVersion11',
Expand Down
5 changes: 5 additions & 0 deletions aiohttp/client_reqrep.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,11 @@ def update_body_from_data(self, data, skip_auto_headers):
return

if asyncio.iscoroutine(data):
warnings.warn(
'coroutine as data object is deprecated, '
'use aiohttp.streamer #1664',
DeprecationWarning, stacklevel=2)

self.body = data
if (hdrs.CONTENT_LENGTH not in self.headers and
self.chunked is None):
Expand Down
70 changes: 70 additions & 0 deletions aiohttp/payload_streamer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
""" Payload implemenation for coroutines as data provider.
As a simple case, you can upload data from file::
@aiohttp.streamer
def file_sender(writer, file_name=None):
with open(file_name, 'rb') as f:
chunk = f.read(2**16)
while chunk:
yield from writer.write(chunk)
chunk = f.read(2**16)
Then you can use `file_sender` like this:
async with session.post('http://httpbin.org/post',
data=file_sender(file_name='hude_file')) as resp:
print(await resp.text())
..note:: Coroutine must accept `writer` as first argument
"""

import asyncio

from . import payload

__all__ = ('streamer',)


class _stream_wrapper:

def __init__(self, coro, args, kwargs):
self.coro = coro
self.args = args
self.kwargs = kwargs

@asyncio.coroutine
def __call__(self, writer):
yield from self.coro(writer, *self.args, **self.kwargs)


class streamer:

def __init__(self, coro):
self.coro = coro

def __call__(self, *args, **kwargs):
return _stream_wrapper(self.coro, args, kwargs)


class StreamWrapperPayload(payload.Payload):

@asyncio.coroutine
def write(self, writer):
yield from self._value(writer)


class StreamPayload(StreamWrapperPayload):

def __init__(self, value, *args, **kwargs):
super().__init__(value(), *args, **kwargs)

@asyncio.coroutine
def write(self, writer):
yield from self._value(writer)


payload.PAYLOAD_REGISTRY.register(StreamPayload, streamer)
payload.PAYLOAD_REGISTRY.register(StreamWrapperPayload, _stream_wrapper)
73 changes: 70 additions & 3 deletions tests/test_client_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1381,7 +1381,6 @@ def handler(request):
resp.close()


@pytest.mark.xfail
@asyncio.coroutine
def test_POST_STREAM_DATA(loop, test_client, fname):
@asyncio.coroutine
Expand All @@ -1390,7 +1389,75 @@ def handler(request):
content = yield from request.read()
with fname.open('rb') as f:
expected = f.read()
assert request.content_length == str(len(expected))
assert request.content_length == len(expected)
assert content == expected

return web.HTTPOk()

app = web.Application(loop=loop)
app.router.add_post('/', handler)
client = yield from test_client(app)

with fname.open('rb') as f:
data_size = len(f.read())

@aiohttp.streamer
def stream(writer, fname):
with fname.open('rb') as f:
data = f.read(100)
while data:
yield from writer.write(data)
data = f.read(100)

resp = yield from client.post(
'/', data=stream(fname), headers={'Content-Length': str(data_size)})
assert 200 == resp.status
resp.close()


@asyncio.coroutine
def test_POST_STREAM_DATA_no_params(loop, test_client, fname):
@asyncio.coroutine
def handler(request):
assert request.content_type == 'application/octet-stream'
content = yield from request.read()
with fname.open('rb') as f:
expected = f.read()
assert request.content_length == len(expected)
assert content == expected

return web.HTTPOk()

app = web.Application(loop=loop)
app.router.add_post('/', handler)
client = yield from test_client(app)

with fname.open('rb') as f:
data_size = len(f.read())

@aiohttp.streamer
def stream(writer):
with fname.open('rb') as f:
data = f.read(100)
while data:
yield from writer.write(data)
data = f.read(100)

resp = yield from client.post(
'/', data=stream, headers={'Content-Length': str(data_size)})
assert 200 == resp.status
resp.close()


@asyncio.coroutine
def test_POST_STREAM_DATA_coroutine_deprecated(loop, test_client, fname):
@asyncio.coroutine
def handler(request):
assert request.content_type == 'application/octet-stream'
content = yield from request.read()
with fname.open('rb') as f:
expected = f.read()
assert request.content_length == len(expected)
assert content == expected

return web.HTTPOk()
Expand All @@ -1399,7 +1466,7 @@ def handler(request):
app.router.add_post('/', handler)
client = yield from test_client(app)

with fname.open() as f:
with fname.open('rb') as f:
data = f.read()
fut = create_future(loop)

Expand Down

0 comments on commit 22462a4

Please sign in to comment.