Skip to content

Commit

Permalink
Support Response(content=<bytes iterator>) (#1265)
Browse files Browse the repository at this point in the history
* Support Response(content=<bytes iterator>)

* Update test for merged master
  • Loading branch information
tomchristie authored Sep 11, 2020
1 parent 4bd08be commit 5ee6135
Show file tree
Hide file tree
Showing 6 changed files with 112 additions and 83 deletions.
35 changes: 18 additions & 17 deletions httpx/_content_streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import httpcore

from ._exceptions import StreamConsumed
from ._types import FileContent, FileTypes, RequestData, RequestFiles
from ._types import FileContent, FileTypes, RequestData, RequestFiles, ResponseContent
from ._utils import (
format_form_param,
guess_content_type,
Expand Down Expand Up @@ -72,11 +72,8 @@ class IteratorStream(ContentStream):
Request content encoded as plain bytes, using an byte iterator.
"""

def __init__(
self, iterator: typing.Iterator[bytes], close_func: typing.Callable = None
) -> None:
def __init__(self, iterator: typing.Iterator[bytes]) -> None:
self.iterator = iterator
self.close_func = close_func
self.is_stream_consumed = False

def can_replay(self) -> bool:
Expand All @@ -95,21 +92,14 @@ def __iter__(self) -> typing.Iterator[bytes]:
def __aiter__(self) -> typing.AsyncIterator[bytes]:
raise RuntimeError("Attempted to call a async iterator on an sync stream.")

def close(self) -> None:
if self.close_func is not None:
self.close_func()


class AsyncIteratorStream(ContentStream):
"""
Request content encoded as plain bytes, using an async byte iterator.
"""

def __init__(
self, aiterator: typing.AsyncIterator[bytes], close_func: typing.Callable = None
) -> None:
def __init__(self, aiterator: typing.AsyncIterator[bytes]) -> None:
self.aiterator = aiterator
self.close_func = close_func
self.is_stream_consumed = False

def can_replay(self) -> bool:
Expand All @@ -128,10 +118,6 @@ async def __aiter__(self) -> typing.AsyncIterator[bytes]:
async for part in self.aiterator:
yield part

async def aclose(self) -> None:
if self.close_func is not None:
await self.close_func()


class JSONStream(ContentStream):
"""
Expand Down Expand Up @@ -402,3 +388,18 @@ def encode(
return IteratorStream(iterator=data)

raise TypeError(f"Unexpected type for 'data', {type(data)!r}")


def encode_response(content: ResponseContent = None) -> ContentStream:
if content is None:
return ByteStream(b"")
elif isinstance(content, bytes):
return ByteStream(body=content)
elif hasattr(content, "__aiter__"):
content = typing.cast(typing.AsyncIterator[bytes], content)
return AsyncIteratorStream(aiterator=content)
elif hasattr(content, "__iter__"):
content = typing.cast(typing.Iterator[bytes], content)
return IteratorStream(iterator=content)

raise TypeError(f"Unexpected type for 'content', {type(content)!r}")
11 changes: 7 additions & 4 deletions httpx/_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import rfc3986
import rfc3986.exceptions

from ._content_streams import ByteStream, ContentStream, encode
from ._content_streams import ByteStream, ContentStream, encode, encode_response
from ._decoders import (
SUPPORTED_DECODERS,
ContentDecoder,
Expand Down Expand Up @@ -44,6 +44,7 @@
QueryParamTypes,
RequestData,
RequestFiles,
ResponseContent,
URLTypes,
)
from ._utils import (
Expand Down Expand Up @@ -674,7 +675,7 @@ def __init__(
http_version: str = None,
headers: HeaderTypes = None,
stream: ContentStream = None,
content: bytes = None,
content: ResponseContent = None,
history: typing.List["Response"] = None,
elapsed_func: typing.Callable = None,
):
Expand All @@ -694,8 +695,10 @@ def __init__(
if stream is not None:
self._raw_stream = stream
else:
self._raw_stream = ByteStream(body=content or b"")
self.read()
self._raw_stream = encode_response(content)
if content is None or isinstance(content, bytes):
# Load the response body, except for streaming content.
self.read()

self._num_bytes_downloaded = 0

Expand Down
2 changes: 2 additions & 0 deletions httpx/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@
None,
]

ResponseContent = Union[bytes, Iterator[bytes], AsyncIterator[bytes]]

RequestData = Union[dict, str, bytes, Iterator[bytes], AsyncIterator[bytes]]

FileContent = Union[IO[str], IO[bytes], str, bytes]
Expand Down
63 changes: 11 additions & 52 deletions tests/models/test_responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import pytest

import httpx
from httpx._content_streams import AsyncIteratorStream, IteratorStream


def streaming_body():
Expand Down Expand Up @@ -215,10 +214,9 @@ async def test_aread():


def test_iter_raw():
stream = IteratorStream(iterator=streaming_body())
response = httpx.Response(
200,
stream=stream,
content=streaming_body(),
)

raw = b""
Expand All @@ -228,12 +226,7 @@ def test_iter_raw():


def test_iter_raw_increments_updates_counter():
stream = IteratorStream(iterator=streaming_body())

response = httpx.Response(
200,
stream=stream,
)
response = httpx.Response(200, content=streaming_body())

num_downloaded = response.num_bytes_downloaded
for part in response.iter_raw():
Expand All @@ -243,11 +236,7 @@ def test_iter_raw_increments_updates_counter():

@pytest.mark.asyncio
async def test_aiter_raw():
stream = AsyncIteratorStream(aiterator=async_streaming_body())
response = httpx.Response(
200,
stream=stream,
)
response = httpx.Response(200, content=async_streaming_body())

raw = b""
async for part in response.aiter_raw():
Expand All @@ -257,12 +246,7 @@ async def test_aiter_raw():

@pytest.mark.asyncio
async def test_aiter_raw_increments_updates_counter():
stream = AsyncIteratorStream(aiterator=async_streaming_body())

response = httpx.Response(
200,
stream=stream,
)
response = httpx.Response(200, content=async_streaming_body())

num_downloaded = response.num_bytes_downloaded
async for part in response.aiter_raw():
Expand Down Expand Up @@ -346,10 +330,9 @@ async def test_aiter_lines():


def test_sync_streaming_response():
stream = IteratorStream(iterator=streaming_body())
response = httpx.Response(
200,
stream=stream,
content=streaming_body(),
)

assert response.status_code == 200
Expand All @@ -364,10 +347,9 @@ def test_sync_streaming_response():

@pytest.mark.asyncio
async def test_async_streaming_response():
stream = AsyncIteratorStream(aiterator=async_streaming_body())
response = httpx.Response(
200,
stream=stream,
content=async_streaming_body(),
)

assert response.status_code == 200
Expand All @@ -381,10 +363,9 @@ async def test_async_streaming_response():


def test_cannot_read_after_stream_consumed():
stream = IteratorStream(iterator=streaming_body())
response = httpx.Response(
200,
stream=stream,
content=streaming_body(),
)

content = b""
Expand All @@ -397,10 +378,9 @@ def test_cannot_read_after_stream_consumed():

@pytest.mark.asyncio
async def test_cannot_aread_after_stream_consumed():
stream = AsyncIteratorStream(aiterator=async_streaming_body())
response = httpx.Response(
200,
stream=stream,
content=async_streaming_body(),
)

content = b""
Expand All @@ -412,54 +392,33 @@ async def test_cannot_aread_after_stream_consumed():


def test_cannot_read_after_response_closed():
is_closed = False

def close_func():
nonlocal is_closed
is_closed = True

stream = IteratorStream(iterator=streaming_body(), close_func=close_func)
response = httpx.Response(
200,
stream=stream,
content=streaming_body(),
)

response.close()
assert is_closed

with pytest.raises(httpx.ResponseClosed):
response.read()


@pytest.mark.asyncio
async def test_cannot_aread_after_response_closed():
is_closed = False

async def close_func():
nonlocal is_closed
is_closed = True

stream = AsyncIteratorStream(
aiterator=async_streaming_body(), close_func=close_func
)
response = httpx.Response(
200,
stream=stream,
content=async_streaming_body(),
)

await response.aclose()
assert is_closed

with pytest.raises(httpx.ResponseClosed):
await response.aread()


@pytest.mark.asyncio
async def test_elapsed_not_available_until_closed():
stream = AsyncIteratorStream(aiterator=async_streaming_body())
response = httpx.Response(
200,
stream=stream,
content=async_streaming_body(),
)

with pytest.raises(RuntimeError):
Expand Down
71 changes: 70 additions & 1 deletion tests/test_content_streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytest

from httpx import StreamConsumed
from httpx._content_streams import ContentStream, encode
from httpx._content_streams import ContentStream, encode, encode_response


@pytest.mark.asyncio
Expand Down Expand Up @@ -251,3 +251,72 @@ async def test_multipart_multiple_files_single_input_content():
b"--+++--\r\n",
]
)


@pytest.mark.asyncio
async def test_response_empty_content():
stream = encode_response()
sync_content = b"".join([part for part in stream])
async_content = b"".join([part async for part in stream])

assert stream.can_replay()
assert stream.get_headers() == {}
assert sync_content == b""
assert async_content == b""


@pytest.mark.asyncio
async def test_response_bytes_content():
stream = encode_response(content=b"Hello, world!")
sync_content = b"".join([part for part in stream])
async_content = b"".join([part async for part in stream])

assert stream.can_replay()
assert stream.get_headers() == {"Content-Length": "13"}
assert sync_content == b"Hello, world!"
assert async_content == b"Hello, world!"


@pytest.mark.asyncio
async def test_response_iterator_content():
def hello_world():
yield b"Hello, "
yield b"world!"

stream = encode_response(content=hello_world())
content = b"".join([part for part in stream])

assert not stream.can_replay()
assert stream.get_headers() == {"Transfer-Encoding": "chunked"}
assert content == b"Hello, world!"

with pytest.raises(RuntimeError):
[part async for part in stream]

with pytest.raises(StreamConsumed):
[part for part in stream]


@pytest.mark.asyncio
async def test_response_aiterator_content():
async def hello_world():
yield b"Hello, "
yield b"world!"

stream = encode_response(content=hello_world())
content = b"".join([part async for part in stream])

assert not stream.can_replay()
assert stream.get_headers() == {"Transfer-Encoding": "chunked"}
assert content == b"Hello, world!"

with pytest.raises(RuntimeError):
[part for part in stream]

with pytest.raises(StreamConsumed):
[part async for part in stream]


def test_response_invalid_argument():
with pytest.raises(TypeError):
encode_response(123) # type: ignore
Loading

0 comments on commit 5ee6135

Please sign in to comment.