diff --git a/api/apps/conversation_app.py b/api/apps/conversation_app.py index 3e0ff89f459..401afd9eb70 100644 --- a/api/apps/conversation_app.py +++ b/api/apps/conversation_app.py @@ -228,8 +228,9 @@ def tts(): def stream_audio(): try: - for chunk in tts_mdl.tts(text): - yield chunk + for txt in re.split(r"[,。/《》?;:!\n\r:;]+", text): + for chunk in tts_mdl.tts(txt): + yield chunk except Exception as e: yield ("data:" + json.dumps({"retcode": 500, "retmsg": str(e), "data": {"answer": "**ERROR**: " + str(e)}}, diff --git a/api/apps/llm_app.py b/api/apps/llm_app.py index 69bf7f220d7..19fbbece279 100644 --- a/api/apps/llm_app.py +++ b/api/apps/llm_app.py @@ -93,24 +93,27 @@ def set_api_key(): if msg: return get_data_error_result(retmsg=msg) - llm = { + llm_config = { "api_key": req["api_key"], "api_base": req.get("base_url", "") } for n in ["model_type", "llm_name"]: if n in req: - llm[n] = req[n] + llm_config[n] = req[n] - if not TenantLLMService.filter_update( - [TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == factory], llm): - for llm in LLMService.query(fid=factory): + for llm in LLMService.query(fid=factory): + if not TenantLLMService.filter_update( + [TenantLLM.tenant_id == current_user.id, + TenantLLM.llm_factory == factory, + TenantLLM.llm_name == llm.llm_name], + llm_config): TenantLLMService.save( tenant_id=current_user.id, llm_factory=factory, llm_name=llm.llm_name, model_type=llm.model_type, - api_key=req["api_key"], - api_base=req.get("base_url", "") + api_key=llm_config["api_key"], + api_base=llm_config["api_base"] ) return get_json_result(data=True) diff --git a/rag/llm/tts_model.py b/rag/llm/tts_model.py index 4ffbc521cff..1af100723fa 100644 --- a/rag/llm/tts_model.py +++ b/rag/llm/tts_model.py @@ -161,6 +161,7 @@ def on_event(self, result: SpeechSynthesisResult): class OpenAITTS(Base): def __init__(self, key, model_name="tts-1", base_url="https://api.openai.com/v1"): + if not base_url: base_url="https://api.openai.com/v1" self.api_key = key self.model_name = model_name self.base_url = base_url @@ -181,6 +182,6 @@ def tts(self, text, voice="alloy"): if response.status_code != 200: raise Exception(f"**Error**: {response.status_code}, {response.text}") - for chunk in response.iter_content(chunk_size=1024): + for chunk in response.iter_content(): if chunk: yield chunk