Skip to content

Commit

Permalink
add message type literal to usage stats
Browse files Browse the repository at this point in the history
  • Loading branch information
Caren Thomas committed Dec 20, 2024
1 parent 108a460 commit db96c9d
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 4 deletions.
4 changes: 2 additions & 2 deletions letta/client/streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ def _sse_post(url: str, data: dict, headers: dict) -> Generator[LettaStreamingRe
yield ToolCallMessage(**chunk_data)
elif "tool_return" in chunk_data:
yield ToolReturnMessage(**chunk_data)
elif "usage" in chunk_data:
yield LettaUsageStatistics(**chunk_data["usage"])
elif "step_count" in chunk_data:
yield LettaUsageStatistics(**chunk_data)
else:
raise ValueError(f"Unknown message type in chunk_data: {chunk_data}")

Expand Down
3 changes: 2 additions & 1 deletion letta/schemas/usage.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Literal
from pydantic import BaseModel, Field


Expand All @@ -11,7 +12,7 @@ class LettaUsageStatistics(BaseModel):
total_tokens (int): The total number of tokens processed by the agent.
step_count (int): The number of steps taken by the agent.
"""

message_type: Literal["usage_statistics"] = "usage_statistics"
completion_tokens: int = Field(0, description="The number of tokens generated by the agent.")
prompt_tokens: int = Field(0, description="The number of tokens in the prompt.")
total_tokens: int = Field(0, description="The total number of tokens processed by the agent.")
Expand Down
2 changes: 1 addition & 1 deletion letta/server/rest_api/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ async def sse_async_generator(
# Double-check the type
if not isinstance(usage, LettaUsageStatistics):
raise ValueError(f"Expected LettaUsageStatistics, got {type(usage)}")
yield sse_formatter({"usage": usage.model_dump()})
yield sse_formatter(usage.model_dump())

except ContextWindowExceededError as e:
log_error_to_sentry(e)
Expand Down

0 comments on commit db96c9d

Please sign in to comment.