Skip to content

Commit

Permalink
Merge pull request #32 from peterkeppert/event-streaming
Browse files Browse the repository at this point in the history
Switch to astream_events to support token-level streaming with Gemini
  • Loading branch information
JoshuaC215 authored Sep 19, 2024
2 parents 0c5d006 + 21ca545 commit 24f86d2
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 56 deletions.
12 changes: 6 additions & 6 deletions compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@ services:
- .env
develop:
watch:
- path: agent/
- path: src/agent/
action: sync+restart
target: /app/agent/
- path: schema/
- path: src/schema/
action: sync+restart
target: /app/schema/
- path: service/
- path: src/service/
action: sync+restart
target: /app/service/

Expand All @@ -31,12 +31,12 @@ services:
- AGENT_URL=http://agent_service
develop:
watch:
- path: client/
- path: src/client/
action: sync+restart
target: /app/client/
- path: schema/
- path: src/schema/
action: sync+restart
target: /app/schema/
- path: streamlit_app.py
- path: src/streamlit_app.py
action: sync+restart
target: /app/streamlit_app.py
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ classifiers = [
"Programming Language :: Python :: 3.12",
]

requires-python = ">=3.9, <=3.12.3"
requires-python = ">=3.10, <=3.12.3"

# NOTE: FastAPI < 0.100.0 and Pydantic v1 is required until langchain has full pydantic v2 compatibility
# https://python.langchain.com/v0.1/docs/guides/development/pydantic_compatibility/
Expand Down Expand Up @@ -51,7 +51,7 @@ dev = [

[tool.ruff]
line-length = 100
target-version = "py39"
target-version = "py310"

[tool.pytest_env]
OPENAI_API_KEY = "sk-fake-openai-key"
4 changes: 3 additions & 1 deletion src/agent/llama_guard.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,9 @@ def __init__(self):
print("GROQ_API_KEY not set, skipping LlamaGuard")
self.model = None
return
self.model = ChatGroq(model="llama-guard-3-8b", temperature=0.0)
self.model = ChatGroq(model="llama-guard-3-8b", temperature=0.0).with_config(
tags=["llama_guard"],
)
self.prompt = PromptTemplate.from_template(llama_guard_instructions)

def _compile_prompt(self, role: str, messages: List[AnyMessage]) -> str:
Expand Down
84 changes: 37 additions & 47 deletions src/service/service.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import asyncio
from contextlib import asynccontextmanager
import json
import os
import warnings
from typing import AsyncGenerator, Dict, Any, Tuple
from uuid import uuid4
from fastapi import FastAPI, HTTPException, Request, Response
from fastapi.responses import StreamingResponse
from langchain_core.callbacks import AsyncCallbackHandler
from langchain_core._api import LangChainBetaWarning
from langchain_core.runnables import RunnableConfig
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
from langgraph.graph.graph import CompiledGraph
Expand All @@ -15,16 +15,7 @@
from agent import research_assistant
from schema import ChatMessage, Feedback, UserInput, StreamInput


class TokenQueueStreamingHandler(AsyncCallbackHandler):
"""LangChain callback handler for streaming LLM tokens to an asyncio queue."""

def __init__(self, queue: asyncio.Queue):
self.queue = queue

async def on_llm_new_token(self, token: str, **kwargs) -> None:
if token:
await self.queue.put(token)
warnings.filterwarnings("ignore", category=LangChainBetaWarning)


@asynccontextmanager
Expand Down Expand Up @@ -93,47 +84,46 @@ async def message_generator(user_input: StreamInput) -> AsyncGenerator[str, None
agent: CompiledGraph = app.state.agent
kwargs, run_id = _parse_input(user_input)

# Use an asyncio queue to process both messages and tokens in
# chronological order, so we can easily yield them to the client.
output_queue = asyncio.Queue(maxsize=10)
if user_input.stream_tokens:
kwargs["config"]["callbacks"] = [TokenQueueStreamingHandler(queue=output_queue)]

# Pass the agent's stream of messages to the queue in a separate task, so
# we can yield the messages to the client in the main thread.
async def run_agent_stream():
async for s in agent.astream(**kwargs, stream_mode="updates"):
await output_queue.put(s)
await output_queue.put(None)

stream_task = asyncio.create_task(run_agent_stream())

# Process the queue and yield messages over the SSE stream.
while s := await output_queue.get():
if isinstance(s, str):
# str is an LLM token
yield f"data: {json.dumps({'type': 'token', 'content': s})}\n\n"
# Process streamed events from the graph and yield messages over the SSE stream.
async for event in agent.astream_events(**kwargs, version="v2"):
if not event:
continue

# Otherwise, s should be a dict of state updates for each node in the graph.
# s could have updates for multiple nodes, so check each for messages.
new_messages = []
for _, state in s.items():
if "messages" in state:
new_messages.extend(state["messages"])
for message in new_messages:
try:
chat_message = ChatMessage.from_langchain(message)
chat_message.run_id = str(run_id)
except Exception as e:
yield f"data: {json.dumps({'type': 'error', 'content': f'Error parsing message: {e}'})}\n\n"
continue
# LangGraph re-sends the input message, which feels weird, so drop it
# Yield messages written to the graph state after node execution finishes.
if (
event["event"] == "on_chain_end"
# on_chain_end gets called a bunch of times in a graph execution
# This filters out everything except for "graph node finished"
and any(t.startswith("graph:step:") for t in event.get("tags", []))
and "messages" in event["data"]["output"]
):
new_messages = event["data"]["output"]["messages"]
for message in new_messages:
try:
chat_message = ChatMessage.from_langchain(message)
chat_message.run_id = str(run_id)
except Exception as e:
yield f"data: {json.dumps({'type': 'error', 'content': f'Error parsing message: {e}'})}\n\n"
continue
# LangGraph re-sends the input message, which feels weird, so drop it
if chat_message.type == "human" and chat_message.content == user_input.message:
continue
yield f"data: {json.dumps({'type': 'message', 'content': chat_message.dict()})}\n\n"

await stream_task
# Yield tokens streamed from LLMs.
if (
event["event"] == "on_chat_model_stream"
and user_input.stream_tokens
and "llama_guard" not in event.get("tags", [])
):
content = event["data"]["chunk"].content
if content:
# Empty content in the context of OpenAI or Anthropic usually means
# that the model is asking for a tool to be invoked.
# So we only print non-empty content.
yield f"data: {json.dumps({'type': 'token', 'content': content})}\n\n"
continue

yield "data: [DONE]\n\n"


Expand Down

0 comments on commit 24f86d2

Please sign in to comment.