Skip to content

Commit

Permalink
feat(client): support accessing raw response objects (#256)
Browse files Browse the repository at this point in the history
  • Loading branch information
stainless-bot committed Oct 27, 2023
1 parent ed9dd52 commit a8cc529
Show file tree
Hide file tree
Showing 90 changed files with 5,151 additions and 216 deletions.
20 changes: 20 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,26 @@ if response.my_field is None:
print('Got json like {"my_field": null}.')
```

### Accessing raw response data (e.g. headers)

The "raw" Response object can be accessed by prefixing `.with_raw_response.` to any HTTP method call.

```py
from modern_treasury import ModernTreasury

client = ModernTreasury()
response = client.external_accounts.with_raw_response.create(
counterparty_id="182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e",
)

print(response.headers.get('X-My-Header'))

external_account = response.parse() # get the object that `external_accounts.create()` would have returned
print(external_account.id)
```

These methods return an [`APIResponse`](https://github.com/Modern-Treasury/modern-treasury-python/src/modern_treasury/_response.py) object.

### Configuring the HTTP client

You can directly override the [httpx client](https://www.python-httpx.org/api/#client) to customize it for your use case, including:
Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ format = { chain = [

typecheck = { chain = [
"typecheck:pyright",
"typecheck:verify-types",
"typecheck:mypy"
]}
"typecheck:pyright" = "pyright"
Expand Down
217 changes: 85 additions & 132 deletions src/modern_treasury/_base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
overload,
)
from functools import lru_cache
from typing_extensions import Literal, get_args, override, get_origin
from typing_extensions import Literal, override

import anyio
import httpx
Expand All @@ -49,11 +49,11 @@
ModelT,
Headers,
Timeout,
NoneType,
NotGiven,
ResponseT,
Transport,
AnyMapping,
PostParser,
ProxiesTypes,
RequestFiles,
AsyncTransport,
Expand All @@ -63,20 +63,16 @@
)
from ._utils import is_dict, is_given, is_mapping
from ._compat import model_copy, model_dump
from ._models import (
BaseModel,
GenericModel,
FinalRequestOptions,
validate_type,
construct_type,
from ._models import GenericModel, FinalRequestOptions, validate_type, construct_type
from ._response import APIResponse
from ._constants import (
DEFAULT_LIMITS,
DEFAULT_TIMEOUT,
DEFAULT_MAX_RETRIES,
RAW_RESPONSE_HEADER,
)
from ._streaming import Stream, AsyncStream
from ._exceptions import (
APIStatusError,
APITimeoutError,
APIConnectionError,
APIResponseValidationError,
)
from ._exceptions import APIStatusError, APITimeoutError, APIConnectionError

log: logging.Logger = logging.getLogger(__name__)

Expand All @@ -101,19 +97,6 @@
HTTPX_DEFAULT_TIMEOUT = Timeout(5.0)


# default timeout is 1 minute
DEFAULT_TIMEOUT = Timeout(timeout=60.0, connect=5.0)
DEFAULT_MAX_RETRIES = 2
DEFAULT_LIMITS = Limits(max_connections=100, max_keepalive_connections=20)


class MissingStreamClassError(TypeError):
def __init__(self) -> None:
super().__init__(
"The `stream` argument was set to `True` but the `stream_cls` argument was not given. See `modern_treasury._streaming` for reference",
)


class PageInfo:
"""Stores the necesary information to build the request to retrieve the next page.
Expand Down Expand Up @@ -182,6 +165,7 @@ def _params_from_url(self, url: URL) -> httpx.QueryParams:

def _info_to_options(self, info: PageInfo) -> FinalRequestOptions:
options = model_copy(self._options)
options._strip_raw_response_header()

if not isinstance(info.params, NotGiven):
options.params = {**options.params, **info.params}
Expand Down Expand Up @@ -260,13 +244,17 @@ def __await__(self) -> Generator[Any, None, AsyncPageT]:
return self._get_page().__await__()

async def _get_page(self) -> AsyncPageT:
page = await self._client.request(self._page_cls, self._options)
page._set_private_attributes( # pyright: ignore[reportPrivateUsage]
model=self._model,
options=self._options,
client=self._client,
)
return page
def _parser(resp: AsyncPageT) -> AsyncPageT:
resp._set_private_attributes(
model=self._model,
options=self._options,
client=self._client,
)
return resp

self._options.post_parser = _parser

return await self._client.request(self._page_cls, self._options)

async def __aiter__(self) -> AsyncIterator[ModelT]:
# https://github.com/microsoft/pyright/issues/3464
Expand Down Expand Up @@ -317,9 +305,10 @@ async def get_next_page(self: AsyncPageT) -> AsyncPageT:


_HttpxClientT = TypeVar("_HttpxClientT", bound=Union[httpx.Client, httpx.AsyncClient])
_DefaultStreamT = TypeVar("_DefaultStreamT", bound=Union[Stream[Any], AsyncStream[Any]])


class BaseClient(Generic[_HttpxClientT]):
class BaseClient(Generic[_HttpxClientT, _DefaultStreamT]):
_client: _HttpxClientT
_version: str
_base_url: URL
Expand All @@ -330,6 +319,7 @@ class BaseClient(Generic[_HttpxClientT]):
_transport: Transport | AsyncTransport | None
_strict_response_validation: bool
_idempotency_header: str | None
_default_stream_cls: type[_DefaultStreamT] | None = None

def __init__(
self,
Expand Down Expand Up @@ -504,80 +494,28 @@ def _serialize_multipartform(self, data: Mapping[object, object]) -> dict[str, o
serialized[key] = value
return serialized

def _extract_stream_chunk_type(self, stream_cls: type) -> type:
args = get_args(stream_cls)
if not args:
raise TypeError(
f"Expected stream_cls to have been given a generic type argument, e.g. Stream[Foo] but received {stream_cls}",
)
return cast(type, args[0])

def _process_response(
self,
*,
cast_to: Type[ResponseT],
options: FinalRequestOptions, # noqa: ARG002
options: FinalRequestOptions,
response: httpx.Response,
stream: bool,
stream_cls: type[Stream[Any]] | type[AsyncStream[Any]] | None,
) -> ResponseT:
if cast_to is NoneType:
return cast(ResponseT, None)

if cast_to == str:
return cast(ResponseT, response.text)

origin = get_origin(cast_to) or cast_to

if inspect.isclass(origin) and issubclass(origin, httpx.Response):
# Because of the invariance of our ResponseT TypeVar, users can subclass httpx.Response
# and pass that class to our request functions. We cannot change the variance to be either
# covariant or contravariant as that makes our usage of ResponseT illegal. We could construct
# the response class ourselves but that is something that should be supported directly in httpx
# as it would be easy to incorrectly construct the Response object due to the multitude of arguments.
if cast_to != httpx.Response:
raise ValueError(f"Subclasses of httpx.Response cannot be passed to `cast_to`")
return cast(ResponseT, response)

# The check here is necessary as we are subverting the the type system
# with casts as the relationship between TypeVars and Types are very strict
# which means we must return *exactly* what was input or transform it in a
# way that retains the TypeVar state. As we cannot do that in this function
# then we have to resort to using `cast`. At the time of writing, we know this
# to be safe as we have handled all the types that could be bound to the
# `ResponseT` TypeVar, however if that TypeVar is ever updated in the future, then
# this function would become unsafe but a type checker would not report an error.
if (
cast_to is not UnknownResponse
and not origin is list
and not origin is dict
and not origin is Union
and not issubclass(origin, BaseModel)
):
raise RuntimeError(
f"Invalid state, expected {cast_to} to be a subclass type of {BaseModel}, {dict}, {list} or {Union}."
)

# split is required to handle cases where additional information is included
# in the response, e.g. application/json; charset=utf-8
content_type, *_ = response.headers.get("content-type").split(";")
if content_type != "application/json":
if self._strict_response_validation:
raise APIResponseValidationError(
response=response,
message=f"Expected Content-Type response header to be `application/json` but received `{content_type}` instead.",
body=response.text,
)

# If the API responds with content that isn't JSON then we just return
# the (decoded) text without performing any parsing so that you can still
# handle the response however you need to.
return response.text # type: ignore
api_response = APIResponse(
raw=response,
client=self,
cast_to=cast_to,
stream=stream,
stream_cls=stream_cls,
options=options,
)

data = response.json()
if response.request.headers.get(RAW_RESPONSE_HEADER) == "true":
return cast(ResponseT, api_response)

try:
return self._process_response_data(data=data, cast_to=cast_to, response=response)
except pydantic.ValidationError as err:
raise APIResponseValidationError(response=response, body=data) from err
return api_response.parse()

def _process_response_data(
self,
Expand Down Expand Up @@ -734,7 +672,7 @@ def _idempotency_key(self) -> str:
return f"stainless-python-retry-{uuid.uuid4()}"


class SyncAPIClient(BaseClient[httpx.Client]):
class SyncAPIClient(BaseClient[httpx.Client, Stream[Any]]):
_client: httpx.Client
_has_custom_http_client: bool
_default_stream_cls: type[Stream[Any]] | None = None
Expand Down Expand Up @@ -930,23 +868,32 @@ def _request(
raise self._make_status_error_from_response(err.response) from None
except httpx.TimeoutException as err:
if retries > 0:
return self._retry_request(options, cast_to, retries, stream=stream, stream_cls=stream_cls)
return self._retry_request(
options,
cast_to,
retries,
stream=stream,
stream_cls=stream_cls,
)
raise APITimeoutError(request=request) from err
except Exception as err:
if retries > 0:
return self._retry_request(options, cast_to, retries, stream=stream, stream_cls=stream_cls)
return self._retry_request(
options,
cast_to,
retries,
stream=stream,
stream_cls=stream_cls,
)
raise APIConnectionError(request=request) from err

if stream:
if stream_cls:
return stream_cls(cast_to=self._extract_stream_chunk_type(stream_cls), response=response, client=self)

stream_cls = cast("type[_StreamT] | None", self._default_stream_cls)
if stream_cls is None:
raise MissingStreamClassError()
return stream_cls(cast_to=cast_to, response=response, client=self)

return self._process_response(cast_to=cast_to, options=options, response=response)
return self._process_response(
cast_to=cast_to,
options=options,
response=response,
stream=stream,
stream_cls=stream_cls,
)

def _retry_request(
self,
Expand Down Expand Up @@ -980,13 +927,17 @@ def _request_api_list(
page: Type[SyncPageT],
options: FinalRequestOptions,
) -> SyncPageT:
resp = self.request(page, options, stream=False)
resp._set_private_attributes( # pyright: ignore[reportPrivateUsage]
client=self,
model=model,
options=options,
)
return resp
def _parser(resp: SyncPageT) -> SyncPageT:
resp._set_private_attributes(
client=self,
model=model,
options=options,
)
return resp

options.post_parser = _parser

return self.request(page, options, stream=False)

@overload
def get(
Expand Down Expand Up @@ -1144,7 +1095,7 @@ def get_api_list(
return self._request_api_list(model, page, opts)


class AsyncAPIClient(BaseClient[httpx.AsyncClient]):
class AsyncAPIClient(BaseClient[httpx.AsyncClient, AsyncStream[Any]]):
_client: httpx.AsyncClient
_has_custom_http_client: bool
_default_stream_cls: type[AsyncStream[Any]] | None = None
Expand Down Expand Up @@ -1354,16 +1305,13 @@ async def _request(
return await self._retry_request(options, cast_to, retries, stream=stream, stream_cls=stream_cls)
raise APIConnectionError(request=request) from err

if stream:
if stream_cls:
return stream_cls(cast_to=self._extract_stream_chunk_type(stream_cls), response=response, client=self)

stream_cls = cast("type[_AsyncStreamT] | None", self._default_stream_cls)
if stream_cls is None:
raise MissingStreamClassError()
return stream_cls(cast_to=cast_to, response=response, client=self)

return self._process_response(cast_to=cast_to, options=options, response=response)
return self._process_response(
cast_to=cast_to,
options=options,
response=response,
stream=stream,
stream_cls=stream_cls,
)

async def _retry_request(
self,
Expand Down Expand Up @@ -1560,6 +1508,7 @@ def make_request_options(
extra_body: Body | None = None,
idempotency_key: str | None = None,
timeout: float | None | NotGiven = NOT_GIVEN,
post_parser: PostParser | NotGiven = NOT_GIVEN,
) -> RequestOptions:
"""Create a dict of type RequestOptions without keys of NotGiven values."""
options: RequestOptions = {}
Expand All @@ -1581,6 +1530,10 @@ def make_request_options(
if idempotency_key is not None:
options["idempotency_key"] = idempotency_key

if is_given(post_parser):
# internal
options["post_parser"] = post_parser # type: ignore

return options


Expand Down
Loading

0 comments on commit a8cc529

Please sign in to comment.