This repository has been archived by the owner on Jul 1, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 35
/
Copy pathrag.py
72 lines (60 loc) · 2.45 KB
/
rag.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import faiss
import time
import numpy as np
from tqdm import tqdm
from lamini.api.embedding import Embedding
from lamini import Lamini
from directory_helper import DirectoryLoader
class LaminiIndex:
def __init__(self, loader):
self.loader = loader
self.build_index()
def build_index(self):
self.content_chunks = []
self.index = None
for chunk_batch in tqdm(self.loader):
embeddings = self.get_embeddings(chunk_batch)
if self.index is None:
self.index = faiss.IndexFlatL2(len(embeddings[0]))
self.index.add(embeddings)
self.content_chunks.extend(chunk_batch)
def get_embeddings(self, examples):
ebd = Embedding()
embeddings = ebd.generate(examples)
embedding_list = [embedding[0] for embedding in embeddings]
return np.array(embedding_list)
def query(self, query, k=5):
embedding = self.get_embeddings([query])[0]
embedding_array = np.array([embedding])
_, indices = self.index.search(embedding_array, k)
return [self.content_chunks[i] for i in indices[0]]
class QueryEngine:
def __init__(self, index, k=5):
self.index = index
self.k = k
self.model = Lamini(model_name="mistralai/Mistral-7B-Instruct-v0.1")
def answer_question(self, question):
most_similar = self.index.query(question, k=self.k)
prompt = "\n".join(reversed(most_similar)) + "\n\n" + question
print("------------------------------ Prompt ------------------------------\n" + prompt + "\n----------------------------- End Prompt -----------------------------")
return self.model.generate("<s>[INST]" + prompt + "[/INST]")
class RetrievalAugmentedRunner:
def __init__(self, dir, k=5):
self.k = k
self.loader = DirectoryLoader(dir)
def train(self):
self.index = LaminiIndex(self.loader)
def __call__(self, query):
query_engine = QueryEngine(self.index, k=self.k)
return query_engine.answer_question(query)
def main():
model = RetrievalAugmentedRunner(dir="data")
start = time.time()
model.train()
print("Time taken to build index: ", time.time() - start)
while True:
prompt = input("\n\nEnter another investment question (e.g. Have we invested in any generative AI companies in 2023?): ")
start = time.time()
print(model(prompt))
print("\nTime taken: ", time.time() - start)
main()