From 7539d142a9a9da8390290e5616d293da98eca8dc Mon Sep 17 00:00:00 2001 From: yungongzi Date: Mon, 26 Aug 2024 13:34:29 +0800 Subject: [PATCH] VolcEngine SDK V3 adaptation (#2082) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1) Configuration interface update 2) Back-end adaptation API update Note: The official no longer supports the Skylark1/2 series, and all have been switched to the Doubao series ![image](https://github.com/user-attachments/assets/f6fd8782-0cdf-4c0b-ac8f-9eb130f667a5) ### What problem does this PR solve? _Briefly describe what this PR aims to solve. Include background context that will help reviewers understand the purpose of the PR._ ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) - [ ] New Feature (non-breaking change which adds functionality) - [ ] Documentation Update - [ ] Refactoring - [ ] Performance Improvement - [ ] Other (please describe): Co-authored-by: 海贼宅 --- api/apps/llm_app.py | 11 ++- conf/llm_factories.json | 10 ++- rag/llm/chat_model.py | 68 ++----------------- web/src/locales/en.ts | 11 ++- web/src/locales/zh-traditional.ts | 10 +-- web/src/locales/zh.ts | 10 +-- .../setting-model/volcengine-modal/index.tsx | 20 +++--- 7 files changed, 44 insertions(+), 96 deletions(-) diff --git a/api/apps/llm_app.py b/api/apps/llm_app.py index b705fc55171..53d530706f2 100644 --- a/api/apps/llm_app.py +++ b/api/apps/llm_app.py @@ -113,13 +113,10 @@ def add_llm(): if factory == "VolcEngine": # For VolcEngine, due to its special authentication method - # Assemble volc_ak, volc_sk, endpoint_id into api_key - temp = list(ast.literal_eval(req["llm_name"]).items())[0] - llm_name = temp[0] - endpoint_id = temp[1] - api_key = '{' + f'"volc_ak": "{req.get("volc_ak", "")}", ' \ - f'"volc_sk": "{req.get("volc_sk", "")}", ' \ - f'"ep_id": "{endpoint_id}", ' + '}' + # Assemble ark_api_key endpoint_id into api_key + llm_name = req["llm_name"] + api_key = '{' + f'"ark_api_key": "{req.get("ark_api_key", "")}", ' \ + f'"ep_id": "{req.get("endpoint_id", "")}", ' + '}' elif factory == "Tencent Hunyuan": api_key = '{' + f'"hunyuan_sid": "{req.get("hunyuan_sid", "")}", ' \ f'"hunyuan_sk": "{req.get("hunyuan_sk", "")}"' + '}' diff --git a/conf/llm_factories.json b/conf/llm_factories.json index ff256317681..4af341ee201 100644 --- a/conf/llm_factories.json +++ b/conf/llm_factories.json @@ -349,13 +349,19 @@ "status": "1", "llm": [ { - "llm_name": "Skylark2-pro-32k", + "llm_name": "Doubao-pro-128k", + "tags": "LLM,CHAT,128k", + "max_tokens": 131072, + "model_type": "chat" + }, + { + "llm_name": "Doubao-pro-32k", "tags": "LLM,CHAT,32k", "max_tokens": 32768, "model_type": "chat" }, { - "llm_name": "Skylark2-pro-4k", + "llm_name": "Doubao-pro-4k", "tags": "LLM,CHAT,4k", "max_tokens": 4096, "model_type": "chat" diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py index 12c41a39921..64c39912ffa 100644 --- a/rag/llm/chat_model.py +++ b/rag/llm/chat_model.py @@ -450,72 +450,16 @@ def chat_streamly(self, system, history, gen_conf): class VolcEngineChat(Base): - def __init__(self, key, model_name, base_url): + def __init__(self, key, model_name, base_url='https://ark.cn-beijing.volces.com/api/v3'): """ Since do not want to modify the original database fields, and the VolcEngine authentication method is quite special, - Assemble ak, sk, ep_id into api_key, store it as a dictionary type, and parse it for use + Assemble ark_api_key, ep_id into api_key, store it as a dictionary type, and parse it for use model_name is for display only """ - self.client = MaasService('maas-api.ml-platform-cn-beijing.volces.com', 'cn-beijing') - self.volc_ak = eval(key).get('volc_ak', '') - self.volc_sk = eval(key).get('volc_sk', '') - self.client.set_ak(self.volc_ak) - self.client.set_sk(self.volc_sk) - self.model_name = eval(key).get('ep_id', '') - - def chat(self, system, history, gen_conf): - if system: - history.insert(0, {"role": "system", "content": system}) - try: - req = { - "parameters": { - "min_new_tokens": gen_conf.get("min_new_tokens", 1), - "top_k": gen_conf.get("top_k", 0), - "max_prompt_tokens": gen_conf.get("max_prompt_tokens", 30000), - "temperature": gen_conf.get("temperature", 0.1), - "max_new_tokens": gen_conf.get("max_tokens", 1000), - "top_p": gen_conf.get("top_p", 0.3), - }, - "messages": history - } - response = self.client.chat(self.model_name, req) - ans = response.choices[0].message.content.strip() - if response.choices[0].finish_reason == "length": - ans += "...\nFor the content length reason, it stopped, continue?" if is_english( - [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?" - return ans, response.usage.total_tokens - except Exception as e: - return "**ERROR**: " + str(e), 0 - - def chat_streamly(self, system, history, gen_conf): - if system: - history.insert(0, {"role": "system", "content": system}) - ans = "" - tk_count = 0 - try: - req = { - "parameters": { - "min_new_tokens": gen_conf.get("min_new_tokens", 1), - "top_k": gen_conf.get("top_k", 0), - "max_prompt_tokens": gen_conf.get("max_prompt_tokens", 30000), - "temperature": gen_conf.get("temperature", 0.1), - "max_new_tokens": gen_conf.get("max_tokens", 1000), - "top_p": gen_conf.get("top_p", 0.3), - }, - "messages": history - } - stream = self.client.stream_chat(self.model_name, req) - for resp in stream: - if not resp.choices[0].message.content: - continue - ans += resp.choices[0].message.content - if resp.choices[0].finish_reason == "stop": - tk_count = resp.usage.total_tokens - yield ans - - except Exception as e: - yield ans + "\n**ERROR**: " + str(e) - yield tk_count + base_url = base_url if base_url else 'https://ark.cn-beijing.volces.com/api/v3' + ark_api_key = eval(key).get('ark_api_key', '') + model_name = eval(key).get('ep_id', '') + super().__init__(ark_api_key, model_name, base_url) class MiniMaxChat(Base): diff --git a/web/src/locales/en.ts b/web/src/locales/en.ts index 84566d4885d..093687216f9 100644 --- a/web/src/locales/en.ts +++ b/web/src/locales/en.ts @@ -502,12 +502,11 @@ The above is the content you need to summarize.`, baseUrlNameMessage: 'Please input your base url!', vision: 'Does it support Vision?', ollamaLink: 'How to integrate {{name}}', - volcModelNameMessage: - 'Please input your model name! Format: {"ModelName":"EndpointID"}', - addVolcEngineAK: 'VOLC ACCESS_KEY', - volcAKMessage: 'Please input your VOLC_ACCESS_KEY', - addVolcEngineSK: 'VOLC SECRET_KEY', - volcSKMessage: 'Please input your SECRET_KEY', + volcModelNameMessage: 'Please input your model name!', + addEndpointID: 'EndpointID of the model', + endpointIDMessage: 'Please input your EndpointID of the model', + addArkApiKey: 'VOLC ARK_API_KEY', + ArkApiKeyMessage: 'Please input your ARK_API_KEY', bedrockModelNameMessage: 'Please input your model name!', addBedrockEngineAK: 'ACCESS KEY', bedrockAKMessage: 'Please input your ACCESS KEY', diff --git a/web/src/locales/zh-traditional.ts b/web/src/locales/zh-traditional.ts index 7ba9fbf1531..55adfc7db04 100644 --- a/web/src/locales/zh-traditional.ts +++ b/web/src/locales/zh-traditional.ts @@ -465,11 +465,11 @@ export default { modelTypeMessage: '請輸入模型類型!', baseUrlNameMessage: '請輸入基礎 Url!', ollamaLink: '如何集成 {{name}}', - volcModelNameMessage: '請輸入模型名稱!格式:{"模型名稱":"EndpointID"}', - addVolcEngineAK: '火山 ACCESS_KEY', - volcAKMessage: '請輸入VOLC_ACCESS_KEY', - addVolcEngineSK: '火山 SECRET_KEY', - volcSKMessage: '請輸入VOLC_SECRET_KEY', + volcModelNameMessage: '請輸入模型名稱!', + addEndpointID: '模型 EndpointID', + endpointIDMessage: '請輸入模型對應的EndpointID', + addArkApiKey: '火山 ARK_API_KEY', + ArkApiKeyMessage: '請輸入火山創建的ARK_API_KEY', bedrockModelNameMessage: '請輸入名稱!', addBedrockEngineAK: 'ACCESS KEY', bedrockAKMessage: '請輸入 ACCESS KEY', diff --git a/web/src/locales/zh.ts b/web/src/locales/zh.ts index 9f04fbddac1..72b96662913 100644 --- a/web/src/locales/zh.ts +++ b/web/src/locales/zh.ts @@ -482,11 +482,11 @@ export default { modelTypeMessage: '请输入模型类型!', baseUrlNameMessage: '请输入基础 Url!', ollamaLink: '如何集成 {{name}}', - volcModelNameMessage: '请输入模型名称!格式:{"模型名称":"EndpointID"}', - addVolcEngineAK: '火山 ACCESS_KEY', - volcAKMessage: '请输入VOLC_ACCESS_KEY', - addVolcEngineSK: '火山 SECRET_KEY', - volcSKMessage: '请输入VOLC_SECRET_KEY', + volcModelNameMessage: '请输入模型名称!', + addEndpointID: '模型 EndpointID', + endpointIDMessage: '请输入模型对应的EndpointID', + addArkApiKey: '火山 ARK_API_KEY', + ArkApiKeyMessage: '请输入火山创建的ARK_API_KEY', bedrockModelNameMessage: '请输入名称!', addBedrockEngineAK: 'ACCESS KEY', bedrockAKMessage: '请输入 ACCESS KEY', diff --git a/web/src/pages/user-setting/setting-model/volcengine-modal/index.tsx b/web/src/pages/user-setting/setting-model/volcengine-modal/index.tsx index e88c15496eb..181fc0e1e06 100644 --- a/web/src/pages/user-setting/setting-model/volcengine-modal/index.tsx +++ b/web/src/pages/user-setting/setting-model/volcengine-modal/index.tsx @@ -8,6 +8,8 @@ type FieldType = IAddLlmRequestBody & { vision: boolean; volc_ak: string; volc_sk: string; + endpoint_id: string; + ark_api_key: string; }; const { Option } = Select; @@ -51,7 +53,7 @@ const VolcEngineModal = ({ return ( @@ -88,18 +90,18 @@ const VolcEngineModal = ({ - label={t('addVolcEngineAK')} - name="volc_ak" - rules={[{ required: true, message: t('volcAKMessage') }]} + label={t('addEndpointID')} + name="endpoint_id" + rules={[{ required: true, message: t('endpointIDMessage') }]} > - + - label={t('addVolcEngineSK')} - name="volc_sk" - rules={[{ required: true, message: t('volcAKMessage') }]} + label={t('addArkApiKey')} + name="ark_api_key" + rules={[{ required: true, message: t('ArkApiKeyMessage') }]} > - + {({ getFieldValue }) =>