Skip to content

Commit

Permalink
chore(internal): update base client (#197)
Browse files Browse the repository at this point in the history
  • Loading branch information
stainless-bot committed Sep 6, 2023
1 parent 1e582ba commit f8969c8
Showing 1 changed file with 17 additions and 3 deletions.
20 changes: 17 additions & 3 deletions src/modern_treasury/_base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
overload,
)
from functools import lru_cache
from typing_extensions import Literal, get_origin
from typing_extensions import Literal, get_args, get_origin

import anyio
import httpx
Expand Down Expand Up @@ -458,6 +458,14 @@ def _serialize_multipartform(self, data: Mapping[object, object]) -> dict[str, o
serialized[key] = value
return serialized

def _extract_stream_chunk_type(self, stream_cls: type) -> type:
args = get_args(stream_cls)
if not args:
raise TypeError(
f"Expected stream_cls to have been given a generic type argument, e.g. Stream[Foo] but received {stream_cls}",
)
return cast(type, args[0])

def _process_response(
self,
*,
Expand Down Expand Up @@ -793,7 +801,10 @@ def _request(
raise APIConnectionError(request=request) from err

if stream:
stream_cls = stream_cls or cast("type[_StreamT] | None", self._default_stream_cls)
if stream_cls:
return stream_cls(cast_to=self._extract_stream_chunk_type(stream_cls), response=response, client=self)

stream_cls = cast("type[_StreamT] | None", self._default_stream_cls)
if stream_cls is None:
raise MissingStreamClassError()
return stream_cls(cast_to=cast_to, response=response, client=self)
Expand Down Expand Up @@ -1156,7 +1167,10 @@ async def _request(
raise APIConnectionError(request=request) from err

if stream:
stream_cls = stream_cls or cast("type[_AsyncStreamT] | None", self._default_stream_cls)
if stream_cls:
return stream_cls(cast_to=self._extract_stream_chunk_type(stream_cls), response=response, client=self)

stream_cls = cast("type[_AsyncStreamT] | None", self._default_stream_cls)
if stream_cls is None:
raise MissingStreamClassError()
return stream_cls(cast_to=cast_to, response=response, client=self)
Expand Down

0 comments on commit f8969c8

Please sign in to comment.