-
Notifications
You must be signed in to change notification settings - Fork 74
/
Copy pathllama_index.py
187 lines (138 loc) · 5.65 KB
/
llama_index.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
import os
import streamlit as st
import utils.logs as logs
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
# This is not used but required by llama-index and must be set FIRST
os.environ["OPENAI_API_KEY"] = "sk-abc123"
from llama_index.core import (
VectorStoreIndex,
SimpleDirectoryReader,
Settings,
)
###################################
#
# Setup Embedding Model
#
###################################
@st.cache_resource(show_spinner=False)
def setup_embedding_model(
model: str,
):
"""
Sets up an embedding model using the Hugging Face library.
Args:
model (str): The name of the embedding model to use.
Returns:
An instance of the HuggingFaceEmbedding class, configured with the specified model and device.
Raises:
ValueError: If the specified model is not a valid embedding model.
Notes:
The `device` parameter can be set to 'cpu' or 'cuda' to specify the device to use for the embedding computations. If 'cuda' is used and CUDA is available, the embedding model will be run on the GPU. Otherwise, it will be run on the CPU.
"""
try:
from torch import cuda
device = "cpu" if not cuda.is_available() else "cuda"
except:
device = "cpu"
finally:
logs.log.info(f"Using {device} to generate embeddings")
try:
Settings.embed_model = HuggingFaceEmbedding(
model_name=model,
device=device,
)
logs.log.info(f"Embedding model created successfully")
return
except Exception as err:
print(f"Failed to setup the embedding model: {err}")
###################################
#
# Load Documents
#
###################################
def load_documents(data_dir: str):
"""
Loads documents from a directory of files.
Args:
data_dir (str): The path to the directory containing the documents to be loaded.
Returns:
A list of documents, where each document is a string representing the content of the corresponding file.
Raises:
Exception: If there is an error creating the data index.
Notes:
The `data_dir` parameter should be a path to a directory containing files that represent the documents to be loaded. The function will iterate over all files in the directory, and load their contents into a list of strings.
"""
try:
files = SimpleDirectoryReader(input_dir=data_dir, recursive=True)
documents = files.load_data(files)
logs.log.info(f"Loaded {len(documents):,} documents from files")
return documents
except Exception as err:
logs.log.error(f"Error creating data index: {err}")
raise Exception(f"Error creating data index: {err}")
finally:
for file in os.scandir(data_dir):
if file.is_file() and not file.name.startswith(
".gitkeep"
): # TODO: Confirm syntax here
os.remove(file.path)
logs.log.info(f"Document loading complete; removing local file(s)")
###################################
#
# Create Document Index
#
###################################
@st.cache_data(show_spinner=False)
def create_index(_documents):
"""
Creates an index from the provided documents and service context.
Args:
documents (list[str]): A list of strings representing the content of the documents to be indexed.
Returns:
An instance of `VectorStoreIndex`, containing the indexed data.
Raises:
Exception: If there is an error creating the index.
Notes:
The `documents` parameter should be a list of strings representing the content of the documents to be indexed.
"""
try:
index = VectorStoreIndex.from_documents(
documents=_documents, show_progress=True
)
logs.log.info("Index created from loaded documents successfully")
return index
except Exception as err:
logs.log.error(f"Index creation failed: {err}")
raise Exception(f"Index creation failed: {err}")
###################################
#
# Create Query Engine
#
###################################
# @st.cache_resource(show_spinner=False)
def create_query_engine(_documents):
"""
Creates a query engine from the provided documents and service context.
Args:
documents (list[str]): A list of strings representing the content of the documents to be indexed.
Returns:
An instance of `QueryEngine`, containing the indexed data and allowing for querying of the data using a variety of parameters.
Raises:
Exception: If there is an error creating the query engine.
Notes:
The `documents` parameter should be a list of strings representing the content of the documents to be indexed.
This function uses the `create_index` function to create an index from the provided documents and service context, and then creates a query engine from the resulting index. The `query_engine` parameter is used to specify the parameters of the query engine, including the number of top-ranked items to return (`similarity_top_k`) and the response mode (`response_mode`).
"""
try:
index = create_index(_documents)
query_engine = index.as_query_engine(
similarity_top_k=st.session_state["top_k"],
response_mode=st.session_state["chat_mode"],
streaming=True,
)
st.session_state["query_engine"] = query_engine
logs.log.info("Query Engine created successfully")
return query_engine
except Exception as e:
logs.log.error(f"Error when creating Query Engine: {e}")
raise Exception(f"Error when creating Query Engine: {e}")