Skip to content

Commit

Permalink
feat: 支持工作流ai对话节点添加节点上下文 (#1791)
Browse files Browse the repository at this point in the history
  • Loading branch information
shaohuzhang1 authored Dec 9, 2024
1 parent 5c64d63 commit f65546a
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ class ChatNodeSerializer(serializers.Serializer):

model_params_setting = serializers.DictField(required=False, error_messages=ErrMessage.integer("模型参数相关设置"))

dialogue_type = serializers.CharField(required=True, error_messages=ErrMessage.char("上下文类型"))


class IChatNode(INode):
type = 'ai-chat-node'
Expand All @@ -39,5 +41,6 @@ def _run(self):
def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id,
chat_record_id,
model_params_setting=None,
dialogue_type=None,
**kwargs) -> NodeResult:
pass
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from django.db.models import QuerySet
from langchain.schema import HumanMessage, SystemMessage
from langchain_core.messages import BaseMessage
from langchain_core.messages import BaseMessage, AIMessage

from application.flow.i_step_node import NodeResult, INode
from application.flow.step_node.ai_chat_step_node.i_chat_node import IChatNode
Expand Down Expand Up @@ -72,6 +72,22 @@ def get_default_model_params_setting(model_id):
return model_params_setting


def get_node_message(chat_record, runtime_node_id):
node_details = chat_record.get_node_details_runtime_node_id(runtime_node_id)
if node_details is None:
return []
return [HumanMessage(node_details.get('question')), AIMessage(node_details.get('answer'))]


def get_workflow_message(chat_record):
return [chat_record.get_human_message(), chat_record.get_ai_message()]


def get_message(chat_record, dialogue_type, runtime_node_id):
return get_node_message(chat_record, runtime_node_id) if dialogue_type == 'NODE' else get_workflow_message(
chat_record)


class BaseChatNode(IChatNode):
def save_context(self, details, workflow_manage):
self.context['answer'] = details.get('answer')
Expand All @@ -80,12 +96,17 @@ def save_context(self, details, workflow_manage):

def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, chat_record_id,
model_params_setting=None,
dialogue_type=None,
**kwargs) -> NodeResult:
if dialogue_type is None:
dialogue_type = 'WORKFLOW'

if model_params_setting is None:
model_params_setting = get_default_model_params_setting(model_id)
chat_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'),
**model_params_setting)
history_message = self.get_history_message(history_chat_record, dialogue_number)
history_message = self.get_history_message(history_chat_record, dialogue_number, dialogue_type,
self.runtime_node_id)
self.context['history_message'] = history_message
question = self.generate_prompt_question(prompt)
self.context['question'] = question.content
Expand All @@ -103,10 +124,10 @@ def execute(self, model_id, system, prompt, dialogue_number, history_chat_record
_write_context=write_context)

@staticmethod
def get_history_message(history_chat_record, dialogue_number):
def get_history_message(history_chat_record, dialogue_number, dialogue_type, runtime_node_id):
start_index = len(history_chat_record) - dialogue_number
history_message = reduce(lambda x, y: [*x, *y], [
[history_chat_record[index].get_human_message(), history_chat_record[index].get_ai_message()]
get_message(history_chat_record[index], dialogue_type, runtime_node_id)
for index in
range(start_index if start_index > 0 else 0, len(history_chat_record))], [])
return history_message
Expand Down
3 changes: 3 additions & 0 deletions apps/application/models/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,5 +167,8 @@ def get_human_message(self):
def get_ai_message(self):
return AIMessage(content=self.answer_text)

def get_node_details_runtime_node_id(self, runtime_node_id):
return self.details.get(runtime_node_id, None)

class Meta:
db_table = "application_chat_record"
9 changes: 7 additions & 2 deletions ui/src/workflow/common/NodeContainer.vue
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,8 @@
v-if="showAnchor"
@mousemove.stop
@mousedown.stop
@keydown.stop
@click.stop
@wheel.stop
@wheel="handleWheel"
:show="showAnchor"
:id="id"
style="left: 100%; top: 50%; transform: translate(0, -50%)"
Expand Down Expand Up @@ -142,6 +141,12 @@ const showNode = computed({
return true
}
})
const handleWheel = (event: any) => {
const isCombinationKeyPressed = event.ctrlKey || event.metaKey
if (!isCombinationKeyPressed) {
event.stopPropagation()
}
}
const node_status = computed(() => {
if (props.nodeModel.properties.status) {
return props.nodeModel.properties.status
Expand Down
12 changes: 11 additions & 1 deletion ui/src/workflow/nodes/ai-chat-node/index.vue
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,15 @@
/>
</el-form-item>
<el-form-item label="历史聊天记录">
<template #label>
<div class="flex-between">
<div>历史聊天记录</div>
<el-select v-model="chat_data.dialogue_type" type="small" style="width: 100px">
<el-option label="节点" value="NODE" />
<el-option label="工作流" value="WORKFLOW" />
</el-select>
</div>
</template>
<el-input-number
v-model="chat_data.dialogue_number"
:min="0"
Expand Down Expand Up @@ -246,7 +255,8 @@ const form = {
dialogue_number: 1,
is_result: false,
temperature: null,
max_tokens: null
max_tokens: null,
dialogue_type: 'WORKFLOW'
}
const chat_data = computed({
Expand Down

0 comments on commit f65546a

Please sign in to comment.