From 8e78033893e2371d875da3dfced90ed32a44df0a Mon Sep 17 00:00:00 2001 From: liuhua <10215101452@stu.ecun.edu.cn> Date: Mon, 2 Dec 2024 14:51:15 +0800 Subject: [PATCH] Fix the bug that the agent could not find the context --- api/apps/sdk/session.py | 56 ++++++++++++++++++++--------------------- api/db/db_models.py | 9 ++++++- 2 files changed, 36 insertions(+), 29 deletions(-) diff --git a/api/apps/sdk/session.py b/api/apps/sdk/session.py index 551f9f847a1..59606861d69 100644 --- a/api/apps/sdk/session.py +++ b/api/apps/sdk/session.py @@ -35,7 +35,7 @@ @manager.route('/chats//sessions', methods=['POST']) @token_required -def create(tenant_id, chat_id): +def create(tenant_id,chat_id): req = request.json req["dialog_id"] = chat_id dia = DialogService.query(tenant_id=tenant_id, id=req["dialog_id"], status=StatusEnum.VALID.value) @@ -77,9 +77,10 @@ def create_agent_session(tenant_id, agent_id): conv = { "id": get_uuid(), "dialog_id": cvs.id, - "user_id": req.get("usr_id", "") if isinstance(req, dict) else "", + "user_id": req.get("usr_id","") if isinstance(req, dict) else "", "message": [{"role": "assistant", "content": canvas.get_prologue()}], - "source": "agent" + "source": "agent", + "dsl":json.loads(cvs.dsl) } API4ConversationService.save(**conv) conv["agent_id"] = conv.pop("dialog_id") @@ -88,11 +89,11 @@ def create_agent_session(tenant_id, agent_id): @manager.route('/chats//sessions/', methods=['PUT']) @token_required -def update(tenant_id, chat_id, session_id): +def update(tenant_id,chat_id,session_id): req = request.json req["dialog_id"] = chat_id conv_id = session_id - conv = ConversationService.query(id=conv_id, dialog_id=chat_id) + conv = ConversationService.query(id=conv_id,dialog_id=chat_id) if not conv: return get_error_data_result(message="Session does not exist") if not DialogService.query(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value): @@ -123,12 +124,12 @@ def completion(tenant_id, chat_id): return get_error_data_result(message="`name` can not be empty.") ConversationService.save(**conv) e, conv = ConversationService.get_by_id(conv["id"]) - session_id = conv.id + session_id=conv.id else: session_id = req.get("session_id") if not req.get("question"): return get_error_data_result(message="Please input your question.") - conv = ConversationService.query(id=session_id, dialog_id=chat_id) + conv = ConversationService.query(id=session_id,dialog_id=chat_id) if not conv: return get_error_data_result(message="Session does not exist") conv = conv[0] @@ -182,18 +183,18 @@ def fillin_conv(ans): chunk_list.append(new_chunk) reference["chunks"] = chunk_list ans["id"] = message_id - ans["session_id"] = session_id + ans["session_id"]=session_id def stream(): nonlocal dia, msg, req, conv try: for ans in chat(dia, msg, **req): fillin_conv(ans) - yield "data:" + json.dumps({"code": 0, "data": ans}, ensure_ascii=False) + "\n\n" + yield "data:" + json.dumps({"code": 0, "data": ans}, ensure_ascii=False) + "\n\n" ConversationService.update_by_id(conv.id, conv.to_dict()) except Exception as e: yield "data:" + json.dumps({"code": 500, "message": str(e), - "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, + "data": {"answer": "**ERROR**: " + str(e),"reference": []}}, ensure_ascii=False) + "\n\n" yield "data:" + json.dumps({"code": 0, "data": True}, ensure_ascii=False) + "\n\n" @@ -237,7 +238,8 @@ def agent_completion(tenant_id, agent_id): "dialog_id": cvs.id, "user_id": req.get("user_id", ""), "message": [{"role": "assistant", "content": canvas.get_prologue()}], - "source": "agent" + "source": "agent", + "dsl": json.loads(cvs.dsl) } API4ConversationService.save(**conv) conv = API4Conversation(**conv) @@ -246,6 +248,7 @@ def agent_completion(tenant_id, agent_id): e, conv = API4ConversationService.get_by_id(req["session_id"]) if not e: return get_error_data_result(message="Session not found!") + canvas = Canvas(json.dumps(conv.dsl), tenant_id) messages = conv.message question = req.get("question") @@ -267,11 +270,11 @@ def agent_completion(tenant_id, agent_id): if not msg[-1].get("id"): msg[-1]["id"] = get_uuid() message_id = msg[-1]["id"] - if "quote" not in req: req["quote"] = False stream = req.get("stream", True) def fillin_conv(ans): reference = ans["reference"] + print(reference,flush=True) temp_reference = deepcopy(ans["reference"]) nonlocal conv, message_id if not conv.reference: @@ -322,7 +325,7 @@ def rename_field(ans): def sse(): nonlocal answer, cvs try: - for ans in canvas.run(stream=True): + for ans in canvas.run(stream=stream): if ans.get("running_status"): yield "data:" + json.dumps({"code": 0, "message": "", "data": {"answer": ans["content"], @@ -341,10 +344,10 @@ def sse(): canvas.history.append(("assistant", final_ans["content"])) if final_ans.get("reference"): canvas.reference.append(final_ans["reference"]) - cvs.dsl = json.loads(str(canvas)) + conv.dsl = json.loads(str(canvas)) API4ConversationService.append_message(conv.id, conv.to_dict()) except Exception as e: - cvs.dsl = json.loads(str(canvas)) + conv.dsl = json.loads(str(canvas)) API4ConversationService.append_message(conv.id, conv.to_dict()) yield "data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, @@ -364,7 +367,7 @@ def sse(): canvas.messages.append({"role": "assistant", "content": final_ans["content"], "id": message_id}) if final_ans.get("reference"): canvas.reference.append(final_ans["reference"]) - cvs.dsl = json.loads(str(canvas)) + conv.dsl = json.loads(str(canvas)) result = {"answer": final_ans["content"], "reference": final_ans.get("reference", [])} fillin_conv(result) @@ -372,10 +375,9 @@ def sse(): rename_field(result) return get_result(data=result) - @manager.route('/chats//sessions', methods=['GET']) @token_required -def list_session(chat_id, tenant_id): +def list_session(chat_id,tenant_id): if not DialogService.query(tenant_id=tenant_id, id=chat_id, status=StatusEnum.VALID.value): return get_error_data_result(message=f"You don't own the assistant {chat_id}.") id = request.args.get("id") @@ -387,7 +389,7 @@ def list_session(chat_id, tenant_id): desc = False else: desc = True - convs = ConversationService.get_list(chat_id, page_number, items_per_page, orderby, desc, id, name) + convs = ConversationService.get_list(chat_id,page_number,items_per_page,orderby,desc,id,name) if not convs: return get_result(data=[]) for conv in convs: @@ -429,7 +431,7 @@ def list_session(chat_id, tenant_id): @manager.route('/chats//sessions', methods=["DELETE"]) @token_required -def delete(tenant_id, chat_id): +def delete(tenant_id,chat_id): if not DialogService.query(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value): return get_error_data_result(message="You don't own the chat") req = request.json @@ -437,22 +439,21 @@ def delete(tenant_id, chat_id): if not req: ids = None else: - ids = req.get("ids") + ids=req.get("ids") if not ids: conv_list = [] for conv in convs: conv_list.append(conv.id) else: - conv_list = ids + conv_list=ids for id in conv_list: - conv = ConversationService.query(id=id, dialog_id=chat_id) + conv = ConversationService.query(id=id,dialog_id=chat_id) if not conv: return get_error_data_result(message="The chat doesn't own the session") ConversationService.delete_by_id(id) return get_result() - @manager.route('/sessions/ask', methods=['POST']) @token_required def ask_about(tenant_id): @@ -461,18 +462,17 @@ def ask_about(tenant_id): return get_error_data_result("`question` is required.") if not req.get("dataset_ids"): return get_error_data_result("`dataset_ids` is required.") - if not isinstance(req.get("dataset_ids"), list): + if not isinstance(req.get("dataset_ids"),list): return get_error_data_result("`dataset_ids` should be a list.") - req["kb_ids"] = req.pop("dataset_ids") + req["kb_ids"]=req.pop("dataset_ids") for kb_id in req["kb_ids"]: - if not KnowledgebaseService.accessible(kb_id, tenant_id): + if not KnowledgebaseService.accessible(kb_id,tenant_id): return get_error_data_result(f"You don't own the dataset {kb_id}.") kbs = KnowledgebaseService.query(id=kb_id) kb = kbs[0] if kb.chunk_num == 0: return get_error_data_result(f"The dataset {kb_id} doesn't own parsed file") uid = tenant_id - def stream(): nonlocal req, uid try: diff --git a/api/db/db_models.py b/api/db/db_models.py index f57309f0af3..b90f06f67cb 100644 --- a/api/db/db_models.py +++ b/api/db/db_models.py @@ -947,7 +947,7 @@ class API4Conversation(DataBaseModel): reference = JSONField(null=True, default=[]) tokens = IntegerField(default=0) source = CharField(max_length=16, null=True, help_text="none|agent|dialog", index=True) - + dsl = JSONField(null=True, default={}) duration = FloatField(default=0, index=True) round = IntegerField(default=0, index=True) thumb_up = IntegerField(default=0, index=True) @@ -1070,3 +1070,10 @@ def migrate_db(): ) except Exception: pass + try: + migrate( + migrator.add_column("api_4_conversation","dsl",JSONField(null=True, default={})) + ) + except Exception: + pass +