diff --git a/tests/unit/test_stubs.py b/tests/unit/test_stubs.py index e9e5aadd..1eddeff0 100644 --- a/tests/unit/test_stubs.py +++ b/tests/unit/test_stubs.py @@ -1,9 +1,12 @@ import contextlib +import http.client as httplib +from io import BytesIO +from tempfile import NamedTemporaryFile from unittest import mock from pytest import mark -from vcr import mode +from vcr import mode, use_cassette from vcr.cassette import Cassette from vcr.stubs import VCRHTTPSConnection @@ -21,3 +24,52 @@ def testing_connect(*args): vcr_connection.cassette = Cassette("test", record_mode=mode.ALL) vcr_connection.real_connection.connect() assert vcr_connection.real_connection.sock is not None + + def test_body_consumed_once_stream(self, tmpdir, httpbin): + self._test_body_consumed_once( + tmpdir, + httpbin, + BytesIO(b"1234567890"), + BytesIO(b"9876543210"), + BytesIO(b"9876543210"), + ) + + def test_body_consumed_once_iterator(self, tmpdir, httpbin): + self._test_body_consumed_once( + tmpdir, + httpbin, + iter([b"1234567890"]), + iter([b"9876543210"]), + iter([b"9876543210"]), + ) + + # data2 and data3 should serve the same data, potentially as iterators + def _test_body_consumed_once( + self, + tmpdir, + httpbin, + data1, + data2, + data3, + ): + with NamedTemporaryFile(dir=tmpdir, suffix=".yml") as f: + testpath = f.name + # NOTE: ``use_cassette`` is not okay with the file existing + # already. So we using ``.close()`` to not only + # close but also delete the empty file, before we start. + f.close() + host, port = httpbin.host, httpbin.port + match_on = ["method", "uri", "body"] + with use_cassette(testpath, match_on=match_on): + conn1 = httplib.HTTPConnection(host, port) + conn1.request("POST", "/anything", body=data1) + conn1.getresponse() + conn2 = httplib.HTTPConnection(host, port) + conn2.request("POST", "/anything", body=data2) + conn2.getresponse() + with use_cassette(testpath, match_on=match_on) as cass: + conn3 = httplib.HTTPConnection(host, port) + conn3.request("POST", "/anything", body=data3) + conn3.getresponse() + assert cass.play_counts[0] == 0 + assert cass.play_counts[1] == 1 diff --git a/tests/unit/test_util.py b/tests/unit/test_util.py new file mode 100644 index 00000000..3c8b0b52 --- /dev/null +++ b/tests/unit/test_util.py @@ -0,0 +1,33 @@ +from io import BytesIO, StringIO + +import pytest + +from vcr import request +from vcr.util import read_body + + +@pytest.mark.parametrize( + "input_, expected_output", + [ + (BytesIO(b"Stream"), b"Stream"), + (StringIO("Stream"), b"Stream"), + (iter(["StringIter"]), b"StringIter"), + (iter(["String", "Iter"]), b"StringIter"), + (iter([b"BytesIter"]), b"BytesIter"), + (iter([b"Bytes", b"Iter"]), b"BytesIter"), + (iter([70, 111, 111]), b"Foo"), + (iter([]), b""), + ("String", b"String"), + (b"Bytes", b"Bytes"), + ], +) +def test_read_body(input_, expected_output): + r = request.Request("POST", "http://host.com/", input_, {}) + assert read_body(r) == expected_output + + +def test_unsupported_read_body(): + r = request.Request("POST", "http://host.com/", iter([[]]), {}) + with pytest.raises(ValueError) as excinfo: + assert read_body(r) + assert excinfo.value.args == ("Body type not supported",) diff --git a/vcr/request.py b/vcr/request.py index 130f19c7..d633d5c2 100644 --- a/vcr/request.py +++ b/vcr/request.py @@ -3,7 +3,7 @@ from io import BytesIO from urllib.parse import parse_qsl, urlparse -from .util import CaseInsensitiveDict +from .util import CaseInsensitiveDict, _is_nonsequence_iterator log = logging.getLogger(__name__) @@ -17,8 +17,11 @@ def __init__(self, method, uri, body, headers): self.method = method self.uri = uri self._was_file = hasattr(body, "read") + self._was_iter = _is_nonsequence_iterator(body) if self._was_file: self.body = body.read() + elif self._was_iter: + self.body = list(body) else: self.body = body self.headers = headers @@ -36,7 +39,11 @@ def headers(self, value): @property def body(self): - return BytesIO(self._body) if self._was_file else self._body + if self._was_file: + return BytesIO(self._body) + if self._was_iter: + return iter(self._body) + return self._body @body.setter def body(self, value): diff --git a/vcr/util.py b/vcr/util.py index 09c6d1dd..8b40d863 100644 --- a/vcr/util.py +++ b/vcr/util.py @@ -89,9 +89,28 @@ def composed(incoming): return composed +def _is_nonsequence_iterator(obj): + return hasattr(obj, "__iter__") and not isinstance( + obj, + (bytearray, bytes, dict, list, str), + ) + + def read_body(request): if hasattr(request.body, "read"): return request.body.read() + if _is_nonsequence_iterator(request.body): + body = list(request.body) + if body: + if isinstance(body[0], str): + return "".join(body).encode("utf-8") + elif isinstance(body[0], (bytes, bytearray)): + return b"".join(body) + elif isinstance(body[0], int): + return bytes(body) + else: + raise ValueError(f"Body type {type(body[0])} not supported") + return b"" return request.body