Skip to content

Commit

Permalink
feat(client): add retries_taken to raw response class (#501)
Browse files Browse the repository at this point in the history
  • Loading branch information
stainless-app[bot] authored and stainless-bot committed Aug 5, 2024
1 parent 3833a6f commit d34ae60
Show file tree
Hide file tree
Showing 4 changed files with 122 additions and 1 deletion.
10 changes: 10 additions & 0 deletions src/modern_treasury/_base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1050,6 +1050,7 @@ def _request(
response=response,
stream=stream,
stream_cls=stream_cls,
retries_taken=options.get_max_retries(self.max_retries) - retries,
)

def _retry_request(
Expand Down Expand Up @@ -1091,6 +1092,7 @@ def _process_response(
response: httpx.Response,
stream: bool,
stream_cls: type[Stream[Any]] | type[AsyncStream[Any]] | None,
retries_taken: int = 0,
) -> ResponseT:
if response.request.headers.get(RAW_RESPONSE_HEADER) == "true":
return cast(
Expand All @@ -1102,6 +1104,7 @@ def _process_response(
stream=stream,
stream_cls=stream_cls,
options=options,
retries_taken=retries_taken,
),
)

Expand All @@ -1121,6 +1124,7 @@ def _process_response(
stream=stream,
stream_cls=stream_cls,
options=options,
retries_taken=retries_taken,
),
)

Expand All @@ -1134,6 +1138,7 @@ def _process_response(
stream=stream,
stream_cls=stream_cls,
options=options,
retries_taken=retries_taken,
)
if bool(response.request.headers.get(RAW_RESPONSE_HEADER)):
return cast(ResponseT, api_response)
Expand Down Expand Up @@ -1624,6 +1629,7 @@ async def _request(
response=response,
stream=stream,
stream_cls=stream_cls,
retries_taken=options.get_max_retries(self.max_retries) - retries,
)

async def _retry_request(
Expand Down Expand Up @@ -1663,6 +1669,7 @@ async def _process_response(
response: httpx.Response,
stream: bool,
stream_cls: type[Stream[Any]] | type[AsyncStream[Any]] | None,
retries_taken: int = 0,
) -> ResponseT:
if response.request.headers.get(RAW_RESPONSE_HEADER) == "true":
return cast(
Expand All @@ -1674,6 +1681,7 @@ async def _process_response(
stream=stream,
stream_cls=stream_cls,
options=options,
retries_taken=retries_taken,
),
)

Expand All @@ -1693,6 +1701,7 @@ async def _process_response(
stream=stream,
stream_cls=stream_cls,
options=options,
retries_taken=retries_taken,
),
)

Expand All @@ -1706,6 +1715,7 @@ async def _process_response(
stream=stream,
stream_cls=stream_cls,
options=options,
retries_taken=retries_taken,
)
if bool(response.request.headers.get(RAW_RESPONSE_HEADER)):
return cast(ResponseT, api_response)
Expand Down
18 changes: 17 additions & 1 deletion src/modern_treasury/_legacy_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,18 @@
import logging
import datetime
import functools
from typing import TYPE_CHECKING, Any, Union, Generic, TypeVar, Callable, Iterator, AsyncIterator, cast, overload
from typing import (
TYPE_CHECKING,
Any,
Union,
Generic,
TypeVar,
Callable,
Iterator,
AsyncIterator,
cast,
overload,
)
from typing_extensions import Awaitable, ParamSpec, override, deprecated, get_origin

import anyio
Expand Down Expand Up @@ -53,6 +64,9 @@ class LegacyAPIResponse(Generic[R]):

http_response: httpx.Response

retries_taken: int
"""The number of retries made. If no retries happened this will be `0`"""

def __init__(
self,
*,
Expand All @@ -62,6 +76,7 @@ def __init__(
stream: bool,
stream_cls: type[Stream[Any]] | type[AsyncStream[Any]] | None,
options: FinalRequestOptions,
retries_taken: int = 0,
) -> None:
self._cast_to = cast_to
self._client = client
Expand All @@ -70,6 +85,7 @@ def __init__(
self._stream_cls = stream_cls
self._options = options
self.http_response = raw
self.retries_taken = retries_taken

@overload
def parse(self, *, to: type[_T]) -> _T:
Expand Down
5 changes: 5 additions & 0 deletions src/modern_treasury/_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ class BaseAPIResponse(Generic[R]):

http_response: httpx.Response

retries_taken: int
"""The number of retries made. If no retries happened this will be `0`"""

def __init__(
self,
*,
Expand All @@ -64,6 +67,7 @@ def __init__(
stream: bool,
stream_cls: type[Stream[Any]] | type[AsyncStream[Any]] | None,
options: FinalRequestOptions,
retries_taken: int = 0,
) -> None:
self._cast_to = cast_to
self._client = client
Expand All @@ -72,6 +76,7 @@ def __init__(
self._stream_cls = stream_cls
self._options = options
self.http_response = raw
self.retries_taken = retries_taken

@property
def headers(self) -> httpx.Headers:
Expand Down
90 changes: 90 additions & 0 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -953,6 +953,49 @@ def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter) -> Non

assert _get_open_connections(self.client) == 0

@pytest.mark.parametrize("failures_before_success", [0, 2, 4])
@mock.patch("modern_treasury._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
@pytest.mark.respx(base_url=base_url)
def test_retries_taken(self, client: ModernTreasury, failures_before_success: int, respx_mock: MockRouter) -> None:
client = client.with_options(max_retries=4)

nb_retries = 0

def retry_handler(_request: httpx.Request) -> httpx.Response:
nonlocal nb_retries
if nb_retries < failures_before_success:
nb_retries += 1
return httpx.Response(500)
return httpx.Response(200)

respx_mock.post("/api/counterparties").mock(side_effect=retry_handler)

response = client.counterparties.with_raw_response.create(name="name")

assert response.retries_taken == failures_before_success

@pytest.mark.parametrize("failures_before_success", [0, 2, 4])
@mock.patch("modern_treasury._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
@pytest.mark.respx(base_url=base_url)
def test_retries_taken_new_response_class(
self, client: ModernTreasury, failures_before_success: int, respx_mock: MockRouter
) -> None:
client = client.with_options(max_retries=4)

nb_retries = 0

def retry_handler(_request: httpx.Request) -> httpx.Response:
nonlocal nb_retries
if nb_retries < failures_before_success:
nb_retries += 1
return httpx.Response(500)
return httpx.Response(200)

respx_mock.post("/api/counterparties").mock(side_effect=retry_handler)

with client.counterparties.with_streaming_response.create(name="name") as response:
assert response.retries_taken == failures_before_success


class TestAsyncModernTreasury:
client = AsyncModernTreasury(
Expand Down Expand Up @@ -1862,3 +1905,50 @@ async def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter)
)

assert _get_open_connections(self.client) == 0

@pytest.mark.parametrize("failures_before_success", [0, 2, 4])
@mock.patch("modern_treasury._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
@pytest.mark.respx(base_url=base_url)
@pytest.mark.asyncio
async def test_retries_taken(
self, async_client: AsyncModernTreasury, failures_before_success: int, respx_mock: MockRouter
) -> None:
client = async_client.with_options(max_retries=4)

nb_retries = 0

def retry_handler(_request: httpx.Request) -> httpx.Response:
nonlocal nb_retries
if nb_retries < failures_before_success:
nb_retries += 1
return httpx.Response(500)
return httpx.Response(200)

respx_mock.post("/api/counterparties").mock(side_effect=retry_handler)

response = await client.counterparties.with_raw_response.create(name="name")

assert response.retries_taken == failures_before_success

@pytest.mark.parametrize("failures_before_success", [0, 2, 4])
@mock.patch("modern_treasury._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
@pytest.mark.respx(base_url=base_url)
@pytest.mark.asyncio
async def test_retries_taken_new_response_class(
self, async_client: AsyncModernTreasury, failures_before_success: int, respx_mock: MockRouter
) -> None:
client = async_client.with_options(max_retries=4)

nb_retries = 0

def retry_handler(_request: httpx.Request) -> httpx.Response:
nonlocal nb_retries
if nb_retries < failures_before_success:
nb_retries += 1
return httpx.Response(500)
return httpx.Response(200)

respx_mock.post("/api/counterparties").mock(side_effect=retry_handler)

async with client.counterparties.with_streaming_response.create(name="name") as response:
assert response.retries_taken == failures_before_success

0 comments on commit d34ae60

Please sign in to comment.