-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathchat.py
187 lines (166 loc) · 6.8 KB
/
chat.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
# -*- coding: utf-8 -*-
"""ACMI collection chat
Uses Langchain RetrievalQA to chat with our collection data.
https://github.com/langchain-ai/langchain/blob/master/cookbook/openai_functions_retrieval_qa.ipynb
Colab prototype: https://colab.research.google.com/drive/1RLe2LliEE63KaQgxXDv3xccmxCYpmmPx
"""
import json
import os
import openai
import requests
from furl import furl
from langchain.chains import (ConversationalRetrievalChain, LLMChain,
RetrievalQA, create_qa_with_sources_chain)
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
from langchain.memory import ConversationBufferMemory
from langchain.prompts import PromptTemplate
from langchain_community.document_loaders import JSONLoader
from langchain_community.embeddings import OllamaEmbeddings
from langchain_community.llms import Ollama
from langchain_community.vectorstores import Chroma
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
DATABASE_PATH = os.getenv('DATABASE_PATH', '')
COLLECTION_NAME = os.getenv('COLLECTION_NAME', 'works')
PERSIST_DIRECTORY = os.getenv('PERSIST_DIRECTORY', 'works_db')
MODEL = os.getenv('MODEL', 'gpt-4o')
EMBEDDINGS_MODEL = os.getenv('EMBEDDINGS_MODEL', None)
LLM_BASE_URL = os.getenv('LLM_BASE_URL', None)
REBUILD = os.getenv('REBUILD', 'false').lower() == 'true'
HISTORY = os.getenv('HISTORY', 'true').lower() == 'true'
ALL = os.getenv('ALL', 'false').lower() == 'true'
# Set true if you'd like langchain tracing via LangSmith https://smith.langchain.com
os.environ['LANGCHAIN_TRACING_V2'] = 'false'
if MODEL.startswith('gpt'):
llm = ChatOpenAI(temperature=0, model=MODEL)
embeddings = OpenAIEmbeddings(model=EMBEDDINGS_MODEL or 'text-embedding-ada-002')
else:
llm = Ollama(model=MODEL)
embeddings = OllamaEmbeddings(model=EMBEDDINGS_MODEL or MODEL)
if LLM_BASE_URL:
llm.base_url = LLM_BASE_URL
embeddings.base_url = LLM_BASE_URL
docsearch = Chroma(
collection_name=COLLECTION_NAME,
embedding_function=embeddings,
persist_directory=f'{DATABASE_PATH}{PERSIST_DIRECTORY}',
)
if len(docsearch) < 1 or REBUILD:
json_data = {
'results': [],
}
params = {'page': ''}
TMP_FILE_PATH = 'data.json'
if os.path.isfile(TMP_FILE_PATH):
print('Loading works from the ACMI Public API data.json file you have already created...')
with open(TMP_FILE_PATH, 'r', encoding='utf-8') as tmp_file:
json_data = json.load(tmp_file)
else:
if ALL:
print('Loading all of the works from the ACMI Public API')
while True:
page_data = requests.get(
'https://api.acmi.net.au/works/',
params=params,
timeout=10,
).json()
json_data['results'].extend(page_data['results'])
if not page_data.get('next'):
break
params['page'] = furl(page_data.get('next')).args.get('page')
if len(json_data['results']) % 1000 == 0:
print(f'Downloaded {len(json_data["results"])}...')
else:
print('Loading the first ten pages of works from the ACMI Public API')
PAGES = 10
json_data = {
'results': [],
}
for index in range(1, (PAGES + 1)):
page_data = requests.get(
'https://api.acmi.net.au/works/',
params=params,
timeout=10,
)
json_data['results'].extend(page_data.json()['results'])
print(f'Downloaded {page_data.request.url}')
params['page'] = furl(page_data.json().get('next')).args.get('page')
print(f'Finished downloading {len(json_data["results"])} works.')
with open(TMP_FILE_PATH, 'w', encoding='utf-8') as json_file:
json.dump(json_data, json_file)
json_loader = JSONLoader(
file_path=TMP_FILE_PATH,
jq_schema='.results[]',
text_content=False,
)
data = json_loader.load()
# Add source metadata
for i, item in enumerate(data):
item.metadata['source'] = f'https://api.acmi.net.au/works/{json_data["results"][i]["id"]}'
def chunks(input_list, number_per_chunk):
"""Yield successive chunks from the input_list."""
for idx in range(0, len(input_list), number_per_chunk):
yield input_list[idx:idx + number_per_chunk]
# Add to the vector database in chunks to avoid OpenAI rate limits
for i, sublist in enumerate(chunks(data, 10)):
docsearch.add_documents(
sublist,
)
print(f'Added {len(sublist)} items to the database... total {(i + 1) * len(sublist)}')
print(f'Finished adding {len(data)} items to the database')
qa_chain = create_qa_with_sources_chain(llm)
doc_prompt = PromptTemplate(
template='Content: {page_content}\nSource: {source}',
input_variables=['page_content', 'source'],
)
final_qa_chain = StuffDocumentsChain(
llm_chain=qa_chain,
document_variable_name='context',
document_prompt=doc_prompt,
)
retrieval_qa = RetrievalQA(
retriever=docsearch.as_retriever(),
combine_documents_chain=final_qa_chain,
)
if HISTORY:
memory = ConversationBufferMemory(memory_key='chat_history', return_messages=True)
TEMPLATE = """Given a chat history and the latest user question
which might reference context in the chat history,
formulate a standalone question which can be understood
without the chat history. Do NOT answer the question,
just reformulate it if needed and otherwise return it as is.
Chat History:
{chat_history}
Follow Up Input: {question}
Standalone question:"""
CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(TEMPLATE)
condense_question_chain = LLMChain(
llm=llm,
prompt=CONDENSE_QUESTION_PROMPT,
)
retrieval_qa = ConversationalRetrievalChain(
question_generator=condense_question_chain,
retriever=docsearch.as_retriever(),
memory=memory,
combine_docs_chain=final_qa_chain,
)
print('=========================')
print('ACMI collection chat v0.1')
print('=========================\n')
while True:
try:
query = input('Question: ')
if HISTORY:
response = retrieval_qa.invoke({'question': query}).get('answer')
else:
response = retrieval_qa.invoke(query).get('result')
try:
print(f'Answer: {json.loads(response)["answer"]}')
print(f'Sources: {json.loads(response)["sources"]}\n')
except TypeError:
print(f'Answer: {response}\n')
except KeyboardInterrupt:
print('\n\nNice chatting to you.\n')
break
except openai.BadRequestError:
print('\n\nSorry, something went wrong with the OpenAI request.\n')
break