Skip to content

Commit

Permalink
feat: add message type literal to usage stats (#2297)
Browse files Browse the repository at this point in the history
Co-authored-by: Caren Thomas <caren@caren-mac.local>
  • Loading branch information
carenthomas and Caren Thomas authored Dec 21, 2024
1 parent 803833e commit b9b77fd
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 4 deletions.
4 changes: 2 additions & 2 deletions letta/client/streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ def _sse_post(url: str, data: dict, headers: dict) -> Generator[LettaStreamingRe
yield ToolCallMessage(**chunk_data)
elif "tool_return" in chunk_data:
yield ToolReturnMessage(**chunk_data)
elif "usage" in chunk_data:
yield LettaUsageStatistics(**chunk_data["usage"])
elif "step_count" in chunk_data:
yield LettaUsageStatistics(**chunk_data)
else:
raise ValueError(f"Unknown message type in chunk_data: {chunk_data}")

Expand Down
3 changes: 2 additions & 1 deletion letta/schemas/usage.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Literal
from pydantic import BaseModel, Field


Expand All @@ -11,7 +12,7 @@ class LettaUsageStatistics(BaseModel):
total_tokens (int): The total number of tokens processed by the agent.
step_count (int): The number of steps taken by the agent.
"""

message_type: Literal["usage_statistics"] = "usage_statistics"
completion_tokens: int = Field(0, description="The number of tokens generated by the agent.")
prompt_tokens: int = Field(0, description="The number of tokens in the prompt.")
total_tokens: int = Field(0, description="The total number of tokens processed by the agent.")
Expand Down
2 changes: 1 addition & 1 deletion letta/server/rest_api/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ async def sse_async_generator(
# Double-check the type
if not isinstance(usage, LettaUsageStatistics):
raise ValueError(f"Expected LettaUsageStatistics, got {type(usage)}")
yield sse_formatter({"usage": usage.model_dump()})
yield sse_formatter(usage.model_dump())

except ContextWindowExceededError as e:
log_error_to_sentry(e)
Expand Down

0 comments on commit b9b77fd

Please sign in to comment.