Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: improved chunk validation #159

Closed
wants to merge 21 commits into from
Closed
Changes from 1 commit
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix: typing fixes
  • Loading branch information
adubovik committed Sep 16, 2024

Verified

This commit was signed with the committer’s verified signature.
commit f8f16b56decc793df804929de25ecec3dc28842a
14 changes: 7 additions & 7 deletions aidial_sdk/chat_completion/chunks.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, TypedDict

from openai.types.chat.chat_completion_chunk import ChatCompletionChunk

@@ -8,16 +8,16 @@
from aidial_sdk.utils.json import remove_nones


class DefaultChunk(BaseModel):
response_id: Optional[str] = None
model: Optional[str] = None
created: Optional[int] = None
object: Optional[str] = None
class DefaultChunk(TypedDict, total=False):
response_id: str
model: str
created: int
object: str


class BaseChunk(ABC):
@abstractmethod
def to_dict(self, overrides: dict) -> Dict[str, Any]:
def to_dict(self, overrides: DefaultChunk) -> Dict[str, Any]:
pass


26 changes: 10 additions & 16 deletions aidial_sdk/chat_completion/response.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,6 @@
import asyncio
from time import time
from typing import (
Any,
AsyncGenerator,
Callable,
Coroutine,
List,
Union,
)
from typing import Any, AsyncGenerator, Callable, Coroutine, List, Union
from uuid import uuid4

from aidial_sdk.chat_completion.choice import Choice
@@ -58,7 +51,6 @@ def __init__(self, request: Request):

self._default_chunk = DefaultChunk(
response_id=str(uuid4()),
model=None,
created=int(time()),
object=(
"chat.completion.chunk" if request.stream else "chat.completion"
@@ -71,7 +63,7 @@ def get_block_response(self) -> dict:
async def _generate_stream(
self, first_chunk: BaseChunk
) -> AsyncGenerator[Any, None]:
chunk = first_chunk.to_dict()
chunk = first_chunk.to_dict(self._default_chunk)

if self.request.stream:
yield format_chunk(chunk)
@@ -123,18 +115,20 @@ async def _generate_stream(

if isinstance(item, EndChoiceChunk):
if item.choice_index == (self.request.n or 1) - 1:
last_end_choice_chunk = item.to_dict()
last_end_choice_chunk = item.to_dict(self._default_chunk)
self._queue.task_done()
continue

if isinstance(
item,
(UsageChunk, UsagePerModelChunk, DiscardedMessagesChunk),
):
usage_chunk = merge(usage_chunk, item.to_dict())
usage_chunk = merge(
usage_chunk, item.to_dict(self._default_chunk)
)

elif isinstance(item, BaseChunk):
chunk = item.to_dict()
chunk = item.to_dict(self._default_chunk)

if self.request.stream:
yield format_chunk(chunk)
@@ -283,7 +277,7 @@ def send_chunk(self, chunk: Union[BaseChunk, EndMarker]):
)

if isinstance(chunk, BaseChunk):
self._snapshot.add_delta(chunk.to_dict(self._default_chunk.dict()))
self._snapshot.add_delta(chunk.to_dict(self._default_chunk))

self._queue.put_nowait(chunk)

@@ -304,12 +298,12 @@ def set_model(self, model: str):
'Trying to set "model" after start of generation'
)

self._model = model
self._default_chunk["model"] = model

def set_response_id(self, response_id: str):
if self._generation_started:
raise runtime_error(
'Trying to set "response_id" after start of generation',
)

self._response_id = response_id
self._default_chunk["response_id"] = response_id
Loading