diff --git a/tests/test_client_errors.py b/tests/test_client_errors.py index 2706d22..9c6b83a 100644 --- a/tests/test_client_errors.py +++ b/tests/test_client_errors.py @@ -1,4 +1,9 @@ import itertools +import mmap +import os +import random +import tempfile +from http.client import CannotSendRequest import pytest @@ -25,23 +30,32 @@ class ATestException(Exception): class MockConnection: - def __init__(self, fail_request_after=None, fail_getresponse_after=None): + def __init__(self, fail_request_after=None, fail_getresponse_after=None, exception_class=ATestException): self.times_closed = 0 self.times_request = 0 self.times_got_response = 0 self.fail_request_after = fail_request_after self.fail_getresponse_after = fail_getresponse_after + self.exception_class = exception_class def request(self, method, path, body, headers): self.times_request += 1 + data = b"" + if hasattr(body, "read"): + data = body.read() if self.fail_request_after is not None and self.times_request > self.fail_request_after: - raise ATestException(f"Failing request number {self.times_request}") + raise self.exception_class(f"Failing request number {self.times_request}") + if method == "PUT": + if not os.path.exists(path): + return + with open(path, "wb") as file: + file.write(data) def getresponse(self): self.times_got_response += 1 if self.fail_getresponse_after is not None and self.times_got_response > self.fail_getresponse_after: - raise ATestException(f"Failing response number {self.times_got_response}") + raise self.exception_class(f"Failing response number {self.times_got_response}") return MockResponse() def raise_for_status(self, expected_statuses=None): @@ -52,9 +66,10 @@ def close(self): class MockTransport(v3io.dataplane.transport.httpclient.Transport): - def __init__(self, *args, connection_options=None, **kwargs): + def __init__(self, *args, connection_options=None, reset_after_create_connections=False, **kwargs): self.mock_connections = [] self.connection_options = connection_options or {} + self.reset_after_create_connections = reset_after_create_connections super().__init__(*args, **kwargs) def _create_connection(self, host, ssl_context): @@ -62,6 +77,34 @@ def _create_connection(self, host, ssl_context): self.mock_connections.append(conn) return conn + def _create_connections(self, num_connections, host, ssl_context): + super()._create_connections(num_connections=num_connections, host=host, ssl_context=ssl_context) + if self.reset_after_create_connections: + self.connection_options = {} + + +def test_first_connection_failure(): + connection_options = {"fail_request_after": 0, "exception_class": CannotSendRequest} + client = v3io.dataplane.Client() + + mock_transport = MockTransport( + client._logger, connection_options=connection_options, reset_after_create_connections=True + ) + client._transport = mock_transport + size = 1024 + data = random.Random(0).randbytes(size) + with mmap.mmap(-1, size) as mmap_obj, tempfile.NamedTemporaryFile(mode="w+b", delete=False) as temp_file: + mmap_obj.write(data) + mmap_obj.seek(0) + components = temp_file.name.split("/") + container = components[1] # Index 0 will be an empty string due to leading '/' + path = "/" + "/".join(components[2:]) + client.put_object( + container=container, path=path, body=mmap_obj, raise_for_status=v3io.dataplane.RaiseForStatus.never + ) + temp_file.seek(0) + assert temp_file.read() == data, "Binary data read back differs from original data" + def test_connection_creation_and_close(): client = v3io.dataplane.Client() diff --git a/v3io/dataplane/transport/httpclient.py b/v3io/dataplane/transport/httpclient.py index 4353de7..63924d4 100644 --- a/v3io/dataplane/transport/httpclient.py +++ b/v3io/dataplane/transport/httpclient.py @@ -154,7 +154,10 @@ def _send_request_on_connection(self, request, connection): self.log( "Tx", connection=connection, method=request.method, path=path, headers=request.headers, body=request.body ) - + starting_offset = 0 + is_body_seekable = request.body and hasattr(request.body, "seek") and hasattr(request.body, "tell") + if is_body_seekable: + starting_offset = request.body.tell() try: try: connection.request(request.method, path, request.body, request.headers) @@ -166,6 +169,11 @@ def _send_request_on_connection(self, request, connection): connection=connection, ) connection.close() + if is_body_seekable: + # If the first connection fails, the pointer of the body might move at the size + # of the first connection blocksize. + # We need to reset the position of the pointer in order to send the whole file. + request.body.seek(starting_offset) connection = self._create_connection(self._host, self._ssl_context) request.transport.connection_used = connection connection.request(request.method, path, request.body, request.headers)