Skip to content

Commit

Permalink
feat: implement basic rag functionality
Browse files Browse the repository at this point in the history
Co-authored-by: Sverre Nystad SverreNystad@users.noreply.github.com
  • Loading branch information
JonBergland committed Jan 9, 2025
1 parent 1eba055 commit 3073306
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 0 deletions.
45 changes: 45 additions & 0 deletions rules-engine/src/knowledge_base/agent/rag.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import numpy as np
from langchain_openai import ChatOpenAI
from langchain_openai import OpenAIEmbeddings
from knowledge_base.models.models import OpenAIModels

class RAG:
def __init__(self, model: OpenAIModels):
self.llm = ChatOpenAI(model=model)
self.embeddings = OpenAIEmbeddings()
self.doc_embeddings: list[list[float]] = None
self.docs: list[str] = None

def load_documents(self, documents: list[str]):
"""Load documents and compute their embeddings."""
self.docs = documents
self.doc_embeddings = self.embeddings.embed_documents(documents)

def get_most_relevant_docs(self, query: str, k: int = 5, threshold: float = 0.8):
"""Find the most relevant document for a given query."""
if not self.docs or not self.doc_embeddings:
raise ValueError("Documents and their embeddings are not loaded.")

query_embedding = self.embeddings.embed_query(query)

# Using cosine similarity
similarities = [
np.dot(query_embedding, doc_emb)
/ (np.linalg.norm(query_embedding) * np.linalg.norm(doc_emb))
for doc_emb in self.doc_embeddings
]

# TODO: Allow for top k elements to be choosen
most_relevant_doc_index = np.argmax(similarities)

return [self.docs[most_relevant_doc_index]]

def generate_answer(self, query: str, relevant_doc: list[str]):
"""Generate an answer for a given query based on the most relevant document."""
prompt = f"question: {query}\n\nDocuments: {relevant_doc}"
messages = [
("system", "You are a helpful assistant that answers questions based on given documents only."),
("human", prompt),
]
ai_msg = self.llm.invoke(messages)
return ai_msg.content
44 changes: 44 additions & 0 deletions rules-engine/src/knowledge_base/agent/rag_service.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from knowledge_base.agent.agent import Agent
from knowledge_base.agent.rag import RAG

from langchain_core.prompts import PromptTemplate
# OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
Expand Down Expand Up @@ -74,3 +75,46 @@ def format_docs(docs):
rules = result
return rules


def rag_setup():
from ragas.llms import LangchainLLMWrapper
from ragas.embeddings import LangchainEmbeddingsWrapper
from langchain_openai import ChatOpenAI
from langchain_openai import OpenAIEmbeddings
from knowledge_base.models.models import OpenAIModels
from langchain_core.documents import Document
evaluator_llm = LangchainLLMWrapper(OpenAIModels.gpt_4o_mini)
evaluator_embeddings = LangchainEmbeddingsWrapper(OpenAIEmbeddings())

# filepath = "rules-engine/src/knowledge_base/rulesystems/cc-srd5.md"
filepath = "knowledge_base/rulesystems/cc-srd5.md"

rules_document: str = ""
with open(filepath, encoding="utf-8") as f:
rules_document = f.read()



text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
splits = text_splitter.split_text(rules_document)


rag = RAG(OpenAIModels.gpt_4o_mini)

# Load documents
rag.load_documents(splits)

# Query and retrieve the most relevant document
query = "How many hitpoints does an Barbarian have at level 1? "
relevant_doc = rag.get_most_relevant_docs(query)

# Generate an answer
answer = rag.generate_answer(query, relevant_doc)

print(f"Query: {query}")
print(f"Relevant Document: {relevant_doc}")
print(f"Answer: {answer}")




0 comments on commit 3073306

Please sign in to comment.