From edf17491200b0f667e3024323ca2d4547df1d41c Mon Sep 17 00:00:00 2001 From: Nino Richter Date: Tue, 10 Dec 2024 09:49:37 +0100 Subject: [PATCH 1/7] First commit for the new injestion logic --- app/api/admin_router.py | 48 ++++ app/api/auth_router.py | 19 ++ app/api/knowledge_router.py | 106 +++++++++ app/api/question_router.py | 56 +---- app/data/database_requests.py | 19 ++ app/data/knowledge_base_requests.py | 41 ++++ app/injestion/__init__.py | 0 app/injestion/document_loader.py | 5 + app/injestion/document_splitter.py | 134 ++++++----- app/injestion/injestion_handler.py | 128 +++++++++++ app/injestion/vector_store_initializer.py | 9 +- app/main.py | 8 +- app/managers/weaviate_manager.py | 256 +++++++++++++++------- app/utils/dependencies.py | 4 + 14 files changed, 640 insertions(+), 193 deletions(-) create mode 100644 app/api/admin_router.py create mode 100644 app/api/auth_router.py create mode 100644 app/api/knowledge_router.py create mode 100644 app/data/database_requests.py create mode 100644 app/data/knowledge_base_requests.py create mode 100644 app/injestion/__init__.py create mode 100644 app/injestion/injestion_handler.py diff --git a/app/api/admin_router.py b/app/api/admin_router.py new file mode 100644 index 0000000..09e815a --- /dev/null +++ b/app/api/admin_router.py @@ -0,0 +1,48 @@ +import logging + +from fastapi import HTTPException, APIRouter, status, Response, Depends + +from app.utils.dependencies import request_handler, auth_handler, weaviate_manager, model +from app.injestion.vector_store_initializer import initialize_vectorstores +from app.utils.environment import config +from app.data.user_requests import UserChat, UserRequest + + +admin_router = APIRouter(prefix="/api/admin", tags=["settings", "admin"], + dependencies=[Depends(auth_handler.verify_token)]) + +# TODO: Remove +@admin_router.get("/initSchema", + status_code=status.HTTP_202_ACCEPTED, ) +async def initializeDb(): + initialize_vectorstores(config.KNOWLEDGE_BASE_FOLDER, config.QA_FOLDER, weaviate_manager) + return + + +# TODO: Remove +@admin_router.post("/document") +async def add_document(request: UserRequest): + question = request.message + classification = request.study_program + if not question or not classification: + raise HTTPException(status_code=400, detail="No question or classification provided") + + logging.info(f"Received document: {question} with classification: {classification}") + try: + request_handler.add_document(question, classification) + return Response(status_code=status.HTTP_200_OK) + + except Exception as e: + logging.error(f"Failed to add document: {e}") + raise HTTPException(status_code=500, detail="Failed to add document") + +@admin_router.get("/ping") +async def ping(): + logging.info(config.GPU_URL) + return {"answer": "Server running."} + + +@admin_router.get("/hi") +async def ping(): + logging.info("hi") + return model.complete([{"role": "user", "content": "Hi"}]) \ No newline at end of file diff --git a/app/api/auth_router.py b/app/api/auth_router.py new file mode 100644 index 0000000..3e9268a --- /dev/null +++ b/app/api/auth_router.py @@ -0,0 +1,19 @@ +import logging + +from fastapi import HTTPException, APIRouter, status, Response, Depends + +from app.managers.auth_handler import LoginRequest +from app.utils.dependencies import auth_handler +from app.utils.environment import config + +auth_router = APIRouter(prefix="/api", tags=["authorization"], dependencies=[Depends(auth_handler.verify_api_key)]) + +@auth_router.post("/token") +async def login(login_request: LoginRequest): + if config.WITHOUT_USER_LOGIN == "true" or ( + login_request.username == config.EXPECTED_USERNAME and login_request.password == config.EXPECTED_PASSWORD): + token_data = {"sub": "angular_app"} + access_token = auth_handler.create_access_token(data=token_data) + return {"access_token": access_token, "token_type": "bearer"} + else: + raise HTTPException(status_code=401, detail="Invalid username or password") \ No newline at end of file diff --git a/app/api/knowledge_router.py b/app/api/knowledge_router.py new file mode 100644 index 0000000..19cb019 --- /dev/null +++ b/app/api/knowledge_router.py @@ -0,0 +1,106 @@ +from fastapi import HTTPException, APIRouter, status, Response, Depends +from app.data.knowledge_base_requests import AddWebsiteRequest, EditDocumentRequest, EditSampleQuestionRequest, EditWebsiteRequest, AddDocumentRequest, AddSampleQuestionRequest, RefreshContentRequest +from app.utils.dependencies import injestion_handler +from app.data.database_requests import DatabaseDocumentMetadata + +knowledge_router = APIRouter(prefix="/knowledge", tags=["knowledge"]) + +@knowledge_router.post("/website/add") +async def add_website(body: AddWebsiteRequest): + try: + injestion_handler.add_website(body) + return Response(status_code=200) + except Exception as e: + return Response(status_code=500) + +@knowledge_router.post("/website/{id}/refresh") +async def refresh_website(id: int, body: RefreshContentRequest): + try: + injestion_handler.refresh_content(id=id, content=body.content) + return Response(status_code=200) + except Exception as e: + return Response(status_code=500) + +@knowledge_router.post("/website/{id}/update") +async def update_website(id: int, body: EditWebsiteRequest): + try: + metadata: DatabaseDocumentMetadata = DatabaseDocumentMetadata( + study_programs=body.studyPrograms + ) + injestion_handler.update_database_document(id=id, metadata=metadata) + return Response(status_code=200) + except Exception as e: + return Response(status_code=500) + +@knowledge_router.delete("/website/{id}/delete") +async def delete_website(id: int): + try: + injestion_handler.delete_document(id=id) + return Response(status_code=200) + except Exception as e: + return Response(status_code=500) + + +# === Document Endpoints === + +@knowledge_router.post("/document/add") +async def add_document(body: AddDocumentRequest): + try: + injestion_handler.add_document(body) + return Response(status_code=200) + except Exception as e: + return Response(status_code=500) + +@knowledge_router.post("/document/{id}/refresh") +async def refresh_document(id: int, body: RefreshContentRequest): + try: + injestion_handler.refresh_content(id=id, content=body.content) + return Response(status_code=200) + except Exception as e: + return Response(status_code=500) + +@knowledge_router.post("/document/{id}/edit") +async def edit_document(id: int, body: EditDocumentRequest): + try: + metadata: DatabaseDocumentMetadata = DatabaseDocumentMetadata( + study_programs=body.studyPrograms + ) + injestion_handler.update_database_document(id=id, metadata=metadata) + return Response(status_code=200) + except Exception as e: + return Response(status_code=500) + +@knowledge_router.delete("/document/{id}/delete") +async def delete_document(id: int): + try: + injestion_handler.delete_document(id=id) + return Response(status_code=200) + except Exception as e: + return Response(status_code=500) + + +# === Sample Question Endpoints === + +@knowledge_router.post("/sample-question/add") +async def add_sample_question(body: AddSampleQuestionRequest): + try: + injestion_handler.add_sample_question(sample_question=body) + return Response(status_code=200) + except Exception as e: + return Response(status_code=500) + +@knowledge_router.post("/sample-question/{id}/edit") +async def edit_sample_question(id: int, body: EditSampleQuestionRequest): + try: + injestion_handler.update_sample_question(kb_id=id, sample_question=body) + return Response(status_code=200) + except Exception as e: + return Response(status_code=500) + +@knowledge_router.delete("/sample-question/{id}/delete") +async def delete_sample_question(id: int): + try: + injestion_handler.delete_sample_question(id=id) + return Response(status_code=200) + except Exception as e: + return Response(status_code=500) \ No newline at end of file diff --git a/app/api/question_router.py b/app/api/question_router.py index 8adcb76..2b8d28d 100644 --- a/app/api/question_router.py +++ b/app/api/question_router.py @@ -1,28 +1,12 @@ import logging -from fastapi import HTTPException, APIRouter, status, Response, Depends +from fastapi import HTTPException, APIRouter, Depends from app.data.user_requests import UserChat, UserRequest -from app.injestion.vector_store_initializer import initialize_vectorstores -from app.managers.auth_handler import LoginRequest -from app.utils.dependencies import request_handler, auth_handler, weaviate_manager, model +from app.utils.dependencies import request_handler, auth_handler from app.utils.environment import config -auth_router = APIRouter(prefix="/api", tags=["authorization"], dependencies=[Depends(auth_handler.verify_api_key)]) question_router = APIRouter(prefix="/api/v1/question", tags=["response"]) -admin_router = APIRouter(prefix="/api/admin", tags=["settings", "admin"], - dependencies=[Depends(auth_handler.verify_token)]) - - -@auth_router.post("/token") -async def login(login_request: LoginRequest): - if config.WITHOUT_USER_LOGIN == "true" or ( - login_request.username == config.EXPECTED_USERNAME and login_request.password == config.EXPECTED_PASSWORD): - token_data = {"sub": "angular_app"} - access_token = auth_handler.create_access_token(data=token_data) - return {"access_token": access_token, "token_type": "bearer"} - else: - raise HTTPException(status_code=401, detail="Invalid username or password") @question_router.post("/ask", tags=["email"], dependencies=[Depends(auth_handler.verify_api_key)]) @@ -78,39 +62,3 @@ async def chat(request: UserChat): logging.info(f"Received messages.") answer = request_handler.handle_chat(messages, study_program=request.study_program) return {"answer": answer} - - -@admin_router.get("/initSchema", - status_code=status.HTTP_202_ACCEPTED, ) -async def initializeDb(): - initialize_vectorstores(config.KNOWLEDGE_BASE_FOLDER, config.QA_FOLDER, weaviate_manager) - return - - -@admin_router.post("/document") -async def add_document(request: UserRequest): - question = request.message - classification = request.study_program - if not question or not classification: - raise HTTPException(status_code=400, detail="No question or classification provided") - - logging.info(f"Received document: {question} with classification: {classification}") - try: - request_handler.add_document(question, classification) - return Response(status_code=status.HTTP_200_OK) - - except Exception as e: - logging.error(f"Failed to add document: {e}") - raise HTTPException(status_code=500, detail="Failed to add document") - - -@admin_router.get("/ping") -async def ping(): - logging.info(config.GPU_URL) - return {"answer": "Server running."} - - -@admin_router.get("/hi") -async def ping(): - logging.info("hi") - return model.complete([{"role": "user", "content": "Hi"}]) \ No newline at end of file diff --git a/app/data/database_requests.py b/app/data/database_requests.py new file mode 100644 index 0000000..b2d6da3 --- /dev/null +++ b/app/data/database_requests.py @@ -0,0 +1,19 @@ +from pydantic import BaseModel +from typing import List, Optional + +class DatabaseDocument(BaseModel): + id: int + link: Optional[str] = None + study_programs: List[str] + content: str + +class DatabaseDocumentMetadata(BaseModel): + link: Optional[str] = None + study_programs: List[str] + +class DatabaseSampleQuestion(BaseModel): + id: int + topic: str + question: str + answer: str + study_programs: List[str] diff --git a/app/data/knowledge_base_requests.py b/app/data/knowledge_base_requests.py new file mode 100644 index 0000000..d682975 --- /dev/null +++ b/app/data/knowledge_base_requests.py @@ -0,0 +1,41 @@ +from pydantic import BaseModel +from typing import List + + +class AddWebsiteRequest(BaseModel): + id: int + title: str + link: str + studyPrograms: List[str] + content: str + type: str + +class RefreshContentRequest(BaseModel): + content: str + +class EditWebsiteRequest(BaseModel): + title: str + studyPrograms: List[int] + +class AddDocumentRequest(BaseModel): + id: int + title: str + studyPrograms: List[int] + content: str + +class EditDocumentRequest(BaseModel): + title: str + studyPrograms: List[int] + +class AddSampleQuestionRequest(BaseModel): + id: int + question: str + answer: str + topic: str + studyPrograms: List[int] + +class EditSampleQuestionRequest(BaseModel): + question: str + answer: str + topic: str + studyPrograms: List[int] \ No newline at end of file diff --git a/app/injestion/__init__.py b/app/injestion/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/injestion/document_loader.py b/app/injestion/document_loader.py index 5f8d236..d086cc0 100644 --- a/app/injestion/document_loader.py +++ b/app/injestion/document_loader.py @@ -9,6 +9,7 @@ from app.data.user_requests import SampleQuestion, WebsiteContent +# TODO: Remove def load_pdf_documents_from_folder(base_folder: str, study_program: str = "general") -> List[Document]: """ Traverse the base folder and all its subfolders to find and load PDF files into Document objects. @@ -52,6 +53,8 @@ def load_pdf_documents_from_folder(base_folder: str, study_program: str = "gener return documents + +# TODO: Remove def load_qa_pairs_from_folder(qa_folder: str) -> List[SampleQuestion]: """ Reads JSON files from the qa_folder and extracts QA pairs. @@ -107,6 +110,8 @@ def load_qa_pairs_from_folder(qa_folder: str) -> List[SampleQuestion]: logging.info(f"Loaded {len(qa_pairs)} QA pairs from folder: {qa_folder}") return qa_pairs + +# TODO: Remove def load_website_content_from_folder(base_folder: str) -> List[WebsiteContent]: """ Traverse through the base folder and all subfolders to find and load JSON files diff --git a/app/injestion/document_splitter.py b/app/injestion/document_splitter.py index b04c1f7..5389a7c 100644 --- a/app/injestion/document_splitter.py +++ b/app/injestion/document_splitter.py @@ -7,68 +7,94 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter from app.data.user_requests import WebsiteContent -def split_tum_documents(tum_documents: List[WebsiteContent], chunk_size: int = 1200, chunk_overlap: int = 200) -> List[WebsiteContent]: - """ - Split TUM WebsiteContent documents using RecursiveCharacterTextSplitter into smaller chunks. - """ - tum_chunks = [] - text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap) +class DocumentSplitter: + def split_cit_content(self, content: str): + result: List[str] = [] + sections = content.split('----------------------------------------') + for section in sections: + section = section.strip() + result.append(section) + return result - for doc in tum_documents: - content = doc.content + def split_tum_content(self, content: str, chunk_size: int = 1200, chunk_overlap: int = 200): + text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap) chunks = text_splitter.split_text(content) - - # Create WebsiteContent chunks and preserve metadata - for chunk in chunks: - tum_chunks.append(WebsiteContent( - type=doc.type, - content=chunk, - link=doc.link, - study_program=doc.study_program - )) - logging.info(f"Split TUM document into {len(chunks)} chunks.") - - logging.info(f"Total TUM chunks: {len(tum_chunks)}") - return tum_chunks - -def split_pdf_documents(pdf_documents: List[Document], chunk_size: int = 1200, chunk_overlap: int = 200) -> List[Document]: - """ - Split PDF Document objects using RecursiveCharacterTextSplitter into smaller chunks. - """ - pdf_chunks = [] - text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap) + return chunks - for doc in pdf_documents: - content = doc.page_content + def split_pdf_document(self, content: str, chunk_size: int = 1200, chunk_overlap: int = 200): + text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap) chunks = text_splitter.split_text(content) - - # Create Document chunks and preserve metadata - for chunk in chunks: - pdf_chunks.append(Document(page_content=chunk, metadata=doc.metadata)) - logging.info(f"Split PDF document into {len(chunks)} chunks.") + return chunks - logging.info(f"Total PDF chunks: {len(pdf_chunks)}") - return pdf_chunks - -def split_cit_documents(cit_documents: List[WebsiteContent]) -> List[WebsiteContent]: - """ - Split CIT WebsiteContent documents into smaller chunks based on a predefined separator. - """ - cit_chunks = [] - for doc in cit_documents: - sections = doc.content.split('----------------------------------------') - for section in sections: - section = section.strip() - if section: - # Create smaller WebsiteContent chunks and preserve metadata - cit_chunks.append(WebsiteContent( + # TODO: Remove + @staticmethod + def split_pdf_documents(pdf_documents: List[Document], chunk_size: int = 1200, chunk_overlap: int = 200) -> List[Document]: + """ + Split PDF Document objects using RecursiveCharacterTextSplitter into smaller chunks. + """ + pdf_chunks = [] + text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap) + + for doc in pdf_documents: + content = doc.page_content + chunks = text_splitter.split_text(content) + + # Create Document chunks and preserve metadata + for chunk in chunks: + pdf_chunks.append(Document(page_content=chunk, metadata=doc.metadata)) + logging.info(f"Split PDF document into {len(chunks)} chunks.") + + logging.info(f"Total PDF chunks: {len(pdf_chunks)}") + return pdf_chunks + + # TODO: Remove + @staticmethod + def split_tum_documents(tum_documents: List[WebsiteContent], chunk_size: int = 1200, chunk_overlap: int = 200) -> List[WebsiteContent]: + """ + Split TUM WebsiteContent documents using RecursiveCharacterTextSplitter into smaller chunks. + """ + tum_chunks = [] + text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap) + + for doc in tum_documents: + content = doc.content + chunks = text_splitter.split_text(content) + + # Create WebsiteContent chunks and preserve metadata + for chunk in chunks: + tum_chunks.append(WebsiteContent( type=doc.type, - content=section, + content=chunk, link=doc.link, study_program=doc.study_program )) - logging.info(f"Split CIT document into {len(sections)} sections.") + logging.info(f"Split TUM document into {len(chunks)} chunks.") + + logging.info(f"Total TUM chunks: {len(tum_chunks)}") + return tum_chunks + + # TODO: Remove + @staticmethod + def split_cit_documents(cit_documents: List[WebsiteContent]) -> List[WebsiteContent]: + """ + Split CIT WebsiteContent documents into smaller chunks based on a predefined separator. + """ + cit_chunks = [] + + for doc in cit_documents: + sections = doc.content.split('----------------------------------------') + for section in sections: + section = section.strip() + if section: + # Create smaller WebsiteContent chunks and preserve metadata + cit_chunks.append(WebsiteContent( + type=doc.type, + content=section, + link=doc.link, + study_program=doc.study_program + )) + logging.info(f"Split CIT document into {len(sections)} sections.") - logging.info(f"Total TUM chunks: {len(cit_chunks)}") - return cit_chunks \ No newline at end of file + logging.info(f"Total TUM chunks: {len(cit_chunks)}") + return cit_chunks \ No newline at end of file diff --git a/app/injestion/injestion_handler.py b/app/injestion/injestion_handler.py new file mode 100644 index 0000000..bbba8cc --- /dev/null +++ b/app/injestion/injestion_handler.py @@ -0,0 +1,128 @@ +from typing import List, Optional + +from app.managers.weaviate_manager import WeaviateManager +from app.injestion.document_splitter import DocumentSplitter +from app.data.knowledge_base_requests import AddWebsiteRequest, EditDocumentRequest, EditSampleQuestionRequest, EditWebsiteRequest, AddDocumentRequest, AddSampleQuestionRequest, RefreshContentRequest +from app.data.database_requests import DatabaseDocument, DatabaseDocumentMetadata, DatabaseSampleQuestion + +class InjestionHandler: + def __init__(self, weaviate_manager: WeaviateManager, document_splitter: DocumentSplitter): + self.weaviate_manager = weaviate_manager + self.document_splitter = document_splitter + + def add_website(self, website: AddWebsiteRequest): + website_docs: List[DatabaseDocument] = [] + if website.type == "CIT": + chunks = self.document_splitter.split_cit_content(website.content) + for chunk in chunks: + website_docs.append( + DatabaseDocument( + id=website.id, + content=chunk, + link=website.link, + study_programs=self.prepare_study_programs(website.studyPrograms) + ) + ) + else: + chunks = self.document_splitter.split_tum_content(website.content) + for chunk in chunks: + website_docs.append( + DatabaseDocument( + id=website.id, + content=chunk, + link=website.link, + study_programs=self.prepare_study_programs(website.studyPrograms) + ) + ) + self.weaviate_manager.add_documents(website_docs) + + def add_document(self, document: AddDocumentRequest): + website_docs: List[DatabaseDocument] = [] + chunks = self.document_splitter.split_pdf_document(document.content) + for chunk in chunks: + website_docs.append( + DatabaseDocument( + id=document.id, + content=chunk, + study_programs=self.prepare_study_programs(document.studyPrograms) + ) + ) + self.weaviate_manager.add_documents(website_docs) + + def update_database_document(self, id: int, metadata: DatabaseDocumentMetadata): + self.weaviate_manager.update_documents(id, self.prepare_study_programs(metadata.study_programs)) + + def refresh_content(self, id: int, content: str): + metadata: Optional[DatabaseDocumentMetadata] = self.weaviate_manager.delete_by_kb_id(kb_id=id, return_metadata=True) + if metadata is not None: + website_docs: List[DatabaseDocument] = [] + if metadata.link is None: + chunks = self.document_splitter.split_pdf_document(content) + for chunk in chunks: + website_docs.append( + DatabaseDocument( + id=id, + content=chunk, + study_programs=self.prepare_study_programs(metadata.study_programs) + ) + ) + else: + if "cit.tum.de" in metadata.link: + chunks = self.document_splitter.split_cit_content(content) + for chunk in chunks: + website_docs.append( + DatabaseDocument( + id=id, + content=chunk, + link=metadata.link, + study_programs=self.prepare_study_programs(metadata.study_programs) + ) + ) + else: + chunks = self.document_splitter.split_tum_content(content) + for chunk in chunks: + website_docs.append( + DatabaseDocument( + id=id, + content=chunk, + link=metadata.link, + study_programs=self.prepare_study_programs(metadata.study_programs) + ) + ) + self.weaviate_manager.add_documents(website_docs) + + def delete_document(self, id: str): + self.weaviate_manager.delete_by_kb_id(kb_id=id, return_metadata=False) + + def add_sample_question(self, sample_question: AddSampleQuestionRequest): + database_sq = DatabaseSampleQuestion( + id=sample_question.id, + topic=sample_question.topic, + question=sample_question.question, + answer=sample_question.answer, + study_programs=sample_question.studyPrograms + ) + self.weaviate_manager.add_sample_question(database_sq) + + def update_sample_question(self, kb_id: int, sample_question: EditSampleQuestionRequest): + database_sq = DatabaseSampleQuestion( + id=kb_id, + topic=sample_question.topic, + question=sample_question.question, + answer=sample_question.answer, + study_programs=sample_question.studyPrograms + ) + self.weaviate_manager.update_sample_question(database_sq) + + def delete_sample_question(self, id: int): + self.weaviate_manager.delete_sample_question(id=id) + + # Handle content not specific to study programs + def prepare_study_programs(self, study_programs: List[str]) -> List[str]: + if len(study_programs) == 0: + return ["general"] + else: + return study_programs + + + \ No newline at end of file diff --git a/app/injestion/vector_store_initializer.py b/app/injestion/vector_store_initializer.py index 2aaff68..a17696b 100644 --- a/app/injestion/vector_store_initializer.py +++ b/app/injestion/vector_store_initializer.py @@ -4,12 +4,13 @@ from langchain.docstore.document import Document from app.injestion.document_loader import load_website_content_from_folder, load_pdf_documents_from_folder, load_qa_pairs_from_folder -from app.injestion.document_splitter import split_cit_documents, split_tum_documents, split_pdf_documents +from app.injestion.document_splitter import DocumentSplitter from app.managers.weaviate_manager import WeaviateManager from app.utils.environment import config from app.data.user_requests import SampleQuestion, WebsiteContent +# TODO: Remove def initialize_vectorstores(base_folder: str, qa_folder: str, weaviate_manager: WeaviateManager): """ Initializes vector stores by adding documents to Weaviate with their embeddings. @@ -40,15 +41,15 @@ def initialize_vectorstores(base_folder: str, qa_folder: str, weaviate_manager: # PDF files pdf_docs: List[Document] = load_pdf_documents_from_folder(base_folder) - pdf_docs_split: List[Document] = split_pdf_documents(pdf_documents=pdf_docs) + pdf_docs_split: List[Document] = DocumentSplitter.split_pdf_documents(pdf_documents=pdf_docs) weaviate_manager.add_documents(pdf_docs_split) # Website content website_content: List[WebsiteContent] = load_website_content_from_folder(base_folder) cit_content = [content for content in website_content if content.type == "CIT"] tum_content = [content for content in website_content if content.type == "TUM"] - cit_chunks = split_cit_documents(cit_content) - tum_chunks = split_tum_documents(tum_content) + cit_chunks = DocumentSplitter.split_cit_documents(cit_content) + tum_chunks = DocumentSplitter.split_tum_documents(tum_content) split_website_content: List[WebsiteContent] = cit_chunks + tum_chunks weaviate_manager.add_website_content(split_website_content) diff --git a/app/main.py b/app/main.py index 7d0514d..3f76dff 100644 --- a/app/main.py +++ b/app/main.py @@ -1,6 +1,9 @@ import logging -from app.api.question_router import question_router, admin_router, auth_router +from app.api.question_router import question_router +from app.api.admin_router import admin_router +from app.api.auth_router import auth_router +from app.api.knowledge_router import knowledge_router from app.utils.setup_logging import setup_logging setup_logging() @@ -30,7 +33,7 @@ async def lifespan(app: FastAPI): app.add_middleware( CORSMiddleware, - allow_origins=["http://localhost:4200"], + allow_origins=["http://localhost:9007"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], @@ -50,6 +53,7 @@ async def validation_exception_handler(request: Request, exc: RequestValidationE app.include_router(router=question_router) app.include_router(router=admin_router) app.include_router(router=auth_router) +app.include_router(router=knowledge_router) if __name__ == "__main__": logging.info("Starting FastAPI server") diff --git a/app/managers/weaviate_manager.py b/app/managers/weaviate_manager.py index b3daab3..772ffa6 100644 --- a/app/managers/weaviate_manager.py +++ b/app/managers/weaviate_manager.py @@ -1,6 +1,6 @@ import logging from enum import Enum -from typing import List, Union, Tuple +from typing import List, Union, Tuple, Optional import weaviate import weaviate.classes as wvc @@ -15,6 +15,7 @@ from app.retrieval_strategies.reranker import Reranker from app.utils.environment import config from app.data.user_requests import SampleQuestion, WebsiteContent +from app.data.database_requests import DatabaseDocument, DatabaseSampleQuestion, DatabaseDocumentMetadata class DocumentSchema(Enum): @@ -22,7 +23,8 @@ class DocumentSchema(Enum): Schema for the embedded chunks """ COLLECTION_NAME = "CITKnowledgeBase" - STUDY_PROGRAM = "study_program" + KNOWLEDGE_BASE_ID = "kb_id" + STUDY_PROGRAMS = "study_programs" CONTENT = "content" LINK = "link" @@ -32,8 +34,9 @@ class QASchema(Enum): Schema for the QA Collection """ COLLECTION_NAME = "QACollection" + KNOWLEDGE_BASE_ID = "kb_id" TOPIC = "topic" - STUDY_PROGRAM = "study_program" + STUDY_PROGRAMS = "study_programs" QUESTION = "question" ANSWER = "answer" @@ -65,9 +68,17 @@ def initialize_schema(self) -> Collection: # Define properties for the collection properties = [ Property( - name=DocumentSchema.STUDY_PROGRAM.value, + name=DocumentSchema.KNOWLEDGE_BASE_ID.value, + description="The Angelos ID of the document", + data_type=DataType.INT, + index_filterable=True, + index_range_filters=False, + index_searchable=False + ), + Property( + name=DocumentSchema.STUDY_PROGRAMS.value, description="The study program of the document", - data_type=DataType.TEXT, + data_type=[DataType.TEXT], index_filterable=True, index_range_filters=True, index_searchable=True @@ -132,16 +143,24 @@ def initialize_qa_schema(self) -> Collection: # Define properties for the QA collection properties = [ + Property( + name=QASchema.KNOWLEDGE_BASE_ID.value, + description="The Angelos ID of the sample question", + data_type=DataType.INT, + index_filterable=True, + index_range_filters=False, + index_searchable=False + ), Property( name=QASchema.TOPIC.value, - description="The topic of the conversation", + description="The topic of the sample question", data_type=DataType.TEXT, index_inverted=False ), Property( - name=QASchema.STUDY_PROGRAM.value, - description="The study program of the student", - data_type=DataType.TEXT, + name=QASchema.STUDY_PROGRAMS.value, + description="The relevant study program", + data_type=[DataType.TEXT], index_filterable=True, index_range_filters=True, index_searchable=True @@ -360,22 +379,10 @@ def delete_collection(self): else: logging.warning(f"Collection {collection_name} does not exist") return False - - def add_document(self, text: str, study_program: str): - """Add a document with classification to Weaviate.""" - try: - text_embedding = self.model.embed(text) - # logging.info(f"Adding document with embedding: {text_embedding}") - self.documents.data.insert(properties={DocumentSchema.CONTENT.value: text, - DocumentSchema.STUDY_PROGRAM.value: study_program}, - vector=text_embedding) - logging.info(f"Document successfully added with study program: {study_program}") - except Exception as e: - logging.error(f"Failed to add document: {e}") - - def add_documents(self, chunks: List[Document]): + + def add_documents(self, chunks: List[DatabaseDocument]): """ - Add chunks of Document objects to the vector database. + Add chunks of DatabaseDocument objects to the vector database. """ try: batch_size = 500 @@ -387,19 +394,21 @@ def add_documents(self, chunks: List[Document]): chunk_batch = chunks[i:i + batch_size] if isinstance(self.model, OllamaModel): # For other models, embed each chunk one at a time - embeddings = [self.model.embed(chunk.page_content) for chunk in chunk_batch] + embeddings = [self.model.embed(chunk.content) for chunk in chunk_batch] else: - texts = [chunk.page_content for chunk in chunk_batch] + texts = [chunk.content for chunk in chunk_batch] embeddings = self.model.embed_batch(texts) # Embed in batch # Add the chunks to the vector database in a batch with self.documents.batch.rate_limit(requests_per_minute=600) as batch: for index, chunk in enumerate(chunk_batch): - study_program = chunk.metadata.get("study_program", "general") + study_programs = chunk.study_programs # Prepare properties properties = { - DocumentSchema.CONTENT.value: chunk.page_content, - DocumentSchema.STUDY_PROGRAM.value: study_program, + DocumentSchema.KNOWLEDGE_BASE_ID.value: chunk.id, + DocumentSchema.CONTENT.value: chunk.content, + DocumentSchema.LINK.value: chunk.link, + DocumentSchema.STUDY_PROGRAMS.value: study_programs, } # Add the document chunk to the batch @@ -407,75 +416,164 @@ def add_documents(self, chunks: List[Document]): except Exception as e: logging.error(f"Error adding document: {e}") - - def add_website_content(self, website_contents: List[WebsiteContent]): + raise + + def delete_by_kb_id(self, kb_id: int, return_metadata: bool) -> Optional[DatabaseDocumentMetadata]: """ - Add chunks of WebsiteContent objects to the vector database, handling metadata and optional fields. + Delete all database entries by kb_id and return other properties """ try: - batch_size = 500 - num_chunks = len(website_contents) - logging.info(f"Adding {num_chunks} website contents in batches of {batch_size}") - - # Split the contents into batches of 500 for embedding and adding - for i in range(0, num_chunks, batch_size): - content_batch = website_contents[i:i + batch_size] + if return_metadata: + query_result = self.documents.query.fetch_objects( + filters=Filter.by_property(DocumentSchema.KNOWLEDGE_BASE_ID.value).equal(kb_id) + ) - # If using OpenAI model, embed in batches, otherwise embed one by one - if isinstance(self.model, OllamaModel): - embeddings = [self.model.embed(content.content) for content in content_batch] + if not query_result.objects: + logging.info(f"No documents found with knowledge_base_id: {kb_id}") + return None else: - texts = [content.content for content in content_batch] - embeddings = self.model.embed_batch(texts) # Batch embed + result = query_result.objects[0] + properties = result.properties + metadata = DatabaseDocumentMetadata( + link=properties[DocumentSchema.LINK], + study_programs=properties[DocumentSchema.STUDY_PROGRAMS] + ) + self.documents.data.delete_many( + where=Filter.by_property(DocumentSchema.KNOWLEDGE_BASE_ID).equal(kb_id) + ) + return metadata + else: + self.documents.data.delete_many( + where=Filter.by_property(DocumentSchema.KNOWLEDGE_BASE_ID).equal(kb_id) + ) + return None + + except Exception as e: + logging.error(f"Error deleting documents: {e}") + + def update_documents(self, kb_id: int, document: DatabaseDocumentMetadata): + try: + query_result = self.documents.query.fetch_objects( + filters=Filter.by_property(DocumentSchema.KNOWLEDGE_BASE_ID.value).equal(document.id) + ) - # Add the contents to the vector database in a batch - with self.documents.batch.rate_limit(requests_per_minute=600) as batch: - for index, content in enumerate(content_batch): - properties = { - DocumentSchema.CONTENT.value: content.content, - DocumentSchema.STUDY_PROGRAM.value: content.study_program, - DocumentSchema.LINK.value: content.link - } + if not query_result.objects: + logging.info(f"No documents found with knowledge_base_id: {document.id}") + return - # Add the content chunk to the batch - batch.add_object(properties=properties, vector=embeddings[index]) + # Iterate through the results and update the properties + for result in query_result.objects: + uuid = result.uuid + properties = result.properties + properties[DocumentSchema.LINK] = document.link # Update the link + properties[DocumentSchema.STUDY_PROGRAMS] = document.study_programs # Update the study programs + + # Reinsert the object with the updated properties + self.documents.data.update( + uuid=uuid, + properties=properties + ) + logging.info(f"Updated title for documents with knowledge_base_id: {kb_id}") except Exception as e: - logging.error(f"Error adding website content: {e}") + logging.error(f"Error updating title for knowledge_base_id {kb_id}: {e}") + raise - def add_qa_pairs(self, qa_pairs: List[SampleQuestion]): + def add_sample_question(self, sample_question: DatabaseSampleQuestion): """ - Adds QA pairs to the QA collection in Weaviate. + Adds a sample question to the QA collection in Weaviate. Args: - - qa_pairs: List of dictionaries, each containing 'topic', 'question', and 'answer' fields. + - The SampleQuestion to add Returns: - None """ - for qa_pair in qa_pairs: - try: - # Prepare the data entry for insertion - topic = qa_pair.topic - study_program = qa_pair.study_program - question = qa_pair.question - answer = qa_pair.answer - - # Add to QA collection in Weaviate - embedding = self.model.embed(question) - - self.qa_collection.data.insert( - properties={ - QASchema.TOPIC.value: topic, - QASchema.STUDY_PROGRAM.value: study_program, - QASchema.QUESTION.value: question, - QASchema.ANSWER.value: answer - }, + try: + # Prepare the data entry for insertion + kb_id = sample_question.id + topic = sample_question.topic + study_programs = sample_question.study_programs + question = sample_question.question + answer = sample_question.answer + + # Add to QA collection in Weaviate + embedding = self.model.embed(question) + + self.qa_collection.data.insert( + properties={ + QASchema.KNOWLEDGE_BASE_ID.value: kb_id, + QASchema.TOPIC.value: topic, + QASchema.STUDY_PROGRAMS.value: study_programs, + QASchema.QUESTION.value: question, + QASchema.ANSWER.value: answer + }, + vector=embedding + ) + logging.info(f"Inserted QA pair with topic: {sample_question.topic}") + except Exception as e: + logging.error(f"Failed to insert sample question with topic {sample_question.topic}: {e}") + raise + + def update_sample_question(self, sample_question: DatabaseSampleQuestion): + """ + Adds a sample question to the QA collection in Weaviate. + + Args: + - The SampleQuestion to add + + Returns: + - None + """ + try: + # Prepare the data entry for insertion + kb_id = sample_question.id + topic = sample_question.topic + study_programs = sample_question.study_programs + question = sample_question.question + answer = sample_question.answer + + # Add to QA collection in Weaviate + embedding = self.model.embed(question) + + query_result = self.documents.query.fetch_objects( + filters=Filter.by_property(QASchema.KNOWLEDGE_BASE_ID.value).equal(sample_question.id) + ) + + if not query_result.objects: + logging.info(f"No documents found with knowledge_base_id: {sample_question.id}") + return + + # Iterate through the results and update the properties + for result in query_result.objects: + uuid = result.uuid + properties = result.properties + properties[QASchema.KNOWLEDGE_BASE_ID.value] = kb_id + properties[QASchema.TOPIC.value] = topic + properties[QASchema.STUDY_PROGRAMS.value] = study_programs + properties[QASchema.QUESTION.value] = question + properties[QASchema.ANSWER].value = answer + + # Reinsert the object with the updated properties + self.qa_collection.data.update( + uuid=uuid, + properties=properties, vector=embedding ) - logging.info(f"Inserted QA pair with topic: {qa_pair.topic}") - except Exception as e: - logging.error(f"Failed to insert QA pair with topic {qa_pair.topic}: {e}") + + logging.info(f"Updated title for documents with knowledge_base_id: {kb_id}") + except Exception as e: + logging.error(f"Failed to update sample question with topic {sample_question.topic}: {e}") + raise + + def delete_sample_question(self, id: int): + try: + self.documents.data.delete_many( + where=Filter.by_property(DocumentSchema.KNOWLEDGE_BASE_ID).equal(id) + ) + except Exception as e: + logging.error(f"Failed to update sample question with ID {id}: {e}") + raise @staticmethod def normalize_study_program_name(study_program: str) -> str: diff --git a/app/utils/dependencies.py b/app/utils/dependencies.py index 835638f..3ef033a 100644 --- a/app/utils/dependencies.py +++ b/app/utils/dependencies.py @@ -4,6 +4,8 @@ from app.retrieval_strategies.reranker import Reranker from app.models.model_loader import get_model from app.prompt.prompt_manager import PromptManager +from app.injestion.document_splitter import DocumentSplitter +from app.injestion.injestion_handler import InjestionHandler from app.utils.environment import config # Initialize resources @@ -11,8 +13,10 @@ reranker = Reranker(model=model, api_key_en=config.COHERE_API_KEY_EN, api_key_multi=config.COHERE_API_KEY_MULTI) weaviate_manager = WeaviateManager(config.WEAVIATE_URL, embedding_model=model, reranker=reranker) prompt_manager = PromptManager() +document_splitter = DocumentSplitter() request_handler = RequestHandler(weaviate_manager=weaviate_manager, model=model, prompt_manager=prompt_manager) auth_handler = AuthHandler(angelos_api_key=config.ANGELOS_APP_API_KEY, secret_key=config.API_ENDPOINT_KEY, algorithm=config.ALGORITHM, access_token_expires_minutes=config.ACCESS_TOKEN_EXPIRE_MINUTES) +injestion_handler = InjestionHandler(weaviate_manager=weaviate_manager, document_splitter=document_splitter) # Provide a shutdown mechanism for the model def shutdown_model(): From 6191b646157bf1a6d8261681b0d4ca8b21b76abd Mon Sep 17 00:00:00 2001 From: Nino Richter Date: Mon, 23 Dec 2024 11:45:08 +0100 Subject: [PATCH 2/7] Adjust injestion to handle str ids --- app/api/knowledge_router.py | 59 ++++++++----- app/data/database_requests.py | 7 +- app/data/knowledge_base_requests.py | 20 +++-- app/injestion/document_splitter.py | 3 +- app/injestion/injestion_handler.py | 86 ++++++++++++------ app/main.py | 13 ++- app/managers/weaviate_manager.py | 130 ++++++++++++++++++++-------- docker/weaviate.local.yml | 22 +++++ docker/weaviate.yml | 16 ++-- 9 files changed, 247 insertions(+), 109 deletions(-) create mode 100644 docker/weaviate.local.yml diff --git a/app/api/knowledge_router.py b/app/api/knowledge_router.py index 19cb019..b6777bc 100644 --- a/app/api/knowledge_router.py +++ b/app/api/knowledge_router.py @@ -1,28 +1,37 @@ +from typing import List from fastapi import HTTPException, APIRouter, status, Response, Depends from app.data.knowledge_base_requests import AddWebsiteRequest, EditDocumentRequest, EditSampleQuestionRequest, EditWebsiteRequest, AddDocumentRequest, AddSampleQuestionRequest, RefreshContentRequest -from app.utils.dependencies import injestion_handler +from app.utils.dependencies import injestion_handler, auth_handler from app.data.database_requests import DatabaseDocumentMetadata -knowledge_router = APIRouter(prefix="/knowledge", tags=["knowledge"]) +knowledge_router = APIRouter(prefix="/api/knowledge", tags=["knowledge"]) -@knowledge_router.post("/website/add") +@knowledge_router.post("/website/add", dependencies=[Depends(auth_handler.verify_api_key)]) async def add_website(body: AddWebsiteRequest): try: injestion_handler.add_website(body) return Response(status_code=200) except Exception as e: return Response(status_code=500) + +@knowledge_router.post("/website/addBatch", dependencies=[Depends(auth_handler.verify_api_key)]) +async def add_websites(body: List[AddWebsiteRequest]): + try: + injestion_handler.add_websites(body) + return Response(status_code=200) + except Exception as e: + return Response(status_code=500) -@knowledge_router.post("/website/{id}/refresh") -async def refresh_website(id: int, body: RefreshContentRequest): +@knowledge_router.post("/website/{id}/refresh", dependencies=[Depends(auth_handler.verify_api_key)]) +async def refresh_website(id: str, body: RefreshContentRequest): try: injestion_handler.refresh_content(id=id, content=body.content) return Response(status_code=200) except Exception as e: return Response(status_code=500) -@knowledge_router.post("/website/{id}/update") -async def update_website(id: int, body: EditWebsiteRequest): +@knowledge_router.post("/website/{id}/update", dependencies=[Depends(auth_handler.verify_api_key)]) +async def update_website(id: str, body: EditWebsiteRequest): try: metadata: DatabaseDocumentMetadata = DatabaseDocumentMetadata( study_programs=body.studyPrograms @@ -32,8 +41,8 @@ async def update_website(id: int, body: EditWebsiteRequest): except Exception as e: return Response(status_code=500) -@knowledge_router.delete("/website/{id}/delete") -async def delete_website(id: int): +@knowledge_router.delete("/website/{id}/delete", dependencies=[Depends(auth_handler.verify_api_key)]) +async def delete_website(id: str): try: injestion_handler.delete_document(id=id) return Response(status_code=200) @@ -43,24 +52,32 @@ async def delete_website(id: int): # === Document Endpoints === -@knowledge_router.post("/document/add") +@knowledge_router.post("/document/add", dependencies=[Depends(auth_handler.verify_api_key)]) async def add_document(body: AddDocumentRequest): try: injestion_handler.add_document(body) return Response(status_code=200) except Exception as e: return Response(status_code=500) + +@knowledge_router.post("/sample-question/addBatch", dependencies=[Depends(auth_handler.verify_api_key)]) +async def add_sample_questions(body: List[AddSampleQuestionRequest]): + try: + injestion_handler.add_sample_questions(body) + return Response(status_code=200) + except Exception as e: + return Response(status_code=500) -@knowledge_router.post("/document/{id}/refresh") -async def refresh_document(id: int, body: RefreshContentRequest): +@knowledge_router.post("/document/{id}/refresh", dependencies=[Depends(auth_handler.verify_api_key)]) +async def refresh_document(id: str, body: RefreshContentRequest): try: injestion_handler.refresh_content(id=id, content=body.content) return Response(status_code=200) except Exception as e: return Response(status_code=500) -@knowledge_router.post("/document/{id}/edit") -async def edit_document(id: int, body: EditDocumentRequest): +@knowledge_router.post("/document/{id}/edit", dependencies=[Depends(auth_handler.verify_api_key)]) +async def edit_document(id: str, body: EditDocumentRequest): try: metadata: DatabaseDocumentMetadata = DatabaseDocumentMetadata( study_programs=body.studyPrograms @@ -70,8 +87,8 @@ async def edit_document(id: int, body: EditDocumentRequest): except Exception as e: return Response(status_code=500) -@knowledge_router.delete("/document/{id}/delete") -async def delete_document(id: int): +@knowledge_router.delete("/document/{id}/delete", dependencies=[Depends(auth_handler.verify_api_key)]) +async def delete_document(id: str): try: injestion_handler.delete_document(id=id) return Response(status_code=200) @@ -81,7 +98,7 @@ async def delete_document(id: int): # === Sample Question Endpoints === -@knowledge_router.post("/sample-question/add") +@knowledge_router.post("/sample-question/add", dependencies=[Depends(auth_handler.verify_api_key)]) async def add_sample_question(body: AddSampleQuestionRequest): try: injestion_handler.add_sample_question(sample_question=body) @@ -89,16 +106,16 @@ async def add_sample_question(body: AddSampleQuestionRequest): except Exception as e: return Response(status_code=500) -@knowledge_router.post("/sample-question/{id}/edit") -async def edit_sample_question(id: int, body: EditSampleQuestionRequest): +@knowledge_router.post("/sample-question/{id}/edit", dependencies=[Depends(auth_handler.verify_api_key)]) +async def edit_sample_question(id: str, body: EditSampleQuestionRequest): try: injestion_handler.update_sample_question(kb_id=id, sample_question=body) return Response(status_code=200) except Exception as e: return Response(status_code=500) -@knowledge_router.delete("/sample-question/{id}/delete") -async def delete_sample_question(id: int): +@knowledge_router.delete("/sample-question/{id}/delete", dependencies=[Depends(auth_handler.verify_api_key)]) +async def delete_sample_question(id: str): try: injestion_handler.delete_sample_question(id=id) return Response(status_code=200) diff --git a/app/data/database_requests.py b/app/data/database_requests.py index b2d6da3..d3c506e 100644 --- a/app/data/database_requests.py +++ b/app/data/database_requests.py @@ -2,18 +2,21 @@ from typing import List, Optional class DatabaseDocument(BaseModel): - id: int + id: str link: Optional[str] = None study_programs: List[str] content: str + org_id: int class DatabaseDocumentMetadata(BaseModel): link: Optional[str] = None study_programs: List[str] + org_id: int class DatabaseSampleQuestion(BaseModel): - id: int + id: str topic: str question: str answer: str study_programs: List[str] + org_id: int diff --git a/app/data/knowledge_base_requests.py b/app/data/knowledge_base_requests.py index d682975..0a1c9ff 100644 --- a/app/data/knowledge_base_requests.py +++ b/app/data/knowledge_base_requests.py @@ -3,7 +3,8 @@ class AddWebsiteRequest(BaseModel): - id: int + id: str + orgId: int title: str link: str studyPrograms: List[str] @@ -15,27 +16,30 @@ class RefreshContentRequest(BaseModel): class EditWebsiteRequest(BaseModel): title: str - studyPrograms: List[int] + studyPrograms: List[str] class AddDocumentRequest(BaseModel): - id: int + id: str + orgId: int title: str - studyPrograms: List[int] + studyPrograms: List[str] content: str class EditDocumentRequest(BaseModel): title: str - studyPrograms: List[int] + studyPrograms: List[str] class AddSampleQuestionRequest(BaseModel): - id: int + id: str + orgId: int question: str answer: str topic: str - studyPrograms: List[int] + studyPrograms: List[str] class EditSampleQuestionRequest(BaseModel): question: str answer: str topic: str - studyPrograms: List[int] \ No newline at end of file + studyPrograms: List[str] + orgId: int \ No newline at end of file diff --git a/app/injestion/document_splitter.py b/app/injestion/document_splitter.py index 5389a7c..3adf8ae 100644 --- a/app/injestion/document_splitter.py +++ b/app/injestion/document_splitter.py @@ -13,7 +13,8 @@ def split_cit_content(self, content: str): sections = content.split('----------------------------------------') for section in sections: section = section.strip() - result.append(section) + if section: + result.append(section) return result def split_tum_content(self, content: str, chunk_size: int = 1200, chunk_overlap: int = 200): diff --git a/app/injestion/injestion_handler.py b/app/injestion/injestion_handler.py index bbba8cc..28497ad 100644 --- a/app/injestion/injestion_handler.py +++ b/app/injestion/injestion_handler.py @@ -14,27 +14,42 @@ def add_website(self, website: AddWebsiteRequest): website_docs: List[DatabaseDocument] = [] if website.type == "CIT": chunks = self.document_splitter.split_cit_content(website.content) - for chunk in chunks: - website_docs.append( - DatabaseDocument( - id=website.id, - content=chunk, - link=website.link, - study_programs=self.prepare_study_programs(website.studyPrograms) - ) - ) else: - chunks = self.document_splitter.split_tum_content(website.content) + chunks = self.document_splitter.split_tum_content(website.content) + + for chunk in chunks: + website_docs.append( + DatabaseDocument( + id=website.id, + content=chunk, + link=website.link, + study_programs=self.prepare_study_programs(website.studyPrograms), + org_id=website.orgId + ) + ) + + self.weaviate_manager.add_documents(website_docs) + + def add_websites(self, websites: List[AddWebsiteRequest]): + all_website_docs: List[DatabaseDocument] = [] + for website in websites: + if website.type == "CIT": + chunks = self.document_splitter.split_cit_content(website.content) + else: + chunks = self.document_splitter.split_tum_content(website.content) + for chunk in chunks: - website_docs.append( + all_website_docs.append( DatabaseDocument( id=website.id, content=chunk, link=website.link, - study_programs=self.prepare_study_programs(website.studyPrograms) + study_programs=self.prepare_study_programs(website.studyPrograms), + org_id=website.orgId ) ) - self.weaviate_manager.add_documents(website_docs) + + self.weaviate_manager.add_documents(all_website_docs) def add_document(self, document: AddDocumentRequest): website_docs: List[DatabaseDocument] = [] @@ -44,15 +59,16 @@ def add_document(self, document: AddDocumentRequest): DatabaseDocument( id=document.id, content=chunk, - study_programs=self.prepare_study_programs(document.studyPrograms) + study_programs=self.prepare_study_programs(document.studyPrograms), + org_id=document.orgId ) ) self.weaviate_manager.add_documents(website_docs) - def update_database_document(self, id: int, metadata: DatabaseDocumentMetadata): + def update_database_document(self, id: str, metadata: DatabaseDocumentMetadata): self.weaviate_manager.update_documents(id, self.prepare_study_programs(metadata.study_programs)) - def refresh_content(self, id: int, content: str): + def refresh_content(self, id: str, content: str): metadata: Optional[DatabaseDocumentMetadata] = self.weaviate_manager.delete_by_kb_id(kb_id=id, return_metadata=True) if metadata is not None: website_docs: List[DatabaseDocument] = [] @@ -63,7 +79,8 @@ def refresh_content(self, id: int, content: str): DatabaseDocument( id=id, content=chunk, - study_programs=self.prepare_study_programs(metadata.study_programs) + study_programs=self.prepare_study_programs(metadata.study_programs), + org_id=metadata.org_id ) ) else: @@ -75,7 +92,8 @@ def refresh_content(self, id: int, content: str): id=id, content=chunk, link=metadata.link, - study_programs=self.prepare_study_programs(metadata.study_programs) + study_programs=self.prepare_study_programs(metadata.study_programs), + org_id=metadata.org_id ) ) else: @@ -86,7 +104,8 @@ def refresh_content(self, id: int, content: str): id=id, content=chunk, link=metadata.link, - study_programs=self.prepare_study_programs(metadata.study_programs) + study_programs=self.prepare_study_programs(metadata.study_programs), + org_id=metadata.org_id ) ) self.weaviate_manager.add_documents(website_docs) @@ -100,21 +119,38 @@ def add_sample_question(self, sample_question: AddSampleQuestionRequest): topic=sample_question.topic, question=sample_question.question, answer=sample_question.answer, - study_programs=sample_question.studyPrograms + study_programs=self.prepare_study_programs(sample_question.studyPrograms), + org_id=sample_question.orgId ) self.weaviate_manager.add_sample_question(database_sq) - def update_sample_question(self, kb_id: int, sample_question: EditSampleQuestionRequest): + def add_sample_questions(self, sample_questions: List[AddSampleQuestionRequest]): + db_questions = [] + for sq in sample_questions: + db_questions.append( + DatabaseSampleQuestion( + id=sq.id, + topic=sq.topic, + question=sq.question, + answer=sq.answer, + study_programs=self.prepare_study_programs(sq.studyPrograms), + org_id=sq.orgId + ) + ) + self.weaviate_manager.add_sample_questions(db_questions) + + def update_sample_question(self, kb_id: str, sample_question: EditSampleQuestionRequest): database_sq = DatabaseSampleQuestion( id=kb_id, topic=sample_question.topic, question=sample_question.question, answer=sample_question.answer, - study_programs=sample_question.studyPrograms + study_programs=self.prepare_study_programs(sample_question.studyPrograms), + org_id=sample_question.orgId ) self.weaviate_manager.update_sample_question(database_sq) - def delete_sample_question(self, id: int): + def delete_sample_question(self, id: str): self.weaviate_manager.delete_sample_question(id=id) # Handle content not specific to study programs @@ -122,7 +158,5 @@ def prepare_study_programs(self, study_programs: List[str]) -> List[str]: if len(study_programs) == 0: return ["general"] else: - return study_programs - - + return [sp.replace(" ", "-").lower() for sp in study_programs] \ No newline at end of file diff --git a/app/main.py b/app/main.py index 3f76dff..7024f85 100644 --- a/app/main.py +++ b/app/main.py @@ -1,14 +1,15 @@ import logging +import uvicorn + +from app.utils.setup_logging import setup_logging + +setup_logging() from app.api.question_router import question_router from app.api.admin_router import admin_router from app.api.auth_router import auth_router from app.api.knowledge_router import knowledge_router -from app.utils.setup_logging import setup_logging - -setup_logging() -import uvicorn from contextlib import asynccontextmanager from fastapi import FastAPI, Request, status from fastapi.exceptions import RequestValidationError @@ -18,8 +19,6 @@ from app.utils.dependencies import shutdown_model -logging.info("Starting application...") - @asynccontextmanager async def lifespan(app: FastAPI): @@ -33,7 +32,7 @@ async def lifespan(app: FastAPI): app.add_middleware( CORSMiddleware, - allow_origins=["http://localhost:9007"], + allow_origins=["http://localhost:9007", "http://localhost:4200"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], diff --git a/app/managers/weaviate_manager.py b/app/managers/weaviate_manager.py index 772ffa6..53b94e0 100644 --- a/app/managers/weaviate_manager.py +++ b/app/managers/weaviate_manager.py @@ -27,6 +27,7 @@ class DocumentSchema(Enum): STUDY_PROGRAMS = "study_programs" CONTENT = "content" LINK = "link" + ORGANISATION_ID = "org_id" class QASchema(Enum): @@ -39,6 +40,7 @@ class QASchema(Enum): STUDY_PROGRAMS = "study_programs" QUESTION = "question" ANSWER = "answer" + ORGANISATION_ID = "org_id" class WeaviateManager: @@ -70,7 +72,7 @@ def initialize_schema(self) -> Collection: Property( name=DocumentSchema.KNOWLEDGE_BASE_ID.value, description="The Angelos ID of the document", - data_type=DataType.INT, + data_type=DataType.TEXT, index_filterable=True, index_range_filters=False, index_searchable=False @@ -78,23 +80,32 @@ def initialize_schema(self) -> Collection: Property( name=DocumentSchema.STUDY_PROGRAMS.value, description="The study program of the document", - data_type=[DataType.TEXT], + data_type=DataType.TEXT_ARRAY, index_filterable=True, - index_range_filters=True, - index_searchable=True + index_range_filters=False, + index_searchable=False, + index_inverted=True ), Property( name=DocumentSchema.CONTENT.value, description="The content of the document", data_type=DataType.TEXT, - index_inverted=False # Disable inverted index if not needed + index_inverted=False ), Property( name=DocumentSchema.LINK.value, description="The link of the document", data_type=DataType.TEXT, - index_inverted=False # Disable inverted index if not needed - ) + index_inverted=False + ), + Property( + name=DocumentSchema.ORGANISATION_ID.value, + description="The Organisation ID of the document", + data_type=DataType.INT, + index_filterable=True, + index_range_filters=False, + index_searchable=False + ), ] # Define vector index configuration (use cosine distance metric) @@ -146,7 +157,7 @@ def initialize_qa_schema(self) -> Collection: Property( name=QASchema.KNOWLEDGE_BASE_ID.value, description="The Angelos ID of the sample question", - data_type=DataType.INT, + data_type=DataType.TEXT, index_filterable=True, index_range_filters=False, index_searchable=False @@ -160,10 +171,11 @@ def initialize_qa_schema(self) -> Collection: Property( name=QASchema.STUDY_PROGRAMS.value, description="The relevant study program", - data_type=[DataType.TEXT], + data_type=DataType.TEXT_ARRAY, index_filterable=True, - index_range_filters=True, - index_searchable=True + index_range_filters=False, + index_searchable=False, + index_inverted=True ), Property( name=QASchema.QUESTION.value, @@ -177,6 +189,14 @@ def initialize_qa_schema(self) -> Collection: data_type=DataType.TEXT, index_inverted=False ), + Property( + name=QASchema.ORGANISATION_ID.value, + description="The Organisation ID of the sample question", + data_type=DataType.INT, + index_filterable=True, + index_range_filters=False, + index_searchable=False + ), ] # Define vector index configuration @@ -229,6 +249,7 @@ def get_relevant_context(self, question: str, study_program: str, language: str, try: # Define the number of documents to retrieve min_relevance_score = 0.25 + if study_program.lower() != "general": limit = 10 min_relevance_score = 0.15 @@ -238,14 +259,12 @@ def get_relevant_context(self, question: str, study_program: str, language: str, # Normalize the study program name and calculate its length study_program = WeaviateManager.normalize_study_program_name(study_program) - study_program_length = len(study_program) # Perform the vector-based query with filters query_result = self.documents.query.near_vector( near_vector=question_embedding, filters=Filter.all_of([ - Filter.by_property(DocumentSchema.STUDY_PROGRAM.value).equal(study_program), - Filter.by_property(DocumentSchema.STUDY_PROGRAM.value, length=True).equal(study_program_length), + Filter.by_property(DocumentSchema.STUDY_PROGRAMS.value).contains_any([study_program]) ]), limit=limit, # include_vector=True, @@ -263,7 +282,6 @@ def get_relevant_context(self, question: str, study_program: str, language: str, # Remove exact duplicates from context_list content_content_list = WeaviateManager.remove_exact_duplicates(content_content_list) - # logging.info(f"Context list length after removing exact duplicates: {len(context_list)}") # Rerank the unique contexts using Cohere sorted_context = self.reranker.rerank_with_cohere(context_list=content_content_list, query=question, @@ -326,7 +344,7 @@ def get_relevant_sample_questions(self, question: str, language: str) -> List[Sa topic = result.properties.get(QASchema.TOPIC.value, "") retrieved_question = result.properties.get(QASchema.QUESTION.value, "") answer = result.properties.get(QASchema.ANSWER.value, "") - study_program = result.properties.get(QASchema.STUDY_PROGRAM, "") + study_program = result.properties.get(QASchema.STUDY_PROGRAMS, []) sample_questions.append(SampleQuestion(topic=topic, question=retrieved_question, answer=answer, study_program=study_program)) # Rerank the sample questions using the reranker @@ -393,22 +411,23 @@ def add_documents(self, chunks: List[DatabaseDocument]): for i in range(0, num_chunks, batch_size): chunk_batch = chunks[i:i + batch_size] if isinstance(self.model, OllamaModel): - # For other models, embed each chunk one at a time + # For Ollama models, embed each chunk one at a time embeddings = [self.model.embed(chunk.content) for chunk in chunk_batch] else: texts = [chunk.content for chunk in chunk_batch] embeddings = self.model.embed_batch(texts) # Embed in batch - + logging.info(f"Chunk batch size: {len(chunk_batch)}") + # Add the chunks to the vector database in a batch with self.documents.batch.rate_limit(requests_per_minute=600) as batch: for index, chunk in enumerate(chunk_batch): - study_programs = chunk.study_programs # Prepare properties properties = { DocumentSchema.KNOWLEDGE_BASE_ID.value: chunk.id, DocumentSchema.CONTENT.value: chunk.content, DocumentSchema.LINK.value: chunk.link, - DocumentSchema.STUDY_PROGRAMS.value: study_programs, + DocumentSchema.STUDY_PROGRAMS.value: chunk.study_programs, + DocumentSchema.ORGANISATION_ID.value: chunk.org_id } # Add the document chunk to the batch @@ -418,7 +437,7 @@ def add_documents(self, chunks: List[DatabaseDocument]): logging.error(f"Error adding document: {e}") raise - def delete_by_kb_id(self, kb_id: int, return_metadata: bool) -> Optional[DatabaseDocumentMetadata]: + def delete_by_kb_id(self, kb_id: str, return_metadata: bool) -> Optional[DatabaseDocumentMetadata]: """ Delete all database entries by kb_id and return other properties """ @@ -435,23 +454,24 @@ def delete_by_kb_id(self, kb_id: int, return_metadata: bool) -> Optional[Databas result = query_result.objects[0] properties = result.properties metadata = DatabaseDocumentMetadata( - link=properties[DocumentSchema.LINK], - study_programs=properties[DocumentSchema.STUDY_PROGRAMS] + link=properties[DocumentSchema.LINK.value], + study_programs=properties[DocumentSchema.STUDY_PROGRAMS.value], + org_id=properties[DocumentSchema.STUDY_PROGRAMS.value] ) self.documents.data.delete_many( - where=Filter.by_property(DocumentSchema.KNOWLEDGE_BASE_ID).equal(kb_id) + where=Filter.by_property(DocumentSchema.KNOWLEDGE_BASE_ID.value).equal(kb_id) ) return metadata else: self.documents.data.delete_many( - where=Filter.by_property(DocumentSchema.KNOWLEDGE_BASE_ID).equal(kb_id) + where=Filter.by_property(DocumentSchema.KNOWLEDGE_BASE_ID.value).equal(kb_id) ) return None except Exception as e: logging.error(f"Error deleting documents: {e}") - def update_documents(self, kb_id: int, document: DatabaseDocumentMetadata): + def update_documents(self, kb_id: str, document: DatabaseDocumentMetadata): try: query_result = self.documents.query.fetch_objects( filters=Filter.by_property(DocumentSchema.KNOWLEDGE_BASE_ID.value).equal(document.id) @@ -465,8 +485,8 @@ def update_documents(self, kb_id: int, document: DatabaseDocumentMetadata): for result in query_result.objects: uuid = result.uuid properties = result.properties - properties[DocumentSchema.LINK] = document.link # Update the link - properties[DocumentSchema.STUDY_PROGRAMS] = document.study_programs # Update the study programs + properties[DocumentSchema.LINK.value] = document.link # Update the link + properties[DocumentSchema.STUDY_PROGRAMS.value] = document.study_programs # Update the study programs # Reinsert the object with the updated properties self.documents.data.update( @@ -496,6 +516,7 @@ def add_sample_question(self, sample_question: DatabaseSampleQuestion): study_programs = sample_question.study_programs question = sample_question.question answer = sample_question.answer + org_id = sample_question.org_id # Add to QA collection in Weaviate embedding = self.model.embed(question) @@ -506,7 +527,8 @@ def add_sample_question(self, sample_question: DatabaseSampleQuestion): QASchema.TOPIC.value: topic, QASchema.STUDY_PROGRAMS.value: study_programs, QASchema.QUESTION.value: question, - QASchema.ANSWER.value: answer + QASchema.ANSWER.value: answer, + QASchema.ORGANISATION_ID.value: org_id }, vector=embedding ) @@ -515,6 +537,46 @@ def add_sample_question(self, sample_question: DatabaseSampleQuestion): logging.error(f"Failed to insert sample question with topic {sample_question.topic}: {e}") raise + def add_sample_questions(self, questions: List[DatabaseSampleQuestion]): + """ + Add multiple sample questions to the QA collection in Weaviate. + Batches the embedding and insertion process for efficiency. + """ + try: + batch_size = 500 + num_questions = len(questions) + logging.info(f"Adding {num_questions} sample questions in batches of {batch_size}") + + for i in range(0, num_questions, batch_size): + question_batch = questions[i:i + batch_size] + + if isinstance(self.model, OllamaModel): + # For Ollama models, embed each question one at a time + embeddings = [self.model.embed(q.question) for q in question_batch] + else: + # For other models, embed in batch + texts = [q.question for q in question_batch] + embeddings = self.model.embed_batch(texts) + + # Insert into the QA collection in a batch + with self.qa_collection.batch.rate_limit(requests_per_minute=600) as batch: + for idx, sq in enumerate(question_batch): + properties = { + QASchema.KNOWLEDGE_BASE_ID.value: sq.id, + QASchema.TOPIC.value: sq.topic, + QASchema.STUDY_PROGRAMS.value: sq.study_programs, + QASchema.QUESTION.value: sq.question, + QASchema.ANSWER.value: sq.answer, + QASchema.ORGANISATION_ID.value: sq.org_id + } + + batch.add_object(properties=properties, vector=embeddings[idx]) + + logging.info(f"Successfully inserted {num_questions} sample questions.") + except Exception as e: + logging.error(f"Failed to insert sample questions: {e}") + raise + def update_sample_question(self, sample_question: DatabaseSampleQuestion): """ Adds a sample question to the QA collection in Weaviate. @@ -526,12 +588,11 @@ def update_sample_question(self, sample_question: DatabaseSampleQuestion): - None """ try: - # Prepare the data entry for insertion - kb_id = sample_question.id topic = sample_question.topic study_programs = sample_question.study_programs question = sample_question.question answer = sample_question.answer + # Add to QA collection in Weaviate embedding = self.model.embed(question) @@ -548,11 +609,10 @@ def update_sample_question(self, sample_question: DatabaseSampleQuestion): for result in query_result.objects: uuid = result.uuid properties = result.properties - properties[QASchema.KNOWLEDGE_BASE_ID.value] = kb_id properties[QASchema.TOPIC.value] = topic properties[QASchema.STUDY_PROGRAMS.value] = study_programs properties[QASchema.QUESTION.value] = question - properties[QASchema.ANSWER].value = answer + properties[QASchema.ANSWER.value] = answer # Reinsert the object with the updated properties self.qa_collection.data.update( @@ -561,12 +621,12 @@ def update_sample_question(self, sample_question: DatabaseSampleQuestion): vector=embedding ) - logging.info(f"Updated title for documents with knowledge_base_id: {kb_id}") + logging.info(f"Updated sample question with knowledge_base_id: {sample_question.id}") except Exception as e: logging.error(f"Failed to update sample question with topic {sample_question.topic}: {e}") raise - def delete_sample_question(self, id: int): + def delete_sample_question(self, id: str): try: self.documents.data.delete_many( where=Filter.by_property(DocumentSchema.KNOWLEDGE_BASE_ID).equal(id) diff --git a/docker/weaviate.local.yml b/docker/weaviate.local.yml new file mode 100644 index 0000000..9366c96 --- /dev/null +++ b/docker/weaviate.local.yml @@ -0,0 +1,22 @@ +--- +services: + weaviate: + command: + - --host + - 0.0.0.0 + - --port + - '8001' + - --scheme + - http + image: cr.weaviate.io/semitechnologies/weaviate:1.25.3 + expose: + - "8001" + - "50051" + ports: + - "8001:8001" + - "50051:50051" + volumes: + - ${WEAVIATE_VOLUME_MOUNT:-./.docker-data/weaviate-data}:/var/lib/weaviate + restart: on-failure:3 + env_file: + - ./weaviate/default.env \ No newline at end of file diff --git a/docker/weaviate.yml b/docker/weaviate.yml index a55993e..9366c96 100644 --- a/docker/weaviate.yml +++ b/docker/weaviate.yml @@ -2,12 +2,12 @@ services: weaviate: command: - - --host - - 0.0.0.0 - - --port - - '8001' - - --scheme - - http + - --host + - 0.0.0.0 + - --port + - '8001' + - --scheme + - http image: cr.weaviate.io/semitechnologies/weaviate:1.25.3 expose: - "8001" @@ -19,6 +19,4 @@ services: - ${WEAVIATE_VOLUME_MOUNT:-./.docker-data/weaviate-data}:/var/lib/weaviate restart: on-failure:3 env_file: - - ./docker/weaviate/default.env - networks: - - angelos-network \ No newline at end of file + - ./weaviate/default.env \ No newline at end of file From 94e2609ed376ccf190b371bfc17ee0fd9cf3e67a Mon Sep 17 00:00:00 2001 From: Nino Richter Date: Mon, 23 Dec 2024 13:40:00 +0100 Subject: [PATCH 3/7] Bug fix in update function --- app/managers/weaviate_manager.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/app/managers/weaviate_manager.py b/app/managers/weaviate_manager.py index 53b94e0..64b3e18 100644 --- a/app/managers/weaviate_manager.py +++ b/app/managers/weaviate_manager.py @@ -456,7 +456,7 @@ def delete_by_kb_id(self, kb_id: str, return_metadata: bool) -> Optional[Databas metadata = DatabaseDocumentMetadata( link=properties[DocumentSchema.LINK.value], study_programs=properties[DocumentSchema.STUDY_PROGRAMS.value], - org_id=properties[DocumentSchema.STUDY_PROGRAMS.value] + org_id=properties[DocumentSchema.ORGANISATION_ID.value] ) self.documents.data.delete_many( where=Filter.by_property(DocumentSchema.KNOWLEDGE_BASE_ID.value).equal(kb_id) @@ -474,7 +474,7 @@ def delete_by_kb_id(self, kb_id: str, return_metadata: bool) -> Optional[Databas def update_documents(self, kb_id: str, document: DatabaseDocumentMetadata): try: query_result = self.documents.query.fetch_objects( - filters=Filter.by_property(DocumentSchema.KNOWLEDGE_BASE_ID.value).equal(document.id) + filters=Filter.by_property(DocumentSchema.KNOWLEDGE_BASE_ID.value).equal(kb_id) ) if not query_result.objects: From 409f583b4ddf9d48e0367cdaaa06145d2b5d0e5a Mon Sep 17 00:00:00 2001 From: Nino Richter Date: Thu, 26 Dec 2024 18:11:15 +0100 Subject: [PATCH 4/7] integrate org id --- app/api/question_router.py | 27 +++++++++++++++------------ app/data/user_requests.py | 1 + app/managers/request_handler.py | 18 ++++++++---------- app/managers/weaviate_manager.py | 25 +++++++++++++++++-------- 4 files changed, 41 insertions(+), 30 deletions(-) diff --git a/app/api/question_router.py b/app/api/question_router.py index 2b8d28d..84e8cf0 100644 --- a/app/api/question_router.py +++ b/app/api/question_router.py @@ -1,6 +1,6 @@ import logging -from fastapi import HTTPException, APIRouter, Depends +from fastapi import HTTPException, APIRouter, Depends, Query from app.data.user_requests import UserChat, UserRequest from app.utils.dependencies import request_handler, auth_handler @@ -46,19 +46,22 @@ async def ask(request: UserRequest): return {"answer": answer} -@question_router.post("/chat", tags=["chatbot"], dependencies=[Depends(auth_handler.verify_token)]) -async def chat(request: UserChat): +@question_router.post("/chat", tags=["chatbot"], dependencies=[Depends(auth_handler.verify_api_key)]) +async def chat( + request: UserChat, + filterByOrg: bool = Query(..., description="Indicates whether to filter context by organization") +): messages = request.messages + org_id = request.orgId + if not messages: raise HTTPException(status_code=400, detail="No messages have been provided") + + answer = request_handler.handle_chat( + messages, + study_program=request.study_program, + org_id=org_id, + filter_by_org=filterByOrg + ) - last_message = messages[-1].message - if len(last_message) > config.MAX_MESSAGE_LENGTH: - raise HTTPException( - status_code=400, - detail=f"Message length exceeds the allowed limit of {config.MAX_MESSAGE_LENGTH} characters" - ) - - logging.info(f"Received messages.") - answer = request_handler.handle_chat(messages, study_program=request.study_program) return {"answer": answer} diff --git a/app/data/user_requests.py b/app/data/user_requests.py index 27df300..bd9677c 100644 --- a/app/data/user_requests.py +++ b/app/data/user_requests.py @@ -8,6 +8,7 @@ class ChatMessage(BaseModel): class UserChat(BaseModel): messages: List[ChatMessage] study_program: Optional[str] = None + orgId: int class SampleQuestion(BaseModel): topic: str diff --git a/app/managers/request_handler.py b/app/managers/request_handler.py index 26f9b00..a24cb25 100644 --- a/app/managers/request_handler.py +++ b/app/managers/request_handler.py @@ -51,7 +51,7 @@ def handle_question_test_mode(self, question: str, classification: str, language answer, tokens = self.model.complete_with_tokens(messages) return answer, tokens, general_context_list, specific_context_list - def handle_chat(self, messages: List[ChatMessage], study_program: str): + def handle_chat(self, messages: List[ChatMessage], study_program: str, org_id: int, filter_by_org: bool): """Handles the question by fetching relevant documents and generating an answer.""" # The last message is the user's current question last_message = messages[-1].message @@ -73,7 +73,7 @@ def handle_chat(self, messages: List[ChatMessage], study_program: str): # Retrieve general context based on the last message general_context_last = self.weaviate_manager.get_relevant_context( - general_query, "general", lang, limit=context_limit, top_n=context_top_n + general_query, "general", lang, limit=context_limit, top_n=context_top_n, org_id=org_id, filter_by_org=filter_by_org ) general_context = general_context_last @@ -84,7 +84,7 @@ def handle_chat(self, messages: List[ChatMessage], study_program: str): # Retrieve general context using the chat history general_context_history = self.weaviate_manager.get_relevant_context( - chat_query, "general", lang, limit=4, top_n=2 + chat_query, "general", lang, limit=4, top_n=2, org_id=org_id, filter_by_org=filter_by_org ) # Combine the contexts general_context = f"{general_context_last}\n-----\n{general_context_history}" @@ -94,24 +94,24 @@ def handle_chat(self, messages: List[ChatMessage], study_program: str): if study_program and study_program.lower() != "general": # Retrieve specific context based on the last message specific_context_last = self.weaviate_manager.get_relevant_context( - last_message, study_program, lang, limit=context_limit, top_n=context_top_n + last_message, study_program, lang, limit=context_limit, top_n=context_top_n, org_id=org_id, filter_by_org=filter_by_org ) specific_context = specific_context_last if get_history_context: # Retrieve specific context using the chat history specific_context_history = self.weaviate_manager.get_relevant_context( - chat_query, study_program, lang, limit=4, top_n=2 + chat_query, study_program, lang, limit=4, top_n=2, org_id=org_id, filter_by_org=filter_by_org ) # Combine the contexts specific_context = f"{specific_context_last}\n-----\n{specific_context_history}" # Retrieve and format sample questions sample_questions = self.weaviate_manager.get_relevant_sample_questions( - question=last_message, language=lang + question=last_message, language=lang, org_id=org_id ) sample_questions_formatted = self.prompt_manager.format_sample_questions( - sample_questions, lang + sample_questions=sample_questions, language=lang ) # Format chat history (excluding the last message) @@ -133,6 +133,4 @@ def handle_chat(self, messages: List[ChatMessage], study_program: str): # Generate and return the answer return self.model.complete(messages_to_model) - - def add_document(self, question: str, classification: str): - return self.weaviate_manager.add_document(question, classification) + diff --git a/app/managers/weaviate_manager.py b/app/managers/weaviate_manager.py index 64b3e18..08c15f7 100644 --- a/app/managers/weaviate_manager.py +++ b/app/managers/weaviate_manager.py @@ -229,8 +229,8 @@ def initialize_qa_schema(self) -> Collection: except Exception as e: logging.error(f"Error creating schema for {collection_name}: {e}") - def get_relevant_context(self, question: str, study_program: str, language: str, - test_mode: bool = False, limit = 10, top_n = 5) -> Union[str, Tuple[str, List[str]]]: + def get_relevant_context(self, question: str, study_program: str, language: str, org_id: Optional[int], test_mode: bool = False, + limit = 10, top_n = 5, filter_by_org: bool = False) -> Union[str, Tuple[str, List[str]]]: """ Retrieve relevant documents based on the question embedding and study program. Optionally returns both the concatenated context and the sorted context list for testing purposes. @@ -250,6 +250,16 @@ def get_relevant_context(self, question: str, study_program: str, language: str, # Define the number of documents to retrieve min_relevance_score = 0.25 + # Define filter + if filter_by_org and org_id is not None: + filters=Filter.all_of([ + Filter.by_property(DocumentSchema.STUDY_PROGRAMS.value).contains_any([study_program]), + Filter.by_property(DocumentSchema.ORGANISATION_ID.value).equal(org_id), + ]) + else: + filters=Filter.by_property(DocumentSchema.STUDY_PROGRAMS.value).contains_any([study_program]) + + # If getting general context, adjust the parameters if study_program.lower() != "general": limit = 10 min_relevance_score = 0.15 @@ -263,9 +273,7 @@ def get_relevant_context(self, question: str, study_program: str, language: str, # Perform the vector-based query with filters query_result = self.documents.query.near_vector( near_vector=question_embedding, - filters=Filter.all_of([ - Filter.by_property(DocumentSchema.STUDY_PROGRAMS.value).contains_any([study_program]) - ]), + filters=filters, limit=limit, # include_vector=True, return_metadata=wvc.query.MetadataQuery(certainty=True, score=True, distance=True) @@ -312,7 +320,7 @@ def get_relevant_context(self, question: str, study_program: str, language: str, # logging.error("Traceback:\n%s", tb) return "" if not test_mode else ("", []) - def get_relevant_sample_questions(self, question: str, language: str) -> List[SampleQuestion]: + def get_relevant_sample_questions(self, question: str, language: str, org_id: int) -> List[SampleQuestion]: """ Retrieve relevant sample questions and answers based on the question embedding. @@ -335,6 +343,7 @@ def get_relevant_sample_questions(self, question: str, language: str) -> List[Sa query_result = self.qa_collection.query.near_vector( near_vector=question_embedding, limit=limit, + filters=Filter.by_property(DocumentSchema.ORGANISATION_ID.value).equal(org_id), return_metadata=wvc.query.MetadataQuery(certainty=True, score=True, distance=True) ) @@ -597,12 +606,12 @@ def update_sample_question(self, sample_question: DatabaseSampleQuestion): # Add to QA collection in Weaviate embedding = self.model.embed(question) - query_result = self.documents.query.fetch_objects( + query_result = self.qa_collection.query.fetch_objects( filters=Filter.by_property(QASchema.KNOWLEDGE_BASE_ID.value).equal(sample_question.id) ) if not query_result.objects: - logging.info(f"No documents found with knowledge_base_id: {sample_question.id}") + logging.info(f"No sample question found with knowledge_base_id: {sample_question.id}") return # Iterate through the results and update the properties From e4b5fd240cc6e3660632ebd1ecd538653b1f62d8 Mon Sep 17 00:00:00 2001 From: Nino Richter Date: Fri, 27 Dec 2024 15:20:18 +0100 Subject: [PATCH 5/7] Make final adjustments to make new chat logic work --- app/data/user_requests.py | 2 +- app/managers/weaviate_manager.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/app/data/user_requests.py b/app/data/user_requests.py index bd9677c..b365619 100644 --- a/app/data/user_requests.py +++ b/app/data/user_requests.py @@ -14,7 +14,7 @@ class SampleQuestion(BaseModel): topic: str question: str answer: str - study_program: str + study_program: List[str] class WebsiteContent(BaseModel): type: str diff --git a/app/managers/weaviate_manager.py b/app/managers/weaviate_manager.py index 08c15f7..01491bd 100644 --- a/app/managers/weaviate_manager.py +++ b/app/managers/weaviate_manager.py @@ -358,6 +358,7 @@ def get_relevant_sample_questions(self, question: str, language: str, org_id: in # Rerank the sample questions using the reranker context_list = [sq.question for sq in sample_questions] + logging.info(f" SAMPLE QUESTION CONTEXT LIST: {context_list}") sorted_questions = self.reranker.rerank_with_cohere( context_list=context_list, query=question, language=language, top_n=top_n, min_relevance_score=min_relevance_score From c02780ab0837fd0e113f9984dd0c2dfdb325da8c Mon Sep 17 00:00:00 2001 From: Nino Richter Date: Fri, 3 Jan 2025 17:21:25 +0100 Subject: [PATCH 6/7] Clean up, new context retrieval logic for mail --- app/api/admin_router.py | 30 +---- app/api/question_router.py | 7 +- app/data/database_requests.py | 6 + app/data/user_requests.py | 13 +- app/injestion/document_loader.py | 153 ---------------------- app/injestion/document_splitter.py | 77 +---------- app/injestion/vector_store_initializer.py | 56 -------- app/managers/request_handler.py | 13 +- app/managers/weaviate_manager.py | 15 +-- 9 files changed, 30 insertions(+), 340 deletions(-) delete mode 100644 app/injestion/document_loader.py delete mode 100644 app/injestion/vector_store_initializer.py diff --git a/app/api/admin_router.py b/app/api/admin_router.py index 09e815a..3132617 100644 --- a/app/api/admin_router.py +++ b/app/api/admin_router.py @@ -1,9 +1,8 @@ import logging -from fastapi import HTTPException, APIRouter, status, Response, Depends +from fastapi import APIRouter, Depends -from app.utils.dependencies import request_handler, auth_handler, weaviate_manager, model -from app.injestion.vector_store_initializer import initialize_vectorstores +from app.utils.dependencies import auth_handler, model from app.utils.environment import config from app.data.user_requests import UserChat, UserRequest @@ -11,31 +10,6 @@ admin_router = APIRouter(prefix="/api/admin", tags=["settings", "admin"], dependencies=[Depends(auth_handler.verify_token)]) -# TODO: Remove -@admin_router.get("/initSchema", - status_code=status.HTTP_202_ACCEPTED, ) -async def initializeDb(): - initialize_vectorstores(config.KNOWLEDGE_BASE_FOLDER, config.QA_FOLDER, weaviate_manager) - return - - -# TODO: Remove -@admin_router.post("/document") -async def add_document(request: UserRequest): - question = request.message - classification = request.study_program - if not question or not classification: - raise HTTPException(status_code=400, detail="No question or classification provided") - - logging.info(f"Received document: {question} with classification: {classification}") - try: - request_handler.add_document(question, classification) - return Response(status_code=status.HTTP_200_OK) - - except Exception as e: - logging.error(f"Failed to add document: {e}") - raise HTTPException(status_code=500, detail="Failed to add document") - @admin_router.get("/ping") async def ping(): logging.info(config.GPU_URL) diff --git a/app/api/question_router.py b/app/api/question_router.py index 84e8cf0..386f405 100644 --- a/app/api/question_router.py +++ b/app/api/question_router.py @@ -14,6 +14,8 @@ async def ask(request: UserRequest): question = request.message classification = request.study_program.lower() language = request.language.lower() + org_id = request.org_id + if not question or not classification: raise HTTPException(status_code=400, detail="No question or classification provided") @@ -28,7 +30,8 @@ async def ask(request: UserRequest): if config.TEST_MODE: answer, used_tokens, general_context, specific_context = request_handler.handle_question_test_mode(question, classification, - language) + language, + org_id=org_id) if language == "german": answer += "\n\n**Diese Antwort wurde automatisch generiert.**" else: @@ -37,7 +40,7 @@ async def ask(request: UserRequest): return {"answer": answer, "used_tokens": used_tokens, "general_context": general_context, "specific_context": specific_context} else: - answer = request_handler.handle_question(question, classification, language) + answer = request_handler.handle_question(question, classification, language, org_id=org_id) if language == "german": answer += "\n\n**Diese Antwort wurde automatisch generiert.**" else: diff --git a/app/data/database_requests.py b/app/data/database_requests.py index d3c506e..0332882 100644 --- a/app/data/database_requests.py +++ b/app/data/database_requests.py @@ -20,3 +20,9 @@ class DatabaseSampleQuestion(BaseModel): answer: str study_programs: List[str] org_id: int + +class SampleQuestion(BaseModel): + topic: str + question: str + answer: str + study_programs: List[str] diff --git a/app/data/user_requests.py b/app/data/user_requests.py index b365619..35540c4 100644 --- a/app/data/user_requests.py +++ b/app/data/user_requests.py @@ -9,20 +9,9 @@ class UserChat(BaseModel): messages: List[ChatMessage] study_program: Optional[str] = None orgId: int - -class SampleQuestion(BaseModel): - topic: str - question: str - answer: str - study_program: List[str] - -class WebsiteContent(BaseModel): - type: str - content: str - link: str - study_program: str class UserRequest(BaseModel): + org_id: int message: str study_program: str language: str diff --git a/app/injestion/document_loader.py b/app/injestion/document_loader.py deleted file mode 100644 index d086cc0..0000000 --- a/app/injestion/document_loader.py +++ /dev/null @@ -1,153 +0,0 @@ -import logging -import os -import json -from typing import List, Dict, Any - -from langchain.docstore.document import Document -from langchain.text_splitter import RecursiveCharacterTextSplitter -from langchain_community.document_loaders import PyPDFLoader -from app.data.user_requests import SampleQuestion, WebsiteContent - - -# TODO: Remove -def load_pdf_documents_from_folder(base_folder: str, study_program: str = "general") -> List[Document]: - """ - Traverse the base folder and all its subfolders to find and load PDF files into Document objects. - Attach the study program (subfolder name) to the Document metadata. - - Args: - - base_folder: Path to the base folder containing PDF files and subfolders. - - study_program: Name of the study program (subfolder name). - - Returns: - - A list of Document objects parsed from the PDF files. - """ - documents: List[Document] = [] - - # Go through each subdirectory and file in the base folder - for subdir in os.listdir(base_folder): - subfolder_path = os.path.join(base_folder, subdir) - - # If the current path is a directory, recurse into it - if os.path.isdir(subfolder_path): - # Use the folder name as the study program when recursing into the subfolder - new_study_program = subdir - documents.extend(load_pdf_documents_from_folder(subfolder_path, new_study_program)) - - # If it's a file and ends with .pdf, process it - elif subdir.lower().endswith(".pdf"): - file_path = os.path.join(base_folder, subdir) - try: - loader = PyPDFLoader(file_path) - loaded_docs = loader.load() - - # Attach the study program name to each document's metadata - for doc in loaded_docs: - doc.metadata["study_program"] = study_program - doc.metadata["source"] = file_path - - documents.extend(loaded_docs) - logging.info(f"Loaded {len(loaded_docs)} pages from .pdf file: {file_path} under study program: {study_program}") - except Exception as e: - logging.error(f"Failed to load PDF file {file_path}: {e}") - - return documents - - -# TODO: Remove -def load_qa_pairs_from_folder(qa_folder: str) -> List[SampleQuestion]: - """ - Reads JSON files from the qa_folder and extracts QA pairs. - - Args: - - qa_folder: Path to the folder containing JSON QA files. - - Returns: - - List of QA pairs, each represented as an instance of SampleQuestion with 'topic', 'question', and 'answer'. - """ - qa_pairs: List[SampleQuestion] = [] - - for file_name in os.listdir(qa_folder): - file_path = os.path.join(qa_folder, file_name) - if file_name.endswith(".json"): - try: - with open(file_path, 'r', encoding='utf-8') as file: - data: Dict[str, Any] = json.load(file) - - # Default values and type validation - topic = data.get("topic", "Unknown Topic") - study_program = data.get("study_program", "general") - correspondence = data.get("correspondence", []) - - if not isinstance(correspondence, list): - logging.warning(f"Unexpected format in file: {file_path}") - continue - - question, answer = None, None - for entry in correspondence: - if isinstance(entry, dict): - sender = entry.get("sender") - message = entry.get("message", "") - order_key = entry.get("orderKey") - if sender == "STUDENT" and order_key == 0: - question = message - elif sender == "AA" and order_key == 1: - answer = message - - if question and answer: - # Append a new SampleQuestion object to the list - qa_pairs.append(SampleQuestion( - topic=topic, - question=question, - answer=answer, - study_program=study_program - )) - else: - logging.error(f"Failed to parse JSON file {file_name}") - except (json.JSONDecodeError, FileNotFoundError) as e: - logging.error(f"Failed to load file {file_name}: {e}") - - logging.info(f"Loaded {len(qa_pairs)} QA pairs from folder: {qa_folder}") - return qa_pairs - - -# TODO: Remove -def load_website_content_from_folder(base_folder: str) -> List[WebsiteContent]: - """ - Traverse through the base folder and all subfolders to find and load JSON files - and convert them into WebsiteContent objects. - - Args: - - base_folder: Path to the base folder containing JSON files and subfolders. - - Returns: - - A list of WebsiteContent objects parsed from the JSON files. - """ - website_contents: List[WebsiteContent] = [] - - # Traverse the base folder and all its subfolders - for root, _, files in os.walk(base_folder): - for file in files: - file_path = os.path.join(root, file) - if file.lower().endswith(".json"): - try: - with open(file_path, 'r', encoding='utf-8') as f: - data = json.load(f) - - # Validate if necessary fields are present - if "type" in data and "content" in data and "link" in data and "study_program" in data: - content_object = WebsiteContent( - type=data["type"], - content=data["content"], - link=data["link"], - study_program=data["study_program"] - ) - website_contents.append(content_object) - logging.info(f"Loaded WebsiteContent from {file_path}") - else: - logging.warning(f"Missing required fields in JSON file: {file_path}") - except (json.JSONDecodeError, FileNotFoundError) as e: - logging.error(f"Failed to load file {file_path}: {e}") - - logging.info(f"Total WebsiteContent objects loaded: {len(website_contents)}") - return website_contents \ No newline at end of file diff --git a/app/injestion/document_splitter.py b/app/injestion/document_splitter.py index 3adf8ae..47135ee 100644 --- a/app/injestion/document_splitter.py +++ b/app/injestion/document_splitter.py @@ -3,9 +3,7 @@ import json from typing import List, Dict, Any -from langchain.docstore.document import Document from langchain.text_splitter import RecursiveCharacterTextSplitter -from app.data.user_requests import WebsiteContent class DocumentSplitter: def split_cit_content(self, content: str): @@ -25,77 +23,4 @@ def split_tum_content(self, content: str, chunk_size: int = 1200, chunk_overlap: def split_pdf_document(self, content: str, chunk_size: int = 1200, chunk_overlap: int = 200): text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap) chunks = text_splitter.split_text(content) - return chunks - - - # TODO: Remove - @staticmethod - def split_pdf_documents(pdf_documents: List[Document], chunk_size: int = 1200, chunk_overlap: int = 200) -> List[Document]: - """ - Split PDF Document objects using RecursiveCharacterTextSplitter into smaller chunks. - """ - pdf_chunks = [] - text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap) - - for doc in pdf_documents: - content = doc.page_content - chunks = text_splitter.split_text(content) - - # Create Document chunks and preserve metadata - for chunk in chunks: - pdf_chunks.append(Document(page_content=chunk, metadata=doc.metadata)) - logging.info(f"Split PDF document into {len(chunks)} chunks.") - - logging.info(f"Total PDF chunks: {len(pdf_chunks)}") - return pdf_chunks - - # TODO: Remove - @staticmethod - def split_tum_documents(tum_documents: List[WebsiteContent], chunk_size: int = 1200, chunk_overlap: int = 200) -> List[WebsiteContent]: - """ - Split TUM WebsiteContent documents using RecursiveCharacterTextSplitter into smaller chunks. - """ - tum_chunks = [] - text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap) - - for doc in tum_documents: - content = doc.content - chunks = text_splitter.split_text(content) - - # Create WebsiteContent chunks and preserve metadata - for chunk in chunks: - tum_chunks.append(WebsiteContent( - type=doc.type, - content=chunk, - link=doc.link, - study_program=doc.study_program - )) - logging.info(f"Split TUM document into {len(chunks)} chunks.") - - logging.info(f"Total TUM chunks: {len(tum_chunks)}") - return tum_chunks - - # TODO: Remove - @staticmethod - def split_cit_documents(cit_documents: List[WebsiteContent]) -> List[WebsiteContent]: - """ - Split CIT WebsiteContent documents into smaller chunks based on a predefined separator. - """ - cit_chunks = [] - - for doc in cit_documents: - sections = doc.content.split('----------------------------------------') - for section in sections: - section = section.strip() - if section: - # Create smaller WebsiteContent chunks and preserve metadata - cit_chunks.append(WebsiteContent( - type=doc.type, - content=section, - link=doc.link, - study_program=doc.study_program - )) - logging.info(f"Split CIT document into {len(sections)} sections.") - - logging.info(f"Total TUM chunks: {len(cit_chunks)}") - return cit_chunks \ No newline at end of file + return chunks \ No newline at end of file diff --git a/app/injestion/vector_store_initializer.py b/app/injestion/vector_store_initializer.py deleted file mode 100644 index a17696b..0000000 --- a/app/injestion/vector_store_initializer.py +++ /dev/null @@ -1,56 +0,0 @@ -import logging -import os -from typing import List -from langchain.docstore.document import Document - -from app.injestion.document_loader import load_website_content_from_folder, load_pdf_documents_from_folder, load_qa_pairs_from_folder -from app.injestion.document_splitter import DocumentSplitter -from app.managers.weaviate_manager import WeaviateManager -from app.utils.environment import config -from app.data.user_requests import SampleQuestion, WebsiteContent - - -# TODO: Remove -def initialize_vectorstores(base_folder: str, qa_folder: str, weaviate_manager: WeaviateManager): - """ - Initializes vector stores by adding documents to Weaviate with their embeddings. - - Args: - - base_folder: Path to the base folder containing general and program-specific documents. - - weaviate_manager: Instance of WeaviateManager to manage embeddings and document insertion. - - Returns: - - None (documents and embeddings are inserted into Weaviate). - """ - delete_before_init = config.DELETE_BEFORE_INIT.lower() == "true" - - # Delete existing data if the DELETE_BEFORE_INIT is set to true - if delete_before_init: - logging.warning("Deleting existing data before initialization...") - weaviate_manager.delete_collections() - weaviate_manager.initialize_schema() - weaviate_manager.initialize_qa_schema() - else: - logging.info("Skipping data deletion...") - - logging.info("Initializing vector stores...") - - # Process QA pairs - qa_pairs: List[SampleQuestion] = load_qa_pairs_from_folder(qa_folder) - weaviate_manager.add_qa_pairs(qa_pairs) - - # PDF files - pdf_docs: List[Document] = load_pdf_documents_from_folder(base_folder) - pdf_docs_split: List[Document] = DocumentSplitter.split_pdf_documents(pdf_documents=pdf_docs) - weaviate_manager.add_documents(pdf_docs_split) - - # Website content - website_content: List[WebsiteContent] = load_website_content_from_folder(base_folder) - cit_content = [content for content in website_content if content.type == "CIT"] - tum_content = [content for content in website_content if content.type == "TUM"] - cit_chunks = DocumentSplitter.split_cit_documents(cit_content) - tum_chunks = DocumentSplitter.split_tum_documents(tum_content) - split_website_content: List[WebsiteContent] = cit_chunks + tum_chunks - weaviate_manager.add_website_content(split_website_content) - - logging.info("Vector stores initialized and documents saved to Weaviate.") diff --git a/app/managers/request_handler.py b/app/managers/request_handler.py index a24cb25..186ec78 100644 --- a/app/managers/request_handler.py +++ b/app/managers/request_handler.py @@ -14,15 +14,16 @@ def __init__(self, weaviate_manager: WeaviateManager, model: BaseModelClient, pr self.model = model self.prompt_manager = prompt_manager - def handle_question(self, question: str, classification: str, language: str): + def handle_question(self, question: str, classification: str, language: str, org_id: int): """Handles the question by fetching relevant documents and generating an answer.""" general_context = self.weaviate_manager.get_relevant_context(question=question, study_program="general", - language=language) + language=language, org_id=org_id) specific_context = None if classification != "general": specific_context = self.weaviate_manager.get_relevant_context(question=question, study_program=classification, - language=language) + language=language, + org_id=org_id) sample_questions = self.weaviate_manager.get_relevant_sample_questions(question=question, language=language) sample_questions_formatted = self.prompt_manager.format_sample_questions(sample_questions, language) messages = self.prompt_manager.create_messages(general_context, specific_context, sample_questions_formatted, @@ -30,21 +31,23 @@ def handle_question(self, question: str, classification: str, language: str): return self.model.complete(messages) - def handle_question_test_mode(self, question: str, classification: str, language: str): + def handle_question_test_mode(self, question: str, classification: str, language: str, org_id: int): """Handles the question by fetching relevant documents and generating an answer.""" general_context, general_context_list = self.weaviate_manager.get_relevant_context(question=question, study_program="general", language=language, + org_id=org_id, test_mode=True) specific_context = None if classification != "general": specific_context, specific_context_list = self.weaviate_manager.get_relevant_context(question=question, study_program=classification, language=language, + org_id=org_id, test_mode=True) else: specific_context_list = [] - sample_questions = self.weaviate_manager.get_relevant_sample_questions(question=question, language=language) + sample_questions = self.weaviate_manager.get_relevant_sample_questions(question=question, language=language, org_id=org_id) sample_questions_formatted = self.prompt_manager.format_sample_questions(sample_questions, language) messages = self.prompt_manager.create_messages(general_context, specific_context, sample_questions_formatted, question, language, classification) diff --git a/app/managers/weaviate_manager.py b/app/managers/weaviate_manager.py index 01491bd..024d47b 100644 --- a/app/managers/weaviate_manager.py +++ b/app/managers/weaviate_manager.py @@ -14,8 +14,7 @@ from app.models.ollama_model import OllamaModel from app.retrieval_strategies.reranker import Reranker from app.utils.environment import config -from app.data.user_requests import SampleQuestion, WebsiteContent -from app.data.database_requests import DatabaseDocument, DatabaseSampleQuestion, DatabaseDocumentMetadata +from app.data.database_requests import DatabaseDocument, DatabaseSampleQuestion, DatabaseDocumentMetadata, SampleQuestion class DocumentSchema(Enum): @@ -230,7 +229,7 @@ def initialize_qa_schema(self) -> Collection: logging.error(f"Error creating schema for {collection_name}: {e}") def get_relevant_context(self, question: str, study_program: str, language: str, org_id: Optional[int], test_mode: bool = False, - limit = 10, top_n = 5, filter_by_org: bool = False) -> Union[str, Tuple[str, List[str]]]: + limit = 10, top_n = 5, filter_by_org: bool = True) -> Union[str, Tuple[str, List[str]]]: """ Retrieve relevant documents based on the question embedding and study program. Optionally returns both the concatenated context and the sorted context list for testing purposes. @@ -250,6 +249,9 @@ def get_relevant_context(self, question: str, study_program: str, language: str, # Define the number of documents to retrieve min_relevance_score = 0.25 + # Normalize the study program name + study_program = WeaviateManager.normalize_study_program_name(study_program) + # Define filter if filter_by_org and org_id is not None: filters=Filter.all_of([ @@ -267,9 +269,6 @@ def get_relevant_context(self, question: str, study_program: str, language: str, # Embed the question using the embedding model question_embedding = self.model.embed(question) - # Normalize the study program name and calculate its length - study_program = WeaviateManager.normalize_study_program_name(study_program) - # Perform the vector-based query with filters query_result = self.documents.query.near_vector( near_vector=question_embedding, @@ -353,8 +352,8 @@ def get_relevant_sample_questions(self, question: str, language: str, org_id: in topic = result.properties.get(QASchema.TOPIC.value, "") retrieved_question = result.properties.get(QASchema.QUESTION.value, "") answer = result.properties.get(QASchema.ANSWER.value, "") - study_program = result.properties.get(QASchema.STUDY_PROGRAMS, []) - sample_questions.append(SampleQuestion(topic=topic, question=retrieved_question, answer=answer, study_program=study_program)) + study_programs = result.properties.get(QASchema.STUDY_PROGRAMS, []) + sample_questions.append(SampleQuestion(topic=topic, question=retrieved_question, answer=answer, study_programs=study_programs)) # Rerank the sample questions using the reranker context_list = [sq.question for sq in sample_questions] From 6ce9a09d38272739330cf6b0ec253a47e6bb8269 Mon Sep 17 00:00:00 2001 From: Nino Richter Date: Mon, 6 Jan 2025 10:40:24 +0100 Subject: [PATCH 7/7] Changes for deployment and clean up --- .github/workflows/deploy.yml | 8 +++----- app/api/admin_router.py | 1 - app/api/auth_router.py | 19 ------------------- app/main.py | 2 -- app/managers/auth_handler.py | 28 +--------------------------- app/utils/dependencies.py | 2 +- app/utils/environment.py | 7 ------- docker-compose.yml | 13 +++---------- docker/weaviate.yml | 16 +++++++++------- 9 files changed, 17 insertions(+), 79 deletions(-) delete mode 100644 app/api/auth_router.py diff --git a/.github/workflows/deploy.yml b/.github/workflows/deploy.yml index 323bb4e..dd5b34a 100644 --- a/.github/workflows/deploy.yml +++ b/.github/workflows/deploy.yml @@ -106,8 +106,10 @@ jobs: proxy_key: ${{ secrets.DEPLOYMENT_GATEWAY_SSH_KEY }} proxy_port: ${{ vars.DEPLOYMENT_GATEWAY_PORT }} script: | - rm /home/${{ vars.VM_USERNAME }}/${{ github.repository }}/.env.prod + rm /home/${{ vars.VM_USERNAME }}/${{ github.repository }}/.env.prod || true touch /home/${{ vars.VM_USERNAME }}/${{ github.repository }}/.env.prod + + # Add relevant environment variables echo "WEAVIATE_URL=${{ vars.WEAVIATE_URL }}" >> /home/${{ vars.VM_USERNAME }}/${{ github.repository }}/.env.prod echo "WEAVIATE_PORT=${{ vars.WEAVIATE_PORT }}" >> /home/${{ vars.VM_USERNAME }}/${{ github.repository }}/.env.prod echo "KNOWLEDGE_BASE_FOLDER=${{ vars.KNOWLEDGE_BASE_FOLDER }}" >> /home/${{ vars.VM_USERNAME }}/${{ github.repository }}/.env.prod @@ -133,10 +135,6 @@ jobs: echo "COHERE_API_KEY_MULTI=${{ secrets.COHERE_API_KEY_MULTI }}" >> /home/${{ vars.VM_USERNAME }}/${{ github.repository }}/.env.prod echo "COHERE_API_KEY_EN=${{ secrets.COHERE_API_KEY_EN }}" >> /home/${{ vars.VM_USERNAME }}/${{ github.repository }}/.env.prod echo "ANGELOS_APP_API_KEY=${{ secrets.ANGELOS_APP_API_KEY }}" >> /home/${{ vars.VM_USERNAME }}/${{ github.repository }}/.env.prod - echo "API_ENDPOINT_KEY=${{ secrets.API_ENDPOINT_KEY }}" >> /home/${{ vars.VM_USERNAME }}/${{ github.repository }}/.env.prod - echo "EXPECTED_USERNAME=${{ secrets.EXPECTED_USERNAME }}" >> /home/${{ vars.VM_USERNAME }}/${{ github.repository }}/.env.prod - echo "EXPECTED_PASSWORD=${{ secrets.EXPECTED_PASSWORD }}" >> /home/${{ vars.VM_USERNAME }}/${{ github.repository }}/.env.prod - echo "WITHOUT_USER_LOGIN=${{ vars.WITHOUT_USER_LOGIN }}" >> /home/${{ vars.VM_USERNAME }}/${{ github.repository }}/.env.prod - name: SSH to VM and Execute Docker-Compose Up diff --git a/app/api/admin_router.py b/app/api/admin_router.py index 3132617..d58367a 100644 --- a/app/api/admin_router.py +++ b/app/api/admin_router.py @@ -4,7 +4,6 @@ from app.utils.dependencies import auth_handler, model from app.utils.environment import config -from app.data.user_requests import UserChat, UserRequest admin_router = APIRouter(prefix="/api/admin", tags=["settings", "admin"], diff --git a/app/api/auth_router.py b/app/api/auth_router.py deleted file mode 100644 index 3e9268a..0000000 --- a/app/api/auth_router.py +++ /dev/null @@ -1,19 +0,0 @@ -import logging - -from fastapi import HTTPException, APIRouter, status, Response, Depends - -from app.managers.auth_handler import LoginRequest -from app.utils.dependencies import auth_handler -from app.utils.environment import config - -auth_router = APIRouter(prefix="/api", tags=["authorization"], dependencies=[Depends(auth_handler.verify_api_key)]) - -@auth_router.post("/token") -async def login(login_request: LoginRequest): - if config.WITHOUT_USER_LOGIN == "true" or ( - login_request.username == config.EXPECTED_USERNAME and login_request.password == config.EXPECTED_PASSWORD): - token_data = {"sub": "angular_app"} - access_token = auth_handler.create_access_token(data=token_data) - return {"access_token": access_token, "token_type": "bearer"} - else: - raise HTTPException(status_code=401, detail="Invalid username or password") \ No newline at end of file diff --git a/app/main.py b/app/main.py index 7024f85..2f6026a 100644 --- a/app/main.py +++ b/app/main.py @@ -7,7 +7,6 @@ from app.api.question_router import question_router from app.api.admin_router import admin_router -from app.api.auth_router import auth_router from app.api.knowledge_router import knowledge_router from contextlib import asynccontextmanager @@ -51,7 +50,6 @@ async def validation_exception_handler(request: Request, exc: RequestValidationE app.include_router(router=question_router) app.include_router(router=admin_router) -app.include_router(router=auth_router) app.include_router(router=knowledge_router) if __name__ == "__main__": diff --git a/app/managers/auth_handler.py b/app/managers/auth_handler.py index 7375ee2..a598553 100644 --- a/app/managers/auth_handler.py +++ b/app/managers/auth_handler.py @@ -12,36 +12,10 @@ class LoginRequest(BaseModel): class AuthHandler: - def __init__(self, angelos_api_key: str, secret_key: str, algorithm: str, access_token_expires_minutes: float): + def __init__(self, angelos_api_key: str): self.angelos_api_key = angelos_api_key - self.secret_key = secret_key - self.algorithm = algorithm - self.access_token_expires_minutes = access_token_expires_minutes - - def create_access_token(self, data: dict): - """ - Generates a JWT token with an expiration time. - """ - to_encode = data.copy() - expire = datetime.now(timezone.utc) + timedelta(minutes=self.access_token_expires_minutes) - to_encode.update({"exp": expire}) - return jwt.encode(to_encode, self.secret_key, algorithm=self.algorithm) async def verify_api_key(self, x_api_key: str = Header(None)): if x_api_key != self.angelos_api_key: logging.error("Unauthorized access") raise HTTPException(status_code=403, detail="Unauthorized access") - - async def verify_token(self, authorization: str = Header(...)): - """ - Dependency to validate the JWT token in the Authorization header. - """ - try: - token = authorization.split(" ")[1] - payload = jwt.decode(token, self.secret_key, algorithms=[self.algorithm]) - except jwt.ExpiredSignatureError: - logging.error("Token has expired") - raise HTTPException(status_code=403, detail="Token has expired") - except jwt.PyJWTError: - logging.error("Invalid token") - raise HTTPException(status_code=403, detail="Invalid token") diff --git a/app/utils/dependencies.py b/app/utils/dependencies.py index 3ef033a..a57aedc 100644 --- a/app/utils/dependencies.py +++ b/app/utils/dependencies.py @@ -15,7 +15,7 @@ prompt_manager = PromptManager() document_splitter = DocumentSplitter() request_handler = RequestHandler(weaviate_manager=weaviate_manager, model=model, prompt_manager=prompt_manager) -auth_handler = AuthHandler(angelos_api_key=config.ANGELOS_APP_API_KEY, secret_key=config.API_ENDPOINT_KEY, algorithm=config.ALGORITHM, access_token_expires_minutes=config.ACCESS_TOKEN_EXPIRE_MINUTES) +auth_handler = AuthHandler(angelos_api_key=config.ANGELOS_APP_API_KEY) injestion_handler = InjestionHandler(weaviate_manager=weaviate_manager, document_splitter=document_splitter) # Provide a shutdown mechanism for the model diff --git a/app/utils/environment.py b/app/utils/environment.py index 7250e4e..0ff12ba 100644 --- a/app/utils/environment.py +++ b/app/utils/environment.py @@ -40,14 +40,7 @@ class Config: COHERE_API_KEY_MULTI = os.getenv("COHERE_API_KEY_MULTI") COHERE_API_KEY_EN = os.getenv("COHERE_API_KEY_EN") # Safeguard - API_ENDPOINT_KEY = os.getenv("API_ENDPOINT_KEY") ANGELOS_APP_API_KEY = os.getenv("ANGELOS_APP_API_KEY") - ALGORITHM = "HS256" - ACCESS_TOKEN_EXPIRE_MINUTES = 60 - MAX_MESSAGE_LENGTH = 3000 - EXPECTED_USERNAME = os.getenv("EXPECTED_USERNAME") - EXPECTED_PASSWORD = os.getenv("EXPECTED_PASSWORD") - WITHOUT_USER_LOGIN = os.getenv("WITHOUT_USER_LOGIN") config = Config() diff --git a/docker-compose.yml b/docker-compose.yml index 8f0fd0a..50a0568 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -8,7 +8,7 @@ services: volumes: - ./knowledge:/app/knowledge # env_file: - # - development.env + # - .env.prod command: [ "uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000" ] environment: # Weaviate Database @@ -17,12 +17,9 @@ services: - KNOWLEDGE_BASE_FOLDER - QA_FOLDER # Development config - - TEST_MODE - DELETE_BEFORE_INIT # Ollama - USE_OLLAMA - - API_KEY - - URL - GPU_URL - GPU_USER - GPU_PASSWORD @@ -34,7 +31,7 @@ services: - OPENAI_MODEL - OPENAI_EMBEDDING_MODEL # Azure OpenAI - - USE_AZURE=true + - USE_AZURE - AZURE_OPENAI_API_KEY - AZURE_OPENAI_DEPLOYMENT - AZURE_OPENAI_EMBEDDING_DEPLOYMENT @@ -44,12 +41,8 @@ services: - COHERE_API_KEY - COHERE_API_KEY_MULTI - COHERE_API_KEY_EN - # some auth - - API_ENDPOINT_KEY + # Authentication - ANGELOS_APP_API_KEY - - EXPECTED_PASSWORD - - EXPECTED_USERNAME - - WITHOUT_USER_LOGIN networks: - angelos-network diff --git a/docker/weaviate.yml b/docker/weaviate.yml index 9366c96..a55993e 100644 --- a/docker/weaviate.yml +++ b/docker/weaviate.yml @@ -2,12 +2,12 @@ services: weaviate: command: - - --host - - 0.0.0.0 - - --port - - '8001' - - --scheme - - http + - --host + - 0.0.0.0 + - --port + - '8001' + - --scheme + - http image: cr.weaviate.io/semitechnologies/weaviate:1.25.3 expose: - "8001" @@ -19,4 +19,6 @@ services: - ${WEAVIATE_VOLUME_MOUNT:-./.docker-data/weaviate-data}:/var/lib/weaviate restart: on-failure:3 env_file: - - ./weaviate/default.env \ No newline at end of file + - ./docker/weaviate/default.env + networks: + - angelos-network \ No newline at end of file