From e5cda396ef7f5c684f03311e81b43b888f15ac7e Mon Sep 17 00:00:00 2001 From: Alexis VIALARET Date: Thu, 21 Dec 2023 18:18:33 +0100 Subject: [PATCH] upd: use the RAG object --- .gitignore | 3 +- backend/config.yaml | 3 +- backend/config_renderer.py | 23 --------- backend/document_store.py | 48 ------------------- backend/main.py | 37 +++----------- backend/model.py | 15 ------ .../{chatbot.py => rag_components/chain.py} | 0 backend/rag_components/config_renderer.py | 14 ++++++ backend/rag_components/document_loader.py | 15 ------ backend/rag_components/main.py | 21 ++++++-- 10 files changed, 40 insertions(+), 139 deletions(-) delete mode 100644 backend/config_renderer.py delete mode 100644 backend/document_store.py rename backend/{chatbot.py => rag_components/chain.py} (100%) create mode 100644 backend/rag_components/config_renderer.py diff --git a/.gitignore b/.gitignore index 44d5196..8472256 100644 --- a/.gitignore +++ b/.gitignore @@ -139,4 +139,5 @@ secrets/* data/ *.sqlite -*.sqlite3 \ No newline at end of file +*.sqlite3 +vector_database/ \ No newline at end of file diff --git a/backend/config.yaml b/backend/config.yaml index 9ca450d..966e906 100644 --- a/backend/config.yaml +++ b/backend/config.yaml @@ -8,8 +8,6 @@ llm_model_config: model_source: AzureChatOpenAI deployment_name: gpt4v temperature: 0.1 - streaming: true - verbose: true embedding_provider_config: openai_api_type: azure @@ -24,6 +22,7 @@ embedding_model_config: vector_store_provider: model_source: Chroma persist_directory: vector_database/ + documents_to_retreive: 10 chat_message_history_config: source: ChatMessageHistory diff --git a/backend/config_renderer.py b/backend/config_renderer.py deleted file mode 100644 index 8a0e568..0000000 --- a/backend/config_renderer.py +++ /dev/null @@ -1,23 +0,0 @@ -import os -from pathlib import Path - -import yaml -from dotenv import load_dotenv -from jinja2 import Environment, FileSystemLoader - - -def get_config() -> dict: - load_dotenv() - env = Environment(loader=FileSystemLoader(Path(__file__).parent)) - template = env.get_template("config.yaml") - config = template.render(os.environ) - return yaml.safe_load(config) - - -def load_models_config(): - with open(Path(__file__).parent / "config.yaml", "r") as file: - return yaml.safe_load(file) - - -if __name__ == "__main__": - print(get_config()) diff --git a/backend/document_store.py b/backend/document_store.py deleted file mode 100644 index 106ec8e..0000000 --- a/backend/document_store.py +++ /dev/null @@ -1,48 +0,0 @@ -from enum import Enum -from pathlib import Path -from typing import List - -import chromadb -from langchain.docstore.document import Document -from langchain.embeddings import OpenAIEmbeddings -from langchain.vectorstores import Chroma - - -class StorageBackend(Enum): - """Enumeration of supported storage backends.""" - - LOCAL = "local" - MEMORY = "memory" - GCS = "gcs" - S3 = "s3" - AZURE = "az" - - -def get_storage_root_path(bucket_name: str, storage_backend: StorageBackend) -> Path: - """Constructs the root path for the storage based on the bucket name and storage backend.""" - return Path(f"{storage_backend.value}://{bucket_name}") - - -def persist_to_bucket(bucket_path: str, store: Chroma) -> None: - """Persists the data in the given Chroma store to a bucket.""" - store.persist("./db/chroma") - # TODO: Uplaod persisted file on disk to bucket_path gcs - - -def store_documents( - docs: List[Document], bucket_path: str, storage_backend: StorageBackend -) -> None: - """Stores a list of documents in a specified bucket using a given storage backend.""" - langchain_documents = [doc.to_langchain_document() for doc in docs] - embeddings_model = OpenAIEmbeddings() - persistent_client = chromadb.PersistentClient() - collection = persistent_client.get_or_create_collection( - get_storage_root_path(bucket_path, storage_backend) - ) - collection.add(documents=langchain_documents) - langchain_chroma = Chroma( - client=persistent_client, - collection_name=bucket_path, - embedding_function=embeddings_model.embed_documents, - ) - print("There are", langchain_chroma._collection.count(), "in the collection") diff --git a/backend/main.py b/backend/main.py index a33dfe0..67a95ed 100644 --- a/backend/main.py +++ b/backend/main.py @@ -8,10 +8,8 @@ from jose import JWTError, jwt - -from backend.config_renderer import get_config -from backend.document_store import StorageBackend -from backend.model import Doc, Message +from backend.model import Message +from backend.rag_components.main import RAG from backend.user_management import ( ALGORITHM, SECRET_KEY, @@ -23,11 +21,6 @@ user_exists, ) from database.database import Database -from backend.config_renderer import get_config -from backend.rag_components.chat_message_history import get_conversation_buffer_memory -from backend.rag_components.embedding import get_embedding_model -from backend.rag_components.vector_store import get_vector_store -from backend.chatbot import get_answer_chain, get_response_stream app = FastAPI() @@ -48,10 +41,10 @@ async def get_current_user(token: str = Depends(oauth2_scheme)) -> User: ) try: payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) - email: str = payload.get("email") # 'sub' is commonly used to store user identity + email: str = payload.get("email") if email is None: raise credentials_exception - # Here you should fetch the user from the database by user_id + user = get_user(email) if user is None: raise credentials_exception @@ -164,15 +157,10 @@ async def chat_prompt(message: Message, current_user: User = Depends(get_current (message.id, message.timestamp, message.chat_id, message.sender, message.content), ) - config = get_config() - embeddings = get_embedding_model(config) - vector_store = get_vector_store(embeddings, config) - memory = get_conversation_buffer_memory(config, message.chat_id) - answer_chain, callback_handler = get_answer_chain(config, vector_store, memory) - - response_stream = get_response_stream(answer_chain, callback_handler, message.content) + rag = RAG() + response = rag.generate_response(message) - return StreamingResponse(streamed_llm_response(message.chat_id, response_stream), media_type="text/event-stream") + return StreamingResponse(streamed_llm_response(message.chat_id, response), media_type="text/event-stream") @app.post("/chat/regenerate") @@ -220,17 +208,6 @@ async def feedback_thumbs_down( ) -############################################ -### Other ### -############################################ - - -@app.post("/index/documents") -async def index_documents(chunks: List[Doc], bucket: str, storage_backend: StorageBackend) -> None: - """Index documents in a specified storage backend.""" - document_store.store_documents(chunks, bucket, storage_backend) - - if __name__ == "__main__": import uvicorn diff --git a/backend/model.py b/backend/model.py index 5e24a4f..3890f75 100644 --- a/backend/model.py +++ b/backend/model.py @@ -1,7 +1,3 @@ -from datetime import datetime -from uuid import uuid4 - -from langchain.docstore.document import Document from pydantic import BaseModel @@ -11,14 +7,3 @@ class Message(BaseModel): chat_id: str sender: str content: str - - -class Doc(BaseModel): - """Represents a document with content and associated metadata.""" - - content: str - metadata: dict - - def to_langchain_document(self) -> Document: - """Converts the current Doc instance into a langchain Document.""" - return Document(page_content=self.content, metadata=self.metadata) diff --git a/backend/chatbot.py b/backend/rag_components/chain.py similarity index 100% rename from backend/chatbot.py rename to backend/rag_components/chain.py diff --git a/backend/rag_components/config_renderer.py b/backend/rag_components/config_renderer.py new file mode 100644 index 0000000..c29c483 --- /dev/null +++ b/backend/rag_components/config_renderer.py @@ -0,0 +1,14 @@ +import os +from pathlib import Path + +import yaml +from dotenv import load_dotenv +from jinja2 import Environment, FileSystemLoader + + +def get_config(config_file_path: Path) -> dict: + load_dotenv() + env = Environment(loader=FileSystemLoader(config_file_path)) + template = env.get_template("config.yaml") + config = template.render(os.environ) + return yaml.safe_load(config) diff --git a/backend/rag_components/document_loader.py b/backend/rag_components/document_loader.py index da4ae23..16917cb 100644 --- a/backend/rag_components/document_loader.py +++ b/backend/rag_components/document_loader.py @@ -50,18 +50,3 @@ def get_loaders() -> List[str]: if inspect.isclass(obj): loaders.append(obj.__name__) return loaders - - -if __name__ == "__main__": - from pathlib import Path - - from backend.config_renderer import get_config - from frontend.lib.chat import Message - - config = get_config() - data_to_store = Path(f"{Path(__file__).parent.parent.parent}/data/billionaires_csv.csv") - prompt = "Quelles sont les 5 plus grandes fortunes de France ?" - chat_id = "test" - input_query = Message("user", prompt, chat_id) - response = generate_response(data_to_store, config, input_query) - print(response) diff --git a/backend/rag_components/main.py b/backend/rag_components/main.py index db93076..9a6736b 100644 --- a/backend/rag_components/main.py +++ b/backend/rag_components/main.py @@ -3,8 +3,11 @@ from langchain.docstore.document import Document from langchain.vectorstores.utils import filter_complex_metadata +from backend.rag_components.chain import get_answer_chain, get_response_stream -from backend.config_renderer import get_config +from backend.rag_components.config_renderer import get_config +from backend.model import Message +from backend.rag_components.chat_message_history import get_conversation_buffer_memory from backend.rag_components.document_loader import get_documents from backend.rag_components.embedding import get_embedding_model from backend.rag_components.llm import get_llm_model @@ -12,14 +15,22 @@ class RAG: - def __init__(self): - self.config = get_config() + def __init__(self, config_file_path: Path = None): + if config_file_path is None: + config_file_path = Path(__file__).parents[1] + + self.config = get_config(config_file_path) self.llm = get_llm_model(self.config) self.embeddings = get_embedding_model(self.config) self.vector_store = get_vector_store(self.embeddings, self.config) - def generate_response(): - pass + def generate_response(self, message: Message): + embeddings = get_embedding_model(self.config) + vector_store = get_vector_store(embeddings, self.config) + memory = get_conversation_buffer_memory(self.config, message.chat_id) + answer_chain, callback_handler = get_answer_chain(self.config, vector_store, memory) + response_stream = get_response_stream(answer_chain, callback_handler, message.content) + return response_stream def load_documents(self, documents: List[Document]): # TODO améliorer la robustesse du load_document