Skip to content

Commit

Permalink
Extract openai predict logic into smaller methods (kserve#3716)
Browse files Browse the repository at this point in the history
* refactor into smaller methods

Signed-off-by: grandbora <grandbora@fb.com>

* address comments

Signed-off-by: grandbora <grandbora@fb.com>

* format

Signed-off-by: grandbora <grandbora@fb.com>

---------

Signed-off-by: grandbora <grandbora@fb.com>
  • Loading branch information
grandbora authored Jun 18, 2024
1 parent 212a77c commit 32d3e19
Showing 1 changed file with 28 additions and 16 deletions.
44 changes: 28 additions & 16 deletions python/kserve/kserve/protocol/rest/openai/openai_proxy_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,8 +243,8 @@ async def create_completion(
self, request: CompletionRequest
) -> Union[Completion, AsyncIterator[Completion]]:
self.preprocess_completion_request(request)
req = self._build_request(self._completions_endpoint, request)
if request.params.stream:
req = self._build_request(self._completions_endpoint, request)
r = await self._http_client.send(req, stream=True)
r.raise_for_status()
it = AsyncMappingIterator(
Expand All @@ -254,23 +254,28 @@ async def create_completion(
)
return it
else:
response = await self._http_client.send(req)
response.raise_for_status()
if self.skip_upstream_validation:
obj = response.json()
completion = Completion.model_construct(**obj)
else:
completion = Completion.model_validate_json(response.content)
completion = await self.generate_completion(request)
self.postprocess_completion(completion, request)
return completion

async def generate_completion(self, request: CompletionRequest) -> Completion:
req = self._build_request(self._completions_endpoint, request)
response = await self._http_client.send(req)
response.raise_for_status()
if self.skip_upstream_validation:
obj = response.json()
completion = Completion.model_construct(**obj)
else:
completion = Completion.model_validate_json(response.content)
return completion

@error_handler
async def create_chat_completion(
self, request: ChatCompletionRequest
) -> Union[ChatCompletion, AsyncIterator[ChatCompletionChunk]]:
self.preprocess_chat_completion_request(request)
req = self._build_request(self._chat_completions_endpoint, request)
if request.params.stream:
req = self._build_request(self._chat_completions_endpoint, request)
r = await self._http_client.send(req, stream=True)
r.raise_for_status()
it = AsyncMappingIterator(
Expand All @@ -280,12 +285,19 @@ async def create_chat_completion(
)
return it
else:
response = await self._http_client.send(req)
response.raise_for_status()
if self.skip_upstream_validation:
obj = response.json()
chat_completion = ChatCompletion.model_construct(**obj)
else:
chat_completion = ChatCompletion.model_validate_json(response.content)
chat_completion = await self.generate_chat_completion(request)
self.postprocess_chat_completion(chat_completion, request)
return chat_completion

async def generate_chat_completion(
self, request: ChatCompletionRequest
) -> ChatCompletion:
req = self._build_request(self._chat_completions_endpoint, request)
response = await self._http_client.send(req)
response.raise_for_status()
if self.skip_upstream_validation:
obj = response.json()
chat_completion = ChatCompletion.model_construct(**obj)
else:
chat_completion = ChatCompletion.model_validate_json(response.content)
return chat_completion

0 comments on commit 32d3e19

Please sign in to comment.