Skip to content

Commit

Permalink
upd: use the RAG object
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexisVLRT committed Dec 21, 2023
1 parent 4451414 commit e5cda39
Show file tree
Hide file tree
Showing 10 changed files with 40 additions and 139 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -139,4 +139,5 @@ secrets/*
data/

*.sqlite
*.sqlite3
*.sqlite3
vector_database/
3 changes: 1 addition & 2 deletions backend/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@ llm_model_config:
model_source: AzureChatOpenAI
deployment_name: gpt4v
temperature: 0.1
streaming: true
verbose: true

embedding_provider_config:
openai_api_type: azure
Expand All @@ -24,6 +22,7 @@ embedding_model_config:
vector_store_provider:
model_source: Chroma
persist_directory: vector_database/
documents_to_retreive: 10

chat_message_history_config:
source: ChatMessageHistory
Expand Down
23 changes: 0 additions & 23 deletions backend/config_renderer.py

This file was deleted.

48 changes: 0 additions & 48 deletions backend/document_store.py

This file was deleted.

37 changes: 7 additions & 30 deletions backend/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,8 @@
from jose import JWTError, jwt



from backend.config_renderer import get_config
from backend.document_store import StorageBackend
from backend.model import Doc, Message
from backend.model import Message
from backend.rag_components.main import RAG
from backend.user_management import (
ALGORITHM,
SECRET_KEY,
Expand All @@ -23,11 +21,6 @@
user_exists,
)
from database.database import Database
from backend.config_renderer import get_config
from backend.rag_components.chat_message_history import get_conversation_buffer_memory
from backend.rag_components.embedding import get_embedding_model
from backend.rag_components.vector_store import get_vector_store
from backend.chatbot import get_answer_chain, get_response_stream

app = FastAPI()

Expand All @@ -48,10 +41,10 @@ async def get_current_user(token: str = Depends(oauth2_scheme)) -> User:
)
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
email: str = payload.get("email") # 'sub' is commonly used to store user identity
email: str = payload.get("email")
if email is None:
raise credentials_exception
# Here you should fetch the user from the database by user_id

user = get_user(email)
if user is None:
raise credentials_exception
Expand Down Expand Up @@ -164,15 +157,10 @@ async def chat_prompt(message: Message, current_user: User = Depends(get_current
(message.id, message.timestamp, message.chat_id, message.sender, message.content),
)

config = get_config()
embeddings = get_embedding_model(config)
vector_store = get_vector_store(embeddings, config)
memory = get_conversation_buffer_memory(config, message.chat_id)
answer_chain, callback_handler = get_answer_chain(config, vector_store, memory)

response_stream = get_response_stream(answer_chain, callback_handler, message.content)
rag = RAG()
response = rag.generate_response(message)

return StreamingResponse(streamed_llm_response(message.chat_id, response_stream), media_type="text/event-stream")
return StreamingResponse(streamed_llm_response(message.chat_id, response), media_type="text/event-stream")


@app.post("/chat/regenerate")
Expand Down Expand Up @@ -220,17 +208,6 @@ async def feedback_thumbs_down(
)


############################################
### Other ###
############################################


@app.post("/index/documents")
async def index_documents(chunks: List[Doc], bucket: str, storage_backend: StorageBackend) -> None:
"""Index documents in a specified storage backend."""
document_store.store_documents(chunks, bucket, storage_backend)


if __name__ == "__main__":
import uvicorn

Expand Down
15 changes: 0 additions & 15 deletions backend/model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
from datetime import datetime
from uuid import uuid4

from langchain.docstore.document import Document
from pydantic import BaseModel


Expand All @@ -11,14 +7,3 @@ class Message(BaseModel):
chat_id: str
sender: str
content: str


class Doc(BaseModel):
"""Represents a document with content and associated metadata."""

content: str
metadata: dict

def to_langchain_document(self) -> Document:
"""Converts the current Doc instance into a langchain Document."""
return Document(page_content=self.content, metadata=self.metadata)
File renamed without changes.
14 changes: 14 additions & 0 deletions backend/rag_components/config_renderer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import os
from pathlib import Path

import yaml
from dotenv import load_dotenv
from jinja2 import Environment, FileSystemLoader


def get_config(config_file_path: Path) -> dict:
load_dotenv()
env = Environment(loader=FileSystemLoader(config_file_path))
template = env.get_template("config.yaml")
config = template.render(os.environ)
return yaml.safe_load(config)
15 changes: 0 additions & 15 deletions backend/rag_components/document_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,18 +50,3 @@ def get_loaders() -> List[str]:
if inspect.isclass(obj):
loaders.append(obj.__name__)
return loaders


if __name__ == "__main__":
from pathlib import Path

from backend.config_renderer import get_config
from frontend.lib.chat import Message

config = get_config()
data_to_store = Path(f"{Path(__file__).parent.parent.parent}/data/billionaires_csv.csv")
prompt = "Quelles sont les 5 plus grandes fortunes de France ?"
chat_id = "test"
input_query = Message("user", prompt, chat_id)
response = generate_response(data_to_store, config, input_query)
print(response)
21 changes: 16 additions & 5 deletions backend/rag_components/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,34 @@

from langchain.docstore.document import Document
from langchain.vectorstores.utils import filter_complex_metadata
from backend.rag_components.chain import get_answer_chain, get_response_stream

from backend.config_renderer import get_config
from backend.rag_components.config_renderer import get_config
from backend.model import Message
from backend.rag_components.chat_message_history import get_conversation_buffer_memory
from backend.rag_components.document_loader import get_documents
from backend.rag_components.embedding import get_embedding_model
from backend.rag_components.llm import get_llm_model
from backend.rag_components.vector_store import get_vector_store


class RAG:
def __init__(self):
self.config = get_config()
def __init__(self, config_file_path: Path = None):
if config_file_path is None:
config_file_path = Path(__file__).parents[1]

self.config = get_config(config_file_path)
self.llm = get_llm_model(self.config)
self.embeddings = get_embedding_model(self.config)
self.vector_store = get_vector_store(self.embeddings, self.config)

def generate_response():
pass
def generate_response(self, message: Message):
embeddings = get_embedding_model(self.config)
vector_store = get_vector_store(embeddings, self.config)
memory = get_conversation_buffer_memory(self.config, message.chat_id)
answer_chain, callback_handler = get_answer_chain(self.config, vector_store, memory)
response_stream = get_response_stream(answer_chain, callback_handler, message.content)
return response_stream

def load_documents(self, documents: List[Document]):
# TODO améliorer la robustesse du load_document
Expand Down

0 comments on commit e5cda39

Please sign in to comment.