Skip to content

Commit

Permalink
Merge pull request #159 from justmobilize/remove-response-close-read-all
Browse files Browse the repository at this point in the history
Don't read all on response.close()
  • Loading branch information
dhalbert authored Mar 12, 2024
2 parents 815b326 + 0a9bb61 commit 9544d1f
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 23 deletions.
32 changes: 10 additions & 22 deletions adafruit_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,14 +86,11 @@ class Response:

encoding = None

def __init__(
self, sock: SocketType, session: "Session", fast_close: bool = False
) -> None:
def __init__(self, sock: SocketType, session: "Session") -> None:
self.socket = sock
self.encoding = "utf-8"
self._cached = None
self._headers = {}
self._fast_close = fast_close

# _start_index and _receive_buffer are used when parsing headers.
# _receive_buffer will grow by 32 bytes everytime it is too small.
Expand Down Expand Up @@ -230,27 +227,16 @@ def _throw_away(self, nbytes: int) -> None:
to_read -= self._recv_into(buf, to_read)

def close(self) -> None:
"""Drain the remaining ESP socket buffers. We assume we already got what we wanted."""
"""Close out the socket. If we have a session free it instead."""
if not self.socket:
return
# Make sure we've read all of our response.
if self._cached is None and not self._fast_close:
if self._remaining and self._remaining > 0:
self._throw_away(self._remaining)
elif self._chunked:
while True:
chunk_header = bytes(self._readto(b"\r\n")).split(b";", 1)[0]
if not chunk_header:
break
chunk_size = int(bytes(chunk_header), 16)
if chunk_size == 0:
break
self._throw_away(chunk_size + 2)

if self._session:
# pylint: disable=protected-access
self._session._connection_manager.free_socket(self.socket)
else:
self.socket.close()

self.socket = None

def _parse_headers(self) -> None:
Expand Down Expand Up @@ -365,13 +351,11 @@ def __init__(
socket_pool: SocketpoolModuleType,
ssl_context: Optional[SSLContextType] = None,
session_id: Optional[str] = None,
fast_close: Optional[bool] = False,
) -> None:
self._connection_manager = get_connection_manager(socket_pool)
self._ssl_context = ssl_context
self._session_id = session_id
self._last_response = None
self._fast_close = fast_close

@staticmethod
def _check_headers(headers: Dict[str, str]):
Expand All @@ -389,7 +373,6 @@ def _check_headers(headers: Dict[str, str]):
def _send(socket: SocketType, data: bytes):
total_sent = 0
while total_sent < len(data):
# ESP32SPI sockets raise a RuntimeError when unable to send.
try:
sent = socket.send(data[total_sent:])
except OSError as exc:
Expand All @@ -399,6 +382,7 @@ def _send(socket: SocketType, data: bytes):
# Some worse error.
raise
except RuntimeError as exc:
# ESP32SPI sockets raise a RuntimeError when unable to send.
raise OSError(errno.EIO) from exc
if sent is None:
sent = len(data)
Expand Down Expand Up @@ -566,7 +550,7 @@ def request(
if not socket:
raise OutOfRetries("Repeated socket failures") from last_exc

resp = Response(socket, self, fast_close=self._fast_close) # our response
resp = Response(socket, self) # our response
if allow_redirects:
if "location" in resp.headers and 300 <= resp.status_code <= 399:
# a naive handler for redirects
Expand Down Expand Up @@ -594,6 +578,10 @@ def request(
self._last_response = resp
return resp

def options(self, url: str, **kw) -> Response:
"""Send HTTP OPTIONS request"""
return self.request("OPTIONS", url, **kw)

def head(self, url: str, **kw) -> Response:
"""Send HTTP HEAD request"""
return self.request("HEAD", url, **kw)
Expand Down
1 change: 1 addition & 0 deletions tests/method_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"DELETE",
"GET",
"HEAD",
"OPTIONS",
"PATCH",
"POST",
"PUT",
Expand Down
9 changes: 8 additions & 1 deletion tests/reuse_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,14 @@ def test_get_twice(pool, requests_ssl):


def test_get_twice_after_second(pool, requests_ssl):
sock = mocket.Mocket(mocket.MOCK_RESPONSE + mocket.MOCK_RESPONSE)
sock = mocket.Mocket(
b"H"
b"TTP/1.0 200 OK\r\nContent-Length: "
b"70\r\n\r\nHTTP/1.0 2"
b"H"
b"TTP/1.0 200 OK\r\nContent-Length: "
b"70\r\n\r\nHTTP/1.0 2"
)
pool.socket.return_value = sock

response = requests_ssl.get("https://" + mocket.MOCK_ENDPOINT_1)
Expand Down

0 comments on commit 9544d1f

Please sign in to comment.