forked from hpcaitech/ColossalAI
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathretrieval_conversation_universal.py
138 lines (127 loc) · 5.89 KB
/
retrieval_conversation_universal.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
"""
Multilingual retrieval based conversation system
"""
from typing import List
from colossalqa.data_loader.document_loader import DocumentLoader
from colossalqa.mylogging import get_logger
from colossalqa.retrieval_conversation_en import EnglishRetrievalConversation
from colossalqa.retrieval_conversation_zh import ChineseRetrievalConversation
from colossalqa.retriever import CustomRetriever
from colossalqa.text_splitter import ChineseTextSplitter
from colossalqa.utils import detect_lang_naive
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
logger = get_logger()
class UniversalRetrievalConversation:
"""
Wrapper class for bilingual retrieval conversation system
"""
def __init__(
self,
embedding_model_path: str = "moka-ai/m3e-base",
embedding_model_device: str = "cpu",
zh_model_path: str = None,
zh_model_name: str = None,
en_model_path: str = None,
en_model_name: str = None,
sql_file_path: str = None,
files_zh: List[List[str]] = None,
files_en: List[List[str]] = None,
text_splitter_chunk_size=100,
text_splitter_chunk_overlap=10,
) -> None:
"""
Wrapper for multilingual retrieval qa class (Chinese + English)
Args:
embedding_model_path: local or huggingface embedding model
embedding_model_device:
files_zh: [[file_path, name_of_file, separator],...] defines the files used as supporting documents for Chinese retrieval QA
files_en: [[file_path, name_of_file, separator],...] defines the files used as supporting documents for English retrieval QA
"""
self.embedding = HuggingFaceEmbeddings(
model_name=embedding_model_path,
model_kwargs={"device": embedding_model_device},
encode_kwargs={"normalize_embeddings": False},
)
print("Select files for constructing Chinese retriever")
docs_zh = self.load_supporting_docs(
files=files_zh,
text_splitter=ChineseTextSplitter(
chunk_size=text_splitter_chunk_size, chunk_overlap=text_splitter_chunk_overlap
),
)
# Create retriever
self.information_retriever_zh = CustomRetriever(
k=3, sql_file_path=sql_file_path.replace(".db", "_zh.db"), verbose=True
)
self.information_retriever_zh.add_documents(
docs=docs_zh, cleanup="incremental", mode="by_source", embedding=self.embedding
)
print("Select files for constructing English retriever")
docs_en = self.load_supporting_docs(
files=files_en,
text_splitter=RecursiveCharacterTextSplitter(
chunk_size=text_splitter_chunk_size, chunk_overlap=text_splitter_chunk_overlap
),
)
# Create retriever
self.information_retriever_en = CustomRetriever(
k=3, sql_file_path=sql_file_path.replace(".db", "_en.db"), verbose=True
)
self.information_retriever_en.add_documents(
docs=docs_en, cleanup="incremental", mode="by_source", embedding=self.embedding
)
self.chinese_retrieval_conversation = ChineseRetrievalConversation.from_retriever(
self.information_retriever_zh, model_path=zh_model_path, model_name=zh_model_name
)
self.english_retrieval_conversation = EnglishRetrievalConversation.from_retriever(
self.information_retriever_en, model_path=en_model_path, model_name=en_model_name
)
self.memory = None
def load_supporting_docs(self, files: List[List[str]] = None, text_splitter: TextSplitter = None):
"""
Load supporting documents, currently, all documents will be stored in one vector store
"""
documents = []
if files:
for file in files:
retriever_data = DocumentLoader([[file["data_path"], file["name"]]]).all_data
splits = text_splitter.split_documents(retriever_data)
documents.extend(splits)
else:
while True:
file = input("Select a file to load or press Enter to exit:")
if file == "":
break
data_name = input("Enter a short description of the data:")
separator = input(
"Enter a separator to force separating text into chunks, if no separator is given, the default separator is '\\n\\n', press ENTER directly to skip:"
)
separator = separator if separator != "" else "\n\n"
retriever_data = DocumentLoader([[file, data_name.replace(" ", "_")]]).all_data
# Split
splits = text_splitter.split_documents(retriever_data)
documents.extend(splits)
return documents
def start_test_session(self):
"""
Simple multilingual session for testing purpose, with naive language selection mechanism
"""
while True:
user_input = input("User: ")
lang = detect_lang_naive(user_input)
if "END" == user_input:
print("Agent: Happy to chat with you :)")
break
agent_response = self.run(user_input, which_language=lang)
print(f"Agent: {agent_response}")
def run(self, user_input: str, which_language=str):
"""
Generate the response given the user input and a str indicates the language requirement of the output string
"""
assert which_language in ["zh", "en"]
if which_language == "zh":
agent_response, self.memory = self.chinese_retrieval_conversation.run(user_input, self.memory)
else:
agent_response, self.memory = self.english_retrieval_conversation.run(user_input, self.memory)
return agent_response.split("\n")[0]