Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor ask decorator #4116

Merged
merged 5 commits into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
205 changes: 120 additions & 85 deletions api/db/services/dialog_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from timeit import default_timer as timer
import datetime
from datetime import timedelta
from api.db import LLMType, ParserType,StatusEnum
from api.db import LLMType, ParserType, StatusEnum
from api.db.db_models import Dialog, DB
from api.db.services.common_service import CommonService
from api.db.services.knowledgebase_service import KnowledgebaseService
Expand All @@ -41,14 +41,14 @@ class DialogService(CommonService):
@classmethod
@DB.connection_context()
def get_list(cls, tenant_id,
page_number, items_per_page, orderby, desc, id , name):
page_number, items_per_page, orderby, desc, id, name):
chats = cls.model.select()
if id:
chats = chats.where(cls.model.id == id)
if name:
chats = chats.where(cls.model.name == name)
chats = chats.where(
(cls.model.tenant_id == tenant_id)
(cls.model.tenant_id == tenant_id)
& (cls.model.status == StatusEnum.VALID.value)
)
if desc:
Expand Down Expand Up @@ -137,25 +137,37 @@ def kb_prompt(kbinfos, max_tokens):

def chat(dialog, messages, stream=True, **kwargs):
assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
st = timer()
llm_id, fid = TenantLLMService.split_model_name_and_factory(dialog.llm_id)
llm = LLMService.query(llm_name=llm_id) if not fid else LLMService.query(llm_name=llm_id, fid=fid)

chat_start_ts = timer()

# Get llm model name and model provider name
llm_id, model_provider = TenantLLMService.split_model_name_and_factory(dialog.llm_id)

# Get llm model instance by model and provide name
llm = LLMService.query(llm_name=llm_id) if not model_provider else LLMService.query(llm_name=llm_id, fid=model_provider)

if not llm:
llm = TenantLLMService.query(tenant_id=dialog.tenant_id, llm_name=llm_id) if not fid else \
TenantLLMService.query(tenant_id=dialog.tenant_id, llm_name=llm_id, llm_factory=fid)
# Model name is provided by tenant, but not system built-in
llm = TenantLLMService.query(tenant_id=dialog.tenant_id, llm_name=llm_id) if not model_provider else \
TenantLLMService.query(tenant_id=dialog.tenant_id, llm_name=llm_id, llm_factory=model_provider)
if not llm:
raise LookupError("LLM(%s) not found" % dialog.llm_id)
max_tokens = 8192
else:
max_tokens = llm[0].max_tokens

check_llm_ts = timer()

kbs = KnowledgebaseService.get_by_ids(dialog.kb_ids)
embd_nms = list(set([kb.embd_id for kb in kbs]))
if len(embd_nms) != 1:
embedding_list = list(set([kb.embd_id for kb in kbs]))
if len(embedding_list) != 1:
yield {"answer": "**ERROR**: Knowledge bases use different embedding models.", "reference": []}
return {"answer": "**ERROR**: Knowledge bases use different embedding models.", "reference": []}

is_kg = all([kb.parser_id == ParserType.KG for kb in kbs])
retr = settings.retrievaler if not is_kg else settings.kg_retrievaler
embedding_model_name = embedding_list[0]

is_knowledge_graph = all([kb.parser_id == ParserType.KG for kb in kbs])
retriever = settings.retrievaler if not is_knowledge_graph else settings.kg_retrievaler

questions = [m["content"] for m in messages if m["role"] == "user"][-3:]
attachments = kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else None
Expand All @@ -165,15 +177,21 @@ def chat(dialog, messages, stream=True, **kwargs):
if "doc_ids" in m:
attachments.extend(m["doc_ids"])

embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING, embd_nms[0])
create_retriever_ts = timer()

embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING, embedding_model_name)
if not embd_mdl:
raise LookupError("Embedding model(%s) not found" % embd_nms[0])
raise LookupError("Embedding model(%s) not found" % embedding_model_name)

bind_embedding_ts = timer()

if llm_id2llm_type(dialog.llm_id) == "image2text":
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id)
else:
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)

bind_llm_ts = timer()

prompt_config = dialog.prompt_config
field_map = KnowledgebaseService.get_field_map(dialog.kb_ids)
tts_mdl = None
Expand All @@ -200,32 +218,35 @@ def chat(dialog, messages, stream=True, **kwargs):
questions = [full_question(dialog.tenant_id, dialog.llm_id, messages)]
else:
questions = questions[-1:]
refineQ_tm = timer()
keyword_tm = timer()

refine_question_ts = timer()

rerank_mdl = None
if dialog.rerank_id:
rerank_mdl = LLMBundle(dialog.tenant_id, LLMType.RERANK, dialog.rerank_id)

for _ in range(len(questions) // 2):
questions.append(questions[-1])
bind_reranker_ts = timer()
generate_keyword_ts = bind_reranker_ts

if "knowledge" not in [p["key"] for p in prompt_config["parameters"]]:
kbinfos = {"total": 0, "chunks": [], "doc_aggs": []}
else:
if prompt_config.get("keyword", False):
questions[-1] += keyword_extraction(chat_mdl, questions[-1])
keyword_tm = timer()
generate_keyword_ts = timer()

tenant_ids = list(set([kb.tenant_id for kb in kbs]))
kbinfos = retr.retrieval(" ".join(questions), embd_mdl, tenant_ids, dialog.kb_ids, 1, dialog.top_n,
dialog.similarity_threshold,
dialog.vector_similarity_weight,
doc_ids=attachments,
top=dialog.top_k, aggs=False, rerank_mdl=rerank_mdl)
kbinfos = retriever.retrieval(" ".join(questions), embd_mdl, tenant_ids, dialog.kb_ids, 1, dialog.top_n,
dialog.similarity_threshold,
dialog.vector_similarity_weight,
doc_ids=attachments,
top=dialog.top_k, aggs=False, rerank_mdl=rerank_mdl)

retrieval_ts = timer()

knowledges = kb_prompt(kbinfos, max_tokens)
logging.debug(
"{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
retrieval_tm = timer()

if not knowledges and prompt_config.get("empty_response"):
empty_res = prompt_config["empty_response"]
Expand All @@ -249,17 +270,20 @@ def chat(dialog, messages, stream=True, **kwargs):
max_tokens - used_token_count)

def decorate_answer(answer):
nonlocal prompt_config, knowledges, kwargs, kbinfos, prompt, retrieval_tm
nonlocal prompt_config, knowledges, kwargs, kbinfos, prompt, retrieval_ts

finish_chat_ts = timer()

refs = []
if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)):
answer, idx = retr.insert_citations(answer,
[ck["content_ltks"]
for ck in kbinfos["chunks"]],
[ck["vector"]
for ck in kbinfos["chunks"]],
embd_mdl,
tkweight=1 - dialog.vector_similarity_weight,
vtweight=dialog.vector_similarity_weight)
answer, idx = retriever.insert_citations(answer,
[ck["content_ltks"]
for ck in kbinfos["chunks"]],
[ck["vector"]
for ck in kbinfos["chunks"]],
embd_mdl,
tkweight=1 - dialog.vector_similarity_weight,
vtweight=dialog.vector_similarity_weight)
idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
recall_docs = [
d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
Expand All @@ -274,10 +298,20 @@ def decorate_answer(answer):

if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0:
answer += " Please set LLM API-Key in 'User Setting -> Model providers -> API-Key'"
done_tm = timer()
prompt += "\n\n### Elapsed\n - Refine Question: %.1f ms\n - Keywords: %.1f ms\n - Retrieval: %.1f ms\n - LLM: %.1f ms" % (
(refineQ_tm - st) * 1000, (keyword_tm - refineQ_tm) * 1000, (retrieval_tm - keyword_tm) * 1000,
(done_tm - retrieval_tm) * 1000)
finish_chat_ts = timer()

total_time_cost = (finish_chat_ts - chat_start_ts) * 1000
check_llm_time_cost = (check_llm_ts - chat_start_ts) * 1000
create_retriever_time_cost = (create_retriever_ts - check_llm_ts) * 1000
bind_embedding_time_cost = (bind_embedding_ts - create_retriever_ts) * 1000
bind_llm_time_cost = (bind_llm_ts - bind_embedding_ts) * 1000
refine_question_time_cost = (refine_question_ts - bind_llm_ts) * 1000
bind_reranker_time_cost = (bind_reranker_ts - refine_question_ts) * 1000
generate_keyword_time_cost = (generate_keyword_ts - bind_reranker_ts) * 1000
retrieval_time_cost = (retrieval_ts - generate_keyword_ts) * 1000
generate_result_time_cost = (finish_chat_ts - retrieval_ts) * 1000

prompt = f"{prompt} ### Elapsed\n - Total: {total_time_cost:.1f}ms\n - Check LLM: {check_llm_time_cost:.1f}ms\n - Create retriever: {create_retriever_time_cost:.1f}ms\n - Bind embedding: {bind_embedding_time_cost:.1f}ms\n - Bind LLM: {bind_llm_time_cost:.1f}ms\n - Tune question: {refine_question_time_cost:.1f}ms\n - Bind reranker: {bind_reranker_time_cost:.1f}ms\n - Generate keyword: {generate_keyword_time_cost:.1f}ms\n - Retrieval: {retrieval_time_cost:.1f}ms\n - Generate answer: {generate_result_time_cost:.1f}ms"
return {"answer": answer, "reference": refs, "prompt": prompt}

if stream:
Expand All @@ -304,15 +338,15 @@ def decorate_answer(answer):


def use_sql(question, field_map, tenant_id, chat_mdl, quota=True):
sys_prompt = "你是一个DBA。你需要这对以下表的字段结构,根据用户的问题列表,写出最后一个问题对应的SQL。"
user_promt = """
表名:{};
数据库表字段说明如下:
sys_prompt = "You are a Database Administrator. You need to check the fields of the following tables based on the user's list of questions and write the SQL corresponding to the last question."
user_prompt = """
Table name: {};
Table of database fields are as follows:
{}

问题如下:
Question are as follows:
{}
请写出SQL, 且只要SQL,不要有其他说明及文字。
Please write the SQL, only SQL, without any other explanations or text.
""".format(
index_name(tenant_id),
"\n".join([f"{k}: {v}" for k, v in field_map.items()]),
Expand All @@ -321,10 +355,10 @@ def use_sql(question, field_map, tenant_id, chat_mdl, quota=True):
tried_times = 0

def get_table():
nonlocal sys_prompt, user_promt, question, tried_times
sql = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_promt}], {
nonlocal sys_prompt, user_prompt, question, tried_times
sql = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_prompt}], {
"temperature": 0.06})
logging.debug(f"{question} ==> {user_promt} get SQL: {sql}")
logging.debug(f"{question} ==> {user_prompt} get SQL: {sql}")
sql = re.sub(r"[\r\n]+", " ", sql.lower())
sql = re.sub(r".*select ", "select ", sql.lower())
sql = re.sub(r" +", " ", sql)
Expand Down Expand Up @@ -352,21 +386,23 @@ def get_table():
if tbl is None:
return None
if tbl.get("error") and tried_times <= 2:
user_promt = """
表名:{};
数据库表字段说明如下:
user_prompt = """
Table name: {};
Table of database fields are as follows:
{}

问题如下:
Question are as follows:
{}
Please write the SQL, only SQL, without any other explanations or text.


你上一次给出的错误SQL如下:
The SQL error you provided last time is as follows:
{}

后台报错如下:
Error issued by database as follows:
{}

请纠正SQL中的错误再写一遍,且只要SQL,不要有其他说明及文字。
Please correct the error and write SQL again, only SQL, without any other explanations or text.
""".format(
index_name(tenant_id),
"\n".join([f"{k}: {v}" for k, v in field_map.items()]),
Expand All @@ -381,21 +417,21 @@ def get_table():

docid_idx = set([ii for ii, c in enumerate(
tbl["columns"]) if c["name"] == "doc_id"])
docnm_idx = set([ii for ii, c in enumerate(
doc_name_idx = set([ii for ii, c in enumerate(
tbl["columns"]) if c["name"] == "docnm_kwd"])
clmn_idx = [ii for ii in range(
len(tbl["columns"])) if ii not in (docid_idx | docnm_idx)]
column_idx = [ii for ii in range(
len(tbl["columns"])) if ii not in (docid_idx | doc_name_idx)]

# compose markdown table
clmns = "|" + "|".join([re.sub(r"(/.*|([^()]+))", "", field_map.get(tbl["columns"][i]["name"],
tbl["columns"][i]["name"])) for i in
clmn_idx]) + ("|Source|" if docid_idx and docid_idx else "|")
# compose Markdown table
columns = "|" + "|".join([re.sub(r"(/.*|([^()]+))", "", field_map.get(tbl["columns"][i]["name"],
tbl["columns"][i]["name"])) for i in
column_idx]) + ("|Source|" if docid_idx and docid_idx else "|")

line = "|" + "|".join(["------" for _ in range(len(clmn_idx))]) + \
line = "|" + "|".join(["------" for _ in range(len(column_idx))]) + \
("|------|" if docid_idx and docid_idx else "")

rows = ["|" +
"|".join([rmSpace(str(r[i])) for i in clmn_idx]).replace("None", " ") +
"|".join([rmSpace(str(r[i])) for i in column_idx]).replace("None", " ") +
"|" for r in tbl["rows"]]
rows = [r for r in rows if re.sub(r"[ |]+", "", r)]
if quota:
Expand All @@ -404,24 +440,24 @@ def get_table():
rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)])
rows = re.sub(r"T[0-9]{2}:[0-9]{2}:[0-9]{2}(\.[0-9]+Z)?\|", "|", rows)

if not docid_idx or not docnm_idx:
if not docid_idx or not doc_name_idx:
logging.warning("SQL missing field: " + sql)
return {
"answer": "\n".join([clmns, line, rows]),
"answer": "\n".join([columns, line, rows]),
"reference": {"chunks": [], "doc_aggs": []},
"prompt": sys_prompt
}

docid_idx = list(docid_idx)[0]
docnm_idx = list(docnm_idx)[0]
doc_name_idx = list(doc_name_idx)[0]
doc_aggs = {}
for r in tbl["rows"]:
if r[docid_idx] not in doc_aggs:
doc_aggs[r[docid_idx]] = {"doc_name": r[docnm_idx], "count": 0}
doc_aggs[r[docid_idx]] = {"doc_name": r[doc_name_idx], "count": 0}
doc_aggs[r[docid_idx]]["count"] += 1
return {
"answer": "\n".join([clmns, line, rows]),
"reference": {"chunks": [{"doc_id": r[docid_idx], "docnm_kwd": r[docnm_idx]} for r in tbl["rows"]],
"answer": "\n".join([columns, line, rows]),
"reference": {"chunks": [{"doc_id": r[docid_idx], "docnm_kwd": r[doc_name_idx]} for r in tbl["rows"]],
"doc_aggs": [{"doc_id": did, "doc_name": d["doc_name"], "count": d["count"]} for did, d in
doc_aggs.items()]},
"prompt": sys_prompt
Expand Down Expand Up @@ -492,7 +528,7 @@ def keyword_extraction(chat_mdl, content, topn=3):
kwd = chat_mdl.chat(prompt, msg[1:], {"temperature": 0.2})
if isinstance(kwd, tuple):
kwd = kwd[0]
if kwd.find("**ERROR**") >=0:
if kwd.find("**ERROR**") >= 0:
return ""
return kwd

Expand Down Expand Up @@ -605,16 +641,16 @@ def tts(tts_mdl, text):

def ask(question, kb_ids, tenant_id):
kbs = KnowledgebaseService.get_by_ids(kb_ids)
embd_nms = list(set([kb.embd_id for kb in kbs]))
embedding_list = list(set([kb.embd_id for kb in kbs]))

is_kg = all([kb.parser_id == ParserType.KG for kb in kbs])
retr = settings.retrievaler if not is_kg else settings.kg_retrievaler
is_knowledge_graph = all([kb.parser_id == ParserType.KG for kb in kbs])
retriever = settings.retrievaler if not is_knowledge_graph else settings.kg_retrievaler

embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embd_nms[0])
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embedding_list[0])
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT)
max_tokens = chat_mdl.max_length
tenant_ids = list(set([kb.tenant_id for kb in kbs]))
kbinfos = retr.retrieval(question, embd_mdl, tenant_ids, kb_ids, 1, 12, 0.1, 0.3, aggs=False)
kbinfos = retriever.retrieval(question, embd_mdl, tenant_ids, kb_ids, 1, 12, 0.1, 0.3, aggs=False)
knowledges = kb_prompt(kbinfos, max_tokens)
prompt = """
Role: You're a smart assistant. Your name is Miss R.
Expand All @@ -636,14 +672,14 @@ def ask(question, kb_ids, tenant_id):

def decorate_answer(answer):
nonlocal knowledges, kbinfos, prompt
answer, idx = retr.insert_citations(answer,
[ck["content_ltks"]
for ck in kbinfos["chunks"]],
[ck["vector"]
for ck in kbinfos["chunks"]],
embd_mdl,
tkweight=0.7,
vtweight=0.3)
answer, idx = retriever.insert_citations(answer,
[ck["content_ltks"]
for ck in kbinfos["chunks"]],
[ck["vector"]
for ck in kbinfos["chunks"]],
embd_mdl,
tkweight=0.7,
vtweight=0.3)
idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
recall_docs = [
d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
Expand All @@ -664,4 +700,3 @@ def decorate_answer(answer):
answer = ans
yield {"answer": answer, "reference": {}}
yield decorate_answer(answer)

Loading