From 06e297ee57640f9f58e6303e0205329b760de9e9 Mon Sep 17 00:00:00 2001 From: Stan Girard Date: Fri, 3 May 2024 15:52:17 +0200 Subject: [PATCH] feat(brain): Add ProxyBrain integration --- .../modules/brain/integrations/Proxy/Brain.py | 97 +++++++++++++++++++ .../brain/integrations/Proxy/__init__.py | 0 .../chat/controller/chat/brainful_chat.py | 2 + 3 files changed, 99 insertions(+) create mode 100644 backend/modules/brain/integrations/Proxy/Brain.py create mode 100644 backend/modules/brain/integrations/Proxy/__init__.py diff --git a/backend/modules/brain/integrations/Proxy/Brain.py b/backend/modules/brain/integrations/Proxy/Brain.py new file mode 100644 index 000000000000..6e089579c35d --- /dev/null +++ b/backend/modules/brain/integrations/Proxy/Brain.py @@ -0,0 +1,97 @@ +import json +from typing import AsyncIterable +from uuid import UUID + +from langchain_community.chat_models import ChatLiteLLM +from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder +from logger import get_logger +from modules.brain.knowledge_brain_qa import KnowledgeBrainQA +from modules.chat.dto.chats import ChatQuestion +from modules.chat.dto.outputs import GetChatHistoryOutput +from modules.chat.service.chat_service import ChatService + +logger = get_logger(__name__) + +chat_service = ChatService() + + +class ProxyBrain(KnowledgeBrainQA): + """This is the Proxy brain class. + + Args: + KnowledgeBrainQA (_type_): A brain that store the knowledge internaly + """ + + def __init__( + self, + **kwargs, + ): + super().__init__( + **kwargs, + ) + + def get_chain(self): + + prompt = ChatPromptTemplate.from_messages( + [ + ( + "system", + "You are Quivr. You are an assistant. {custom_personality}", + ), + MessagesPlaceholder(variable_name="chat_history"), + ("human", "{question}"), + ] + ) + + chain = prompt | ChatLiteLLM(model=self.model, max_tokens=self.max_tokens) + + return chain + + async def generate_stream( + self, chat_id: UUID, question: ChatQuestion, save_answer: bool = True + ) -> AsyncIterable: + conversational_qa_chain = self.get_chain() + transformed_history, streamed_chat_history = ( + self.initialize_streamed_chat_history(chat_id, question) + ) + response_tokens = [] + + async for chunk in conversational_qa_chain.astream( + { + "question": question.question, + "chat_history": transformed_history, + "custom_personality": ( + self.prompt_to_use.content if self.prompt_to_use else None + ), + } + ): + response_tokens.append(chunk.content) + streamed_chat_history.assistant = chunk.content + yield f"data: {json.dumps(streamed_chat_history.dict())}" + + self.save_answer(question, response_tokens, streamed_chat_history, save_answer) + + def generate_answer( + self, chat_id: UUID, question: ChatQuestion, save_answer: bool = True + ) -> GetChatHistoryOutput: + conversational_qa_chain = self.get_chain() + transformed_history, streamed_chat_history = ( + self.initialize_streamed_chat_history(chat_id, question) + ) + model_response = conversational_qa_chain.invoke( + { + "question": question.question, + "chat_history": transformed_history, + "custom_personality": ( + self.prompt_to_use.content if self.prompt_to_use else None + ), + } + ) + + answer = model_response.content + + return self.save_non_streaming_answer( + chat_id=chat_id, + question=question, + answer=answer, + ) diff --git a/backend/modules/brain/integrations/Proxy/__init__.py b/backend/modules/brain/integrations/Proxy/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/backend/modules/chat/controller/chat/brainful_chat.py b/backend/modules/chat/controller/chat/brainful_chat.py index d35608ea73f4..e7982d7827af 100644 --- a/backend/modules/chat/controller/chat/brainful_chat.py +++ b/backend/modules/chat/controller/chat/brainful_chat.py @@ -4,6 +4,7 @@ from modules.brain.integrations.Big.Brain import BigBrain from modules.brain.integrations.GPT4.Brain import GPT4Brain from modules.brain.integrations.Notion.Brain import NotionBrain +from modules.brain.integrations.Proxy.Brain import ProxyBrain from modules.brain.integrations.SQL.Brain import SQLBrain from modules.brain.knowledge_brain_qa import KnowledgeBrainQA from modules.brain.service.api_brain_definition_service import ApiBrainDefinitionService @@ -41,6 +42,7 @@ "sql": SQLBrain, "big": BigBrain, "doc": KnowledgeBrainQA, + "proxy": ProxyBrain, } brain_service = BrainService()