From f65546a619cb3fc2ca383fb84480fc4207d0bc5f Mon Sep 17 00:00:00 2001 From: shaohuzhang1 <80892890+shaohuzhang1@users.noreply.github.com> Date: Mon, 9 Dec 2024 11:17:58 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=94=AF=E6=8C=81=E5=B7=A5=E4=BD=9C?= =?UTF-8?q?=E6=B5=81ai=E5=AF=B9=E8=AF=9D=E8=8A=82=E7=82=B9=E6=B7=BB?= =?UTF-8?q?=E5=8A=A0=E8=8A=82=E7=82=B9=E4=B8=8A=E4=B8=8B=E6=96=87=20(#1791?= =?UTF-8?q?)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../ai_chat_step_node/i_chat_node.py | 3 ++ .../ai_chat_step_node/impl/base_chat_node.py | 29 ++++++++++++++++--- apps/application/models/application.py | 3 ++ ui/src/workflow/common/NodeContainer.vue | 9 ++++-- ui/src/workflow/nodes/ai-chat-node/index.vue | 12 +++++++- 5 files changed, 49 insertions(+), 7 deletions(-) diff --git a/apps/application/flow/step_node/ai_chat_step_node/i_chat_node.py b/apps/application/flow/step_node/ai_chat_step_node/i_chat_node.py index b7dfecf6a2b..badc6961a35 100644 --- a/apps/application/flow/step_node/ai_chat_step_node/i_chat_node.py +++ b/apps/application/flow/step_node/ai_chat_step_node/i_chat_node.py @@ -26,6 +26,8 @@ class ChatNodeSerializer(serializers.Serializer): model_params_setting = serializers.DictField(required=False, error_messages=ErrMessage.integer("模型参数相关设置")) + dialogue_type = serializers.CharField(required=True, error_messages=ErrMessage.char("上下文类型")) + class IChatNode(INode): type = 'ai-chat-node' @@ -39,5 +41,6 @@ def _run(self): def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, chat_record_id, model_params_setting=None, + dialogue_type=None, **kwargs) -> NodeResult: pass diff --git a/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py b/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py index d4835f6190e..af68c3131ae 100644 --- a/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py +++ b/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py @@ -12,7 +12,7 @@ from django.db.models import QuerySet from langchain.schema import HumanMessage, SystemMessage -from langchain_core.messages import BaseMessage +from langchain_core.messages import BaseMessage, AIMessage from application.flow.i_step_node import NodeResult, INode from application.flow.step_node.ai_chat_step_node.i_chat_node import IChatNode @@ -72,6 +72,22 @@ def get_default_model_params_setting(model_id): return model_params_setting +def get_node_message(chat_record, runtime_node_id): + node_details = chat_record.get_node_details_runtime_node_id(runtime_node_id) + if node_details is None: + return [] + return [HumanMessage(node_details.get('question')), AIMessage(node_details.get('answer'))] + + +def get_workflow_message(chat_record): + return [chat_record.get_human_message(), chat_record.get_ai_message()] + + +def get_message(chat_record, dialogue_type, runtime_node_id): + return get_node_message(chat_record, runtime_node_id) if dialogue_type == 'NODE' else get_workflow_message( + chat_record) + + class BaseChatNode(IChatNode): def save_context(self, details, workflow_manage): self.context['answer'] = details.get('answer') @@ -80,12 +96,17 @@ def save_context(self, details, workflow_manage): def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, chat_record_id, model_params_setting=None, + dialogue_type=None, **kwargs) -> NodeResult: + if dialogue_type is None: + dialogue_type = 'WORKFLOW' + if model_params_setting is None: model_params_setting = get_default_model_params_setting(model_id) chat_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'), **model_params_setting) - history_message = self.get_history_message(history_chat_record, dialogue_number) + history_message = self.get_history_message(history_chat_record, dialogue_number, dialogue_type, + self.runtime_node_id) self.context['history_message'] = history_message question = self.generate_prompt_question(prompt) self.context['question'] = question.content @@ -103,10 +124,10 @@ def execute(self, model_id, system, prompt, dialogue_number, history_chat_record _write_context=write_context) @staticmethod - def get_history_message(history_chat_record, dialogue_number): + def get_history_message(history_chat_record, dialogue_number, dialogue_type, runtime_node_id): start_index = len(history_chat_record) - dialogue_number history_message = reduce(lambda x, y: [*x, *y], [ - [history_chat_record[index].get_human_message(), history_chat_record[index].get_ai_message()] + get_message(history_chat_record[index], dialogue_type, runtime_node_id) for index in range(start_index if start_index > 0 else 0, len(history_chat_record))], []) return history_message diff --git a/apps/application/models/application.py b/apps/application/models/application.py index 6ed33c48cdb..a05097afa66 100644 --- a/apps/application/models/application.py +++ b/apps/application/models/application.py @@ -167,5 +167,8 @@ def get_human_message(self): def get_ai_message(self): return AIMessage(content=self.answer_text) + def get_node_details_runtime_node_id(self, runtime_node_id): + return self.details.get(runtime_node_id, None) + class Meta: db_table = "application_chat_record" diff --git a/ui/src/workflow/common/NodeContainer.vue b/ui/src/workflow/common/NodeContainer.vue index 7679ea3f7f8..06c1aad2b6d 100644 --- a/ui/src/workflow/common/NodeContainer.vue +++ b/ui/src/workflow/common/NodeContainer.vue @@ -93,9 +93,8 @@ v-if="showAnchor" @mousemove.stop @mousedown.stop - @keydown.stop @click.stop - @wheel.stop + @wheel="handleWheel" :show="showAnchor" :id="id" style="left: 100%; top: 50%; transform: translate(0, -50%)" @@ -142,6 +141,12 @@ const showNode = computed({ return true } }) +const handleWheel = (event: any) => { + const isCombinationKeyPressed = event.ctrlKey || event.metaKey + if (!isCombinationKeyPressed) { + event.stopPropagation() + } +} const node_status = computed(() => { if (props.nodeModel.properties.status) { return props.nodeModel.properties.status diff --git a/ui/src/workflow/nodes/ai-chat-node/index.vue b/ui/src/workflow/nodes/ai-chat-node/index.vue index a8619cdfed6..6071c48c4f3 100644 --- a/ui/src/workflow/nodes/ai-chat-node/index.vue +++ b/ui/src/workflow/nodes/ai-chat-node/index.vue @@ -148,6 +148,15 @@ /> +