Skip to content

Commit

Permalink
Fix guardrail and error handling (#392)
Browse files Browse the repository at this point in the history
  • Loading branch information
moria97 authored Feb 18, 2025
1 parent 765067a commit 490bd40
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 93 deletions.
203 changes: 112 additions & 91 deletions src/pai_rag/core/rag_application.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def _make_chat_completion_response(session_id, response):
)


async def _make_chat_completion_response_with_text(session_id, text):
def _make_chat_completion_response_with_text(session_id, text):
logger.info(f"Finished response: {text}")
return ChatCompletion(
id=session_id,
Expand Down Expand Up @@ -253,7 +253,7 @@ async def _make_chat_completion_chunk_response_with_text(session_id, text):
role=MessageRole.ASSISTANT.value,
content=text,
),
finish_reason=None,
finish_reason="stop",
)
],
object="chat.completion.chunk",
Expand Down Expand Up @@ -342,24 +342,74 @@ async def achat(
self,
chat_request: ChatCompletionRequest,
):
if len(chat_request.messages) == 0:
raise Exception("消息列表为空.")
session_id = uuid_generator()

guardrail = resolve_llm_guardrail(self.config)
passed_guardrail = False if guardrail is not None else True
if len(chat_request.messages) == 0:
if chat_request.stream:
return _make_chat_completion_chunk_response_with_text(
session_id, "看起来你问了一个空问题,请问有什么能帮忙的吗?"
)
else:
return _make_chat_completion_response_with_text(
session_id, "看起来你问了一个空问题,请问有什么能帮忙的吗?"
)

try:
guardrail = resolve_llm_guardrail(self.config)
passed_guardrail = False if guardrail is not None else True

messages = chat_request.messages
system_prompt = None
if messages[0].role == MessageRole.SYSTEM:
system_prompt = messages[0].content
messages = messages[1:]

if not passed_guardrail:
user_messages = [
msg for msg in messages if msg.role == MessageRole.USER
]
# 只有一条对话,直接检查
if len(user_messages) == 1:
guardrail_result = await guardrail.acheck(user_messages[0].content)
if guardrail_result.reject:
if chat_request.stream:
return _make_chat_completion_chunk_response_with_text(
session_id, guardrail_result.advice
)
else:
return _make_chat_completion_response_with_text(
session_id, guardrail_result.advice
)
passed_guardrail = True

if self.config.system.default_web_search:
chat_request.search_web = True
session_config = self.config.model_copy()
index_entry = index_manager.get_index_by_name(chat_request.index_name)
session_config.embedding = index_entry.embedding_config
session_config.index.vector_store = index_entry.vector_store_config

question = messages[-1].content
chat_history = []
for msg in messages[:-1]:
if msg.role == MessageRole.USER:
role = "user"
else:
role = "bot"
chat_history.append({role: msg.content})

if not question:
return RagResponse(answer="请输入您的消息.", session_id=session_id)

messages = chat_request.messages
system_prompt = None
if messages[0].role == MessageRole.SYSTEM:
system_prompt = messages[0].content
messages = messages[1:]
openai_query_transform = resolve_openai_query_transform(session_config)
new_query_bundle = await openai_query_transform.arun(
chat_messages=messages,
)

if not passed_guardrail:
user_messages = [msg for msg in messages if msg.role == MessageRole.USER]
# 只有一条对话,直接检查
if len(user_messages) == 1:
guardrail_result = await guardrail.acheck(user_messages[0].content)
new_question = new_query_bundle.query_str
if not passed_guardrail:
# 多轮对话,用新查询检查
guardrail_result = await guardrail.acheck(new_question)
if guardrail_result.reject:
if chat_request.stream:
return _make_chat_completion_chunk_response_with_text(
Expand All @@ -371,73 +421,50 @@ async def achat(
)
passed_guardrail = True

if self.config.system.default_web_search:
chat_request.search_web = True
session_config = self.config.model_copy()
index_entry = index_manager.get_index_by_name(chat_request.index_name)
session_config.embedding = index_entry.embedding_config
session_config.index.vector_store = index_entry.vector_store_config
logger.info(f"Querying with question '{new_question}'.")
messages[-1].content = ",".join([question, new_question])

question = messages[-1].content
chat_history = []
for msg in messages[:-1]:
if msg.role == MessageRole.USER:
role = "user"
else:
role = "bot"
chat_history.append({role: msg.content})

if not question:
return RagResponse(answer="请输入您的消息.", session_id=session_id)
query_bundle = PaiQueryBundle(
query_str=new_question,
stream=chat_request.stream,
citation=chat_request.citation,
need_web_search=new_query_bundle.need_web_search,
chat_messages_str=messages_to_history_str(messages=messages[-8:]),
)

openai_query_transform = resolve_openai_query_transform(session_config)
new_query_bundle = await openai_query_transform.arun(
chat_messages=messages,
)
if chat_request.force_no_search:
chat_request.search_web = True
query_bundle.need_web_search = False
elif chat_request.force_search_web:
chat_request.search_web = True
query_bundle.need_web_search = True
elif chat_request.force_search_knowledgebase:
chat_request.search_web = False

if chat_request.search_web:
search_engine = resolve_searcher(session_config)
if not search_engine:
raise ValueError(
"AI search config is not valid. Please check your search api configuration."
)

new_question = new_query_bundle.query_str
if not passed_guardrail:
# 多轮对话,用新查询检查
guardrail_result = await guardrail.acheck(new_question)
if guardrail_result.reject:
response = await search_engine.aquery(
query_bundle,
system_role_str=system_prompt,
prompt_template_str=" " if system_prompt else None,
)
if chat_request.stream:
return _make_chat_completion_chunk_response_with_text(
session_id, guardrail_result.advice
return _make_chat_completion_chunk_response(
session_id=session_id,
response=response,
)
else:
return _make_chat_completion_response_with_text(
session_id, guardrail_result.advice
return _make_chat_completion_response(
session_id=session_id, response=response
)
passed_guardrail = True

logger.info(f"Querying with question '{new_question}'.")
messages[-1].content = ",".join([question, new_question])

query_bundle = PaiQueryBundle(
query_str=new_question,
stream=chat_request.stream,
citation=chat_request.citation,
need_web_search=new_query_bundle.need_web_search,
chat_messages_str=messages_to_history_str(messages=messages[-8:]),
)

if chat_request.force_no_search:
chat_request.search_web = True
query_bundle.need_web_search = False
elif chat_request.force_search_web:
chat_request.search_web = True
query_bundle.need_web_search = True
elif chat_request.force_search_knowledgebase:
chat_request.search_web = False

if chat_request.search_web:
search_engine = resolve_searcher(session_config)
if not search_engine:
raise ValueError(
"AI search config is not valid. Please check your search api configuration."
)

response = await search_engine.aquery(
query_engine = resolve_query_engine(session_config)
response = await query_engine.aquery(
query_bundle,
system_role_str=system_prompt,
prompt_template_str=" " if system_prompt else None,
Expand All @@ -451,22 +478,16 @@ async def achat(
return _make_chat_completion_response(
session_id=session_id, response=response
)

query_engine = resolve_query_engine(session_config)
response = await query_engine.aquery(
query_bundle,
system_role_str=system_prompt,
prompt_template_str=" " if system_prompt else None,
)
if chat_request.stream:
return _make_chat_completion_chunk_response(
session_id=session_id,
response=response,
)
else:
return _make_chat_completion_response(
session_id=session_id, response=response
)
except Exception as e:
logger.error(f"Error while processing request: {e}")
if chat_request.stream:
return _make_chat_completion_chunk_response_with_text(
session_id, "抱歉,系统错误,暂时无法处理这个请求。"
)
else:
return _make_chat_completion_response_with_text(
session_id, "抱歉,系统错误,暂时无法处理这个请求。"
)

async def aquery(
self,
Expand Down
15 changes: 13 additions & 2 deletions src/pai_rag/integrations/guardrail/pai_guardrail.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ async def acheck(self, text):
response = await self.client.text_moderation_plus_async(
textModerationPlusRequest
)
if response.status_code == 200:
if response.status_code == 200 and response.body.code == 200:
# 调用成功
risk_level = response.body.data.risk_level
reject = False
Expand Down Expand Up @@ -84,6 +84,17 @@ async def acheck(self, text):
response.status_code, response
)
)
return TextCheckResult(
reject=False,
reason="Check text failed.",
risk_level="low",
advice="",
)
except Exception as err:
logger.info(f"Unhandled error: check text failed due to {err}")
raise err
return TextCheckResult(
reject=False,
reason="Check text failed.",
risk_level="low",
advice="",
)

0 comments on commit 490bd40

Please sign in to comment.