diff --git a/backend/main.py b/backend/main.py index 987203d..f2ddcaa 100644 --- a/backend/main.py +++ b/backend/main.py @@ -12,9 +12,11 @@ from jose import JWTError, jwt from langchain_core.messages.ai import AIMessage, AIMessageChunk +from backend.config import RagConfig from backend.database import Database from backend.logger import get_logger from backend.model import Message +from backend.rag_components.chat_message_history import get_conversation_buffer_memory from backend.rag_components.rag import RAG from backend.user_management import ( ALGORITHM, @@ -89,7 +91,13 @@ async def chat_prompt(message: Message, current_user: User = Depends(get_current } rag = RAG(config=Path(__file__).parent / "config.yaml", logger=logger, context=context) response = rag.generate_response(message) - return StreamingResponse(stream_response(message.chat_id, response), media_type="text/event-stream") + response_stream = stream_response( + rag=rag, + chat_id=message.chat_id, + question=message.content, + response=response + ) + return StreamingResponse(response_stream, media_type="text/event-stream") @app.post("/chat/regenerate") @@ -130,7 +138,7 @@ async def chat(chat_id: str, current_user: User = Depends(get_current_user)) -> return {"chat_id": chat_id, "messages": [message.model_dump() for message in messages]} -async def stream_response(chat_id: str, response): +async def stream_response(rag: RAG, chat_id: str, question, response): full_response = "" response_id = str(uuid4()) try: @@ -154,6 +162,7 @@ async def stream_response(chat_id: str, response): yield full_response.encode("utf-8") finally: await log_response_to_db(chat_id, full_response) + await memorize_response(rag.config, chat_id, question, full_response) async def log_response_to_db(chat_id: str, full_response: str): response_id = str(uuid4()) @@ -163,6 +172,10 @@ async def log_response_to_db(chat_id: str, full_response: str): (response_id, datetime.now().isoformat(), chat_id, "assistant", full_response), ) +async def memorize_response(rag_config: RagConfig, chat_id: str, question: str, answer: str): + memory = get_conversation_buffer_memory(rag_config, chat_id) + memory.save_context({"question": question}, {"answer": answer}) + ############################################ ### Feedback ### diff --git a/docs/recipe_vector_stores_configs.md b/docs/recipe_vector_stores_configs.md index e58c321..bb79a31 100644 --- a/docs/recipe_vector_stores_configs.md +++ b/docs/recipe_vector_stores_configs.md @@ -5,7 +5,7 @@ As we need a backend SQL database to store conversation history and other info, [See the recipes for database configs here](recipe_databases_configs.md) ```shell -pip install pgvector +pip install psycopg2-binary pgvector ``` ```yaml