Skip to content

Commit

Permalink
Merge pull request #17 from ls1intum/handle-dates-and-chat-history
Browse files Browse the repository at this point in the history
Handle dates and chat history
  • Loading branch information
ninori9 authored Nov 5, 2024
2 parents 74c932a + f6d35a8 commit 243524e
Show file tree
Hide file tree
Showing 12 changed files with 248 additions and 456 deletions.
56 changes: 6 additions & 50 deletions app/api/question_router.py
Original file line number Diff line number Diff line change
@@ -1,63 +1,19 @@
import logging
from datetime import datetime, timezone, timedelta

import jwt
from fastapi import HTTPException, APIRouter, status, Response, Header, Depends
from pydantic import BaseModel
from fastapi import HTTPException, APIRouter, status, Response, Depends

from app.data.user_requests import UserChat
from app.data.user_requests import UserChat, UserRequest
from app.injestion.vector_store_initializer import initialize_vectorstores
from app.utils.dependencies import request_handler, weaviate_manager, model
from app.utils.dependencies import request_handler, auth_handler, weaviate_manager, model
from app.utils.environment import config


class UserRequest(BaseModel):
message: str
study_program: str
language: str


SECRET_KEY = config.API_ENDPOINT_KEY
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 60


def create_access_token(data: dict):
"""
Generates a JWT token with an expiration time.
"""
to_encode = data.copy()
expire = datetime.now(timezone.utc) + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
to_encode.update({"exp": expire})
return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)


async def verify_api_key(x_api_key: str = Header(None)):
if x_api_key != config.ANGELOS_APP_API_KEY:
raise HTTPException(status_code=403, detail="Unauthorized access")


async def verify_token(authorization: str = Header(...)):
"""
Dependency to validate the JWT token in the Authorization header.
"""
try:
token = authorization.split(" ")[1]
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
except jwt.ExpiredSignatureError:
raise HTTPException(status_code=403, detail="Token has expired")
except jwt.PyJWTError:
raise HTTPException(status_code=403, detail="Invalid token")


router = APIRouter(prefix="/api/v1/question", tags=["response"], dependencies=[Depends(verify_token)])
auth = APIRouter(prefix="/api", tags=["response"], dependencies=[Depends(verify_api_key)])

router = APIRouter(prefix="/api/v1/question", tags=["response"], dependencies=[Depends(auth_handler.verify_token)])
auth = APIRouter(prefix="/api", tags=["response"], dependencies=[Depends(auth_handler.verify_api_key)])

@auth.post("/token")
async def login():
token_data = {"sub": "angular_app"}
access_token = create_access_token(data=token_data)
access_token = auth_handler.create_access_token(data=token_data)
return {"access_token": access_token, "token_type": "bearer"}


Expand Down
5 changes: 5 additions & 0 deletions app/data/user_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,8 @@ class WebsiteContent(BaseModel):
content: str
link: str
study_program: str

class UserRequest(BaseModel):
message: str
study_program: str
language: str
41 changes: 41 additions & 0 deletions app/managers/auth_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from datetime import datetime, timezone, timedelta
import jwt
import logging
from fastapi import HTTPException, Header

class AuthHandler:
def __init__(self, angelos_api_key: str, secret_key: str, algorithm: str, access_token_expires_minutes: float):
self.angelos_api_key = angelos_api_key
self.secret_key = secret_key
self.algorithm = algorithm
self.access_token_expires_minutes = access_token_expires_minutes

def create_access_token(self, data: dict):
"""
Generates a JWT token with an expiration time.
"""
to_encode = data.copy()
expire = datetime.now(timezone.utc) + timedelta(minutes=self.access_token_expires_minutes)
to_encode.update({"exp": expire})
return jwt.encode(to_encode, self.secret_key, algorithm=self.algorithm)


async def verify_api_key(self, x_api_key: str = Header(None)):
if x_api_key != self.angelos_api_key:
logging.error("Unauthorized access")
raise HTTPException(status_code=403, detail="Unauthorized access")


async def verify_token(self, authorization: str = Header(...)):
"""
Dependency to validate the JWT token in the Authorization header.
"""
try:
token = authorization.split(" ")[1]
payload = jwt.decode(token, self.secret_key, algorithms=[self.algorithm])
except jwt.ExpiredSignatureError:
logging.error("Token has expired")
raise HTTPException(status_code=403, detail="Token has expired")
except jwt.PyJWTError:
logging.error("Invalid token")
raise HTTPException(status_code=403, detail="Invalid token")
64 changes: 58 additions & 6 deletions app/managers/request_handler.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from typing import List
import re
import logging

from app.data.user_requests import ChatMessage
from app.managers.weaviate_manager import WeaviateManager
Expand Down Expand Up @@ -54,21 +56,71 @@ def handle_chat(self, messages: List[ChatMessage], study_program: str):
"""Handles the question by fetching relevant documents and generating an answer."""
# The last message is the user's current question
last_message = messages[-1].message

# Determine language
lang = LanguageDetector.get_language(last_message)
# Get context
general_context = self.weaviate_manager.get_relevant_context(last_message, "general", lang)

# Decide whether to retrieve context based on history
if len(messages) <= 2:
get_history_context = False
general_query = f"{re.sub(r'-', ' ', study_program).title()}: {last_message}"
context_limit = 10
context_top_n = 5
else:
get_history_context = True
general_query = last_message
context_limit = 8
context_top_n = 4

# Retrieve general context based on the last message
general_context_last = self.weaviate_manager.get_relevant_context(
general_query, "general", lang, limit=context_limit, top_n=context_top_n
)
general_context = general_context_last

# If applicable, retrieve additional context based on history
if get_history_context:
# Build a query from the chat history
chat_query = self.prompt_manager.build_chat_query(messages, study_program, num_messages=3)

# Retrieve general context using the chat history
general_context_history = self.weaviate_manager.get_relevant_context(
chat_query, "general", lang, limit=4, top_n=2
)
# Combine the contexts
general_context = f"{general_context_last}\n-----\n{general_context_history}"

# Retrieve specific context if a study program is specified
specific_context = None
if study_program and study_program.lower() != "general":
specific_context = self.weaviate_manager.get_relevant_context(last_message, study_program, lang)
# Retrieve specific context based on the last message
specific_context_last = self.weaviate_manager.get_relevant_context(
last_message, study_program, lang, limit=context_limit, top_n=context_top_n
)
specific_context = specific_context_last

if get_history_context:
# Retrieve specific context using the chat history
specific_context_history = self.weaviate_manager.get_relevant_context(
chat_query, study_program, lang, limit=4, top_n=2
)
# Combine the contexts
specific_context = f"{specific_context_last}\n-----\n{specific_context_history}"

# Retrieve and format sample questions
sample_questions = self.weaviate_manager.get_relevant_sample_questions(question=last_message, language=lang)
sample_questions_formatted = self.prompt_manager.format_sample_questions(sample_questions, lang)
sample_questions = self.weaviate_manager.get_relevant_sample_questions(
question=last_message, language=lang
)
sample_questions_formatted = self.prompt_manager.format_sample_questions(
sample_questions, lang
)

# Format chat history (excluding the last message)
if len(messages) > 1:
history_formatted = self.prompt_manager.format_chat_history(messages[:-1], lang)
history_formatted = self.prompt_manager.format_chat_history(messages, lang)
else:
history_formatted = None # No history to format

# Create messages for the model
messages_to_model = self.prompt_manager.create_messages_with_history(
general_context=general_context,
Expand Down
6 changes: 2 additions & 4 deletions app/managers/weaviate_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def initialize_qa_schema(self) -> Collection:
logging.error(f"Error creating schema for {collection_name}: {e}")

def get_relevant_context(self, question: str, study_program: str, language: str,
test_mode: bool = False) -> Union[str, Tuple[str, List[str]]]:
test_mode: bool = False, limit = 10, top_n = 5) -> Union[str, Tuple[str, List[str]]]:
"""
Retrieve relevant documents based on the question embedding and study program.
Optionally returns both the concatenated context and the sorted context list for testing purposes.
Expand All @@ -209,7 +209,6 @@ def get_relevant_context(self, question: str, study_program: str, language: str,
"""
try:
# Define the number of documents to retrieve
limit = 10
min_relevance_score = 0.25
if study_program.lower() != "general":
limit = 10
Expand All @@ -235,7 +234,6 @@ def get_relevant_context(self, question: str, study_program: str, language: str,
# include_vector=True,
return_metadata=wvc.query.MetadataQuery(certainty=True, score=True, distance=True)
)
logging.info(f"No error yet after {study_program} getting relevant context")
# documents_with_embeddings: List[DocumentWithEmbedding] = []
# for result in query_result.objects:
# logging.info(
Expand All @@ -260,7 +258,7 @@ def get_relevant_context(self, question: str, study_program: str, language: str,
# Rerank the unique contexts using Cohere
sorted_context = self.reranker.rerank_with_cohere(context_list=content_content_list, query=question,
language=language,
min_relevance_score=min_relevance_score, top_n=5)
min_relevance_score=min_relevance_score, top_n=top_n)
# Integrate links
sorted_context_with_links = []
for sorted_content in sorted_context:
Expand Down
31 changes: 27 additions & 4 deletions app/prompt/prompt_manager.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List
from typing import List, Tuple, Union
import logging
import re

Expand Down Expand Up @@ -117,7 +117,7 @@ def __init__(self):
- Re-read the question carefully.
- Analyze the provided general information and, if available, study program-specific context.
- You are part of an ongoing conversation. The 'History' section contains previous exchanges with the student, which you should refer to in order to maintain continuity and avoid repeating information.
- Use the 'History' to understand the flow of the conversation and ensure your answer fits within the context of the ongoing dialogue.
- Use the 'History' to understand the question and the flow of the conversation and ensure your answer fits within the context of the ongoing dialogue.
- If a provided similar question from a student is thematically very similar to the question asked, rely heavily on the respective sample answer from academic advising.
- Else, prioritize study program-specific context over general information.
- If no specific context is provided, base your answer solely on the general context.
Expand Down Expand Up @@ -176,7 +176,7 @@ def __init__(self):
- Lesen Sie die Frage sorgfältig durch.
- Analysieren Sie die bereitgestellten allgemeinen Informationen und, falls vorhanden, die studiengangspezifischen Informationen. Analysiere zudem, falls vorhanden die bereitgestellten ähnlichen Fragen und Antworten basierend auf früheren Anfragen.
- Du bist Teil eines laufenden Gesprächs. Der Abschnitt 'Verlauf' enthält frühere Nachrichten der Unterhaltung zwischen zwischen Ihnen und dem Studenten, auf die du dich beziehen solltest, um die Kontinuität aufrechtzuerhalten und Wiederholungen zu vermeiden.
- Nutze den 'Verlauf', um den Gesprächsfluss zu verstehen und sicherzustellen, dass deine Antwort in den Kontext des laufenden Dialogs passt.
- Nutze den 'Verlauf', um den Gesprächsfluss und Frage zu verstehen und sicherzustellen, dass deine Antwort in den Kontext des laufenden Dialogs passt.
- Wenn eine ähnliche Frage eines Studenten thematisch sehr ähnlich zur gestellten Frage ist, stützen Sie sich stark auf die jeweilige Beispielsantwort der Studienberatung.
- Sonst priorisieren Sie studiengangspezifische Informationen über allgemeine Informationen.
- Stellen Sie keine Vermutungen an, bieten Sie keine Interpretationen an und schaffen Sie keine neuen Informationen. Antworten Sie nur auf der Grundlage der bereitgestellten Informationen.
Expand Down Expand Up @@ -346,4 +346,27 @@ def format_study_program(self, study_program: str, language: str) -> str:
if language.lower() == "english":
return f"The study program of the student is {formatted_program}"
else:
return f"Der Studiengang des Studenten ist {formatted_program}"
return f"Der Studiengang des Studenten ist {formatted_program}"

def build_chat_query(self, messages: List[ChatMessage], study_program: str, num_messages: int = 3) -> str:
"""
Builds a query string from the last num_messages user messages.
Args:
messages (List[ChatMessage]): The list of chat messages.
num_messages (int): The number of recent user messages to include.
Returns:
str: The concatenated query string.
"""
# Extract messages of type 'user'
user_messages = [msg.message for msg in messages if msg.type == 'user']
# Take the last num_messages
recent_user_messages = user_messages[-num_messages:]
# Concatenate them into one query string
query = " ".join(recent_user_messages)
# Integrate study program
formatted_program = re.sub(r'-', ' ', study_program).title()
query_with_program = f"{formatted_program}: {query}"
return query_with_program

3 changes: 0 additions & 3 deletions app/retrieval_strategies/reranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,6 @@ def rerank_with_embeddings(self, context_list: List[DocumentWithEmbedding], keyw
# Rank the documents based on cosine similarity (in descending order)
ranked_indices = np.argsort(-np.array(cosine_similarities))

logging.info(f"Ranked indices: {ranked_indices}")

# Extract the content from the ranked DocumentWithEmbedding objects
ranked_context_list = [context_list[i].content for i in ranked_indices]

Expand Down Expand Up @@ -110,7 +108,6 @@ def rerank_with_cohere(self, context_list: List[str], query: str, language: str,
response_json = response.json()

# Log the full response from the API for debugging
logging.info(f"Cohere full response: {response_json}")
results = response_json.get('results', [])

# Log the ranked documents that are in the top_n
Expand Down
3 changes: 2 additions & 1 deletion app/utils/dependencies.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from app.managers.request_handler import RequestHandler
from app.managers.weaviate_manager import WeaviateManager
from app.managers.auth_handler import AuthHandler
from app.retrieval_strategies.reranker import Reranker
from app.models.model_loader import get_model
from app.prompt.prompt_manager import PromptManager
Expand All @@ -11,7 +12,7 @@
weaviate_manager = WeaviateManager(config.WEAVIATE_URL, embedding_model=model, reranker=reranker)
prompt_manager = PromptManager()
request_handler = RequestHandler(weaviate_manager=weaviate_manager, model=model, prompt_manager=prompt_manager)

auth_handler = AuthHandler(angelos_api_key=config.ANGELOS_APP_API_KEY, secret_key=config.API_ENDPOINT_KEY, algorithm=config.ALGORITHM, access_token_expires_minutes=config.ACCESS_TOKEN_EXPIRE_MINUTES)

# Provide a shutdown mechanism for the model
def shutdown_model():
Expand Down
7 changes: 4 additions & 3 deletions app/utils/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@


class Config:
MAX_MESSAGE_LENGTH = 3000
# Weaviate Database
WEAVIATE_URL = os.getenv("WEAVIATE_URL", "localhost")
WEAVIATE_PORT = os.getenv("WEAVIATE_PORT", "8001")
Expand Down Expand Up @@ -40,10 +39,12 @@ class Config:
COHERE_API_KEY = os.getenv("COHERE_API_KEY")
COHERE_API_KEY_MULTI = os.getenv("COHERE_API_KEY_MULTI")
COHERE_API_KEY_EN = os.getenv("COHERE_API_KEY_EN")

# safeguard
# Safeguard
API_ENDPOINT_KEY = os.getenv("API_ENDPOINT_KEY")
ANGELOS_APP_API_KEY = os.getenv("ANGELOS_APP_API_KEY")
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 60
MAX_MESSAGE_LENGTH = 3000


config = Config()
29 changes: 29 additions & 0 deletions requirements.in
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#-------------app dependencies-----------------------
fastapi==0.112.4
python-dotenv==1.0.1
requests==2.32.3
openai==1.44.1
uvicorn==0.30.6
starlette==0.38.6
langchain==0.3.3
numpy==1.26.4
scikit-learn==1.5.2
cohere==5.11.1
langchain-community==0.3.2
langchain-openai==0.2.2
langchain-core==0.3.10
weaviate-client==4.8.1
pydantic==2.9.2
pyjwt==2.9.0
langdetect==1.0.9

#-------------testing dependencies-----------------------
pytest==8.3.3
pytest-asyncio==0.24.0
pandas==2.2.3
datetime==5.5
tiktoken==0.8.0
bs4==0.0.2
pypdf==5.0.1
cryptography==43.0.1
deepeval==1.3.5
Loading

0 comments on commit 243524e

Please sign in to comment.