diff --git a/compose.yaml b/compose.yaml index fdbe9c1..c26bfe9 100644 --- a/compose.yaml +++ b/compose.yaml @@ -9,13 +9,13 @@ services: - .env develop: watch: - - path: agent/ + - path: src/agent/ action: sync+restart target: /app/agent/ - - path: schema/ + - path: src/schema/ action: sync+restart target: /app/schema/ - - path: service/ + - path: src/service/ action: sync+restart target: /app/service/ @@ -31,12 +31,12 @@ services: - AGENT_URL=http://agent_service develop: watch: - - path: client/ + - path: src/client/ action: sync+restart target: /app/client/ - - path: schema/ + - path: src/schema/ action: sync+restart target: /app/schema/ - - path: streamlit_app.py + - path: src/streamlit_app.py action: sync+restart target: /app/streamlit_app.py diff --git a/pyproject.toml b/pyproject.toml index bcc9e8d..57a1734 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,7 +13,7 @@ classifiers = [ "Programming Language :: Python :: 3.12", ] -requires-python = ">=3.9, <=3.12.3" +requires-python = ">=3.10, <=3.12.3" # NOTE: FastAPI < 0.100.0 and Pydantic v1 is required until langchain has full pydantic v2 compatibility # https://python.langchain.com/v0.1/docs/guides/development/pydantic_compatibility/ @@ -51,7 +51,7 @@ dev = [ [tool.ruff] line-length = 100 -target-version = "py39" +target-version = "py310" [tool.pytest_env] OPENAI_API_KEY = "sk-fake-openai-key" diff --git a/src/agent/llama_guard.py b/src/agent/llama_guard.py index c185a56..34c95e8 100644 --- a/src/agent/llama_guard.py +++ b/src/agent/llama_guard.py @@ -80,7 +80,9 @@ def __init__(self): print("GROQ_API_KEY not set, skipping LlamaGuard") self.model = None return - self.model = ChatGroq(model="llama-guard-3-8b", temperature=0.0) + self.model = ChatGroq(model="llama-guard-3-8b", temperature=0.0).with_config( + tags=["llama_guard"], + ) self.prompt = PromptTemplate.from_template(llama_guard_instructions) def _compile_prompt(self, role: str, messages: List[AnyMessage]) -> str: diff --git a/src/service/service.py b/src/service/service.py index 74f7e11..610605d 100644 --- a/src/service/service.py +++ b/src/service/service.py @@ -1,12 +1,12 @@ -import asyncio from contextlib import asynccontextmanager import json import os +import warnings from typing import AsyncGenerator, Dict, Any, Tuple from uuid import uuid4 from fastapi import FastAPI, HTTPException, Request, Response from fastapi.responses import StreamingResponse -from langchain_core.callbacks import AsyncCallbackHandler +from langchain_core._api import LangChainBetaWarning from langchain_core.runnables import RunnableConfig from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver from langgraph.graph.graph import CompiledGraph @@ -15,16 +15,7 @@ from agent import research_assistant from schema import ChatMessage, Feedback, UserInput, StreamInput - -class TokenQueueStreamingHandler(AsyncCallbackHandler): - """LangChain callback handler for streaming LLM tokens to an asyncio queue.""" - - def __init__(self, queue: asyncio.Queue): - self.queue = queue - - async def on_llm_new_token(self, token: str, **kwargs) -> None: - if token: - await self.queue.put(token) +warnings.filterwarnings("ignore", category=LangChainBetaWarning) @asynccontextmanager @@ -93,47 +84,46 @@ async def message_generator(user_input: StreamInput) -> AsyncGenerator[str, None agent: CompiledGraph = app.state.agent kwargs, run_id = _parse_input(user_input) - # Use an asyncio queue to process both messages and tokens in - # chronological order, so we can easily yield them to the client. - output_queue = asyncio.Queue(maxsize=10) - if user_input.stream_tokens: - kwargs["config"]["callbacks"] = [TokenQueueStreamingHandler(queue=output_queue)] - - # Pass the agent's stream of messages to the queue in a separate task, so - # we can yield the messages to the client in the main thread. - async def run_agent_stream(): - async for s in agent.astream(**kwargs, stream_mode="updates"): - await output_queue.put(s) - await output_queue.put(None) - - stream_task = asyncio.create_task(run_agent_stream()) - - # Process the queue and yield messages over the SSE stream. - while s := await output_queue.get(): - if isinstance(s, str): - # str is an LLM token - yield f"data: {json.dumps({'type': 'token', 'content': s})}\n\n" + # Process streamed events from the graph and yield messages over the SSE stream. + async for event in agent.astream_events(**kwargs, version="v2"): + if not event: continue - # Otherwise, s should be a dict of state updates for each node in the graph. - # s could have updates for multiple nodes, so check each for messages. - new_messages = [] - for _, state in s.items(): - if "messages" in state: - new_messages.extend(state["messages"]) - for message in new_messages: - try: - chat_message = ChatMessage.from_langchain(message) - chat_message.run_id = str(run_id) - except Exception as e: - yield f"data: {json.dumps({'type': 'error', 'content': f'Error parsing message: {e}'})}\n\n" - continue - # LangGraph re-sends the input message, which feels weird, so drop it + # Yield messages written to the graph state after node execution finishes. + if ( + event["event"] == "on_chain_end" + # on_chain_end gets called a bunch of times in a graph execution + # This filters out everything except for "graph node finished" + and any(t.startswith("graph:step:") for t in event.get("tags", [])) + and "messages" in event["data"]["output"] + ): + new_messages = event["data"]["output"]["messages"] + for message in new_messages: + try: + chat_message = ChatMessage.from_langchain(message) + chat_message.run_id = str(run_id) + except Exception as e: + yield f"data: {json.dumps({'type': 'error', 'content': f'Error parsing message: {e}'})}\n\n" + continue + # LangGraph re-sends the input message, which feels weird, so drop it if chat_message.type == "human" and chat_message.content == user_input.message: continue yield f"data: {json.dumps({'type': 'message', 'content': chat_message.dict()})}\n\n" - await stream_task + # Yield tokens streamed from LLMs. + if ( + event["event"] == "on_chat_model_stream" + and user_input.stream_tokens + and "llama_guard" not in event.get("tags", []) + ): + content = event["data"]["chunk"].content + if content: + # Empty content in the context of OpenAI or Anthropic usually means + # that the model is asking for a tool to be invoked. + # So we only print non-empty content. + yield f"data: {json.dumps({'type': 'token', 'content': content})}\n\n" + continue + yield "data: [DONE]\n\n"