From 2b1b2a2213d93388a98571e83d44a9570e724647 Mon Sep 17 00:00:00 2001 From: Arvin Xu Date: Fri, 10 Jan 2025 14:51:39 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20fix:=20fix=20`*=5FMODEL=5FLIST`?= =?UTF-8?q?=20env=20in=20new=20provider=20(#5350)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * update * fix model list * fix tests * fix tests * improve code * update locales * update locales * fix * fix ui --- locales/ar/modelProvider.json | 2 +- locales/bg-BG/modelProvider.json | 2 +- locales/de-DE/modelProvider.json | 2 +- locales/en-US/modelProvider.json | 2 +- locales/es-ES/modelProvider.json | 2 +- locales/fa-IR/modelProvider.json | 2 +- locales/fr-FR/modelProvider.json | 2 +- locales/it-IT/modelProvider.json | 2 +- locales/ja-JP/modelProvider.json | 2 +- locales/ko-KR/modelProvider.json | 2 +- locales/nl-NL/modelProvider.json | 2 +- locales/pl-PL/modelProvider.json | 2 +- locales/pt-BR/modelProvider.json | 2 +- locales/ru-RU/modelProvider.json | 2 +- locales/tr-TR/modelProvider.json | 2 +- locales/vi-VN/modelProvider.json | 2 +- locales/zh-CN/modelProvider.json | 2 +- locales/zh-TW/modelProvider.json | 2 +- package.json | 4 +- src/config/aiModels/index.ts | 38 +++ src/config/modelProviders/index.ts | 3 + src/database/repositories/aiInfra/index.ts | 4 +- .../utils/openaiCompatibleFactory/index.ts | 1 + src/locales/default/modelProvider.ts | 2 +- src/migrations/FromV3ToV4/index.ts | 2 +- ...rLLMConfig.test.ts => _deprecated.test.ts} | 6 +- .../{genServerLLMConfig.ts => _deprecated.ts} | 2 +- .../globalConfig/genServerAiProviderConfig.ts | 42 +++ src/server/globalConfig/index.ts | 24 +- src/server/routers/lambda/aiModel.ts | 4 +- src/server/routers/lambda/aiProvider.ts | 4 +- src/types/aiModel.ts | 1 + src/types/serverConfig.ts | 1 + src/types/user/settings/modelProvider.ts | 2 + .../__snapshots__/parseModels.test.ts.snap | 42 ++- .../__snapshots__/parseModels.test.ts.snap | 112 +++++++ src/utils/_deprecated/parseModels.test.ts | 276 ++++++++++++++++++ src/utils/_deprecated/parseModels.ts | 161 ++++++++++ src/utils/parseModels.test.ts | 199 ++++++++++--- src/utils/parseModels.ts | 55 ++-- 40 files changed, 916 insertions(+), 105 deletions(-) rename src/server/globalConfig/{genServerLLMConfig.test.ts => _deprecated.test.ts} (94%) rename src/server/globalConfig/{genServerLLMConfig.ts => _deprecated.ts} (97%) create mode 100644 src/server/globalConfig/genServerAiProviderConfig.ts create mode 100644 src/utils/_deprecated/__snapshots__/parseModels.test.ts.snap create mode 100644 src/utils/_deprecated/parseModels.test.ts create mode 100644 src/utils/_deprecated/parseModels.ts diff --git a/locales/ar/modelProvider.json b/locales/ar/modelProvider.json index e87210e45129b..7b106c8f4670d 100644 --- a/locales/ar/modelProvider.json +++ b/locales/ar/modelProvider.json @@ -151,7 +151,7 @@ "title": "جارٍ تنزيل النموذج {{model}} " }, "endpoint": { - "desc": "أدخل عنوان واجهة برمجة التطبيقات الخاص بـ Ollama، إذا لم يتم تحديده محليًا، يمكن تركه فارغًا", + "desc": "يجب أن تحتوي على http(s)://، يمكن تركها فارغة إذا لم يتم تحديدها محليًا", "title": "عنوان وكيل الواجهة" }, "setup": { diff --git a/locales/bg-BG/modelProvider.json b/locales/bg-BG/modelProvider.json index 0779b950eef31..38f0dad8091a1 100644 --- a/locales/bg-BG/modelProvider.json +++ b/locales/bg-BG/modelProvider.json @@ -151,7 +151,7 @@ "title": "Изтегляне на модел {{model}} " }, "endpoint": { - "desc": "Въведете адрес на Ollama интерфейсния прокси, оставете празно, ако локално не е указано специално", + "desc": "Трябва да съдържа http(s)://, местният адрес може да остане празен, ако не е зададен допълнително", "title": "Адрес на прокси интерфейс" }, "setup": { diff --git a/locales/de-DE/modelProvider.json b/locales/de-DE/modelProvider.json index c5905fb4959ac..a6ae80f9549bc 100644 --- a/locales/de-DE/modelProvider.json +++ b/locales/de-DE/modelProvider.json @@ -151,7 +151,7 @@ "title": "Lade Modell {{model}} herunter" }, "endpoint": { - "desc": "Geben Sie die Proxy-Adresse der Ollama-Schnittstelle ein, leer lassen, wenn lokal nicht spezifiziert", + "desc": "Muss http(s):// enthalten, kann leer gelassen werden, wenn lokal nicht zusätzlich angegeben.", "title": "Schnittstellen-Proxy-Adresse" }, "setup": { diff --git a/locales/en-US/modelProvider.json b/locales/en-US/modelProvider.json index c09f1d8428122..d292b02887e25 100644 --- a/locales/en-US/modelProvider.json +++ b/locales/en-US/modelProvider.json @@ -151,7 +151,7 @@ "title": "Downloading model {{model}}" }, "endpoint": { - "desc": "Enter the Ollama interface proxy address, leave blank if not specified locally", + "desc": "Must include http(s)://; can be left blank if not specified locally.", "title": "Interface proxy address" }, "setup": { diff --git a/locales/es-ES/modelProvider.json b/locales/es-ES/modelProvider.json index 6f5fa34f73362..587184a429f6e 100644 --- a/locales/es-ES/modelProvider.json +++ b/locales/es-ES/modelProvider.json @@ -151,7 +151,7 @@ "title": "Descargando el modelo {{model}} " }, "endpoint": { - "desc": "Introduce la dirección del proxy de la interfaz de Ollama, déjalo en blanco si no se ha especificado localmente", + "desc": "Debe incluir http(s)://, se puede dejar vacío si no se especifica localmente", "title": "Dirección del proxy de la interfaz" }, "setup": { diff --git a/locales/fa-IR/modelProvider.json b/locales/fa-IR/modelProvider.json index 79457b89f2912..4d115094d511d 100644 --- a/locales/fa-IR/modelProvider.json +++ b/locales/fa-IR/modelProvider.json @@ -151,7 +151,7 @@ "title": "در حال دانلود مدل {{model}} " }, "endpoint": { - "desc": "آدرس پروکسی رابط Ollama را وارد کنید، اگر به صورت محلی تنظیم نشده است، می‌توانید خالی بگذارید", + "desc": "باید شامل http(s):// باشد، اگر محلی به طور اضافی مشخص نشده باشد می‌توان خالی گذاشت", "title": "آدرس سرویس Ollama" }, "setup": { diff --git a/locales/fr-FR/modelProvider.json b/locales/fr-FR/modelProvider.json index 27ec9dc52617a..e7ea42050dd17 100644 --- a/locales/fr-FR/modelProvider.json +++ b/locales/fr-FR/modelProvider.json @@ -151,7 +151,7 @@ "title": "Téléchargement du modèle {{model}} en cours" }, "endpoint": { - "desc": "Saisissez l'adresse du proxy Ollama, laissez vide si non spécifié localement", + "desc": "Doit inclure http(s)://, peut rester vide si non spécifié localement", "title": "Adresse du proxy" }, "setup": { diff --git a/locales/it-IT/modelProvider.json b/locales/it-IT/modelProvider.json index 9be4ad4142840..587aebdc93cea 100644 --- a/locales/it-IT/modelProvider.json +++ b/locales/it-IT/modelProvider.json @@ -151,7 +151,7 @@ "title": "Download del modello in corso {{model}}" }, "endpoint": { - "desc": "Inserisci l'indirizzo del proxy dell'interfaccia Ollama. Lascia vuoto se non specificato localmente", + "desc": "Deve includere http(s)://, può rimanere vuoto se non specificato localmente", "title": "Indirizzo del proxy dell'interfaccia" }, "setup": { diff --git a/locales/ja-JP/modelProvider.json b/locales/ja-JP/modelProvider.json index fde58db18e94e..8a1bcfaa36629 100644 --- a/locales/ja-JP/modelProvider.json +++ b/locales/ja-JP/modelProvider.json @@ -151,7 +151,7 @@ "title": "モデル{{model}}をダウンロード中" }, "endpoint": { - "desc": "Ollamaプロキシインターフェースアドレスを入力してください。ローカルで追加の指定がない場合は空白のままにしてください", + "desc": "http(s)://を含める必要があります。ローカルで特に指定がない場合は空白のままで構いません", "title": "プロキシインターフェースアドレス" }, "setup": { diff --git a/locales/ko-KR/modelProvider.json b/locales/ko-KR/modelProvider.json index 53ed7de115faa..80bb003e6ee3e 100644 --- a/locales/ko-KR/modelProvider.json +++ b/locales/ko-KR/modelProvider.json @@ -151,7 +151,7 @@ "title": "모델 {{model}} 다운로드 중" }, "endpoint": { - "desc": "Ollama 인터페이스 프록시 주소를 입력하세요. 로컬에서 별도로 지정하지 않은 경우 비워둘 수 있습니다", + "desc": "http(s)://를 포함해야 하며, 로컬에서 추가로 지정하지 않은 경우 비워둘 수 있습니다.", "title": "인터페이스 프록시 주소" }, "setup": { diff --git a/locales/nl-NL/modelProvider.json b/locales/nl-NL/modelProvider.json index 9ec6dabc856f4..551459e878d6a 100644 --- a/locales/nl-NL/modelProvider.json +++ b/locales/nl-NL/modelProvider.json @@ -151,7 +151,7 @@ "title": "Model {{model}} wordt gedownload" }, "endpoint": { - "desc": "Voer het Ollama interface proxyadres in, laat leeg indien niet specifiek aangegeven", + "desc": "Moet http(s):// bevatten, kan leeg gelaten worden als lokaal niet specifiek opgegeven", "title": "Interface Proxyadres" }, "setup": { diff --git a/locales/pl-PL/modelProvider.json b/locales/pl-PL/modelProvider.json index 81926714f8054..a7c57fb6a4186 100644 --- a/locales/pl-PL/modelProvider.json +++ b/locales/pl-PL/modelProvider.json @@ -151,7 +151,7 @@ "title": "Pobieranie modelu {{model}}" }, "endpoint": { - "desc": "Wprowadź adres rest API Ollama, jeśli lokalnie nie określono, pozostaw puste", + "desc": "Musi zawierać http(s)://, lokalnie, jeśli nie określono inaczej, można pozostawić puste", "title": "Adres proxy API" }, "setup": { diff --git a/locales/pt-BR/modelProvider.json b/locales/pt-BR/modelProvider.json index e7a9bfc4cd179..5b51fdcbd33ba 100644 --- a/locales/pt-BR/modelProvider.json +++ b/locales/pt-BR/modelProvider.json @@ -151,7 +151,7 @@ "title": "Baixando o modelo {{model}} " }, "endpoint": { - "desc": "Insira o endereço do proxy de interface da Ollama, se não foi especificado localmente, pode deixar em branco", + "desc": "Deve incluir http(s)://, pode deixar em branco se não houver especificação local adicional", "title": "Endereço do Proxy de Interface" }, "setup": { diff --git a/locales/ru-RU/modelProvider.json b/locales/ru-RU/modelProvider.json index 19848e035c2ee..1b9fb76e8ef03 100644 --- a/locales/ru-RU/modelProvider.json +++ b/locales/ru-RU/modelProvider.json @@ -151,7 +151,7 @@ "title": "Загрузка модели {{model}} " }, "endpoint": { - "desc": "Введите адрес прокси-интерфейса Ollama, если локально не указано иное, можете оставить пустым", + "desc": "Должен содержать http(s)://, если локально не указано иное, можно оставить пустым", "title": "Адрес прокси-интерфейса" }, "setup": { diff --git a/locales/tr-TR/modelProvider.json b/locales/tr-TR/modelProvider.json index 9a105367b7663..daaf1b9244bc8 100644 --- a/locales/tr-TR/modelProvider.json +++ b/locales/tr-TR/modelProvider.json @@ -151,7 +151,7 @@ "title": "正在下载模型 {{model}} " }, "endpoint": { - "desc": "Ollama arayüz proxy adresini girin, yerel olarak belirtilmemişse boş bırakılabilir", + "desc": "http(s):// içermelidir, yerel olarak belirtilmemişse boş bırakılabilir", "title": "Arayüz Proxy Adresi" }, "setup": { diff --git a/locales/vi-VN/modelProvider.json b/locales/vi-VN/modelProvider.json index adfbc6c5df1ff..b7f620d5bcc47 100644 --- a/locales/vi-VN/modelProvider.json +++ b/locales/vi-VN/modelProvider.json @@ -151,7 +151,7 @@ "title": "Đang tải mô hình {{model}}" }, "endpoint": { - "desc": "Nhập địa chỉ proxy API của Ollama, có thể để trống nếu không chỉ định cụ thể", + "desc": "Phải bao gồm http(s)://, có thể để trống nếu không chỉ định thêm cho địa phương", "title": "Địa chỉ proxy API" }, "setup": { diff --git a/locales/zh-CN/modelProvider.json b/locales/zh-CN/modelProvider.json index 533bd82235f1d..413fdabfc5d48 100644 --- a/locales/zh-CN/modelProvider.json +++ b/locales/zh-CN/modelProvider.json @@ -151,7 +151,7 @@ "title": "正在下载模型 {{model}} " }, "endpoint": { - "desc": "填入 Ollama 接口代理地址,本地未额外指定可留空", + "desc": "必须包含http(s)://,本地未额外指定可留空", "title": "Ollama 服务地址" }, "setup": { diff --git a/locales/zh-TW/modelProvider.json b/locales/zh-TW/modelProvider.json index d1772d398e122..016d6b9fa3ea7 100644 --- a/locales/zh-TW/modelProvider.json +++ b/locales/zh-TW/modelProvider.json @@ -151,7 +151,7 @@ "title": "正在下載模型 {{model}}" }, "endpoint": { - "desc": "填入 Ollama 接口代理地址,本地未額外指定可留空", + "desc": "必須包含http(s)://,本地未額外指定可留空", "title": "接口代理地址" }, "setup": { diff --git a/package.json b/package.json index 63513165114e0..6897d4c066d32 100644 --- a/package.json +++ b/package.json @@ -129,7 +129,7 @@ "@lobehub/chat-plugins-gateway": "^1.9.0", "@lobehub/icons": "^1.61.1", "@lobehub/tts": "^1.28.0", - "@lobehub/ui": "^1.163.0", + "@lobehub/ui": "^1.164.2", "@neondatabase/serverless": "^0.10.4", "@next/third-parties": "^15.1.4", "@react-spring/web": "^9.7.5", @@ -244,7 +244,7 @@ "@edge-runtime/vm": "^5.0.0", "@huggingface/tasks": "^0.13.13", "@lobehub/i18n-cli": "^1.20.3", - "@lobehub/lint": "^1.25.3", + "@lobehub/lint": "^1.25.5", "@lobehub/seo-cli": "^1.4.3", "@next/bundle-analyzer": "^15.1.4", "@next/eslint-plugin-next": "^15.1.4", diff --git a/src/config/aiModels/index.ts b/src/config/aiModels/index.ts index 66c7cdab0a898..ea0c113a12aed 100644 --- a/src/config/aiModels/index.ts +++ b/src/config/aiModels/index.ts @@ -96,3 +96,41 @@ export const LOBE_DEFAULT_MODEL_LIST = buildDefaultModelList({ zeroone, zhipu, }); + +export { default as ai21 } from './ai21'; +export { default as ai360 } from './ai360'; +export { default as anthropic } from './anthropic'; +export { default as azure } from './azure'; +export { default as baichuan } from './baichuan'; +export { default as bedrock } from './bedrock'; +export { default as cloudflare } from './cloudflare'; +export { default as deepseek } from './deepseek'; +export { default as fireworksai } from './fireworksai'; +export { default as giteeai } from './giteeai'; +export { default as github } from './github'; +export { default as google } from './google'; +export { default as groq } from './groq'; +export { default as higress } from './higress'; +export { default as huggingface } from './huggingface'; +export { default as hunyuan } from './hunyuan'; +export { default as internlm } from './internlm'; +export { default as minimax } from './minimax'; +export { default as mistral } from './mistral'; +export { default as moonshot } from './moonshot'; +export { default as novita } from './novita'; +export { default as ollama } from './ollama'; +export { default as openai } from './openai'; +export { default as openrouter } from './openrouter'; +export { default as perplexity } from './perplexity'; +export { default as qwen } from './qwen'; +export { default as sensenova } from './sensenova'; +export { default as siliconcloud } from './siliconcloud'; +export { default as spark } from './spark'; +export { default as stepfun } from './stepfun'; +export { default as taichu } from './taichu'; +export { default as togetherai } from './togetherai'; +export { default as upstage } from './upstage'; +export { default as wenxin } from './wenxin'; +export { default as xai } from './xai'; +export { default as zeroone } from './zeroone'; +export { default as zhipu } from './zhipu'; diff --git a/src/config/modelProviders/index.ts b/src/config/modelProviders/index.ts index b4a77da4777a2..ebc2b12f28912 100644 --- a/src/config/modelProviders/index.ts +++ b/src/config/modelProviders/index.ts @@ -38,6 +38,9 @@ import XAIProvider from './xai'; import ZeroOneProvider from './zeroone'; import ZhiPuProvider from './zhipu'; +/** + * @deprecated + */ export const LOBE_DEFAULT_MODEL_LIST: ChatModelCard[] = [ OpenAIProvider.chatModels, QwenProvider.chatModels, diff --git a/src/database/repositories/aiInfra/index.ts b/src/database/repositories/aiInfra/index.ts index 7063a5af53030..bd05911aba24b 100644 --- a/src/database/repositories/aiInfra/index.ts +++ b/src/database/repositories/aiInfra/index.ts @@ -120,7 +120,9 @@ export class AiInfraRepos { ): Promise => { try { const { default: providerModels } = await import(`@/config/aiModels/${providerId}`); - return (providerModels as AIChatModelCard[]).map((m) => ({ + + const presetList = this.providerConfigs[providerId]?.serverModelLists || providerModels; + return (presetList as AIChatModelCard[]).map((m) => ({ ...m, enabled: m.enabled || false, source: AiModelSourceEnum.Builtin, diff --git a/src/libs/agent-runtime/utils/openaiCompatibleFactory/index.ts b/src/libs/agent-runtime/utils/openaiCompatibleFactory/index.ts index 48eef0beef3a9..e2a00164b886e 100644 --- a/src/libs/agent-runtime/utils/openaiCompatibleFactory/index.ts +++ b/src/libs/agent-runtime/utils/openaiCompatibleFactory/index.ts @@ -279,6 +279,7 @@ export const LobeOpenAICompatibleFactory = = any> return models.transformModel(item); } + // TODO: should refactor after remove v1 user/modelList code const knownModel = LOBE_DEFAULT_MODEL_LIST.find((model) => model.id === item.id); if (knownModel) { diff --git a/src/locales/default/modelProvider.ts b/src/locales/default/modelProvider.ts index 4d15d69210ed7..613df7996b5eb 100644 --- a/src/locales/default/modelProvider.ts +++ b/src/locales/default/modelProvider.ts @@ -152,7 +152,7 @@ export default { title: '正在下载模型 {{model}} ', }, endpoint: { - desc: '填入 Ollama 接口代理地址,本地未额外指定可留空', + desc: '必须包含http(s)://,本地未额外指定可留空', title: 'Ollama 服务地址', }, setup: { diff --git a/src/migrations/FromV3ToV4/index.ts b/src/migrations/FromV3ToV4/index.ts index 1714837e430af..72691cd683715 100644 --- a/src/migrations/FromV3ToV4/index.ts +++ b/src/migrations/FromV3ToV4/index.ts @@ -1,5 +1,5 @@ import type { Migration, MigrationData } from '@/migrations/VersionController'; -import { transformToChatModelCards } from '@/utils/parseModels'; +import { transformToChatModelCards } from '@/utils/_deprecated/parseModels'; import { V3ConfigState, V3LegacyConfig, V3OpenAIConfig, V3Settings } from './types/v3'; import { V4AzureOpenAIConfig, V4ConfigState, V4ProviderConfig, V4Settings } from './types/v4'; diff --git a/src/server/globalConfig/genServerLLMConfig.test.ts b/src/server/globalConfig/_deprecated.test.ts similarity index 94% rename from src/server/globalConfig/genServerLLMConfig.test.ts rename to src/server/globalConfig/_deprecated.test.ts index 141742cf959b5..5c1d95b9e4167 100644 --- a/src/server/globalConfig/genServerLLMConfig.test.ts +++ b/src/server/globalConfig/_deprecated.test.ts @@ -1,8 +1,6 @@ import { describe, expect, it, vi } from 'vitest'; -import { getLLMConfig } from '@/config/llm'; - -import { genServerLLMConfig } from './genServerLLMConfig'; +import { genServerLLMConfig } from './_deprecated'; // Mock ModelProvider enum vi.mock('@/libs/agent-runtime', () => ({ @@ -40,7 +38,7 @@ vi.mock('@/config/llm', () => ({ })); // Mock parse models utils -vi.mock('@/utils/parseModels', () => ({ +vi.mock('@/utils/_deprecated/parseModels', () => ({ extractEnabledModels: (modelString: string, withDeploymentName?: boolean) => { // Returns different format if withDeploymentName is true return withDeploymentName ? [`${modelString}_withDeployment`] : [modelString]; diff --git a/src/server/globalConfig/genServerLLMConfig.ts b/src/server/globalConfig/_deprecated.ts similarity index 97% rename from src/server/globalConfig/genServerLLMConfig.ts rename to src/server/globalConfig/_deprecated.ts index 2f3b1ac2da4d1..123672f15fa36 100644 --- a/src/server/globalConfig/genServerLLMConfig.ts +++ b/src/server/globalConfig/_deprecated.ts @@ -2,7 +2,7 @@ import { getLLMConfig } from '@/config/llm'; import * as ProviderCards from '@/config/modelProviders'; import { ModelProvider } from '@/libs/agent-runtime'; import { ModelProviderCard } from '@/types/llm'; -import { extractEnabledModels, transformToChatModelCards } from '@/utils/parseModels'; +import { extractEnabledModels, transformToChatModelCards } from '@/utils/_deprecated/parseModels'; export const genServerLLMConfig = (specificConfig: Record) => { const llmConfig = getLLMConfig() as Record; diff --git a/src/server/globalConfig/genServerAiProviderConfig.ts b/src/server/globalConfig/genServerAiProviderConfig.ts new file mode 100644 index 0000000000000..f02a4bdf96f99 --- /dev/null +++ b/src/server/globalConfig/genServerAiProviderConfig.ts @@ -0,0 +1,42 @@ +import * as AiModels from '@/config/aiModels'; +import { getLLMConfig } from '@/config/llm'; +import { ModelProvider } from '@/libs/agent-runtime'; +import { AiFullModelCard } from '@/types/aiModel'; +import { ProviderConfig } from '@/types/user/settings'; +import { extractEnabledModels, transformToAiChatModelList } from '@/utils/parseModels'; + +export const genServerAiProvidersConfig = (specificConfig: Record) => { + const llmConfig = getLLMConfig() as Record; + + return Object.values(ModelProvider).reduce( + (config, provider) => { + const providerUpperCase = provider.toUpperCase(); + const providerCard = AiModels[provider] as AiFullModelCard[]; + const providerConfig = specificConfig[provider as keyof typeof specificConfig] || {}; + const providerModelList = + process.env[providerConfig.modelListKey ?? `${providerUpperCase}_MODEL_LIST`]; + + const defaultChatModels = providerCard.filter((c) => c.type === 'chat'); + + config[provider] = { + enabled: llmConfig[providerConfig.enabledKey || `ENABLED_${providerUpperCase}`], + enabledModels: extractEnabledModels( + providerModelList, + providerConfig.withDeploymentName || false, + ), + serverModelLists: transformToAiChatModelList({ + defaultChatModels: defaultChatModels || [], + modelString: providerModelList, + providerId: provider, + withDeploymentName: providerConfig.withDeploymentName || false, + }), + ...(providerConfig.fetchOnClient !== undefined && { + fetchOnClient: providerConfig.fetchOnClient, + }), + }; + + return config; + }, + {} as Record, + ); +}; diff --git a/src/server/globalConfig/index.ts b/src/server/globalConfig/index.ts index bd46080f18b37..b2af9f6163157 100644 --- a/src/server/globalConfig/index.ts +++ b/src/server/globalConfig/index.ts @@ -6,19 +6,41 @@ import { enableNextAuth } from '@/const/auth'; import { parseSystemAgent } from '@/server/globalConfig/parseSystemAgent'; import { GlobalServerConfig } from '@/types/serverConfig'; -import { genServerLLMConfig } from './genServerLLMConfig'; +import { genServerLLMConfig } from './_deprecated'; +import { genServerAiProvidersConfig } from './genServerAiProviderConfig'; import { parseAgentConfig } from './parseDefaultAgent'; export const getServerGlobalConfig = () => { const { ACCESS_CODES, DEFAULT_AGENT_CONFIG } = getAppConfig(); const config: GlobalServerConfig = { + aiProvider: genServerAiProvidersConfig({ + azure: { + enabledKey: 'ENABLED_AZURE_OPENAI', + withDeploymentName: true, + }, + bedrock: { + enabledKey: 'ENABLED_AWS_BEDROCK', + modelListKey: 'AWS_BEDROCK_MODEL_LIST', + }, + giteeai: { + enabledKey: 'ENABLED_GITEE_AI', + modelListKey: 'GITEE_AI_MODEL_LIST', + }, + ollama: { + fetchOnClient: !process.env.OLLAMA_PROXY_URL, + }, + }), defaultAgent: { config: parseAgentConfig(DEFAULT_AGENT_CONFIG), }, enableUploadFileToServer: !!fileEnv.S3_SECRET_ACCESS_KEY, enabledAccessCode: ACCESS_CODES?.length > 0, + enabledOAuthSSO: enableNextAuth, + /** + * @deprecated + */ languageModel: genServerLLMConfig({ azure: { enabledKey: 'ENABLED_AZURE_OPENAI', diff --git a/src/server/routers/lambda/aiModel.ts b/src/server/routers/lambda/aiModel.ts index e777c7c05fdf3..e8fbfc5872c3f 100644 --- a/src/server/routers/lambda/aiModel.ts +++ b/src/server/routers/lambda/aiModel.ts @@ -19,14 +19,14 @@ const aiModelProcedure = authedProcedure.use(async (opts) => { const { ctx } = opts; const gateKeeper = await KeyVaultsGateKeeper.initWithEnvKey(); - const { languageModel } = getServerGlobalConfig(); + const { aiProvider } = getServerGlobalConfig(); return opts.next({ ctx: { aiInfraRepos: new AiInfraRepos( serverDB, ctx.userId, - languageModel as Record, + aiProvider as Record, ), aiModelModel: new AiModelModel(serverDB, ctx.userId), gateKeeper, diff --git a/src/server/routers/lambda/aiProvider.ts b/src/server/routers/lambda/aiProvider.ts index f865f2657bb52..3ffdc65678981 100644 --- a/src/server/routers/lambda/aiProvider.ts +++ b/src/server/routers/lambda/aiProvider.ts @@ -18,7 +18,7 @@ import { ProviderConfig } from '@/types/user/settings'; const aiProviderProcedure = authedProcedure.use(async (opts) => { const { ctx } = opts; - const { languageModel } = getServerGlobalConfig(); + const { aiProvider } = getServerGlobalConfig(); const gateKeeper = await KeyVaultsGateKeeper.initWithEnvKey(); return opts.next({ @@ -26,7 +26,7 @@ const aiProviderProcedure = authedProcedure.use(async (opts) => { aiInfraRepos: new AiInfraRepos( serverDB, ctx.userId, - languageModel as Record, + aiProvider as Record, ), aiProviderModel: new AiProviderModel(serverDB, ctx.userId), gateKeeper, diff --git a/src/types/aiModel.ts b/src/types/aiModel.ts index 42f65894881d7..e155a025c3635 100644 --- a/src/types/aiModel.ts +++ b/src/types/aiModel.ts @@ -230,6 +230,7 @@ export interface AIRealtimeModelCard extends AIBaseModelCard { export interface AiFullModelCard extends AIBaseModelCard { abilities?: ModelAbilities; + config?: AiModelConfig; contextWindowTokens?: number; displayName?: string; id: string; diff --git a/src/types/serverConfig.ts b/src/types/serverConfig.ts index 99a30df31650e..066d2c9ec3252 100644 --- a/src/types/serverConfig.ts +++ b/src/types/serverConfig.ts @@ -20,6 +20,7 @@ export interface ServerModelProviderConfig { export type ServerLanguageModel = Partial>; export interface GlobalServerConfig { + aiProvider?: ServerLanguageModel; defaultAgent?: DeepPartial; enableUploadFileToServer?: boolean; enabledAccessCode?: boolean; diff --git a/src/types/user/settings/modelProvider.ts b/src/types/user/settings/modelProvider.ts index fabd7f5b13658..476861ea00e9a 100644 --- a/src/types/user/settings/modelProvider.ts +++ b/src/types/user/settings/modelProvider.ts @@ -1,4 +1,5 @@ import { ModelProviderKey } from '@/libs/agent-runtime'; +import { AiFullModelCard } from '@/types/aiModel'; import { ChatModelCard } from '@/types/llm'; export interface ProviderConfig { @@ -27,6 +28,7 @@ export interface ProviderConfig { * fetched models from provider side */ remoteModelCards?: ChatModelCard[]; + serverModelLists?: AiFullModelCard[]; } export type GlobalLLMProviderKey = ModelProviderKey; diff --git a/src/utils/__snapshots__/parseModels.test.ts.snap b/src/utils/__snapshots__/parseModels.test.ts.snap index a96639a72be2a..c59bc28cff535 100644 --- a/src/utils/__snapshots__/parseModels.test.ts.snap +++ b/src/utils/__snapshots__/parseModels.test.ts.snap @@ -4,16 +4,22 @@ exports[`parseModelString > custom deletion, addition, and renaming of models 1` { "add": [ { + "abilities": {}, "displayName": undefined, "id": "llama", + "type": "chat", }, { + "abilities": {}, "displayName": undefined, "id": "claude-2", + "type": "chat", }, { + "abilities": {}, "displayName": "gpt-4-32k", "id": "gpt-4-1106-preview", + "type": "chat", }, ], "removeAll": true, @@ -28,8 +34,10 @@ exports[`parseModelString > duplicate naming model 1`] = ` { "add": [ { + "abilities": {}, "displayName": "gpt-4-32k", "id": "gpt-4-1106-preview", + "type": "chat", }, ], "removeAll": false, @@ -41,12 +49,16 @@ exports[`parseModelString > empty string model 1`] = ` { "add": [ { + "abilities": {}, "displayName": "gpt-4-turbo", "id": "gpt-4-1106-preview", + "type": "chat", }, { + "abilities": {}, "displayName": undefined, "id": "claude-2", + "type": "chat", }, ], "removeAll": false, @@ -58,20 +70,28 @@ exports[`parseModelString > only add the model 1`] = ` { "add": [ { + "abilities": {}, "displayName": undefined, "id": "model1", + "type": "chat", }, { + "abilities": {}, "displayName": undefined, "id": "model2", + "type": "chat", }, { + "abilities": {}, "displayName": undefined, "id": "model3", + "type": "chat", }, { + "abilities": {}, "displayName": undefined, "id": "model4", + "type": "chat", }, ], "removeAll": false, @@ -82,31 +102,43 @@ exports[`parseModelString > only add the model 1`] = ` exports[`transformToChatModelCards > should have file with builtin models like gpt-4-0125-preview 1`] = ` [ { + "abilities": { + "files": true, + "functionCall": true, + }, "contextWindowTokens": 128000, "description": "最新的 GPT-4 Turbo 模型具备视觉功能。现在,视觉请求可以使用 JSON 模式和函数调用。 GPT-4 Turbo 是一个增强版本,为多模态任务提供成本效益高的支持。它在准确性和效率之间找到平衡,适合需要进行实时交互的应用程序场景。", "displayName": "ChatGPT-4", "enabled": true, - "files": true, - "functionCall": true, "id": "gpt-4-0125-preview", "pricing": { "input": 10, "output": 30, }, + "providerId": "openai", + "releasedAt": "2024-01-25", + "source": "builtin", + "type": "chat", }, { + "abilities": { + "files": true, + "functionCall": true, + "vision": true, + }, "contextWindowTokens": 128000, "description": "最新的 GPT-4 Turbo 模型具备视觉功能。现在,视觉请求可以使用 JSON 模式和函数调用。 GPT-4 Turbo 是一个增强版本,为多模态任务提供成本效益高的支持。它在准确性和效率之间找到平衡,适合需要进行实时交互的应用程序场景。", "displayName": "ChatGPT-4 Vision", "enabled": true, - "files": true, - "functionCall": true, "id": "gpt-4-turbo-2024-04-09", "pricing": { "input": 10, "output": 30, }, - "vision": true, + "providerId": "openai", + "releasedAt": "2024-04-09", + "source": "builtin", + "type": "chat", }, ] `; diff --git a/src/utils/_deprecated/__snapshots__/parseModels.test.ts.snap b/src/utils/_deprecated/__snapshots__/parseModels.test.ts.snap new file mode 100644 index 0000000000000..a96639a72be2a --- /dev/null +++ b/src/utils/_deprecated/__snapshots__/parseModels.test.ts.snap @@ -0,0 +1,112 @@ +// Vitest Snapshot v1, https://vitest.dev/guide/snapshot.html + +exports[`parseModelString > custom deletion, addition, and renaming of models 1`] = ` +{ + "add": [ + { + "displayName": undefined, + "id": "llama", + }, + { + "displayName": undefined, + "id": "claude-2", + }, + { + "displayName": "gpt-4-32k", + "id": "gpt-4-1106-preview", + }, + ], + "removeAll": true, + "removed": [ + "all", + "gpt-3.5-turbo", + ], +} +`; + +exports[`parseModelString > duplicate naming model 1`] = ` +{ + "add": [ + { + "displayName": "gpt-4-32k", + "id": "gpt-4-1106-preview", + }, + ], + "removeAll": false, + "removed": [], +} +`; + +exports[`parseModelString > empty string model 1`] = ` +{ + "add": [ + { + "displayName": "gpt-4-turbo", + "id": "gpt-4-1106-preview", + }, + { + "displayName": undefined, + "id": "claude-2", + }, + ], + "removeAll": false, + "removed": [], +} +`; + +exports[`parseModelString > only add the model 1`] = ` +{ + "add": [ + { + "displayName": undefined, + "id": "model1", + }, + { + "displayName": undefined, + "id": "model2", + }, + { + "displayName": undefined, + "id": "model3", + }, + { + "displayName": undefined, + "id": "model4", + }, + ], + "removeAll": false, + "removed": [], +} +`; + +exports[`transformToChatModelCards > should have file with builtin models like gpt-4-0125-preview 1`] = ` +[ + { + "contextWindowTokens": 128000, + "description": "最新的 GPT-4 Turbo 模型具备视觉功能。现在,视觉请求可以使用 JSON 模式和函数调用。 GPT-4 Turbo 是一个增强版本,为多模态任务提供成本效益高的支持。它在准确性和效率之间找到平衡,适合需要进行实时交互的应用程序场景。", + "displayName": "ChatGPT-4", + "enabled": true, + "files": true, + "functionCall": true, + "id": "gpt-4-0125-preview", + "pricing": { + "input": 10, + "output": 30, + }, + }, + { + "contextWindowTokens": 128000, + "description": "最新的 GPT-4 Turbo 模型具备视觉功能。现在,视觉请求可以使用 JSON 模式和函数调用。 GPT-4 Turbo 是一个增强版本,为多模态任务提供成本效益高的支持。它在准确性和效率之间找到平衡,适合需要进行实时交互的应用程序场景。", + "displayName": "ChatGPT-4 Vision", + "enabled": true, + "files": true, + "functionCall": true, + "id": "gpt-4-turbo-2024-04-09", + "pricing": { + "input": 10, + "output": 30, + }, + "vision": true, + }, +] +`; diff --git a/src/utils/_deprecated/parseModels.test.ts b/src/utils/_deprecated/parseModels.test.ts new file mode 100644 index 0000000000000..7adfca0925d0e --- /dev/null +++ b/src/utils/_deprecated/parseModels.test.ts @@ -0,0 +1,276 @@ +import { describe, expect, it } from 'vitest'; + +import { LOBE_DEFAULT_MODEL_LIST, OpenAIProviderCard } from '@/config/modelProviders'; +import { ChatModelCard } from '@/types/llm'; + +import { parseModelString, transformToChatModelCards } from './parseModels'; + +describe('parseModelString', () => { + it('custom deletion, addition, and renaming of models', () => { + const result = parseModelString( + '-all,+llama,+claude-2,-gpt-3.5-turbo,gpt-4-1106-preview=gpt-4-turbo,gpt-4-1106-preview=gpt-4-32k', + ); + + expect(result).toMatchSnapshot(); + }); + + it('duplicate naming model', () => { + const result = parseModelString('gpt-4-1106-preview=gpt-4-turbo,gpt-4-1106-preview=gpt-4-32k'); + expect(result).toMatchSnapshot(); + }); + + it('only add the model', () => { + const result = parseModelString('model1,model2,model3,model4'); + + expect(result).toMatchSnapshot(); + }); + + it('empty string model', () => { + const result = parseModelString('gpt-4-1106-preview=gpt-4-turbo,, ,\n ,+claude-2'); + expect(result).toMatchSnapshot(); + }); + + describe('extension capabilities', () => { + it('with token', () => { + const result = parseModelString('chatglm-6b=ChatGLM 6B<4096>'); + + expect(result.add[0]).toEqual({ + displayName: 'ChatGLM 6B', + id: 'chatglm-6b', + contextWindowTokens: 4096, + }); + }); + + it('token and function calling', () => { + const result = parseModelString('spark-v3.5=讯飞星火 v3.5<8192:fc>'); + + expect(result.add[0]).toEqual({ + displayName: '讯飞星火 v3.5', + functionCall: true, + id: 'spark-v3.5', + contextWindowTokens: 8192, + }); + }); + + it('multi models', () => { + const result = parseModelString( + 'gemini-1.5-flash-latest=Gemini 1.5 Flash<16000:vision>,gpt-4-all=ChatGPT Plus<128000:fc:vision:file>', + ); + + expect(result.add).toEqual([ + { + displayName: 'Gemini 1.5 Flash', + vision: true, + id: 'gemini-1.5-flash-latest', + contextWindowTokens: 16000, + }, + { + displayName: 'ChatGPT Plus', + vision: true, + functionCall: true, + files: true, + id: 'gpt-4-all', + contextWindowTokens: 128000, + }, + ]); + }); + + it('should have file with builtin models like gpt-4-0125-preview', () => { + const result = parseModelString( + '-all,+gpt-4-0125-preview=ChatGPT-4<128000:fc:file>,+gpt-4-turbo-2024-04-09=ChatGPT-4 Vision<128000:fc:vision:file>', + ); + expect(result.add).toEqual([ + { + displayName: 'ChatGPT-4', + files: true, + functionCall: true, + id: 'gpt-4-0125-preview', + contextWindowTokens: 128000, + }, + { + displayName: 'ChatGPT-4 Vision', + files: true, + functionCall: true, + id: 'gpt-4-turbo-2024-04-09', + contextWindowTokens: 128000, + vision: true, + }, + ]); + }); + + it('should handle empty extension capability value', () => { + const result = parseModelString('model1<1024:>'); + expect(result.add[0]).toEqual({ id: 'model1', contextWindowTokens: 1024 }); + }); + + it('should handle empty extension capability name', () => { + const result = parseModelString('model1<1024::file>'); + expect(result.add[0]).toEqual({ id: 'model1', contextWindowTokens: 1024, files: true }); + }); + + it('should handle duplicate extension capabilities', () => { + const result = parseModelString('model1<1024:vision:vision>'); + expect(result.add[0]).toEqual({ id: 'model1', contextWindowTokens: 1024, vision: true }); + }); + + it('should handle case-sensitive extension capability names', () => { + const result = parseModelString('model1<1024:VISION:FC:file>'); + expect(result.add[0]).toEqual({ id: 'model1', contextWindowTokens: 1024, files: true }); + }); + + it('should handle case-sensitive extension capability values', () => { + const result = parseModelString('model1<1024:vision:Fc:File>'); + expect(result.add[0]).toEqual({ id: 'model1', contextWindowTokens: 1024, vision: true }); + }); + + it('should handle empty angle brackets', () => { + const result = parseModelString('model1<>'); + expect(result.add[0]).toEqual({ id: 'model1' }); + }); + + it('should handle not close angle brackets', () => { + const result = parseModelString('model1<,model2'); + expect(result.add).toEqual([{ id: 'model1' }, { id: 'model2' }]); + }); + + it('should handle multi close angle brackets', () => { + const result = parseModelString('model1<>>,model2'); + expect(result.add).toEqual([{ id: 'model1' }, { id: 'model2' }]); + }); + + it('should handle only colon inside angle brackets', () => { + const result = parseModelString('model1<:>'); + expect(result.add[0]).toEqual({ id: 'model1' }); + }); + + it('should handle only non-digit characters inside angle brackets', () => { + const result = parseModelString('model1'); + expect(result.add[0]).toEqual({ id: 'model1' }); + }); + + it('should handle non-digit characters followed by digits inside angle brackets', () => { + const result = parseModelString('model1'); + expect(result.add[0]).toEqual({ id: 'model1' }); + }); + + it('should handle digits followed by non-colon characters inside angle brackets', () => { + const result = parseModelString('model1<1024abc>'); + expect(result.add[0]).toEqual({ id: 'model1', contextWindowTokens: 1024 }); + }); + + it('should handle digits followed by multiple colons inside angle brackets', () => { + const result = parseModelString('model1<1024::>'); + expect(result.add[0]).toEqual({ id: 'model1', contextWindowTokens: 1024 }); + }); + + it('should handle digits followed by a colon and non-letter characters inside angle brackets', () => { + const result = parseModelString('model1<1024:123>'); + expect(result.add[0]).toEqual({ id: 'model1', contextWindowTokens: 1024 }); + }); + + it('should handle digits followed by a colon and spaces inside angle brackets', () => { + const result = parseModelString('model1<1024: vision>'); + expect(result.add[0]).toEqual({ id: 'model1', contextWindowTokens: 1024 }); + }); + + it('should handle digits followed by multiple colons and spaces inside angle brackets', () => { + const result = parseModelString('model1<1024: : vision>'); + expect(result.add[0]).toEqual({ id: 'model1', contextWindowTokens: 1024 }); + }); + }); + + describe('deployment name', () => { + it('should have same deployment name as id', () => { + const result = parseModelString('model1=Model 1', true); + expect(result.add[0]).toEqual({ + id: 'model1', + displayName: 'Model 1', + deploymentName: 'model1', + }); + }); + + it('should have diff deployment name as id', () => { + const result = parseModelString('gpt-35-turbo->my-deploy=GPT 3.5 Turbo', true); + expect(result.add[0]).toEqual({ + id: 'gpt-35-turbo', + displayName: 'GPT 3.5 Turbo', + deploymentName: 'my-deploy', + }); + }); + }); +}); + +describe('transformToChatModelCards', () => { + const defaultChatModels: ChatModelCard[] = [ + { id: 'model1', displayName: 'Model 1', enabled: true }, + { id: 'model2', displayName: 'Model 2', enabled: false }, + ]; + + it('should return undefined when modelString is empty', () => { + const result = transformToChatModelCards({ + modelString: '', + defaultChatModels, + }); + expect(result).toBeUndefined(); + }); + + it('should remove all models when removeAll is true', () => { + const result = transformToChatModelCards({ + modelString: '-all', + defaultChatModels, + }); + expect(result).toEqual([]); + }); + + it('should remove specified models', () => { + const result = transformToChatModelCards({ + modelString: '-model1', + defaultChatModels, + }); + expect(result).toEqual([{ id: 'model2', displayName: 'Model 2', enabled: false }]); + }); + + it('should add a new known model', () => { + const knownModel = LOBE_DEFAULT_MODEL_LIST[0]; + const result = transformToChatModelCards({ + modelString: `${knownModel.id}`, + defaultChatModels, + }); + expect(result).toContainEqual({ + ...knownModel, + displayName: knownModel.displayName || knownModel.id, + enabled: true, + }); + }); + + it('should update an existing known model', () => { + const knownModel = LOBE_DEFAULT_MODEL_LIST[0]; + const result = transformToChatModelCards({ + modelString: `+${knownModel.id}=Updated Model`, + defaultChatModels: [knownModel], + }); + expect(result![0]).toEqual({ ...knownModel, displayName: 'Updated Model', enabled: true }); + }); + + it('should add a new custom model', () => { + const result = transformToChatModelCards({ + modelString: '+custom_model=Custom Model', + defaultChatModels, + }); + expect(result).toContainEqual({ + id: 'custom_model', + displayName: 'Custom Model', + enabled: true, + }); + }); + + it('should have file with builtin models like gpt-4-0125-preview', () => { + const result = transformToChatModelCards({ + modelString: + '-all,+gpt-4-0125-preview=ChatGPT-4<128000:fc:file>,+gpt-4-turbo-2024-04-09=ChatGPT-4 Vision<128000:fc:vision:file>', + defaultChatModels: OpenAIProviderCard.chatModels, + }); + + expect(result).toMatchSnapshot(); + }); +}); diff --git a/src/utils/_deprecated/parseModels.ts b/src/utils/_deprecated/parseModels.ts new file mode 100644 index 0000000000000..7e965902998b8 --- /dev/null +++ b/src/utils/_deprecated/parseModels.ts @@ -0,0 +1,161 @@ +import { produce } from 'immer'; + +import { LOBE_DEFAULT_MODEL_LIST } from '@/config/modelProviders'; +import { ChatModelCard } from '@/types/llm'; + +/** + * Parse model string to add or remove models. + */ +export const parseModelString = (modelString: string = '', withDeploymentName = false) => { + let models: ChatModelCard[] = []; + let removeAll = false; + const removedModels: string[] = []; + const modelNames = modelString.split(/[,,]/).filter(Boolean); + + for (const item of modelNames) { + const disable = item.startsWith('-'); + const nameConfig = item.startsWith('+') || item.startsWith('-') ? item.slice(1) : item; + const [idAndDisplayName, ...capabilities] = nameConfig.split('<'); + let [id, displayName] = idAndDisplayName.split('='); + + let deploymentName: string | undefined; + + if (withDeploymentName) { + [id, deploymentName] = id.split('->'); + if (!deploymentName) deploymentName = id; + } + + if (disable) { + // Disable all models. + if (id === 'all') { + removeAll = true; + } + removedModels.push(id); + continue; + } + + // remove empty model name + if (!item.trim().length) { + continue; + } + + // Remove duplicate model entries. + const existingIndex = models.findIndex(({ id: n }) => n === id); + if (existingIndex !== -1) { + models.splice(existingIndex, 1); + } + + const model: ChatModelCard = { + displayName: displayName || undefined, + id, + }; + + if (deploymentName) { + model.deploymentName = deploymentName; + } + + if (capabilities.length > 0) { + const [maxTokenStr, ...capabilityList] = capabilities[0].replace('>', '').split(':'); + model.contextWindowTokens = parseInt(maxTokenStr, 10) || undefined; + + for (const capability of capabilityList) { + switch (capability) { + case 'vision': { + model.vision = true; + break; + } + case 'fc': { + model.functionCall = true; + break; + } + case 'file': { + model.files = true; + break; + } + default: { + console.warn(`Unknown capability: ${capability}`); + } + } + } + } + + models.push(model); + } + + return { + add: models, + removeAll, + removed: removedModels, + }; +}; + +/** + * Extract a special method to process chatModels + */ +export const transformToChatModelCards = ({ + modelString = '', + defaultChatModels, + withDeploymentName = false, +}: { + defaultChatModels: ChatModelCard[]; + modelString?: string; + withDeploymentName?: boolean; +}): ChatModelCard[] | undefined => { + if (!modelString) return undefined; + + const modelConfig = parseModelString(modelString, withDeploymentName); + let chatModels = modelConfig.removeAll ? [] : defaultChatModels; + + // 处理移除逻辑 + if (!modelConfig.removeAll) { + chatModels = chatModels.filter((m) => !modelConfig.removed.includes(m.id)); + } + + return produce(chatModels, (draft) => { + // 处理添加或替换逻辑 + for (const toAddModel of modelConfig.add) { + // first try to find the model in LOBE_DEFAULT_MODEL_LIST to confirm if it is a known model + const knownModel = LOBE_DEFAULT_MODEL_LIST.find((model) => model.id === toAddModel.id); + + // if the model is known, update it based on the known model + if (knownModel) { + const index = draft.findIndex((model) => model.id === toAddModel.id); + const modelInList = draft[index]; + + // if the model is already in chatModels, update it + if (modelInList) { + draft[index] = { + ...modelInList, + ...toAddModel, + displayName: toAddModel.displayName || modelInList.displayName || modelInList.id, + enabled: true, + }; + } else { + // if the model is not in chatModels, add it + draft.push({ + ...knownModel, + ...toAddModel, + displayName: toAddModel.displayName || knownModel.displayName || knownModel.id, + enabled: true, + }); + } + } else { + // if the model is not in LOBE_DEFAULT_MODEL_LIST, add it as a new custom model + draft.push({ + ...toAddModel, + displayName: toAddModel.displayName || toAddModel.id, + enabled: true, + }); + } + } + }); +}; + +export const extractEnabledModels = (modelString: string = '', withDeploymentName = false) => { + const modelConfig = parseModelString(modelString, withDeploymentName); + const list = modelConfig.add.map((m) => m.id); + + if (list.length === 0) return; + + return list; +}; diff --git a/src/utils/parseModels.test.ts b/src/utils/parseModels.test.ts index 7adfca0925d0e..a11a05c50f5b3 100644 --- a/src/utils/parseModels.test.ts +++ b/src/utils/parseModels.test.ts @@ -1,9 +1,10 @@ import { describe, expect, it } from 'vitest'; -import { LOBE_DEFAULT_MODEL_LIST, OpenAIProviderCard } from '@/config/modelProviders'; -import { ChatModelCard } from '@/types/llm'; +import { LOBE_DEFAULT_MODEL_LIST } from '@/config/aiModels'; +import { openaiChatModels } from '@/config/aiModels/openai'; +import { AiFullModelCard } from '@/types/aiModel'; -import { parseModelString, transformToChatModelCards } from './parseModels'; +import { parseModelString, transformToAiChatModelList } from './parseModels'; describe('parseModelString', () => { it('custom deletion, addition, and renaming of models', () => { @@ -38,6 +39,8 @@ describe('parseModelString', () => { displayName: 'ChatGLM 6B', id: 'chatglm-6b', contextWindowTokens: 4096, + abilities: {}, + type: 'chat', }); }); @@ -46,9 +49,12 @@ describe('parseModelString', () => { expect(result.add[0]).toEqual({ displayName: '讯飞星火 v3.5', - functionCall: true, + abilities: { + functionCall: true, + }, id: 'spark-v3.5', contextWindowTokens: 8192, + type: 'chat', }); }); @@ -60,15 +66,21 @@ describe('parseModelString', () => { expect(result.add).toEqual([ { displayName: 'Gemini 1.5 Flash', - vision: true, + abilities: { + vision: true, + }, id: 'gemini-1.5-flash-latest', contextWindowTokens: 16000, + type: 'chat', }, { displayName: 'ChatGPT Plus', - vision: true, - functionCall: true, - files: true, + abilities: { + vision: true, + functionCall: true, + files: true, + }, + type: 'chat', id: 'gpt-4-all', contextWindowTokens: 128000, }, @@ -82,100 +94,170 @@ describe('parseModelString', () => { expect(result.add).toEqual([ { displayName: 'ChatGPT-4', - files: true, - functionCall: true, + abilities: { + functionCall: true, + files: true, + }, + type: 'chat', id: 'gpt-4-0125-preview', contextWindowTokens: 128000, }, { displayName: 'ChatGPT-4 Vision', - files: true, - functionCall: true, + abilities: { + functionCall: true, + files: true, + vision: true, + }, + type: 'chat', id: 'gpt-4-turbo-2024-04-09', contextWindowTokens: 128000, - vision: true, }, ]); }); it('should handle empty extension capability value', () => { const result = parseModelString('model1<1024:>'); - expect(result.add[0]).toEqual({ id: 'model1', contextWindowTokens: 1024 }); + expect(result.add[0]).toEqual({ + abilities: {}, + type: 'chat', + id: 'model1', + contextWindowTokens: 1024, + }); }); it('should handle empty extension capability name', () => { const result = parseModelString('model1<1024::file>'); - expect(result.add[0]).toEqual({ id: 'model1', contextWindowTokens: 1024, files: true }); + expect(result.add[0]).toEqual({ + id: 'model1', + contextWindowTokens: 1024, + abilities: { + files: true, + }, + type: 'chat', + }); }); it('should handle duplicate extension capabilities', () => { const result = parseModelString('model1<1024:vision:vision>'); - expect(result.add[0]).toEqual({ id: 'model1', contextWindowTokens: 1024, vision: true }); + expect(result.add[0]).toEqual({ + id: 'model1', + contextWindowTokens: 1024, + abilities: { + vision: true, + }, + type: 'chat', + }); }); it('should handle case-sensitive extension capability names', () => { const result = parseModelString('model1<1024:VISION:FC:file>'); - expect(result.add[0]).toEqual({ id: 'model1', contextWindowTokens: 1024, files: true }); + expect(result.add[0]).toEqual({ + id: 'model1', + contextWindowTokens: 1024, + abilities: { + files: true, + }, + type: 'chat', + }); }); it('should handle case-sensitive extension capability values', () => { const result = parseModelString('model1<1024:vision:Fc:File>'); - expect(result.add[0]).toEqual({ id: 'model1', contextWindowTokens: 1024, vision: true }); + expect(result.add[0]).toEqual({ + id: 'model1', + contextWindowTokens: 1024, + abilities: { + vision: true, + }, + type: 'chat', + }); }); it('should handle empty angle brackets', () => { const result = parseModelString('model1<>'); - expect(result.add[0]).toEqual({ id: 'model1' }); + expect(result.add[0]).toEqual({ id: 'model1', abilities: {}, type: 'chat' }); }); it('should handle not close angle brackets', () => { const result = parseModelString('model1<,model2'); - expect(result.add).toEqual([{ id: 'model1' }, { id: 'model2' }]); + expect(result.add).toEqual([ + { id: 'model1', abilities: {}, type: 'chat' }, + { id: 'model2', abilities: {}, type: 'chat' }, + ]); }); it('should handle multi close angle brackets', () => { const result = parseModelString('model1<>>,model2'); - expect(result.add).toEqual([{ id: 'model1' }, { id: 'model2' }]); + expect(result.add).toEqual([ + { id: 'model1', abilities: {}, type: 'chat' }, + { id: 'model2', abilities: {}, type: 'chat' }, + ]); }); it('should handle only colon inside angle brackets', () => { const result = parseModelString('model1<:>'); - expect(result.add[0]).toEqual({ id: 'model1' }); + expect(result.add[0]).toEqual({ id: 'model1', abilities: {}, type: 'chat' }); }); it('should handle only non-digit characters inside angle brackets', () => { const result = parseModelString('model1'); - expect(result.add[0]).toEqual({ id: 'model1' }); + expect(result.add[0]).toEqual({ id: 'model1', abilities: {}, type: 'chat' }); }); it('should handle non-digit characters followed by digits inside angle brackets', () => { const result = parseModelString('model1'); - expect(result.add[0]).toEqual({ id: 'model1' }); + expect(result.add[0]).toEqual({ id: 'model1', abilities: {}, type: 'chat' }); }); it('should handle digits followed by non-colon characters inside angle brackets', () => { const result = parseModelString('model1<1024abc>'); - expect(result.add[0]).toEqual({ id: 'model1', contextWindowTokens: 1024 }); + expect(result.add[0]).toEqual({ + id: 'model1', + contextWindowTokens: 1024, + abilities: {}, + type: 'chat', + }); }); it('should handle digits followed by multiple colons inside angle brackets', () => { const result = parseModelString('model1<1024::>'); - expect(result.add[0]).toEqual({ id: 'model1', contextWindowTokens: 1024 }); + expect(result.add[0]).toEqual({ + id: 'model1', + contextWindowTokens: 1024, + abilities: {}, + type: 'chat', + }); }); it('should handle digits followed by a colon and non-letter characters inside angle brackets', () => { const result = parseModelString('model1<1024:123>'); - expect(result.add[0]).toEqual({ id: 'model1', contextWindowTokens: 1024 }); + expect(result.add[0]).toEqual({ + id: 'model1', + contextWindowTokens: 1024, + abilities: {}, + type: 'chat', + }); }); it('should handle digits followed by a colon and spaces inside angle brackets', () => { const result = parseModelString('model1<1024: vision>'); - expect(result.add[0]).toEqual({ id: 'model1', contextWindowTokens: 1024 }); + expect(result.add[0]).toEqual({ + id: 'model1', + contextWindowTokens: 1024, + abilities: {}, + type: 'chat', + }); }); it('should handle digits followed by multiple colons and spaces inside angle brackets', () => { const result = parseModelString('model1<1024: : vision>'); - expect(result.add[0]).toEqual({ id: 'model1', contextWindowTokens: 1024 }); + expect(result.add[0]).toEqual({ + id: 'model1', + contextWindowTokens: 1024, + abilities: {}, + type: 'chat', + }); }); }); @@ -185,7 +267,11 @@ describe('parseModelString', () => { expect(result.add[0]).toEqual({ id: 'model1', displayName: 'Model 1', - deploymentName: 'model1', + abilities: {}, + type: 'chat', + config: { + deploymentName: 'model1', + }, }); }); @@ -194,48 +280,59 @@ describe('parseModelString', () => { expect(result.add[0]).toEqual({ id: 'gpt-35-turbo', displayName: 'GPT 3.5 Turbo', - deploymentName: 'my-deploy', + abilities: {}, + type: 'chat', + config: { + deploymentName: 'my-deploy', + }, }); }); }); }); describe('transformToChatModelCards', () => { - const defaultChatModels: ChatModelCard[] = [ - { id: 'model1', displayName: 'Model 1', enabled: true }, - { id: 'model2', displayName: 'Model 2', enabled: false }, + const defaultChatModels: AiFullModelCard[] = [ + { id: 'model1', displayName: 'Model 1', enabled: true, type: 'chat' }, + { id: 'model2', displayName: 'Model 2', enabled: false, type: 'chat' }, ]; it('should return undefined when modelString is empty', () => { - const result = transformToChatModelCards({ + const result = transformToAiChatModelList({ modelString: '', defaultChatModels, + providerId: 'openai', }); expect(result).toBeUndefined(); }); it('should remove all models when removeAll is true', () => { - const result = transformToChatModelCards({ + const result = transformToAiChatModelList({ modelString: '-all', defaultChatModels, + providerId: 'openai', }); expect(result).toEqual([]); }); it('should remove specified models', () => { - const result = transformToChatModelCards({ + const result = transformToAiChatModelList({ modelString: '-model1', defaultChatModels, + providerId: 'openai', }); - expect(result).toEqual([{ id: 'model2', displayName: 'Model 2', enabled: false }]); + expect(result).toEqual([ + { id: 'model2', displayName: 'Model 2', enabled: false, type: 'chat' }, + ]); }); it('should add a new known model', () => { - const knownModel = LOBE_DEFAULT_MODEL_LIST[0]; - const result = transformToChatModelCards({ + const knownModel = LOBE_DEFAULT_MODEL_LIST.find((m) => m.providerId === 'ai21')!; + const result = transformToAiChatModelList({ modelString: `${knownModel.id}`, defaultChatModels, + providerId: 'ai21', }); + expect(result).toContainEqual({ ...knownModel, displayName: knownModel.displayName || knownModel.id, @@ -244,31 +341,41 @@ describe('transformToChatModelCards', () => { }); it('should update an existing known model', () => { - const knownModel = LOBE_DEFAULT_MODEL_LIST[0]; - const result = transformToChatModelCards({ + const knownModel = LOBE_DEFAULT_MODEL_LIST.find((m) => m.providerId === 'openai')!; + const result = transformToAiChatModelList({ modelString: `+${knownModel.id}=Updated Model`, defaultChatModels: [knownModel], + providerId: 'openai', + }); + + expect(result).toContainEqual({ + ...knownModel, + displayName: 'Updated Model', + enabled: true, }); - expect(result![0]).toEqual({ ...knownModel, displayName: 'Updated Model', enabled: true }); }); it('should add a new custom model', () => { - const result = transformToChatModelCards({ + const result = transformToAiChatModelList({ modelString: '+custom_model=Custom Model', defaultChatModels, + providerId: 'openai', }); expect(result).toContainEqual({ id: 'custom_model', displayName: 'Custom Model', enabled: true, + abilities: {}, + type: 'chat', }); }); it('should have file with builtin models like gpt-4-0125-preview', () => { - const result = transformToChatModelCards({ + const result = transformToAiChatModelList({ modelString: '-all,+gpt-4-0125-preview=ChatGPT-4<128000:fc:file>,+gpt-4-turbo-2024-04-09=ChatGPT-4 Vision<128000:fc:vision:file>', - defaultChatModels: OpenAIProviderCard.chatModels, + defaultChatModels: openaiChatModels, + providerId: 'openai', }); expect(result).toMatchSnapshot(); diff --git a/src/utils/parseModels.ts b/src/utils/parseModels.ts index 7e965902998b8..8b130394a3bf2 100644 --- a/src/utils/parseModels.ts +++ b/src/utils/parseModels.ts @@ -1,13 +1,14 @@ import { produce } from 'immer'; -import { LOBE_DEFAULT_MODEL_LIST } from '@/config/modelProviders'; -import { ChatModelCard } from '@/types/llm'; +import { LOBE_DEFAULT_MODEL_LIST } from '@/config/aiModels'; +import { AiFullModelCard } from '@/types/aiModel'; +import { merge } from '@/utils/merge'; /** * Parse model string to add or remove models. */ export const parseModelString = (modelString: string = '', withDeploymentName = false) => { - let models: ChatModelCard[] = []; + let models: AiFullModelCard[] = []; let removeAll = false; const removedModels: string[] = []; const modelNames = modelString.split(/[,,]/).filter(Boolean); @@ -45,13 +46,16 @@ export const parseModelString = (modelString: string = '', withDeploymentName = models.splice(existingIndex, 1); } - const model: ChatModelCard = { + const model: AiFullModelCard = { + abilities: {}, displayName: displayName || undefined, id, + // TODO: 临时写死为 chat ,后续基于元数据迭代成对应的类型 + type: 'chat', }; if (deploymentName) { - model.deploymentName = deploymentName; + model.config = { deploymentName }; } if (capabilities.length > 0) { @@ -61,15 +65,15 @@ export const parseModelString = (modelString: string = '', withDeploymentName = for (const capability of capabilityList) { switch (capability) { case 'vision': { - model.vision = true; + model.abilities!.vision = true; break; } case 'fc': { - model.functionCall = true; + model.abilities!.functionCall = true; break; } case 'file': { - model.files = true; + model.abilities!.files = true; break; } default: { @@ -92,15 +96,17 @@ export const parseModelString = (modelString: string = '', withDeploymentName = /** * Extract a special method to process chatModels */ -export const transformToChatModelCards = ({ +export const transformToAiChatModelList = ({ modelString = '', defaultChatModels, + providerId, withDeploymentName = false, }: { - defaultChatModels: ChatModelCard[]; + defaultChatModels: AiFullModelCard[]; modelString?: string; + providerId: string; withDeploymentName?: boolean; -}): ChatModelCard[] | undefined => { +}): AiFullModelCard[] | undefined => { if (!modelString) return undefined; const modelConfig = parseModelString(modelString, withDeploymentName); @@ -115,7 +121,14 @@ export const transformToChatModelCards = ({ // 处理添加或替换逻辑 for (const toAddModel of modelConfig.add) { // first try to find the model in LOBE_DEFAULT_MODEL_LIST to confirm if it is a known model - const knownModel = LOBE_DEFAULT_MODEL_LIST.find((model) => model.id === toAddModel.id); + let knownModel = LOBE_DEFAULT_MODEL_LIST.find( + (model) => model.id === toAddModel.id && model.providerId === providerId, + ); + + if (!knownModel) { + knownModel = LOBE_DEFAULT_MODEL_LIST.find((model) => model.id === toAddModel.id); + if (knownModel) knownModel.providerId = providerId; + } // if the model is known, update it based on the known model if (knownModel) { @@ -124,20 +137,20 @@ export const transformToChatModelCards = ({ // if the model is already in chatModels, update it if (modelInList) { - draft[index] = { - ...modelInList, + draft[index] = merge(modelInList, { ...toAddModel, displayName: toAddModel.displayName || modelInList.displayName || modelInList.id, enabled: true, - }; + }); } else { // if the model is not in chatModels, add it - draft.push({ - ...knownModel, - ...toAddModel, - displayName: toAddModel.displayName || knownModel.displayName || knownModel.id, - enabled: true, - }); + draft.push( + merge(knownModel, { + ...toAddModel, + displayName: toAddModel.displayName || knownModel.displayName || knownModel.id, + enabled: true, + }), + ); } } else { // if the model is not in LOBE_DEFAULT_MODEL_LIST, add it as a new custom model