From bcbc2d335e93fc5cf4ada8a1b5404b5134486e56 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 21 Feb 2024 21:00:49 +0000 Subject: [PATCH] ci: auto fixes from pre-commit.ci For more information, see https://pre-commit.ci --- .../src/openllm_client/__init__.pyi | 24 +++-- openllm-client/src/openllm_client/_http.py | 28 ++++-- openllm-client/src/openllm_client/_schemas.py | 23 ++++- openllm-client/src/openllm_client/_shim.py | 85 ++++++++++++---- openllm-client/src/openllm_client/_stream.py | 8 +- .../src/openllm_client/_typing_compat.py | 4 +- openllm-python/src/openllm/_llm.py | 98 ++++++++++++++----- openllm-python/src/openllm/_runners.py | 28 ++++-- openllm-python/src/openllm/bundle/_package.py | 2 +- openllm-python/src/openllm_cli/_sdk.py | 4 +- .../src/openllm_cli/extension/dive_bentos.py | 4 +- .../extension/get_containerfile.py | 4 +- .../src/openllm_cli/extension/get_prompt.py | 32 ++++-- .../src/openllm_cli/extension/list_models.py | 9 +- .../src/openllm_cli/extension/playground.py | 17 +++- openllm-python/tests/configuration_test.py | 11 ++- openllm-python/tests/conftest.py | 14 ++- openllm-python/tests/strategies_test.py | 8 +- 18 files changed, 310 insertions(+), 93 deletions(-) diff --git a/openllm-client/src/openllm_client/__init__.pyi b/openllm-client/src/openllm_client/__init__.pyi index 3b5ecfb30..bec8fc114 100644 --- a/openllm-client/src/openllm_client/__init__.pyi +++ b/openllm-client/src/openllm_client/__init__.pyi @@ -15,11 +15,17 @@ class HTTPClient: address: str helpers: _Helpers @overload - def __init__(self, address: str, timeout: int = ..., verify: bool = ..., max_retries: int = ..., api_version: str = ...) -> None: ... + def __init__( + self, address: str, timeout: int = ..., verify: bool = ..., max_retries: int = ..., api_version: str = ... + ) -> None: ... @overload - def __init__(self, address: str = ..., timeout: int = ..., verify: bool = ..., max_retries: int = ..., api_version: str = ...) -> None: ... + def __init__( + self, address: str = ..., timeout: int = ..., verify: bool = ..., max_retries: int = ..., api_version: str = ... + ) -> None: ... @overload - def __init__(self, address: None = ..., timeout: int = ..., verify: bool = ..., max_retries: int = ..., api_version: str = ...) -> None: ... + def __init__( + self, address: None = ..., timeout: int = ..., verify: bool = ..., max_retries: int = ..., api_version: str = ... + ) -> None: ... @property def is_ready(self) -> bool: ... def health(self) -> bool: ... @@ -60,11 +66,17 @@ class AsyncHTTPClient: address: str helpers: _AsyncHelpers @overload - def __init__(self, address: str, timeout: int = ..., verify: bool = ..., max_retries: int = ..., api_version: str = ...) -> None: ... + def __init__( + self, address: str, timeout: int = ..., verify: bool = ..., max_retries: int = ..., api_version: str = ... + ) -> None: ... @overload - def __init__(self, address: str = ..., timeout: int = ..., verify: bool = ..., max_retries: int = ..., api_version: str = ...) -> None: ... + def __init__( + self, address: str = ..., timeout: int = ..., verify: bool = ..., max_retries: int = ..., api_version: str = ... + ) -> None: ... @overload - def __init__(self, address: None = ..., timeout: int = ..., verify: bool = ..., max_retries: int = ..., api_version: str = ...) -> None: ... + def __init__( + self, address: None = ..., timeout: int = ..., verify: bool = ..., max_retries: int = ..., api_version: str = ... + ) -> None: ... @property def is_ready(self) -> bool: ... async def health(self) -> bool: ... diff --git a/openllm-client/src/openllm_client/_http.py b/openllm-client/src/openllm_client/_http.py index 9923468e0..ed802a67b 100644 --- a/openllm-client/src/openllm_client/_http.py +++ b/openllm-client/src/openllm_client/_http.py @@ -70,10 +70,14 @@ def query(self, prompt, **attrs): return self.generate(prompt, **attrs) def health(self): - response = self._get('/readyz', response_cls=None, options={'return_raw_response': True, 'max_retries': self._max_retries}) + response = self._get( + '/readyz', response_cls=None, options={'return_raw_response': True, 'max_retries': self._max_retries} + ) return response.status_code == 200 - def generate(self, prompt, llm_config=None, stop=None, adapter_name=None, timeout=None, verify=None, **attrs) -> Response: + def generate( + self, prompt, llm_config=None, stop=None, adapter_name=None, timeout=None, verify=None, **attrs + ) -> Response: if timeout is None: timeout = self._timeout if verify is None: @@ -96,7 +100,9 @@ def generate_stream( for response_chunk in self.generate_iterator(prompt, llm_config, stop, adapter_name, timeout, verify, **attrs): yield StreamingResponse.from_response_chunk(response_chunk) - def generate_iterator(self, prompt, llm_config=None, stop=None, adapter_name=None, timeout=None, verify=None, **attrs) -> t.Iterator[Response]: + def generate_iterator( + self, prompt, llm_config=None, stop=None, adapter_name=None, timeout=None, verify=None, **attrs + ) -> t.Iterator[Response]: if timeout is None: timeout = self._timeout if verify is None: @@ -146,7 +152,9 @@ def _build_auth_headers(self) -> t.Dict[str, str]: @property async def _metadata(self) -> t.Awaitable[Metadata]: if self.__metadata is None: - self.__metadata = await self._post(f'/{self._api_version}/metadata', response_cls=Metadata, json={}, options={'max_retries': self._max_retries}) + self.__metadata = await self._post( + f'/{self._api_version}/metadata', response_cls=Metadata, json={}, options={'max_retries': self._max_retries} + ) return self.__metadata @property @@ -159,10 +167,14 @@ async def query(self, prompt, **attrs): return await self.generate(prompt, **attrs) async def health(self): - response = await self._get('/readyz', response_cls=None, options={'return_raw_response': True, 'max_retries': self._max_retries}) + response = await self._get( + '/readyz', response_cls=None, options={'return_raw_response': True, 'max_retries': self._max_retries} + ) return response.status_code == 200 - async def generate(self, prompt, llm_config=None, stop=None, adapter_name=None, timeout=None, verify=None, **attrs) -> Response: + async def generate( + self, prompt, llm_config=None, stop=None, adapter_name=None, timeout=None, verify=None, **attrs + ) -> Response: if timeout is None: timeout = self._timeout if verify is None: @@ -183,7 +195,9 @@ async def generate(self, prompt, llm_config=None, stop=None, adapter_name=None, async def generate_stream( self, prompt, llm_config=None, stop=None, adapter_name=None, timeout=None, verify=None, **attrs ) -> t.AsyncGenerator[StreamingResponse, t.Any]: - async for response_chunk in self.generate_iterator(prompt, llm_config, stop, adapter_name, timeout, verify, **attrs): + async for response_chunk in self.generate_iterator( + prompt, llm_config, stop, adapter_name, timeout, verify, **attrs + ): yield StreamingResponse.from_response_chunk(response_chunk) async def generate_iterator( diff --git a/openllm-client/src/openllm_client/_schemas.py b/openllm-client/src/openllm_client/_schemas.py index 037c92b3f..1723d583f 100644 --- a/openllm-client/src/openllm_client/_schemas.py +++ b/openllm-client/src/openllm_client/_schemas.py @@ -17,7 +17,7 @@ from ._shim import AsyncClient, Client -__all__ = ['Response', 'CompletionChunk', 'Metadata', 'StreamingResponse', 'Helpers'] +__all__ = ['CompletionChunk', 'Helpers', 'Metadata', 'Response', 'StreamingResponse'] @attr.define @@ -42,7 +42,11 @@ def _structure_metadata(data: t.Dict[str, t.Any], cls: type[Metadata]) -> Metada raise RuntimeError(f'Malformed metadata configuration (Server-side issue): {e}') from None try: return cls( - model_id=data['model_id'], timeout=data['timeout'], model_name=data['model_name'], backend=data['backend'], configuration=configuration + model_id=data['model_id'], + timeout=data['timeout'], + model_name=data['model_name'], + backend=data['backend'], + configuration=configuration, ) except Exception as e: raise RuntimeError(f'Malformed metadata (Server-side issue): {e}') from None @@ -61,7 +65,10 @@ class StreamingResponse(_SchemaMixin): @classmethod def from_response_chunk(cls, response: Response) -> StreamingResponse: return cls( - request_id=response.request_id, index=response.outputs[0].index, text=response.outputs[0].text, token_ids=response.outputs[0].token_ids[0] + request_id=response.request_id, + index=response.outputs[0].index, + text=response.outputs[0].text, + token_ids=response.outputs[0].token_ids[0], ) @@ -88,11 +95,17 @@ def async_client(self): return self._async_client def messages(self, messages, add_generation_prompt=False): - return self.client._post('/v1/helpers/messages', response_cls=str, json=dict(messages=messages, add_generation_prompt=add_generation_prompt)) + return self.client._post( + '/v1/helpers/messages', + response_cls=str, + json=dict(messages=messages, add_generation_prompt=add_generation_prompt), + ) async def async_messages(self, messages, add_generation_prompt=False): return await self.async_client._post( - '/v1/helpers/messages', response_cls=str, json=dict(messages=messages, add_generation_prompt=add_generation_prompt) + '/v1/helpers/messages', + response_cls=str, + json=dict(messages=messages, add_generation_prompt=add_generation_prompt), ) @classmethod diff --git a/openllm-client/src/openllm_client/_shim.py b/openllm-client/src/openllm_client/_shim.py index 04a9a730a..4c7470e7f 100644 --- a/openllm-client/src/openllm_client/_shim.py +++ b/openllm-client/src/openllm_client/_shim.py @@ -140,7 +140,9 @@ def parse(self): data = self._raw_response.json() try: - return self._client._process_response_data(data=data, response_cls=self._response_cls, raw_response=self._raw_response) + return self._client._process_response_data( + data=data, response_cls=self._response_cls, raw_response=self._raw_response + ) except Exception as exc: raise ValueError(exc) from None # validation error here @@ -271,10 +273,16 @@ def _build_headers(self, options: RequestOptions) -> httpx.Headers: def _build_request(self, options: RequestOptions) -> httpx.Request: return self._inner.build_request( - method=options.method, headers=self._build_headers(options), url=self._prepare_url(options.url), json=options.json, params=options.params + method=options.method, + headers=self._build_headers(options), + url=self._prepare_url(options.url), + json=options.json, + params=options.params, ) - def _calculate_retry_timeout(self, remaining_retries: int, options: RequestOptions, headers: t.Optional[httpx.Headers] = None) -> float: + def _calculate_retry_timeout( + self, remaining_retries: int, options: RequestOptions, headers: t.Optional[httpx.Headers] = None + ) -> float: max_retries = options.get_max_retries(self._max_retries) # https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After try: @@ -315,7 +323,9 @@ def _should_retry(self, response: httpx.Response) -> bool: return True return False - def _process_response_data(self, *, response_cls: type[Response], data: t.Dict[str, t.Any], raw_response: httpx.Response) -> Response: + def _process_response_data( + self, *, response_cls: type[Response], data: t.Dict[str, t.Any], raw_response: httpx.Response + ) -> Response: return converter.structure(data, response_cls) def _process_response( @@ -328,13 +338,24 @@ def _process_response( stream_cls: type[_Stream] | type[_AsyncStream] | None, ) -> Response: return APIResponse( - raw_response=raw_response, client=self, response_cls=response_cls, stream=stream, stream_cls=stream_cls, options=options + raw_response=raw_response, + client=self, + response_cls=response_cls, + stream=stream, + stream_cls=stream_cls, + options=options, ).parse() @attr.define(init=False) class Client(BaseClient[httpx.Client, Stream[t.Any]]): - def __init__(self, base_url: str | httpx.URL, version: str, timeout: int | httpx.Timeout = DEFAULT_TIMEOUT, max_retries: int = MAX_RETRIES): + def __init__( + self, + base_url: str | httpx.URL, + version: str, + timeout: int | httpx.Timeout = DEFAULT_TIMEOUT, + max_retries: int = MAX_RETRIES, + ): super().__init__( base_url=base_url, version=version, @@ -366,7 +387,13 @@ def request( stream: bool = False, stream_cls: type[_Stream] | None = None, ) -> Response | _Stream: - return self._request(response_cls=response_cls, options=options, remaining_retries=remaining_retries, stream=stream, stream_cls=stream_cls) + return self._request( + response_cls=response_cls, + options=options, + remaining_retries=remaining_retries, + stream=stream, + stream_cls=stream_cls, + ) def _request( self, @@ -385,7 +412,9 @@ def _request( response.raise_for_status() except httpx.HTTPStatusError as exc: if retries > 0 and self._should_retry(exc.response): - return self._retry_request(response_cls, options, retries, exc.response.headers, stream=stream, stream_cls=stream_cls) + return self._retry_request( + response_cls, options, retries, exc.response.headers, stream=stream, stream_cls=stream_cls + ) # If the response is streamed then we need to explicitly read the completed response exc.response.read() raise ValueError(exc.message) from None @@ -398,7 +427,9 @@ def _request( return self._retry_request(response_cls, options, retries, stream=stream, stream_cls=stream_cls) raise ValueError(request) from None # connection error - return self._process_response(response_cls=response_cls, options=options, raw_response=response, stream=stream, stream_cls=stream_cls) + return self._process_response( + response_cls=response_cls, options=options, raw_response=response, stream=stream, stream_cls=stream_cls + ) def _retry_request( self, @@ -428,7 +459,9 @@ def _get( ) -> Response | _Stream: if options is None: options = {} - return self.request(response_cls, RequestOptions(method='GET', url=path, **options), stream=stream, stream_cls=stream_cls) + return self.request( + response_cls, RequestOptions(method='GET', url=path, **options), stream=stream, stream_cls=stream_cls + ) def _post( self, @@ -442,12 +475,20 @@ def _post( ) -> Response | _Stream: if options is None: options = {} - return self.request(response_cls, RequestOptions(method='POST', url=path, json=json, **options), stream=stream, stream_cls=stream_cls) + return self.request( + response_cls, RequestOptions(method='POST', url=path, json=json, **options), stream=stream, stream_cls=stream_cls + ) @attr.define(init=False) class AsyncClient(BaseClient[httpx.AsyncClient, AsyncStream[t.Any]]): - def __init__(self, base_url: str | httpx.URL, version: str, timeout: int | httpx.Timeout = DEFAULT_TIMEOUT, max_retries: int = MAX_RETRIES): + def __init__( + self, + base_url: str | httpx.URL, + version: str, + timeout: int | httpx.Timeout = DEFAULT_TIMEOUT, + max_retries: int = MAX_RETRIES, + ): super().__init__( base_url=base_url, version=version, @@ -486,7 +527,9 @@ async def request( stream: bool = False, stream_cls: type[_AsyncStream] | None = None, ) -> Response | _AsyncStream: - return await self._request(response_cls, options, remaining_retries=remaining_retries, stream=stream, stream_cls=stream_cls) + return await self._request( + response_cls, options, remaining_retries=remaining_retries, stream=stream, stream_cls=stream_cls + ) async def _request( self, @@ -506,7 +549,9 @@ async def _request( response.raise_for_status() except httpx.HTTPStatusError as exc: if retries > 0 and self._should_retry(exc.response): - return self._retry_request(response_cls, options, retries, exc.response.headers, stream=stream, stream_cls=stream_cls) + return self._retry_request( + response_cls, options, retries, exc.response.headers, stream=stream, stream_cls=stream_cls + ) # If the response is streamed then we need to explicitly read the completed response await exc.response.aread() raise ValueError(exc.message) from None @@ -526,7 +571,9 @@ async def _request( return await self._retry_request(response_cls, options, retries, stream=stream, stream_cls=stream_cls) raise ValueError(request) from err # connection error - return self._process_response(response_cls=response_cls, options=options, raw_response=response, stream=stream, stream_cls=stream_cls) + return self._process_response( + response_cls=response_cls, options=options, raw_response=response, stream=stream, stream_cls=stream_cls + ) async def _retry_request( self, @@ -555,7 +602,9 @@ async def _get( ) -> Response | _AsyncStream: if options is None: options = {} - return await self.request(response_cls, RequestOptions(method='GET', url=path, **options), stream=stream, stream_cls=stream_cls) + return await self.request( + response_cls, RequestOptions(method='GET', url=path, **options), stream=stream, stream_cls=stream_cls + ) async def _post( self, @@ -569,4 +618,6 @@ async def _post( ) -> Response | _AsyncStream: if options is None: options = {} - return await self.request(response_cls, RequestOptions(method='POST', url=path, json=json, **options), stream=stream, stream_cls=stream_cls) + return await self.request( + response_cls, RequestOptions(method='POST', url=path, json=json, **options), stream=stream, stream_cls=stream_cls + ) diff --git a/openllm-client/src/openllm_client/_stream.py b/openllm-client/src/openllm_client/_stream.py index a5103207a..e81a7fb07 100644 --- a/openllm-client/src/openllm_client/_stream.py +++ b/openllm-client/src/openllm_client/_stream.py @@ -38,7 +38,9 @@ def _stream(self) -> t.Iterator[Response]: if sse.data.startswith('[DONE]'): break if sse.event is None: - yield self._client._process_response_data(data=sse.model_dump(), response_cls=self._response_cls, raw_response=self._response) + yield self._client._process_response_data( + data=sse.model_dump(), response_cls=self._response_cls, raw_response=self._response + ) @attr.define(auto_attribs=True) @@ -69,7 +71,9 @@ async def _stream(self) -> t.AsyncGenerator[Response, None]: if sse.data.startswith('[DONE]'): break if sse.event is None: - yield self._client._process_response_data(data=sse.model_dump(), response_cls=self._response_cls, raw_response=self._response) + yield self._client._process_response_data( + data=sse.model_dump(), response_cls=self._response_cls, raw_response=self._response + ) @attr.define diff --git a/openllm-client/src/openllm_client/_typing_compat.py b/openllm-client/src/openllm_client/_typing_compat.py index 48bd0a855..15d86f8a3 100644 --- a/openllm-client/src/openllm_client/_typing_compat.py +++ b/openllm-client/src/openllm_client/_typing_compat.py @@ -11,5 +11,7 @@ overload as overload, ) -Platform = Annotated[LiteralString, Literal['MacOS', 'Linux', 'Windows', 'FreeBSD', 'OpenBSD', 'iOS', 'iPadOS', 'Android', 'Unknown'], str] +Platform = Annotated[ + LiteralString, Literal['MacOS', 'Linux', 'Windows', 'FreeBSD', 'OpenBSD', 'iOS', 'iPadOS', 'Android', 'Unknown'], str +] Architecture = Annotated[LiteralString, Literal['arm', 'arm64', 'x86', 'x86_64', 'Unknown'], str] diff --git a/openllm-python/src/openllm/_llm.py b/openllm-python/src/openllm/_llm.py index f7df0c3b2..db8da204a 100644 --- a/openllm-python/src/openllm/_llm.py +++ b/openllm-python/src/openllm/_llm.py @@ -48,7 +48,9 @@ @attr.define(slots=False, repr=False, init=False) class LLM(t.Generic[M, T]): - async def generate(self, prompt, prompt_token_ids=None, stop=None, stop_token_ids=None, request_id=None, adapter_name=None, **attrs): + async def generate( + self, prompt, prompt_token_ids=None, stop=None, stop_token_ids=None, request_id=None, adapter_name=None, **attrs + ): if adapter_name is not None and self.__llm_backend__ != 'pt': raise NotImplementedError(f'Adapter is not supported with {self.__llm_backend__}.') config = self.config.model_construct_env(**attrs) @@ -63,10 +65,15 @@ async def generate(self, prompt, prompt_token_ids=None, stop=None, stop_token_id raise RuntimeError('No result is returned.') return final_result.with_options( prompt=prompt, - outputs=[output.with_options(text=''.join(texts[output.index]), token_ids=token_ids[output.index]) for output in final_result.outputs], + outputs=[ + output.with_options(text=''.join(texts[output.index]), token_ids=token_ids[output.index]) + for output in final_result.outputs + ], ) - async def generate_iterator(self, prompt, prompt_token_ids=None, stop=None, stop_token_ids=None, request_id=None, adapter_name=None, **attrs): + async def generate_iterator( + self, prompt, prompt_token_ids=None, stop=None, stop_token_ids=None, request_id=None, adapter_name=None, **attrs + ): from bentoml._internal.runner.runner_handle import DummyRunnerHandle if adapter_name is not None and self.__llm_backend__ != 'pt': @@ -131,7 +138,9 @@ async def generate_iterator(self, prompt, prompt_token_ids=None, stop=None, stop # The below are mainly for internal implementation that you don't have to worry about. _model_id: str _revision: t.Optional[str] # - _quantization_config: t.Optional[t.Union[transformers.BitsAndBytesConfig, transformers.GPTQConfig, transformers.AwqConfig]] + _quantization_config: t.Optional[ + t.Union[transformers.BitsAndBytesConfig, transformers.GPTQConfig, transformers.AwqConfig] + ] _quantise: t.Optional[LiteralQuantise] _model_decls: t.Tuple[t.Any, ...] __model_attrs: t.Dict[str, t.Any] # @@ -147,7 +156,9 @@ async def generate_iterator(self, prompt, prompt_token_ids=None, stop=None, stop __llm_torch_dtype__: 'torch.dtype' = None __llm_config__: t.Optional[LLMConfig] = None __llm_backend__: LiteralBackend = None - __llm_quantization_config__: t.Optional[t.Union[transformers.BitsAndBytesConfig, transformers.GPTQConfig, transformers.AwqConfig]] = None + __llm_quantization_config__: t.Optional[ + t.Union[transformers.BitsAndBytesConfig, transformers.GPTQConfig, transformers.AwqConfig] + ] = None __llm_runner__: t.Optional[Runner[M, T]] = None __llm_model__: t.Optional[M] = None __llm_tokenizer__: t.Optional[T] = None @@ -178,7 +189,9 @@ def __init__( torch_dtype = attrs.pop('torch_dtype', None) # backward compatible if torch_dtype is not None: warnings.warn( - 'The argument "torch_dtype" is deprecated and will be removed in the future. Please use "dtype" instead.', DeprecationWarning, stacklevel=3 + 'The argument "torch_dtype" is deprecated and will be removed in the future. Please use "dtype" instead.', + DeprecationWarning, + stacklevel=3, ) dtype = torch_dtype _local = False @@ -234,19 +247,27 @@ def __init__( class _Quantise: @staticmethod - def pt(llm: LLM, quantise=None): return quantise + def pt(llm: LLM, quantise=None): + return quantise + @staticmethod - def vllm(llm: LLM, quantise=None): return quantise + def vllm(llm: LLM, quantise=None): + return quantise @apply(lambda val: tuple(str.lower(i) if i else i for i in val)) def _make_tag_components(self, model_id: str, model_version: str | None, backend: str) -> tuple[str, str | None]: model_id, *maybe_revision = model_id.rsplit(':') if len(maybe_revision) > 0: if model_version is not None: - logger.warning("revision is specified (%s). 'model_version=%s' will be ignored.", maybe_revision[0], model_version) + logger.warning( + "revision is specified (%s). 'model_version=%s' will be ignored.", maybe_revision[0], model_version + ) model_version = maybe_revision[0] if validate_is_path(model_id): - model_id, model_version = resolve_filepath(model_id), first_not_none(model_version, default=generate_hash_from_file(model_id)) + model_id, model_version = ( + resolve_filepath(model_id), + first_not_none(model_version, default=generate_hash_from_file(model_id)), + ) return f'{backend}-{normalise_model_name(model_id)}', model_version @functools.cached_property @@ -255,9 +276,11 @@ def _has_gpus(self): from cuda import cuda err, *_ = cuda.cuInit(0) - if err != cuda.CUresult.CUDA_SUCCESS: raise RuntimeError('Failed to initialise CUDA runtime binding.') + if err != cuda.CUresult.CUDA_SUCCESS: + raise RuntimeError('Failed to initialise CUDA runtime binding.') err, _ = cuda.cuDeviceGetCount() - if err != cuda.CUresult.CUDA_SUCCESS: raise RuntimeError('Failed to get CUDA device count.') + if err != cuda.CUresult.CUDA_SUCCESS: + raise RuntimeError('Failed to get CUDA device count.') return True except (ImportError, RuntimeError): return False @@ -269,7 +292,9 @@ def _torch_dtype(self): _map = _torch_dtype_mapping() if not isinstance(self.__llm_torch_dtype__, torch.dtype): try: - hf_config = transformers.AutoConfig.from_pretrained(self.bentomodel.path, trust_remote_code=self.trust_remote_code) + hf_config = transformers.AutoConfig.from_pretrained( + self.bentomodel.path, trust_remote_code=self.trust_remote_code + ) except OpenLLMException: hf_config = transformers.AutoConfig.from_pretrained(self.model_id, trust_remote_code=self.trust_remote_code) config_dtype = getattr(hf_config, 'torch_dtype', None) @@ -300,7 +325,9 @@ def _tokenizer_attrs(self): return {**self.import_kwargs[1], **self.__tokenizer_attrs} def _cascade_backend(self) -> LiteralBackend: - logger.warning('It is recommended to specify the backend explicitly. Cascading backend might lead to unexpected behaviour.') + logger.warning( + 'It is recommended to specify the backend explicitly. Cascading backend might lead to unexpected behaviour.' + ) if self._has_gpus and is_vllm_available(): return 'vllm' else: @@ -330,7 +357,10 @@ def __repr__(self) -> str: @property def import_kwargs(self): - return {'device_map': 'auto' if self._has_gpus else None, 'torch_dtype': self._torch_dtype}, {'padding_side': 'left', 'truncation_side': 'left'} + return {'device_map': 'auto' if self._has_gpus else None, 'torch_dtype': self._torch_dtype}, { + 'padding_side': 'left', + 'truncation_side': 'left', + } @property def trust_remote_code(self): @@ -363,7 +393,9 @@ def quantization_config(self): if self._quantization_config is not None: self.__llm_quantization_config__ = self._quantization_config elif self._quantise is not None: - self.__llm_quantization_config__, self._model_attrs = infer_quantisation_config(self, self._quantise, **self._model_attrs) + self.__llm_quantization_config__, self._model_attrs = infer_quantisation_config( + self, self._quantise, **self._model_attrs + ) else: raise ValueError("Either 'quantization_config' or 'quantise' must be specified.") return self.__llm_quantization_config__ @@ -418,7 +450,11 @@ def prepare(self, adapter_type='lora', use_gradient_checking=True, **attrs): model = get_peft_model( prepare_model_for_kbit_training(self.model, use_gradient_checkpointing=use_gradient_checking), - self.config['fine_tune_strategies'].get(adapter_type, self.config.make_fine_tune_config(adapter_type)).train().with_config(**attrs).build(), + self.config['fine_tune_strategies'] + .get(adapter_type, self.config.make_fine_tune_config(adapter_type)) + .train() + .with_config(**attrs) + .build(), ) if DEBUG: model.print_trainable_parameters() @@ -438,7 +474,10 @@ def adapter_map(self): if self.__llm_adapter_map__ is None: _map: ResolvedAdapterMap = {k: {} for k in self._adapter_map} for adapter_type, adapter_tuple in self._adapter_map.items(): - base = first_not_none(self.config['fine_tune_strategies'].get(adapter_type), default=self.config.make_fine_tune_config(adapter_type)) + base = first_not_none( + self.config['fine_tune_strategies'].get(adapter_type), + default=self.config.make_fine_tune_config(adapter_type), + ) for adapter in adapter_tuple: _map[adapter_type][adapter.name] = (base.with_config(**adapter.config).build(), adapter.adapter_id) self.__llm_adapter_map__ = _map @@ -453,7 +492,9 @@ def model(self): import torch loaded_in_kbit = ( - getattr(model, 'is_loaded_in_8bit', False) or getattr(model, 'is_loaded_in_4bit', False) or getattr(model, 'is_quantized', False) + getattr(model, 'is_loaded_in_8bit', False) + or getattr(model, 'is_loaded_in_4bit', False) + or getattr(model, 'is_quantized', False) ) if torch.cuda.is_available() and torch.cuda.device_count() == 1 and not loaded_in_kbit: try: @@ -474,7 +515,6 @@ def _architecture_mappings(self): @property def config(self): - import transformers if self._local: config_file = os.path.join(self.model_id, CONFIG_FILE_NAME) else: @@ -482,7 +522,9 @@ def config(self): config_file = self.bentomodel.path_of(CONFIG_FILE_NAME) except OpenLLMException as err: if not is_transformers_available(): - raise MissingDependencyError("Requires 'transformers' to be available. Do 'pip install transformers'") from err + raise MissingDependencyError( + "Requires 'transformers' to be available. Do 'pip install transformers'" + ) from err from transformers.utils import cached_file try: @@ -499,7 +541,9 @@ def config(self): if 'architectures' in loaded_config: for architecture in loaded_config['architectures']: if architecture in self._architecture_mappings: - self.__llm_config__ = openllm_core.AutoConfig.for_model(self._architecture_mappings[architecture]).model_construct_env(**self._model_attrs) + self.__llm_config__ = openllm_core.AutoConfig.for_model( + self._architecture_mappings[architecture] + ).model_construct_env(**self._model_attrs) break else: raise ValueError(f"Failed to find architecture from 'config.json' (config_json_path={config_file})") @@ -520,12 +564,18 @@ def _torch_dtype_mapping() -> dict[str, torch.dtype]: def normalise_model_name(name: str) -> str: - return os.path.basename(resolve_filepath(name)) if validate_is_path(name) else inflection.dasherize(name.replace('/', '--')) + return ( + os.path.basename(resolve_filepath(name)) + if validate_is_path(name) + else inflection.dasherize(name.replace('/', '--')) + ) def convert_peft_config_type(adapter_map: dict[str, str]) -> AdapterMap: if not is_peft_available(): - raise RuntimeError("LoRA adapter requires 'peft' to be installed. Make sure to do 'pip install \"openllm[fine-tune]\"'") + raise RuntimeError( + "LoRA adapter requires 'peft' to be installed. Make sure to do 'pip install \"openllm[fine-tune]\"'" + ) from huggingface_hub import hf_hub_download resolved: AdapterMap = {} diff --git a/openllm-python/src/openllm/_runners.py b/openllm-python/src/openllm/_runners.py index 52961a8ed..1750ae2ae 100644 --- a/openllm-python/src/openllm/_runners.py +++ b/openllm-python/src/openllm/_runners.py @@ -46,7 +46,10 @@ def runner(llm: openllm.LLM[M, T]) -> Runner[M, T]: ( 'runner_methods', { - method.name: {'batchable': method.config.batchable, 'batch_dim': method.config.batch_dim if method.config.batchable else None} + method.name: { + 'batchable': method.config.batchable, + 'batch_dim': method.config.batch_dim if method.config.batchable else None, + } for method in _.runner_methods }, ), @@ -83,7 +86,9 @@ def __init__(self, llm): if dev >= 2: num_gpus = min(dev // 2 * 2, dev) quantise = llm.quantise if llm.quantise and llm.quantise in {'gptq', 'awq', 'squeezellm'} else None - dtype = torch.float16 if quantise == 'gptq' else llm._torch_dtype # NOTE: quantise GPTQ doesn't support bfloat16 yet. + dtype = ( + torch.float16 if quantise == 'gptq' else llm._torch_dtype + ) # NOTE: quantise GPTQ doesn't support bfloat16 yet. try: self.model = vllm.AsyncLLMEngine.from_engine_args( vllm.AsyncEngineArgs( @@ -102,7 +107,9 @@ def __init__(self, llm): ) except Exception as err: traceback.print_exc() - raise openllm.exceptions.OpenLLMException(f'Failed to initialise vLLMEngine due to the following error:\n{err}') from err + raise openllm.exceptions.OpenLLMException( + f'Failed to initialise vLLMEngine due to the following error:\n{err}' + ) from err @bentoml.Runnable.method(batchable=False) async def generate_iterator(self, prompt_token_ids, request_id, stop=None, adapter_name=None, **attrs): @@ -159,7 +166,9 @@ async def generate_iterator(self, prompt_token_ids, request_id, stop=None, adapt if config['logprobs']: # FIXME: logprobs is not supported raise NotImplementedError('Logprobs is yet to be supported with encoder-decoder models.') encoder_output = self.model.encoder(input_ids=torch.as_tensor([prompt_token_ids], device=self.device))[0] - start_ids = torch.as_tensor([[self.model.generation_config.decoder_start_token_id]], dtype=torch.int64, device=self.device) + start_ids = torch.as_tensor( + [[self.model.generation_config.decoder_start_token_id]], dtype=torch.int64, device=self.device + ) else: start_ids = torch.as_tensor([prompt_token_ids], device=self.device) @@ -187,7 +196,9 @@ async def generate_iterator(self, prompt_token_ids, request_id, stop=None, adapt ) logits = self.model.lm_head(out[0]) else: - out = self.model(input_ids=torch.as_tensor([[token]], device=self.device), past_key_values=past_key_values, use_cache=True) + out = self.model( + input_ids=torch.as_tensor([[token]], device=self.device), past_key_values=past_key_values, use_cache=True + ) logits = out.logits past_key_values = out.past_key_values if logits_processor: @@ -231,7 +242,12 @@ async def generate_iterator(self, prompt_token_ids, request_id, stop=None, adapt tmp_output_ids, rfind_start = output_token_ids[input_len:], 0 # XXX: Move this to API server - text = self.tokenizer.decode(tmp_output_ids, skip_special_tokens=True, spaces_between_special_tokens=False, clean_up_tokenization_spaces=True) + text = self.tokenizer.decode( + tmp_output_ids, + skip_special_tokens=True, + spaces_between_special_tokens=False, + clean_up_tokenization_spaces=True, + ) if len(stop) > 0: for it in stop: diff --git a/openllm-python/src/openllm/bundle/_package.py b/openllm-python/src/openllm/bundle/_package.py index 784b495a6..86e889875 100644 --- a/openllm-python/src/openllm/bundle/_package.py +++ b/openllm-python/src/openllm/bundle/_package.py @@ -101,7 +101,7 @@ def create_bento( 'base_name_or_path': llm.model_id, 'bundler': 'openllm.bundle', **{ - f'{package.replace("-","_")}_version': importlib.metadata.version(package) + f'{package.replace("-", "_")}_version': importlib.metadata.version(package) for package in {'openllm', 'openllm-core', 'openllm-client'} }, }) diff --git a/openllm-python/src/openllm_cli/_sdk.py b/openllm-python/src/openllm_cli/_sdk.py index 1c9c04de2..a5167cbe4 100644 --- a/openllm-python/src/openllm_cli/_sdk.py +++ b/openllm-python/src/openllm_cli/_sdk.py @@ -81,7 +81,7 @@ def _start( if adapter_map: args.extend( list( - itertools.chain.from_iterable([['--adapter-id', f"{k}{':'+v if v else ''}"] for k, v in adapter_map.items()]) + itertools.chain.from_iterable([['--adapter-id', f"{k}{':' + v if v else ''}"] for k, v in adapter_map.items()]) ) ) if additional_args: @@ -173,7 +173,7 @@ def _build( if overwrite: args.append('--overwrite') if adapter_map: - args.extend([f"--adapter-id={k}{':'+v if v is not None else ''}" for k, v in adapter_map.items()]) + args.extend([f"--adapter-id={k}{':' + v if v is not None else ''}" for k, v in adapter_map.items()]) if model_version: args.extend(['--model-version', model_version]) if bento_version: diff --git a/openllm-python/src/openllm_cli/extension/dive_bentos.py b/openllm-python/src/openllm_cli/extension/dive_bentos.py index db488004d..541d07bf2 100644 --- a/openllm-python/src/openllm_cli/extension/dive_bentos.py +++ b/openllm-python/src/openllm_cli/extension/dive_bentos.py @@ -21,7 +21,9 @@ @machine_option @click.pass_context @inject -def cli(ctx: click.Context, bento: str, machine: bool, _bento_store: BentoStore = Provide[BentoMLContainer.bento_store]) -> str | None: +def cli( + ctx: click.Context, bento: str, machine: bool, _bento_store: BentoStore = Provide[BentoMLContainer.bento_store] +) -> str | None: """Dive into a BentoLLM. This is synonymous to cd $(b get : -o path).""" try: bentomodel = _bento_store.get(bento) diff --git a/openllm-python/src/openllm_cli/extension/get_containerfile.py b/openllm-python/src/openllm_cli/extension/get_containerfile.py index 886054144..507988292 100644 --- a/openllm-python/src/openllm_cli/extension/get_containerfile.py +++ b/openllm-python/src/openllm_cli/extension/get_containerfile.py @@ -17,7 +17,9 @@ from bentoml._internal.bento import BentoStore -@click.command('get_containerfile', context_settings=termui.CONTEXT_SETTINGS, help='Return Containerfile of any given Bento.') +@click.command( + 'get_containerfile', context_settings=termui.CONTEXT_SETTINGS, help='Return Containerfile of any given Bento.' +) @click.argument('bento', type=str, shell_complete=bento_complete_envvar) @click.pass_context @inject diff --git a/openllm-python/src/openllm_cli/extension/get_prompt.py b/openllm-python/src/openllm_cli/extension/get_prompt.py index b679577fe..0e64c2304 100644 --- a/openllm-python/src/openllm_cli/extension/get_prompt.py +++ b/openllm-python/src/openllm_cli/extension/get_prompt.py @@ -22,7 +22,9 @@ def vformat(self, format_string: str, args: t.Sequence[t.Any], kwargs: t.Mapping raise ValueError('Positional arguments are not supported') return super().vformat(format_string, args, kwargs) - def check_unused_args(self, used_args: set[int | str], args: t.Sequence[t.Any], kwargs: t.Mapping[str, t.Any]) -> None: + def check_unused_args( + self, used_args: set[int | str], args: t.Sequence[t.Any], kwargs: t.Mapping[str, t.Any] + ) -> None: extras = set(kwargs).difference(used_args) if extras: raise KeyError(f'Extra params passed: {extras}') @@ -56,7 +58,9 @@ def format(self, **attrs: t.Any) -> str: try: return self.template.format(**prompt_variables) except KeyError as e: - raise RuntimeError(f"Missing variable '{e.args[0]}' (required: {self._input_variables}) in the prompt template.") from None + raise RuntimeError( + f"Missing variable '{e.args[0]}' (required: {self._input_variables}) in the prompt template." + ) from None @click.command('get_prompt', context_settings=termui.CONTEXT_SETTINGS) @@ -124,15 +128,21 @@ def cli( if prompt_template_file and chat_template_file: ctx.fail('prompt-template-file and chat-template-file are mutually exclusive.') - acceptable = set(openllm.CONFIG_MAPPING_NAMES.keys()) | set(inflection.dasherize(name) for name in openllm.CONFIG_MAPPING_NAMES.keys()) + acceptable = set(openllm.CONFIG_MAPPING_NAMES.keys()) | set( + inflection.dasherize(name) for name in openllm.CONFIG_MAPPING_NAMES.keys() + ) if model_id in acceptable: - logger.warning('Using a default prompt from OpenLLM. Note that this prompt might not work for your intended usage.\n') + logger.warning( + 'Using a default prompt from OpenLLM. Note that this prompt might not work for your intended usage.\n' + ) config = openllm.AutoConfig.for_model(model_id) template = prompt_template_file.read() if prompt_template_file is not None else config.template system_message = system_message or config.system_message try: - formatted = PromptTemplate(template).with_options(system_message=system_message).format(instruction=prompt, **_memoized) + formatted = ( + PromptTemplate(template).with_options(system_message=system_message).format(instruction=prompt, **_memoized) + ) except RuntimeError as err: logger.debug('Exception caught while formatting prompt: %s', err) ctx.fail(str(err)) @@ -149,15 +159,21 @@ def cli( for architecture in config.architectures: if architecture in openllm.AutoConfig._CONFIG_MAPPING_NAMES_TO_ARCHITECTURE(): system_message = ( - openllm.AutoConfig.infer_class_from_name(openllm.AutoConfig._CONFIG_MAPPING_NAMES_TO_ARCHITECTURE()[architecture]) + openllm.AutoConfig.infer_class_from_name( + openllm.AutoConfig._CONFIG_MAPPING_NAMES_TO_ARCHITECTURE()[architecture] + ) .model_construct_env() .system_message ) break else: - ctx.fail(f'Failed to infer system message from model architecture: {config.architectures}. Please pass in --system-message') + ctx.fail( + f'Failed to infer system message from model architecture: {config.architectures}. Please pass in --system-message' + ) messages = [{'role': 'system', 'content': system_message}, {'role': 'user', 'content': prompt}] - formatted = tokenizer.apply_chat_template(messages, chat_template=chat_template_file, add_generation_prompt=add_generation_prompt, tokenize=False) + formatted = tokenizer.apply_chat_template( + messages, chat_template=chat_template_file, add_generation_prompt=add_generation_prompt, tokenize=False + ) termui.echo(orjson.dumps({'prompt': formatted}, option=orjson.OPT_INDENT_2).decode(), fg='white') ctx.exit(0) diff --git a/openllm-python/src/openllm_cli/extension/list_models.py b/openllm-python/src/openllm_cli/extension/list_models.py index eb18ce0d6..6eb49e079 100644 --- a/openllm-python/src/openllm_cli/extension/list_models.py +++ b/openllm-python/src/openllm_cli/extension/list_models.py @@ -33,12 +33,17 @@ def cli(model_name: str | None) -> DictStrAny: } if model_name is not None: ids_in_local_store = { - k: [i for i in v if 'model_name' in i.info.labels and i.info.labels['model_name'] == inflection.dasherize(model_name)] + k: [ + i + for i in v + if 'model_name' in i.info.labels and i.info.labels['model_name'] == inflection.dasherize(model_name) + ] for k, v in ids_in_local_store.items() } ids_in_local_store = {k: v for k, v in ids_in_local_store.items() if v} local_models = { - k: [{'tag': str(i.tag), 'size': human_readable_size(openllm.utils.calc_dir_size(i.path))} for i in val] for k, val in ids_in_local_store.items() + k: [{'tag': str(i.tag), 'size': human_readable_size(openllm.utils.calc_dir_size(i.path))} for i in val] + for k, val in ids_in_local_store.items() } termui.echo(orjson.dumps(local_models, option=orjson.OPT_INDENT_2).decode(), fg='white') return local_models diff --git a/openllm-python/src/openllm_cli/extension/playground.py b/openllm-python/src/openllm_cli/extension/playground.py index f8e5b4da4..fcbc128b2 100644 --- a/openllm-python/src/openllm_cli/extension/playground.py +++ b/openllm-python/src/openllm_cli/extension/playground.py @@ -32,7 +32,14 @@ def load_notebook_metadata() -> DictStrAny: @click.command('playground', context_settings=termui.CONTEXT_SETTINGS) @click.argument('output-dir', default=None, required=False) -@click.option('--port', envvar='JUPYTER_PORT', show_envvar=True, show_default=True, default=8888, help='Default port for Jupyter server') +@click.option( + '--port', + envvar='JUPYTER_PORT', + show_envvar=True, + show_default=True, + default=8888, + help='Default port for Jupyter server', +) @click.pass_context def cli(ctx: click.Context, output_dir: str | None, port: int) -> None: """OpenLLM Playground. @@ -53,7 +60,9 @@ def cli(ctx: click.Context, output_dir: str | None, port: int) -> None: > This command requires Jupyter to be installed. Install it with 'pip install "openllm[playground]"' """ if not is_jupyter_available() or not is_jupytext_available() or not is_notebook_available(): - raise RuntimeError("Playground requires 'jupyter', 'jupytext', and 'notebook'. Install it with 'pip install \"openllm[playground]\"'") + raise RuntimeError( + "Playground requires 'jupyter', 'jupytext', and 'notebook'. Install it with 'pip install \"openllm[playground]\"'" + ) metadata = load_notebook_metadata() _temp_dir = False if output_dir is None: @@ -65,7 +74,9 @@ def cli(ctx: click.Context, output_dir: str | None, port: int) -> None: termui.echo('The playground notebooks will be saved to: ' + os.path.abspath(output_dir), fg='blue') for module in pkgutil.iter_modules(playground.__path__): if module.ispkg or os.path.exists(os.path.join(output_dir, module.name + '.ipynb')): - logger.debug('Skipping: %s (%s)', module.name, 'File already exists' if not module.ispkg else f'{module.name} is a module') + logger.debug( + 'Skipping: %s (%s)', module.name, 'File already exists' if not module.ispkg else f'{module.name} is a module' + ) continue if not isinstance(module.module_finder, importlib.machinery.FileFinder): continue diff --git a/openllm-python/tests/configuration_test.py b/openllm-python/tests/configuration_test.py index 90069b797..fafa40984 100644 --- a/openllm-python/tests/configuration_test.py +++ b/openllm-python/tests/configuration_test.py @@ -66,8 +66,15 @@ def test_config_derived_follow_attrs_protocol(gen_settings: ModelSettings): st.integers(max_value=283473), st.floats(min_value=0.0, max_value=1.0), ) -def test_complex_struct_dump(gen_settings: ModelSettings, field1: int, temperature: float, input_field1: int, input_temperature: float): - cl_ = make_llm_config('ComplexLLM', gen_settings, fields=(('field1', 'float', field1),), generation_fields=(('temperature', temperature),)) +def test_complex_struct_dump( + gen_settings: ModelSettings, field1: int, temperature: float, input_field1: int, input_temperature: float +): + cl_ = make_llm_config( + 'ComplexLLM', + gen_settings, + fields=(('field1', 'float', field1),), + generation_fields=(('temperature', temperature),), + ) sent = cl_() assert sent.model_dump()['field1'] == field1 assert sent.model_dump()['generation_config']['temperature'] == temperature diff --git a/openllm-python/tests/conftest.py b/openllm-python/tests/conftest.py index 1efd9e4d1..e49b2656c 100644 --- a/openllm-python/tests/conftest.py +++ b/openllm-python/tests/conftest.py @@ -10,8 +10,14 @@ if t.TYPE_CHECKING: from openllm_core._typing_compat import LiteralBackend -_MODELING_MAPPING = {'flan_t5': 'google/flan-t5-small', 'opt': 'facebook/opt-125m', 'baichuan': 'baichuan-inc/Baichuan-7B'} -_PROMPT_MAPPING = {'qa': 'Answer the following yes/no question by reasoning step-by-step. Can you write a whole Haiku in a single tweet?'} +_MODELING_MAPPING = { + 'flan_t5': 'google/flan-t5-small', + 'opt': 'facebook/opt-125m', + 'baichuan': 'baichuan-inc/Baichuan-7B', +} +_PROMPT_MAPPING = { + 'qa': 'Answer the following yes/no question by reasoning step-by-step. Can you write a whole Haiku in a single tweet?' +} def parametrise_local_llm(model: str) -> t.Generator[tuple[str, openllm.LLM[t.Any, t.Any]], None, None]: @@ -25,7 +31,9 @@ def parametrise_local_llm(model: str) -> t.Generator[tuple[str, openllm.LLM[t.An def pytest_generate_tests(metafunc: pytest.Metafunc) -> None: if os.getenv('GITHUB_ACTIONS') is None: if 'prompt' in metafunc.fixturenames and 'llm' in metafunc.fixturenames: - metafunc.parametrize('prompt,llm', [(p, llm) for p, llm in parametrise_local_llm(metafunc.function.__name__[5:-15])]) + metafunc.parametrize( + 'prompt,llm', [(p, llm) for p, llm in parametrise_local_llm(metafunc.function.__name__[5:-15])] + ) def pytest_sessionfinish(session: pytest.Session, exitstatus: int): diff --git a/openllm-python/tests/strategies_test.py b/openllm-python/tests/strategies_test.py index f801ed81d..6b95ac0df 100644 --- a/openllm-python/tests/strategies_test.py +++ b/openllm-python/tests/strategies_test.py @@ -73,9 +73,13 @@ def test_nvidia_gpu_validate(monkeypatch: pytest.MonkeyPatch): mcls.setenv('CUDA_VISIBLE_DEVICES', '') assert len(NvidiaGpuResource.from_system()) >= 0 # TODO: real from_system tests - assert pytest.raises(ValueError, NvidiaGpuResource.validate, [*NvidiaGpuResource.from_system(), 1]).match('Input list should be all string type.') + assert pytest.raises(ValueError, NvidiaGpuResource.validate, [*NvidiaGpuResource.from_system(), 1]).match( + 'Input list should be all string type.' + ) assert pytest.raises(ValueError, NvidiaGpuResource.validate, [-2]).match('Input list should be all string type.') - assert pytest.raises(ValueError, NvidiaGpuResource.validate, ['GPU-5ebe9f43', 'GPU-ac33420d4628']).match('Failed to parse available GPUs UUID') + assert pytest.raises(ValueError, NvidiaGpuResource.validate, ['GPU-5ebe9f43', 'GPU-ac33420d4628']).match( + 'Failed to parse available GPUs UUID' + ) def test_nvidia_gpu_from_spec(monkeypatch: pytest.MonkeyPatch):