-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrag.py
49 lines (36 loc) · 1.5 KB
/
rag.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import dotenv
from langchain.prompts import ChatPromptTemplate
from langchain.chat_models import ChatOpenAI
from langchain.schema import StrOutputParser
from langchain.schema.runnable import RunnablePassthrough
from store import Store
from model import Model
dotenv.load_dotenv()
class RAG:
def __init__(self, embedding="OpenAI"):
self.store = Store("data/chroma", retriever_mode="parent", embedding="bge")
self.retriever = self.store.get_retriever()
self.model = Model(model="llama")
self.prompt = ChatPromptTemplate.from_template("""You are an assistant for question-answering tasks. Use the following pieces of retrieved context to answer the question. If you don't know the answer, just say that you don't know. Use three sentences maximum and keep the answer concise. \
Question: {question}
Context: {context}
Answer:"
""")
self.llm = self.model.get_llm()
self.llm_chain = (
self.llm
| StrOutputParser()
| RunnablePassthrough()
)
def _format_docs(self, docs):
return "\n\n".join(doc.page_content for doc in docs)
def invoke(self, str):
docs = self.retriever.invoke(str)
print(docs)
prompt = self.prompt.invoke(
{"context": self._format_docs(docs), "question": str}
)
return self.llm_chain.invoke(prompt)
if __name__ == "__main__":
rag = RAG()
print(rag.invoke("What is Task Decomposition?"))