Skip to content

Commit

Permalink
ci: auto fixes from pre-commit.ci
Browse files Browse the repository at this point in the history
For more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Feb 21, 2024
1 parent 5961af4 commit bcbc2d3
Show file tree
Hide file tree
Showing 18 changed files with 310 additions and 93 deletions.
24 changes: 18 additions & 6 deletions openllm-client/src/openllm_client/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,17 @@ class HTTPClient:
address: str
helpers: _Helpers
@overload
def __init__(self, address: str, timeout: int = ..., verify: bool = ..., max_retries: int = ..., api_version: str = ...) -> None: ...
def __init__(
self, address: str, timeout: int = ..., verify: bool = ..., max_retries: int = ..., api_version: str = ...
) -> None: ...
@overload
def __init__(self, address: str = ..., timeout: int = ..., verify: bool = ..., max_retries: int = ..., api_version: str = ...) -> None: ...
def __init__(
self, address: str = ..., timeout: int = ..., verify: bool = ..., max_retries: int = ..., api_version: str = ...
) -> None: ...
@overload
def __init__(self, address: None = ..., timeout: int = ..., verify: bool = ..., max_retries: int = ..., api_version: str = ...) -> None: ...
def __init__(
self, address: None = ..., timeout: int = ..., verify: bool = ..., max_retries: int = ..., api_version: str = ...
) -> None: ...
@property
def is_ready(self) -> bool: ...
def health(self) -> bool: ...
Expand Down Expand Up @@ -60,11 +66,17 @@ class AsyncHTTPClient:
address: str
helpers: _AsyncHelpers
@overload
def __init__(self, address: str, timeout: int = ..., verify: bool = ..., max_retries: int = ..., api_version: str = ...) -> None: ...
def __init__(
self, address: str, timeout: int = ..., verify: bool = ..., max_retries: int = ..., api_version: str = ...
) -> None: ...
@overload
def __init__(self, address: str = ..., timeout: int = ..., verify: bool = ..., max_retries: int = ..., api_version: str = ...) -> None: ...
def __init__(
self, address: str = ..., timeout: int = ..., verify: bool = ..., max_retries: int = ..., api_version: str = ...
) -> None: ...
@overload
def __init__(self, address: None = ..., timeout: int = ..., verify: bool = ..., max_retries: int = ..., api_version: str = ...) -> None: ...
def __init__(
self, address: None = ..., timeout: int = ..., verify: bool = ..., max_retries: int = ..., api_version: str = ...
) -> None: ...
@property
def is_ready(self) -> bool: ...
async def health(self) -> bool: ...
Expand Down
28 changes: 21 additions & 7 deletions openllm-client/src/openllm_client/_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,14 @@ def query(self, prompt, **attrs):
return self.generate(prompt, **attrs)

def health(self):
response = self._get('/readyz', response_cls=None, options={'return_raw_response': True, 'max_retries': self._max_retries})
response = self._get(
'/readyz', response_cls=None, options={'return_raw_response': True, 'max_retries': self._max_retries}
)
return response.status_code == 200

def generate(self, prompt, llm_config=None, stop=None, adapter_name=None, timeout=None, verify=None, **attrs) -> Response:
def generate(
self, prompt, llm_config=None, stop=None, adapter_name=None, timeout=None, verify=None, **attrs
) -> Response:
if timeout is None:
timeout = self._timeout
if verify is None:
Expand All @@ -96,7 +100,9 @@ def generate_stream(
for response_chunk in self.generate_iterator(prompt, llm_config, stop, adapter_name, timeout, verify, **attrs):
yield StreamingResponse.from_response_chunk(response_chunk)

def generate_iterator(self, prompt, llm_config=None, stop=None, adapter_name=None, timeout=None, verify=None, **attrs) -> t.Iterator[Response]:
def generate_iterator(
self, prompt, llm_config=None, stop=None, adapter_name=None, timeout=None, verify=None, **attrs
) -> t.Iterator[Response]:
if timeout is None:
timeout = self._timeout
if verify is None:
Expand Down Expand Up @@ -146,7 +152,9 @@ def _build_auth_headers(self) -> t.Dict[str, str]:
@property
async def _metadata(self) -> t.Awaitable[Metadata]:
if self.__metadata is None:
self.__metadata = await self._post(f'/{self._api_version}/metadata', response_cls=Metadata, json={}, options={'max_retries': self._max_retries})
self.__metadata = await self._post(
f'/{self._api_version}/metadata', response_cls=Metadata, json={}, options={'max_retries': self._max_retries}
)
return self.__metadata

@property
Expand All @@ -159,10 +167,14 @@ async def query(self, prompt, **attrs):
return await self.generate(prompt, **attrs)

async def health(self):
response = await self._get('/readyz', response_cls=None, options={'return_raw_response': True, 'max_retries': self._max_retries})
response = await self._get(
'/readyz', response_cls=None, options={'return_raw_response': True, 'max_retries': self._max_retries}
)
return response.status_code == 200

async def generate(self, prompt, llm_config=None, stop=None, adapter_name=None, timeout=None, verify=None, **attrs) -> Response:
async def generate(
self, prompt, llm_config=None, stop=None, adapter_name=None, timeout=None, verify=None, **attrs
) -> Response:
if timeout is None:
timeout = self._timeout
if verify is None:
Expand All @@ -183,7 +195,9 @@ async def generate(self, prompt, llm_config=None, stop=None, adapter_name=None,
async def generate_stream(
self, prompt, llm_config=None, stop=None, adapter_name=None, timeout=None, verify=None, **attrs
) -> t.AsyncGenerator[StreamingResponse, t.Any]:
async for response_chunk in self.generate_iterator(prompt, llm_config, stop, adapter_name, timeout, verify, **attrs):
async for response_chunk in self.generate_iterator(
prompt, llm_config, stop, adapter_name, timeout, verify, **attrs
):
yield StreamingResponse.from_response_chunk(response_chunk)

async def generate_iterator(
Expand Down
23 changes: 18 additions & 5 deletions openllm-client/src/openllm_client/_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from ._shim import AsyncClient, Client


__all__ = ['Response', 'CompletionChunk', 'Metadata', 'StreamingResponse', 'Helpers']
__all__ = ['CompletionChunk', 'Helpers', 'Metadata', 'Response', 'StreamingResponse']


@attr.define
Expand All @@ -42,7 +42,11 @@ def _structure_metadata(data: t.Dict[str, t.Any], cls: type[Metadata]) -> Metada
raise RuntimeError(f'Malformed metadata configuration (Server-side issue): {e}') from None
try:
return cls(
model_id=data['model_id'], timeout=data['timeout'], model_name=data['model_name'], backend=data['backend'], configuration=configuration
model_id=data['model_id'],
timeout=data['timeout'],
model_name=data['model_name'],
backend=data['backend'],
configuration=configuration,
)
except Exception as e:
raise RuntimeError(f'Malformed metadata (Server-side issue): {e}') from None
Expand All @@ -61,7 +65,10 @@ class StreamingResponse(_SchemaMixin):
@classmethod
def from_response_chunk(cls, response: Response) -> StreamingResponse:
return cls(
request_id=response.request_id, index=response.outputs[0].index, text=response.outputs[0].text, token_ids=response.outputs[0].token_ids[0]
request_id=response.request_id,
index=response.outputs[0].index,
text=response.outputs[0].text,
token_ids=response.outputs[0].token_ids[0],
)


Expand All @@ -88,11 +95,17 @@ def async_client(self):
return self._async_client

def messages(self, messages, add_generation_prompt=False):
return self.client._post('/v1/helpers/messages', response_cls=str, json=dict(messages=messages, add_generation_prompt=add_generation_prompt))
return self.client._post(
'/v1/helpers/messages',
response_cls=str,
json=dict(messages=messages, add_generation_prompt=add_generation_prompt),
)

async def async_messages(self, messages, add_generation_prompt=False):
return await self.async_client._post(
'/v1/helpers/messages', response_cls=str, json=dict(messages=messages, add_generation_prompt=add_generation_prompt)
'/v1/helpers/messages',
response_cls=str,
json=dict(messages=messages, add_generation_prompt=add_generation_prompt),
)

@classmethod
Expand Down
85 changes: 68 additions & 17 deletions openllm-client/src/openllm_client/_shim.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,9 @@ def parse(self):

data = self._raw_response.json()
try:
return self._client._process_response_data(data=data, response_cls=self._response_cls, raw_response=self._raw_response)
return self._client._process_response_data(
data=data, response_cls=self._response_cls, raw_response=self._raw_response
)
except Exception as exc:
raise ValueError(exc) from None # validation error here

Expand Down Expand Up @@ -271,10 +273,16 @@ def _build_headers(self, options: RequestOptions) -> httpx.Headers:

def _build_request(self, options: RequestOptions) -> httpx.Request:
return self._inner.build_request(
method=options.method, headers=self._build_headers(options), url=self._prepare_url(options.url), json=options.json, params=options.params
method=options.method,
headers=self._build_headers(options),
url=self._prepare_url(options.url),
json=options.json,
params=options.params,
)

def _calculate_retry_timeout(self, remaining_retries: int, options: RequestOptions, headers: t.Optional[httpx.Headers] = None) -> float:
def _calculate_retry_timeout(
self, remaining_retries: int, options: RequestOptions, headers: t.Optional[httpx.Headers] = None
) -> float:
max_retries = options.get_max_retries(self._max_retries)
# https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After
try:
Expand Down Expand Up @@ -315,7 +323,9 @@ def _should_retry(self, response: httpx.Response) -> bool:
return True
return False

def _process_response_data(self, *, response_cls: type[Response], data: t.Dict[str, t.Any], raw_response: httpx.Response) -> Response:
def _process_response_data(
self, *, response_cls: type[Response], data: t.Dict[str, t.Any], raw_response: httpx.Response
) -> Response:
return converter.structure(data, response_cls)

def _process_response(
Expand All @@ -328,13 +338,24 @@ def _process_response(
stream_cls: type[_Stream] | type[_AsyncStream] | None,
) -> Response:
return APIResponse(
raw_response=raw_response, client=self, response_cls=response_cls, stream=stream, stream_cls=stream_cls, options=options
raw_response=raw_response,
client=self,
response_cls=response_cls,
stream=stream,
stream_cls=stream_cls,
options=options,
).parse()


@attr.define(init=False)
class Client(BaseClient[httpx.Client, Stream[t.Any]]):
def __init__(self, base_url: str | httpx.URL, version: str, timeout: int | httpx.Timeout = DEFAULT_TIMEOUT, max_retries: int = MAX_RETRIES):
def __init__(
self,
base_url: str | httpx.URL,
version: str,
timeout: int | httpx.Timeout = DEFAULT_TIMEOUT,
max_retries: int = MAX_RETRIES,
):
super().__init__(
base_url=base_url,
version=version,
Expand Down Expand Up @@ -366,7 +387,13 @@ def request(
stream: bool = False,
stream_cls: type[_Stream] | None = None,
) -> Response | _Stream:
return self._request(response_cls=response_cls, options=options, remaining_retries=remaining_retries, stream=stream, stream_cls=stream_cls)
return self._request(
response_cls=response_cls,
options=options,
remaining_retries=remaining_retries,
stream=stream,
stream_cls=stream_cls,
)

def _request(
self,
Expand All @@ -385,7 +412,9 @@ def _request(
response.raise_for_status()
except httpx.HTTPStatusError as exc:
if retries > 0 and self._should_retry(exc.response):
return self._retry_request(response_cls, options, retries, exc.response.headers, stream=stream, stream_cls=stream_cls)
return self._retry_request(
response_cls, options, retries, exc.response.headers, stream=stream, stream_cls=stream_cls
)
# If the response is streamed then we need to explicitly read the completed response
exc.response.read()
raise ValueError(exc.message) from None
Expand All @@ -398,7 +427,9 @@ def _request(
return self._retry_request(response_cls, options, retries, stream=stream, stream_cls=stream_cls)
raise ValueError(request) from None # connection error

return self._process_response(response_cls=response_cls, options=options, raw_response=response, stream=stream, stream_cls=stream_cls)
return self._process_response(
response_cls=response_cls, options=options, raw_response=response, stream=stream, stream_cls=stream_cls
)

def _retry_request(
self,
Expand Down Expand Up @@ -428,7 +459,9 @@ def _get(
) -> Response | _Stream:
if options is None:
options = {}
return self.request(response_cls, RequestOptions(method='GET', url=path, **options), stream=stream, stream_cls=stream_cls)
return self.request(
response_cls, RequestOptions(method='GET', url=path, **options), stream=stream, stream_cls=stream_cls
)

def _post(
self,
Expand All @@ -442,12 +475,20 @@ def _post(
) -> Response | _Stream:
if options is None:
options = {}
return self.request(response_cls, RequestOptions(method='POST', url=path, json=json, **options), stream=stream, stream_cls=stream_cls)
return self.request(
response_cls, RequestOptions(method='POST', url=path, json=json, **options), stream=stream, stream_cls=stream_cls
)


@attr.define(init=False)
class AsyncClient(BaseClient[httpx.AsyncClient, AsyncStream[t.Any]]):
def __init__(self, base_url: str | httpx.URL, version: str, timeout: int | httpx.Timeout = DEFAULT_TIMEOUT, max_retries: int = MAX_RETRIES):
def __init__(
self,
base_url: str | httpx.URL,
version: str,
timeout: int | httpx.Timeout = DEFAULT_TIMEOUT,
max_retries: int = MAX_RETRIES,
):
super().__init__(
base_url=base_url,
version=version,
Expand Down Expand Up @@ -486,7 +527,9 @@ async def request(
stream: bool = False,
stream_cls: type[_AsyncStream] | None = None,
) -> Response | _AsyncStream:
return await self._request(response_cls, options, remaining_retries=remaining_retries, stream=stream, stream_cls=stream_cls)
return await self._request(
response_cls, options, remaining_retries=remaining_retries, stream=stream, stream_cls=stream_cls
)

async def _request(
self,
Expand All @@ -506,7 +549,9 @@ async def _request(
response.raise_for_status()
except httpx.HTTPStatusError as exc:
if retries > 0 and self._should_retry(exc.response):
return self._retry_request(response_cls, options, retries, exc.response.headers, stream=stream, stream_cls=stream_cls)
return self._retry_request(
response_cls, options, retries, exc.response.headers, stream=stream, stream_cls=stream_cls
)
# If the response is streamed then we need to explicitly read the completed response
await exc.response.aread()
raise ValueError(exc.message) from None
Expand All @@ -526,7 +571,9 @@ async def _request(
return await self._retry_request(response_cls, options, retries, stream=stream, stream_cls=stream_cls)
raise ValueError(request) from err # connection error

return self._process_response(response_cls=response_cls, options=options, raw_response=response, stream=stream, stream_cls=stream_cls)
return self._process_response(
response_cls=response_cls, options=options, raw_response=response, stream=stream, stream_cls=stream_cls
)

async def _retry_request(
self,
Expand Down Expand Up @@ -555,7 +602,9 @@ async def _get(
) -> Response | _AsyncStream:
if options is None:
options = {}
return await self.request(response_cls, RequestOptions(method='GET', url=path, **options), stream=stream, stream_cls=stream_cls)
return await self.request(
response_cls, RequestOptions(method='GET', url=path, **options), stream=stream, stream_cls=stream_cls
)

async def _post(
self,
Expand All @@ -569,4 +618,6 @@ async def _post(
) -> Response | _AsyncStream:
if options is None:
options = {}
return await self.request(response_cls, RequestOptions(method='POST', url=path, json=json, **options), stream=stream, stream_cls=stream_cls)
return await self.request(
response_cls, RequestOptions(method='POST', url=path, json=json, **options), stream=stream, stream_cls=stream_cls
)
8 changes: 6 additions & 2 deletions openllm-client/src/openllm_client/_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@ def _stream(self) -> t.Iterator[Response]:
if sse.data.startswith('[DONE]'):
break
if sse.event is None:
yield self._client._process_response_data(data=sse.model_dump(), response_cls=self._response_cls, raw_response=self._response)
yield self._client._process_response_data(
data=sse.model_dump(), response_cls=self._response_cls, raw_response=self._response
)


@attr.define(auto_attribs=True)
Expand Down Expand Up @@ -69,7 +71,9 @@ async def _stream(self) -> t.AsyncGenerator[Response, None]:
if sse.data.startswith('[DONE]'):
break
if sse.event is None:
yield self._client._process_response_data(data=sse.model_dump(), response_cls=self._response_cls, raw_response=self._response)
yield self._client._process_response_data(
data=sse.model_dump(), response_cls=self._response_cls, raw_response=self._response
)


@attr.define
Expand Down
4 changes: 3 additions & 1 deletion openllm-client/src/openllm_client/_typing_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,7 @@
overload as overload,
)

Platform = Annotated[LiteralString, Literal['MacOS', 'Linux', 'Windows', 'FreeBSD', 'OpenBSD', 'iOS', 'iPadOS', 'Android', 'Unknown'], str]
Platform = Annotated[
LiteralString, Literal['MacOS', 'Linux', 'Windows', 'FreeBSD', 'OpenBSD', 'iOS', 'iPadOS', 'Android', 'Unknown'], str
]
Architecture = Annotated[LiteralString, Literal['arm', 'arm64', 'x86', 'x86_64', 'Unknown'], str]
Loading

0 comments on commit bcbc2d3

Please sign in to comment.