Skip to content

Commit

Permalink
feat: Support reasoning content #2135 (#2158)
Browse files Browse the repository at this point in the history
  • Loading branch information
liuruibin authored Feb 8, 2025
1 parent 0cc1d00 commit 061a41c
Show file tree
Hide file tree
Showing 32 changed files with 816 additions and 166 deletions.
11 changes: 8 additions & 3 deletions apps/application/chat_pipeline/step/chat_step/i_chat_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,16 +65,21 @@ class InstanceSerializer(serializers.Serializer):
post_response_handler = InstanceField(model_type=PostResponseHandler,
error_messages=ErrMessage.base(_("Post-processor")))
# 补全问题
padding_problem_text = serializers.CharField(required=False, error_messages=ErrMessage.base(_("Completion Question")))
padding_problem_text = serializers.CharField(required=False,
error_messages=ErrMessage.base(_("Completion Question")))
# 是否使用流的形式输出
stream = serializers.BooleanField(required=False, error_messages=ErrMessage.base(_("Streaming Output")))
client_id = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Client id")))
client_type = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Client Type")))
# 未查询到引用分段
no_references_setting = NoReferencesSetting(required=True, error_messages=ErrMessage.base(_("No reference segment settings")))
no_references_setting = NoReferencesSetting(required=True,
error_messages=ErrMessage.base(_("No reference segment settings")))

user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid(_("User ID")))

model_setting = serializers.DictField(required=True, allow_null=True,
error_messages=ErrMessage.dict(_("Model settings")))

model_params_setting = serializers.DictField(required=False, allow_null=True,
error_messages=ErrMessage.dict(_("Model parameter settings")))

Expand All @@ -101,5 +106,5 @@ def execute(self, message_list: List[BaseMessage],
paragraph_list=None,
manage: PipelineManage = None,
padding_problem_text: str = None, stream: bool = True, client_id=None, client_type=None,
no_references_setting=None, model_params_setting=None, **kwargs):
no_references_setting=None, model_params_setting=None, model_setting=None, **kwargs):
pass
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from application.chat_pipeline.I_base_chat_pipeline import ParagraphPipelineModel
from application.chat_pipeline.pipeline_manage import PipelineManage
from application.chat_pipeline.step.chat_step.i_chat_step import IChatStep, PostResponseHandler
from application.flow.tools import Reasoning
from application.models.api_key_model import ApplicationPublicAccessClient
from common.constants.authentication_type import AuthenticationType
from setting.models_provider.tools import get_model_instance_by_model_user_id
Expand Down Expand Up @@ -63,17 +64,37 @@ def event_content(response,
problem_text: str,
padding_problem_text: str = None,
client_id=None, client_type=None,
is_ai_chat: bool = None):
is_ai_chat: bool = None,
model_setting=None):
if model_setting is None:
model_setting = {}
reasoning_content_enable = model_setting.get('reasoning_content_enable', False)
reasoning_content_start = model_setting.get('reasoning_content_start', '<think>')
reasoning_content_end = model_setting.get('reasoning_content_end', '</think>')
reasoning = Reasoning(reasoning_content_start,
reasoning_content_end)
all_text = ''
reasoning_content = ''
try:
for chunk in response:
all_text += chunk.content
reasoning_chunk = reasoning.get_reasoning_content(chunk)
content_chunk = reasoning_chunk.get('content')
if 'reasoning_content' in chunk.additional_kwargs:
reasoning_content_chunk = chunk.additional_kwargs.get('reasoning_content', '')
else:
reasoning_content_chunk = reasoning_chunk.get('reasoning_content')
all_text += content_chunk
if reasoning_content_chunk is None:
reasoning_content_chunk = ''
reasoning_content += reasoning_content_chunk
yield manage.get_base_to_response().to_stream_chunk_response(chat_id, str(chat_record_id), 'ai-chat-node',
[], chunk.content,
[], content_chunk,
False,
0, 0, {'node_is_end': False,
'view_type': 'many_view',
'node_type': 'ai-chat-node'})
'node_type': 'ai-chat-node',
'real_node_id': 'ai-chat-node',
'reasoning_content': reasoning_content_chunk if reasoning_content_enable else ''})
# 获取token
if is_ai_chat:
try:
Expand All @@ -87,7 +108,8 @@ def event_content(response,
response_token = 0
write_context(step, manage, request_token, response_token, all_text)
post_response_handler.handler(chat_id, chat_record_id, paragraph_list, problem_text,
all_text, manage, step, padding_problem_text, client_id)
all_text, manage, step, padding_problem_text, client_id,
reasoning_content=reasoning_content if reasoning_content_enable else '')
yield manage.get_base_to_response().to_stream_chunk_response(chat_id, str(chat_record_id), 'ai-chat-node',
[], '', True,
request_token, response_token,
Expand Down Expand Up @@ -122,17 +144,20 @@ def execute(self, message_list: List[BaseMessage],
client_id=None, client_type=None,
no_references_setting=None,
model_params_setting=None,
model_setting=None,
**kwargs):
chat_model = get_model_instance_by_model_user_id(model_id, user_id,
**model_params_setting) if model_id is not None else None
if stream:
return self.execute_stream(message_list, chat_id, problem_text, post_response_handler, chat_model,
paragraph_list,
manage, padding_problem_text, client_id, client_type, no_references_setting)
manage, padding_problem_text, client_id, client_type, no_references_setting,
model_setting)
else:
return self.execute_block(message_list, chat_id, problem_text, post_response_handler, chat_model,
paragraph_list,
manage, padding_problem_text, client_id, client_type, no_references_setting)
manage, padding_problem_text, client_id, client_type, no_references_setting,
model_setting)

def get_details(self, manage, **kwargs):
return {
Expand Down Expand Up @@ -187,14 +212,15 @@ def execute_stream(self, message_list: List[BaseMessage],
manage: PipelineManage = None,
padding_problem_text: str = None,
client_id=None, client_type=None,
no_references_setting=None):
no_references_setting=None,
model_setting=None):
chat_result, is_ai_chat = self.get_stream_result(message_list, chat_model, paragraph_list,
no_references_setting, problem_text)
chat_record_id = uuid.uuid1()
r = StreamingHttpResponse(
streaming_content=event_content(chat_result, chat_id, chat_record_id, paragraph_list,
post_response_handler, manage, self, chat_model, message_list, problem_text,
padding_problem_text, client_id, client_type, is_ai_chat),
padding_problem_text, client_id, client_type, is_ai_chat, model_setting),
content_type='text/event-stream;charset=utf-8')

r['Cache-Control'] = 'no-cache'
Expand Down Expand Up @@ -230,7 +256,13 @@ def execute_block(self, message_list: List[BaseMessage],
paragraph_list=None,
manage: PipelineManage = None,
padding_problem_text: str = None,
client_id=None, client_type=None, no_references_setting=None):
client_id=None, client_type=None, no_references_setting=None,
model_setting=None):
reasoning_content_enable = model_setting.get('reasoning_content_enable', False)
reasoning_content_start = model_setting.get('reasoning_content_start', '<think>')
reasoning_content_end = model_setting.get('reasoning_content_end', '</think>')
reasoning = Reasoning(reasoning_content_start,
reasoning_content_end)
chat_record_id = uuid.uuid1()
# 调用模型
try:
Expand All @@ -243,14 +275,23 @@ def execute_block(self, message_list: List[BaseMessage],
request_token = 0
response_token = 0
write_context(self, manage, request_token, response_token, chat_result.content)
reasoning.get_reasoning_content(chat_result)
reasoning_result = reasoning.get_reasoning_content(chat_result)
content = reasoning_result.get('content')
if 'reasoning_content' in chat_result.response_metadata:
reasoning_content = chat_result.response_metadata.get('reasoning_content', '')
else:
reasoning_content = reasoning_result.get('reasoning_content')
post_response_handler.handler(chat_id, chat_record_id, paragraph_list, problem_text,
chat_result.content, manage, self, padding_problem_text, client_id)
chat_result.content, manage, self, padding_problem_text, client_id,
reasoning_content=reasoning_content if reasoning_content_enable else '')
add_access_num(client_id, client_type, manage.context.get('application_id'))
return manage.get_base_to_response().to_block_response(str(chat_id), str(chat_record_id),
chat_result.content, True,
request_token, response_token)
content, True,
request_token, response_token,
{'reasoning_content': reasoning_content})
except Exception as e:
all_text = '异常' + str(e)
all_text = 'Exception:' + str(e)
write_context(self, manage, 0, 0, all_text)
post_response_handler.handler(chat_id, chat_record_id, paragraph_list, problem_text,
all_text, manage, self, padding_problem_text, client_id)
Expand Down
10 changes: 8 additions & 2 deletions apps/application/flow/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,22 @@


class Answer:
def __init__(self, content, view_type, runtime_node_id, chat_record_id, child_node):
def __init__(self, content, view_type, runtime_node_id, chat_record_id, child_node, real_node_id,
reasoning_content):
self.view_type = view_type
self.content = content
self.reasoning_content = reasoning_content
self.runtime_node_id = runtime_node_id
self.chat_record_id = chat_record_id
self.child_node = child_node
self.real_node_id = real_node_id

def to_dict(self):
return {'view_type': self.view_type, 'content': self.content, 'runtime_node_id': self.runtime_node_id,
'chat_record_id': self.chat_record_id, 'child_node': self.child_node}
'chat_record_id': self.chat_record_id,
'child_node': self.child_node,
'reasoning_content': self.reasoning_content,
'real_node_id': self.real_node_id}


class NodeChunk:
Expand Down
8 changes: 6 additions & 2 deletions apps/application/flow/i_step_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,9 @@ def handler(self, chat_id,
answer_tokens = sum([row.get('answer_tokens') for row in details.values() if
'answer_tokens' in row and row.get('answer_tokens') is not None])
answer_text_list = workflow.get_answer_text_list()
answer_text = '\n\n'.join(answer['content'] for answer in answer_text_list)
answer_text = '\n\n'.join(
'\n\n'.join([a.get('content') for a in answer]) for answer in
answer_text_list)
if workflow.chat_record is not None:
chat_record = workflow.chat_record
chat_record.answer_text = answer_text
Expand Down Expand Up @@ -157,8 +159,10 @@ def save_context(self, details, workflow_manage):
def get_answer_list(self) -> List[Answer] | None:
if self.answer_text is None:
return None
reasoning_content_enable = self.context.get('model_setting', {}).get('reasoning_content_enable', False)
return [
Answer(self.answer_text, self.view_type, self.runtime_node_id, self.workflow_params['chat_record_id'], {})]
Answer(self.answer_text, self.view_type, self.runtime_node_id, self.workflow_params['chat_record_id'], {},
self.runtime_node_id, self.context.get('reasoning_content', '') if reasoning_content_enable else '')]

def __init__(self, node, workflow_params, workflow_manage, up_node_id_list=None,
get_node_params=lambda node: node.properties.get('node_data')):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@ class ChatNodeSerializer(serializers.Serializer):
error_messages=ErrMessage.boolean(_('Whether to return content')))

model_params_setting = serializers.DictField(required=False,
error_messages=ErrMessage.integer(_("Model parameter settings")))

error_messages=ErrMessage.dict(_("Model parameter settings")))
model_setting = serializers.DictField(required=False,
error_messages=ErrMessage.dict('Model settings'))
dialogue_type = serializers.CharField(required=False, allow_blank=True, allow_null=True,
error_messages=ErrMessage.char(_("Context Type")))

Expand All @@ -47,5 +48,6 @@ def execute(self, model_id, system, prompt, dialogue_number, history_chat_record
chat_record_id,
model_params_setting=None,
dialogue_type=None,
model_setting=None,
**kwargs) -> NodeResult:
pass
Loading

0 comments on commit 061a41c

Please sign in to comment.