Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the bug that the agent could not find the context #3795

Merged
merged 1 commit into from
Dec 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 28 additions & 28 deletions api/apps/sdk/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@

@manager.route('/chats/<chat_id>/sessions', methods=['POST'])
@token_required
def create(tenant_id, chat_id):
def create(tenant_id,chat_id):
req = request.json
req["dialog_id"] = chat_id
dia = DialogService.query(tenant_id=tenant_id, id=req["dialog_id"], status=StatusEnum.VALID.value)
Expand Down Expand Up @@ -77,9 +77,10 @@ def create_agent_session(tenant_id, agent_id):
conv = {
"id": get_uuid(),
"dialog_id": cvs.id,
"user_id": req.get("usr_id", "") if isinstance(req, dict) else "",
"user_id": req.get("usr_id","") if isinstance(req, dict) else "",
"message": [{"role": "assistant", "content": canvas.get_prologue()}],
"source": "agent"
"source": "agent",
"dsl":json.loads(cvs.dsl)
}
API4ConversationService.save(**conv)
conv["agent_id"] = conv.pop("dialog_id")
Expand All @@ -88,11 +89,11 @@ def create_agent_session(tenant_id, agent_id):

@manager.route('/chats/<chat_id>/sessions/<session_id>', methods=['PUT'])
@token_required
def update(tenant_id, chat_id, session_id):
def update(tenant_id,chat_id,session_id):
req = request.json
req["dialog_id"] = chat_id
conv_id = session_id
conv = ConversationService.query(id=conv_id, dialog_id=chat_id)
conv = ConversationService.query(id=conv_id,dialog_id=chat_id)
if not conv:
return get_error_data_result(message="Session does not exist")
if not DialogService.query(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value):
Expand Down Expand Up @@ -123,12 +124,12 @@ def completion(tenant_id, chat_id):
return get_error_data_result(message="`name` can not be empty.")
ConversationService.save(**conv)
e, conv = ConversationService.get_by_id(conv["id"])
session_id = conv.id
session_id=conv.id
else:
session_id = req.get("session_id")
if not req.get("question"):
return get_error_data_result(message="Please input your question.")
conv = ConversationService.query(id=session_id, dialog_id=chat_id)
conv = ConversationService.query(id=session_id,dialog_id=chat_id)
if not conv:
return get_error_data_result(message="Session does not exist")
conv = conv[0]
Expand Down Expand Up @@ -182,18 +183,18 @@ def fillin_conv(ans):
chunk_list.append(new_chunk)
reference["chunks"] = chunk_list
ans["id"] = message_id
ans["session_id"] = session_id
ans["session_id"]=session_id

def stream():
nonlocal dia, msg, req, conv
try:
for ans in chat(dia, msg, **req):
fillin_conv(ans)
yield "data:" + json.dumps({"code": 0, "data": ans}, ensure_ascii=False) + "\n\n"
yield "data:" + json.dumps({"code": 0, "data": ans}, ensure_ascii=False) + "\n\n"
ConversationService.update_by_id(conv.id, conv.to_dict())
except Exception as e:
yield "data:" + json.dumps({"code": 500, "message": str(e),
"data": {"answer": "**ERROR**: " + str(e), "reference": []}},
"data": {"answer": "**ERROR**: " + str(e),"reference": []}},
ensure_ascii=False) + "\n\n"
yield "data:" + json.dumps({"code": 0, "data": True}, ensure_ascii=False) + "\n\n"

Expand Down Expand Up @@ -237,7 +238,8 @@ def agent_completion(tenant_id, agent_id):
"dialog_id": cvs.id,
"user_id": req.get("user_id", ""),
"message": [{"role": "assistant", "content": canvas.get_prologue()}],
"source": "agent"
"source": "agent",
"dsl": json.loads(cvs.dsl)
}
API4ConversationService.save(**conv)
conv = API4Conversation(**conv)
Expand All @@ -246,6 +248,7 @@ def agent_completion(tenant_id, agent_id):
e, conv = API4ConversationService.get_by_id(req["session_id"])
if not e:
return get_error_data_result(message="Session not found!")
canvas = Canvas(json.dumps(conv.dsl), tenant_id)

messages = conv.message
question = req.get("question")
Expand All @@ -267,11 +270,11 @@ def agent_completion(tenant_id, agent_id):
if not msg[-1].get("id"): msg[-1]["id"] = get_uuid()
message_id = msg[-1]["id"]

if "quote" not in req: req["quote"] = False
stream = req.get("stream", True)

def fillin_conv(ans):
reference = ans["reference"]
print(reference,flush=True)
temp_reference = deepcopy(ans["reference"])
nonlocal conv, message_id
if not conv.reference:
Expand Down Expand Up @@ -322,7 +325,7 @@ def rename_field(ans):
def sse():
nonlocal answer, cvs
try:
for ans in canvas.run(stream=True):
for ans in canvas.run(stream=stream):
if ans.get("running_status"):
yield "data:" + json.dumps({"code": 0, "message": "",
"data": {"answer": ans["content"],
Expand All @@ -341,10 +344,10 @@ def sse():
canvas.history.append(("assistant", final_ans["content"]))
if final_ans.get("reference"):
canvas.reference.append(final_ans["reference"])
cvs.dsl = json.loads(str(canvas))
conv.dsl = json.loads(str(canvas))
API4ConversationService.append_message(conv.id, conv.to_dict())
except Exception as e:
cvs.dsl = json.loads(str(canvas))
conv.dsl = json.loads(str(canvas))
API4ConversationService.append_message(conv.id, conv.to_dict())
yield "data:" + json.dumps({"code": 500, "message": str(e),
"data": {"answer": "**ERROR**: " + str(e), "reference": []}},
Expand All @@ -364,18 +367,17 @@ def sse():
canvas.messages.append({"role": "assistant", "content": final_ans["content"], "id": message_id})
if final_ans.get("reference"):
canvas.reference.append(final_ans["reference"])
cvs.dsl = json.loads(str(canvas))
conv.dsl = json.loads(str(canvas))

result = {"answer": final_ans["content"], "reference": final_ans.get("reference", [])}
fillin_conv(result)
API4ConversationService.append_message(conv.id, conv.to_dict())
rename_field(result)
return get_result(data=result)


@manager.route('/chats/<chat_id>/sessions', methods=['GET'])
@token_required
def list_session(chat_id, tenant_id):
def list_session(chat_id,tenant_id):
if not DialogService.query(tenant_id=tenant_id, id=chat_id, status=StatusEnum.VALID.value):
return get_error_data_result(message=f"You don't own the assistant {chat_id}.")
id = request.args.get("id")
Expand All @@ -387,7 +389,7 @@ def list_session(chat_id, tenant_id):
desc = False
else:
desc = True
convs = ConversationService.get_list(chat_id, page_number, items_per_page, orderby, desc, id, name)
convs = ConversationService.get_list(chat_id,page_number,items_per_page,orderby,desc,id,name)
if not convs:
return get_result(data=[])
for conv in convs:
Expand Down Expand Up @@ -429,30 +431,29 @@ def list_session(chat_id, tenant_id):

@manager.route('/chats/<chat_id>/sessions', methods=["DELETE"])
@token_required
def delete(tenant_id, chat_id):
def delete(tenant_id,chat_id):
if not DialogService.query(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value):
return get_error_data_result(message="You don't own the chat")
req = request.json
convs = ConversationService.query(dialog_id=chat_id)
if not req:
ids = None
else:
ids = req.get("ids")
ids=req.get("ids")

if not ids:
conv_list = []
for conv in convs:
conv_list.append(conv.id)
else:
conv_list = ids
conv_list=ids
for id in conv_list:
conv = ConversationService.query(id=id, dialog_id=chat_id)
conv = ConversationService.query(id=id,dialog_id=chat_id)
if not conv:
return get_error_data_result(message="The chat doesn't own the session")
ConversationService.delete_by_id(id)
return get_result()


@manager.route('/sessions/ask', methods=['POST'])
@token_required
def ask_about(tenant_id):
Expand All @@ -461,18 +462,17 @@ def ask_about(tenant_id):
return get_error_data_result("`question` is required.")
if not req.get("dataset_ids"):
return get_error_data_result("`dataset_ids` is required.")
if not isinstance(req.get("dataset_ids"), list):
if not isinstance(req.get("dataset_ids"),list):
return get_error_data_result("`dataset_ids` should be a list.")
req["kb_ids"] = req.pop("dataset_ids")
req["kb_ids"]=req.pop("dataset_ids")
for kb_id in req["kb_ids"]:
if not KnowledgebaseService.accessible(kb_id, tenant_id):
if not KnowledgebaseService.accessible(kb_id,tenant_id):
return get_error_data_result(f"You don't own the dataset {kb_id}.")
kbs = KnowledgebaseService.query(id=kb_id)
kb = kbs[0]
if kb.chunk_num == 0:
return get_error_data_result(f"The dataset {kb_id} doesn't own parsed file")
uid = tenant_id

def stream():
nonlocal req, uid
try:
Expand Down
9 changes: 8 additions & 1 deletion api/db/db_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -947,7 +947,7 @@ class API4Conversation(DataBaseModel):
reference = JSONField(null=True, default=[])
tokens = IntegerField(default=0)
source = CharField(max_length=16, null=True, help_text="none|agent|dialog", index=True)

dsl = JSONField(null=True, default={})
duration = FloatField(default=0, index=True)
round = IntegerField(default=0, index=True)
thumb_up = IntegerField(default=0, index=True)
Expand Down Expand Up @@ -1070,3 +1070,10 @@ def migrate_db():
)
except Exception:
pass
try:
migrate(
migrator.add_column("api_4_conversation","dsl",JSONField(null=True, default={}))
)
except Exception:
pass