From 1c20137b0e930c8a7171abaabfa857465db77ad0 Mon Sep 17 00:00:00 2001 From: lloydzhou Date: Fri, 5 Jul 2024 19:59:45 +0800 Subject: [PATCH 1/6] support azure deployment name --- app/api/auth.ts | 2 +- app/api/common.ts | 19 +++++++-------- app/client/api.ts | 19 ++++++++++----- app/client/platforms/openai.ts | 41 ++++++++++++++++++++++++++++++++- app/components/chat.tsx | 35 ++++++++++++++++++---------- app/components/exporter.tsx | 11 +++++---- app/components/home.tsx | 7 +++--- app/components/model-config.tsx | 19 +++++++-------- app/constant.ts | 12 ++++++++++ app/store/access.ts | 7 +++++- app/store/chat.ts | 10 ++++---- app/store/config.ts | 2 ++ app/utils/checkers.ts | 21 ----------------- app/utils/hooks.ts | 7 +++++- next.config.mjs | 5 ++++ 15 files changed, 143 insertions(+), 74 deletions(-) delete mode 100644 app/utils/checkers.ts diff --git a/app/api/auth.ts b/app/api/auth.ts index b750f2d1731..2b4702aedc3 100644 --- a/app/api/auth.ts +++ b/app/api/auth.ts @@ -75,7 +75,7 @@ export function auth(req: NextRequest, modelProvider: ModelProvider) { break; case ModelProvider.GPT: default: - if (serverConfig.isAzure) { + if (req.nextUrl.pathname.includes("azure/deployments")) { systemApiKey = serverConfig.azureApiKey; } else { systemApiKey = serverConfig.apiKey; diff --git a/app/api/common.ts b/app/api/common.ts index 1454fde2ed1..17b5f916533 100644 --- a/app/api/common.ts +++ b/app/api/common.ts @@ -14,9 +14,11 @@ const serverConfig = getServerSideConfig(); export async function requestOpenai(req: NextRequest) { const controller = new AbortController(); + const isAzure = req.nextUrl.pathname.includes("azure/deployments"); + var authValue, authHeaderName = ""; - if (serverConfig.isAzure) { + if (isAzure) { authValue = req.headers .get("Authorization") @@ -56,14 +58,13 @@ export async function requestOpenai(req: NextRequest) { 10 * 60 * 1000, ); - if (serverConfig.isAzure) { - if (!serverConfig.azureApiVersion) { - return NextResponse.json({ - error: true, - message: `missing AZURE_API_VERSION in server env vars`, - }); - } - path = makeAzurePath(path, serverConfig.azureApiVersion); + if (isAzure) { + const azureApiVersion = req?.nextUrl?.searchParams?.get("api-version"); + baseUrl = baseUrl.split("/deployments").shift(); + path = `${req.nextUrl.pathname.replaceAll( + "/api/azure/", + "", + )}?api-version=${azureApiVersion}`; } const fetchUrl = `${baseUrl}/${path}`; diff --git a/app/client/api.ts b/app/client/api.ts index edee993424a..896880fa3f6 100644 --- a/app/client/api.ts +++ b/app/client/api.ts @@ -30,6 +30,7 @@ export interface RequestMessage { export interface LLMConfig { model: string; + providerName?: string; temperature?: number; top_p?: number; stream?: boolean; @@ -54,6 +55,7 @@ export interface LLMUsage { export interface LLMModel { name: string; + displayName?: string; available: boolean; provider: LLMModelProvider; } @@ -160,10 +162,14 @@ export function getHeaders() { Accept: "application/json", }; const modelConfig = useChatStore.getState().currentSession().mask.modelConfig; - const isGoogle = modelConfig.model.startsWith("gemini"); - const isAzure = accessStore.provider === ServiceProvider.Azure; - const isAnthropic = accessStore.provider === ServiceProvider.Anthropic; - const authHeader = isAzure ? "api-key" : isAnthropic ? 'x-api-key' : "Authorization"; + const isGoogle = modelConfig.providerName == ServiceProvider.Azure; + const isAzure = modelConfig.providerName === ServiceProvider.Azure; + const isAnthropic = modelConfig.providerName === ServiceProvider.Anthropic; + const authHeader = isAzure + ? "api-key" + : isAnthropic + ? "x-api-key" + : "Authorization"; const apiKey = isGoogle ? accessStore.googleApiKey : isAzure @@ -172,7 +178,8 @@ export function getHeaders() { ? accessStore.anthropicApiKey : accessStore.openaiApiKey; const clientConfig = getClientConfig(); - const makeBearer = (s: string) => `${isAzure || isAnthropic ? "" : "Bearer "}${s.trim()}`; + const makeBearer = (s: string) => + `${isAzure || isAnthropic ? "" : "Bearer "}${s.trim()}`; const validString = (x: string) => x && x.length > 0; // when using google api in app, not set auth header @@ -185,7 +192,7 @@ export function getHeaders() { validString(accessStore.accessCode) ) { // access_code must send with header named `Authorization`, will using in auth middleware. - headers['Authorization'] = makeBearer( + headers["Authorization"] = makeBearer( ACCESS_CODE_PREFIX + accessStore.accessCode, ); } diff --git a/app/client/platforms/openai.ts b/app/client/platforms/openai.ts index f3599263023..25097e3baa6 100644 --- a/app/client/platforms/openai.ts +++ b/app/client/platforms/openai.ts @@ -1,13 +1,16 @@ "use client"; +// azure and openai, using same models. so using same LLMApi. import { ApiPath, DEFAULT_API_HOST, DEFAULT_MODELS, OpenaiPath, + Azure, REQUEST_TIMEOUT_MS, ServiceProvider, } from "@/app/constant"; import { useAccessStore, useAppConfig, useChatStore } from "@/app/store"; +import { collectModelsWithDefaultModel } from "@/app/utils/model"; import { ChatOptions, @@ -97,6 +100,15 @@ export class ChatGPTApi implements LLMApi { return [baseUrl, path].join("/"); } + getBaseUrl(apiPath: string) { + const isApp = !!getClientConfig()?.isApp; + let baseUrl = isApp ? DEFAULT_API_HOST + "/proxy" + apiPath : apiPath; + if (baseUrl.endsWith("/")) { + baseUrl = baseUrl.slice(0, baseUrl.length - 1); + } + return baseUrl + "/"; + } + extractMessage(res: any) { return res.choices?.at(0)?.message?.content ?? ""; } @@ -113,6 +125,7 @@ export class ChatGPTApi implements LLMApi { ...useChatStore.getState().currentSession().mask.modelConfig, ...{ model: options.config.model, + providerName: options.config.providerName, }, }; @@ -140,7 +153,33 @@ export class ChatGPTApi implements LLMApi { options.onController?.(controller); try { - const chatPath = this.path(OpenaiPath.ChatPath); + let chatPath = ""; + if (modelConfig.providerName == ServiceProvider.Azure) { + // find model, and get displayName as deployName + const { models: configModels, customModels: configCustomModels } = + useAppConfig.getState(); + const { defaultModel, customModels: accessCustomModels } = + useAccessStore.getState(); + + const models = collectModelsWithDefaultModel( + configModels, + [configCustomModels, accessCustomModels].join(","), + defaultModel, + ); + const model = models.find( + (model) => + model.name == modelConfig.model && + model?.provider.providerName == ServiceProvider.Azure, + ); + chatPath = + this.getBaseUrl(ApiPath.Azure) + + Azure.ChatPath( + model?.displayName ?? model.name, + useAccessStore.getState().azureApiVersion, + ); + } else { + chatPath = this.getBaseUrl(ApiPath.OpenAI) + OpenaiPath.ChatPath; + } const chatPayload = { method: "POST", body: JSON.stringify(requestPayload), diff --git a/app/components/chat.tsx b/app/components/chat.tsx index 06119250465..b1bdf757f44 100644 --- a/app/components/chat.tsx +++ b/app/components/chat.tsx @@ -88,6 +88,7 @@ import { Path, REQUEST_TIMEOUT_MS, UNFINISHED_INPUT, + ServiceProvider, } from "../constant"; import { Avatar } from "./emoji"; import { ContextPrompts, MaskAvatar, MaskConfig } from "./mask"; @@ -448,6 +449,9 @@ export function ChatActions(props: { // switch model const currentModel = chatStore.currentSession().mask.modelConfig.model; + const currentProviderName = + chatStore.currentSession().mask.modelConfig?.providerName || + ServiceProvider.OpenAI; const allModels = useAllModels(); const models = useMemo(() => { const filteredModels = allModels.filter((m) => m.available); @@ -479,13 +483,13 @@ export function ChatActions(props: { const isUnavaliableModel = !models.some((m) => m.name === currentModel); if (isUnavaliableModel && models.length > 0) { // show next model to default model if exist - let nextModel: ModelType = ( - models.find((model) => model.isDefault) || models[0] - ).name; - chatStore.updateCurrentSession( - (session) => (session.mask.modelConfig.model = nextModel), - ); - showToast(nextModel); + let nextModel = models.find((model) => model.isDefault) || models[0]; + chatStore.updateCurrentSession((session) => { + session.mask.modelConfig.model = nextModel.name; + session.mask.modelConfig.providerName = nextModel?.provider + ?.providerName as ServiceProvider; + }); + showToast(nextModel.name); } }, [chatStore, currentModel, models]); @@ -573,19 +577,26 @@ export function ChatActions(props: { {showModelSelector && ( ({ - title: m.displayName, - value: m.name, + title: `${m.displayName}${ + m?.provider?.providerName + ? "(" + m?.provider?.providerName + ")" + : "" + }`, + value: `${m.name}@${m?.provider?.providerName}`, }))} onClose={() => setShowModelSelector(false)} onSelection={(s) => { if (s.length === 0) return; + const [model, providerName] = s[0].split("@"); chatStore.updateCurrentSession((session) => { - session.mask.modelConfig.model = s[0] as ModelType; + session.mask.modelConfig.model = model as ModelType; + session.mask.modelConfig.providerName = + providerName as ServiceProvider; session.mask.syncGlobalConfig = false; }); - showToast(s[0]); + showToast(model); }} /> )} diff --git a/app/components/exporter.tsx b/app/components/exporter.tsx index 20e240d93b0..7281fc2f12d 100644 --- a/app/components/exporter.tsx +++ b/app/components/exporter.tsx @@ -36,11 +36,14 @@ import { toBlob, toPng } from "html-to-image"; import { DEFAULT_MASK_AVATAR } from "../store/mask"; import { prettyObject } from "../utils/format"; -import { EXPORT_MESSAGE_CLASS_NAME, ModelProvider } from "../constant"; +import { + EXPORT_MESSAGE_CLASS_NAME, + ModelProvider, + ServiceProvider, +} from "../constant"; import { getClientConfig } from "../config/client"; import { ClientApi } from "../client/api"; import { getMessageTextContent } from "../utils"; -import { identifyDefaultClaudeModel } from "../utils/checkers"; const Markdown = dynamic(async () => (await import("./markdown")).Markdown, { loading: () => , @@ -314,9 +317,9 @@ export function PreviewActions(props: { setShouldExport(false); var api: ClientApi; - if (config.modelConfig.model.startsWith("gemini")) { + if (config.modelConfig.providerName == ServiceProvider.Google) { api = new ClientApi(ModelProvider.GeminiPro); - } else if (identifyDefaultClaudeModel(config.modelConfig.model)) { + } else if (config.modelConfig.providerName == ServiceProvider.Anthropic) { api = new ClientApi(ModelProvider.Claude); } else { api = new ClientApi(ModelProvider.GPT); diff --git a/app/components/home.tsx b/app/components/home.tsx index ffac64fdac0..addb5e80373 100644 --- a/app/components/home.tsx +++ b/app/components/home.tsx @@ -12,7 +12,7 @@ import LoadingIcon from "../icons/three-dots.svg"; import { getCSSVar, useMobileScreen } from "../utils"; import dynamic from "next/dynamic"; -import { ModelProvider, Path, SlotID } from "../constant"; +import { ServiceProvider, ModelProvider, Path, SlotID } from "../constant"; import { ErrorBoundary } from "./error"; import { getISOLang, getLang } from "../locales"; @@ -29,7 +29,6 @@ import { AuthPage } from "./auth"; import { getClientConfig } from "../config/client"; import { ClientApi } from "../client/api"; import { useAccessStore } from "../store"; -import { identifyDefaultClaudeModel } from "../utils/checkers"; export function Loading(props: { noLogo?: boolean }) { return ( @@ -172,9 +171,9 @@ export function useLoadData() { const config = useAppConfig(); var api: ClientApi; - if (config.modelConfig.model.startsWith("gemini")) { + if (config.modelConfig.providerName == ServiceProvider.Google) { api = new ClientApi(ModelProvider.GeminiPro); - } else if (identifyDefaultClaudeModel(config.modelConfig.model)) { + } else if (config.modelConfig.providerName == ServiceProvider.Anthropic) { api = new ClientApi(ModelProvider.Claude); } else { api = new ClientApi(ModelProvider.GPT); diff --git a/app/components/model-config.tsx b/app/components/model-config.tsx index e46a018f463..346fd3a7129 100644 --- a/app/components/model-config.tsx +++ b/app/components/model-config.tsx @@ -1,3 +1,4 @@ +import { ServiceProvider } from "@/app/constant"; import { ModalConfigValidator, ModelConfig } from "../store"; import Locale from "../locales"; @@ -10,25 +11,25 @@ export function ModelConfigList(props: { updateConfig: (updater: (config: ModelConfig) => void) => void; }) { const allModels = useAllModels(); + const value = `${props.modelConfig.model}@${props.modelConfig?.providerName}`; return ( <>