Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support azure deployment name #4930

Merged
merged 6 commits into from
Jul 6, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion app/api/auth.ts
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ export function auth(req: NextRequest, modelProvider: ModelProvider) {
break;
case ModelProvider.GPT:
default:
if (serverConfig.isAzure) {
if (req.nextUrl.pathname.includes("azure/deployments")) {
lloydzhou marked this conversation as resolved.
Show resolved Hide resolved
systemApiKey = serverConfig.azureApiKey;
} else {
systemApiKey = serverConfig.apiKey;
Expand Down
57 changes: 57 additions & 0 deletions app/api/azure/[...path]/route.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import { getServerSideConfig } from "@/app/config/server";
import { ModelProvider } from "@/app/constant";
import { prettyObject } from "@/app/utils/format";
import { NextRequest, NextResponse } from "next/server";
import { auth } from "../../auth";
import { requestOpenai } from "../../common";

async function handle(
req: NextRequest,
{ params }: { params: { path: string[] } },
) {
console.log("[Azure Route] params ", params);

if (req.method === "OPTIONS") {
return NextResponse.json({ body: "OK" }, { status: 200 });
}

const subpath = params.path.join("/");

const authResult = auth(req, ModelProvider.GPT);
if (authResult.error) {
return NextResponse.json(authResult, {
status: 401,
});
}

try {
return await requestOpenai(req);
} catch (e) {
console.error("[Azure] ", e);
return NextResponse.json(prettyObject(e));
}
}

export const GET = handle;
export const POST = handle;

export const runtime = "edge";
export const preferredRegion = [
"arn1",
"bom1",
"cdg1",
"cle1",
"cpt1",
"dub1",
"fra1",
"gru1",
"hnd1",
"iad1",
"icn1",
"kix1",
"lhr1",
"pdx1",
"sfo1",
"sin1",
"syd1",
];
20 changes: 10 additions & 10 deletions app/api/common.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,17 @@ import {
ServiceProvider,
} from "../constant";
import { isModelAvailableInServer } from "../utils/model";
import { makeAzurePath } from "../azure";

const serverConfig = getServerSideConfig();

export async function requestOpenai(req: NextRequest) {
const controller = new AbortController();

const isAzure = req.nextUrl.pathname.includes("azure/deployments");
lloydzhou marked this conversation as resolved.
Show resolved Hide resolved

var authValue,
authHeaderName = "";
if (serverConfig.isAzure) {
if (isAzure) {
authValue =
req.headers
.get("Authorization")
Expand Down Expand Up @@ -56,14 +57,13 @@ export async function requestOpenai(req: NextRequest) {
10 * 60 * 1000,
);

if (serverConfig.isAzure) {
if (!serverConfig.azureApiVersion) {
return NextResponse.json({
error: true,
message: `missing AZURE_API_VERSION in server env vars`,
});
}
path = makeAzurePath(path, serverConfig.azureApiVersion);
if (isAzure) {
const azureApiVersion = req?.nextUrl?.searchParams?.get("api-version");
lloydzhou marked this conversation as resolved.
Show resolved Hide resolved
baseUrl = baseUrl.split("/deployments").shift() as string;
path = `${req.nextUrl.pathname.replaceAll(
"/api/azure/",
"",
)}?api-version=${azureApiVersion}`;
lloydzhou marked this conversation as resolved.
Show resolved Hide resolved
}

const fetchUrl = `${baseUrl}/${path}`;
Expand Down
9 changes: 0 additions & 9 deletions app/azure.ts

This file was deleted.

19 changes: 13 additions & 6 deletions app/client/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ export interface RequestMessage {

export interface LLMConfig {
model: string;
providerName?: string;
lloydzhou marked this conversation as resolved.
Show resolved Hide resolved
temperature?: number;
top_p?: number;
stream?: boolean;
Expand All @@ -54,6 +55,7 @@ export interface LLMUsage {

export interface LLMModel {
name: string;
displayName?: string;
available: boolean;
provider: LLMModelProvider;
}
Expand Down Expand Up @@ -160,10 +162,14 @@ export function getHeaders() {
Accept: "application/json",
};
const modelConfig = useChatStore.getState().currentSession().mask.modelConfig;
const isGoogle = modelConfig.model.startsWith("gemini");
const isAzure = accessStore.provider === ServiceProvider.Azure;
const isAnthropic = accessStore.provider === ServiceProvider.Anthropic;
const authHeader = isAzure ? "api-key" : isAnthropic ? 'x-api-key' : "Authorization";
const isGoogle = modelConfig.providerName == ServiceProvider.Azure;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

apiClient生成headers的时候,使用modelConfig.providerName判断ServiceProvider

lloydzhou marked this conversation as resolved.
Show resolved Hide resolved
const isAzure = modelConfig.providerName === ServiceProvider.Azure;
const isAnthropic = modelConfig.providerName === ServiceProvider.Anthropic;
const authHeader = isAzure
? "api-key"
: isAnthropic
? "x-api-key"
: "Authorization";
const apiKey = isGoogle
? accessStore.googleApiKey
: isAzure
Expand All @@ -172,7 +178,8 @@ export function getHeaders() {
? accessStore.anthropicApiKey
: accessStore.openaiApiKey;
const clientConfig = getClientConfig();
const makeBearer = (s: string) => `${isAzure || isAnthropic ? "" : "Bearer "}${s.trim()}`;
const makeBearer = (s: string) =>
`${isAzure || isAnthropic ? "" : "Bearer "}${s.trim()}`;
const validString = (x: string) => x && x.length > 0;

// when using google api in app, not set auth header
Expand All @@ -185,7 +192,7 @@ export function getHeaders() {
validString(accessStore.accessCode)
) {
// access_code must send with header named `Authorization`, will using in auth middleware.
headers['Authorization'] = makeBearer(
headers["Authorization"] = makeBearer(
ACCESS_CODE_PREFIX + accessStore.accessCode,
);
}
Expand Down
51 changes: 39 additions & 12 deletions app/client/platforms/openai.ts
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
"use client";
// azure and openai, using same models. so using same LLMApi.
import {
ApiPath,
DEFAULT_API_HOST,
DEFAULT_MODELS,
OpenaiPath,
Azure,
REQUEST_TIMEOUT_MS,
ServiceProvider,
} from "@/app/constant";
import { useAccessStore, useAppConfig, useChatStore } from "@/app/store";
import { collectModelsWithDefaultModel } from "@/app/utils/model";

import {
ChatOptions,
Expand All @@ -24,7 +27,6 @@ import {
} from "@fortaine/fetch-event-source";
import { prettyObject } from "@/app/utils/format";
import { getClientConfig } from "@/app/config/client";
import { makeAzurePath } from "@/app/azure";
import {
getMessageTextContent,
getMessageImages,
Expand Down Expand Up @@ -62,33 +64,31 @@ export class ChatGPTApi implements LLMApi {

let baseUrl = "";

const isAzure = path.includes("deployments");
if (accessStore.useCustomConfig) {
const isAzure = accessStore.provider === ServiceProvider.Azure;

if (isAzure && !accessStore.isValidAzure()) {
throw Error(
"incomplete azure config, please check it in your settings page",
);
}

if (isAzure) {
path = makeAzurePath(path, accessStore.azureApiVersion);
}

baseUrl = isAzure ? accessStore.azureUrl : accessStore.openaiUrl;
}

if (baseUrl.length === 0) {
const isApp = !!getClientConfig()?.isApp;
baseUrl = isApp
? DEFAULT_API_HOST + "/proxy" + ApiPath.OpenAI
: ApiPath.OpenAI;
const apiPath = isAzure ? ApiPath.Azure : ApiPath.OpenAI;
baseUrl = isApp ? DEFAULT_API_HOST + "/proxy" + apiPath : apiPath;
}

if (baseUrl.endsWith("/")) {
baseUrl = baseUrl.slice(0, baseUrl.length - 1);
}
if (!baseUrl.startsWith("http") && !baseUrl.startsWith(ApiPath.OpenAI)) {
if (
!baseUrl.startsWith("http") &&
!isAzure &&
!baseUrl.startsWith(ApiPath.OpenAI)
) {
baseUrl = "https://" + baseUrl;
}

Expand All @@ -113,6 +113,7 @@ export class ChatGPTApi implements LLMApi {
...useChatStore.getState().currentSession().mask.modelConfig,
...{
model: options.config.model,
providerName: options.config.providerName,
},
};

Expand Down Expand Up @@ -140,7 +141,33 @@ export class ChatGPTApi implements LLMApi {
options.onController?.(controller);

try {
const chatPath = this.path(OpenaiPath.ChatPath);
let chatPath = "";
if (modelConfig.providerName == ServiceProvider.Azure) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

azure和openai使用的模型都是GPT,所以都使用ChatGPTApi

这里,按modelConfig.providerName判断是否是Azure。
同时尝试从allModels里面获取displayName作为deploymentName使用。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

deploymentName会组合在url里面

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

那么这里最新的azure 配置是什么样 需要更新一下readme

// find model, and get displayName as deployName
const { models: configModels, customModels: configCustomModels } =
useAppConfig.getState();
const { defaultModel, customModels: accessCustomModels } =
useAccessStore.getState();

const models = collectModelsWithDefaultModel(
configModels,
[configCustomModels, accessCustomModels].join(","),
defaultModel,
);
const model = models.find(
(model) =>
model.name == modelConfig.model &&
model?.provider?.providerName == ServiceProvider.Azure,
);
chatPath = this.path(
Azure.ChatPath(
(model?.displayName ?? model?.name) as string,
useAccessStore.getState().azureApiVersion,
),
);
} else {
chatPath = this.path(OpenaiPath.ChatPath);
}
const chatPayload = {
method: "POST",
body: JSON.stringify(requestPayload),
Expand Down
35 changes: 23 additions & 12 deletions app/components/chat.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ import {
Path,
REQUEST_TIMEOUT_MS,
UNFINISHED_INPUT,
ServiceProvider,
} from "../constant";
import { Avatar } from "./emoji";
import { ContextPrompts, MaskAvatar, MaskConfig } from "./mask";
Expand Down Expand Up @@ -448,6 +449,9 @@ export function ChatActions(props: {

// switch model
const currentModel = chatStore.currentSession().mask.modelConfig.model;
const currentProviderName =
chatStore.currentSession().mask.modelConfig?.providerName ||
ServiceProvider.OpenAI;
const allModels = useAllModels();
const models = useMemo(() => {
const filteredModels = allModels.filter((m) => m.available);
Expand Down Expand Up @@ -479,13 +483,13 @@ export function ChatActions(props: {
const isUnavaliableModel = !models.some((m) => m.name === currentModel);
if (isUnavaliableModel && models.length > 0) {
// show next model to default model if exist
let nextModel: ModelType = (
models.find((model) => model.isDefault) || models[0]
).name;
chatStore.updateCurrentSession(
(session) => (session.mask.modelConfig.model = nextModel),
);
showToast(nextModel);
let nextModel = models.find((model) => model.isDefault) || models[0];
chatStore.updateCurrentSession((session) => {
session.mask.modelConfig.model = nextModel.name;
session.mask.modelConfig.providerName = nextModel?.provider
?.providerName as ServiceProvider;
});
showToast(nextModel.name);
}
}, [chatStore, currentModel, models]);

Expand Down Expand Up @@ -573,19 +577,26 @@ export function ChatActions(props: {

{showModelSelector && (
<Selector
defaultSelectedValue={currentModel}
defaultSelectedValue={`${currentModel}@${currentProviderName}`}
lloydzhou marked this conversation as resolved.
Show resolved Hide resolved
items={models.map((m) => ({
title: m.displayName,
value: m.name,
title: `${m.displayName}${
m?.provider?.providerName
? "(" + m?.provider?.providerName + ")"
: ""
}`,
value: `${m.name}@${m?.provider?.providerName}`,
}))}
onClose={() => setShowModelSelector(false)}
onSelection={(s) => {
if (s.length === 0) return;
const [model, providerName] = s[0].split("@");
chatStore.updateCurrentSession((session) => {
session.mask.modelConfig.model = s[0] as ModelType;
session.mask.modelConfig.model = model as ModelType;
session.mask.modelConfig.providerName =
providerName as ServiceProvider;
session.mask.syncGlobalConfig = false;
});
showToast(s[0]);
showToast(model);
}}
/>
)}
Expand Down
11 changes: 7 additions & 4 deletions app/components/exporter.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,14 @@ import { toBlob, toPng } from "html-to-image";
import { DEFAULT_MASK_AVATAR } from "../store/mask";

import { prettyObject } from "../utils/format";
import { EXPORT_MESSAGE_CLASS_NAME, ModelProvider } from "../constant";
import {
EXPORT_MESSAGE_CLASS_NAME,
ModelProvider,
ServiceProvider,
} from "../constant";
import { getClientConfig } from "../config/client";
import { ClientApi } from "../client/api";
import { getMessageTextContent } from "../utils";
import { identifyDefaultClaudeModel } from "../utils/checkers";

const Markdown = dynamic(async () => (await import("./markdown")).Markdown, {
loading: () => <LoadingIcon />,
Expand Down Expand Up @@ -314,9 +317,9 @@ export function PreviewActions(props: {
setShouldExport(false);

var api: ClientApi;
if (config.modelConfig.model.startsWith("gemini")) {
if (config.modelConfig.providerName == ServiceProvider.Google) {
lloydzhou marked this conversation as resolved.
Show resolved Hide resolved
api = new ClientApi(ModelProvider.GeminiPro);
} else if (identifyDefaultClaudeModel(config.modelConfig.model)) {
} else if (config.modelConfig.providerName == ServiceProvider.Anthropic) {
api = new ClientApi(ModelProvider.Claude);
} else {
api = new ClientApi(ModelProvider.GPT);
Expand Down
Loading
Loading