Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add dalle3 model #5173

Merged
merged 5 commits into from
Aug 5, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion app/client/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import {
ServiceProvider,
} from "../constant";
import { ChatMessage, ModelType, useAccessStore, useChatStore } from "../store";
import { ChatGPTApi } from "./platforms/openai";
import { ChatGPTApi, DalleRequestPayload } from "./platforms/openai";
import { GeminiProApi } from "./platforms/google";
import { ClaudeApi } from "./platforms/anthropic";
import { ErnieApi } from "./platforms/baidu";
Expand Down Expand Up @@ -42,6 +42,7 @@ export interface LLMConfig {
stream?: boolean;
presence_penalty?: number;
frequency_penalty?: number;
size?: DalleRequestPayload["size"];
}

export interface ChatOptions {
Expand Down
113 changes: 83 additions & 30 deletions app/client/platforms/openai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,13 @@ import {
} from "@/app/constant";
import { useAccessStore, useAppConfig, useChatStore } from "@/app/store";
import { collectModelsWithDefaultModel } from "@/app/utils/model";
import { preProcessImageContent } from "@/app/utils/chat";
import {
preProcessImageContent,
uploadImage,
base64Image2Blob,
} from "@/app/utils/chat";
import { cloudflareAIGatewayUrl } from "@/app/utils/cloudflare";
import { DalleSize } from "@/app/typing";

import {
ChatOptions,
Expand All @@ -33,6 +38,7 @@ import {
getMessageTextContent,
getMessageImages,
isVisionModel,
isDalle3 as _isDalle3,
} from "@/app/utils";

export interface OpenAIListModelResponse {
Expand All @@ -58,6 +64,14 @@ export interface RequestPayload {
max_tokens?: number;
}

export interface DalleRequestPayload {
model: string;
prompt: string;
response_format: "url" | "b64_json";
n: number;
size: DalleSize;
}

export class ChatGPTApi implements LLMApi {
private disableListModels = true;

Expand Down Expand Up @@ -100,20 +114,31 @@ export class ChatGPTApi implements LLMApi {
return cloudflareAIGatewayUrl([baseUrl, path].join("/"));
}

extractMessage(res: any) {
async extractMessage(res: any) {
if (res.error) {
return "```\n" + JSON.stringify(res, null, 4) + "\n```";
}
// dalle3 model return url, using url create image message
if (res.data) {
let url = res.data?.at(0)?.url ?? "";
const b64_json = res.data?.at(0)?.b64_json ?? "";
if (!url && b64_json) {
// uploadImage
url = await uploadImage(base64Image2Blob(b64_json, "image/png"));
}
return [
{
type: "image_url",
image_url: {
url,
},
},
];
}
return res.choices?.at(0)?.message?.content ?? "";
}

async chat(options: ChatOptions) {
const visionModel = isVisionModel(options.config.model);
const messages: ChatOptions["messages"] = [];
for (const v of options.messages) {
const content = visionModel
? await preProcessImageContent(v.content)
: getMessageTextContent(v);
messages.push({ role: v.role, content });
}

const modelConfig = {
...useAppConfig.getState().modelConfig,
...useChatStore.getState().currentSession().mask.modelConfig,
Expand All @@ -123,26 +148,52 @@ export class ChatGPTApi implements LLMApi {
},
};

const requestPayload: RequestPayload = {
messages,
stream: options.config.stream,
model: modelConfig.model,
temperature: modelConfig.temperature,
presence_penalty: modelConfig.presence_penalty,
frequency_penalty: modelConfig.frequency_penalty,
top_p: modelConfig.top_p,
// max_tokens: Math.max(modelConfig.max_tokens, 1024),
// Please do not ask me why not send max_tokens, no reason, this param is just shit, I dont want to explain anymore.
};
let requestPayload: RequestPayload | DalleRequestPayload;

const isDalle3 = _isDalle3(options.config.model);
if (isDalle3) {
const prompt = getMessageTextContent(
options.messages.slice(-1)?.pop() as any,
);
requestPayload = {
model: options.config.model,
prompt,
// URLs are only valid for 60 minutes after the image has been generated.
response_format: "b64_json", // using b64_json, and save image in CacheStorage
n: 1,
size: options.config?.size ?? "1024x1024",
};
} else {
const visionModel = isVisionModel(options.config.model);
const messages: ChatOptions["messages"] = [];
for (const v of options.messages) {
const content = visionModel
? await preProcessImageContent(v.content)
: getMessageTextContent(v);
messages.push({ role: v.role, content });
}

// add max_tokens to vision model
if (visionModel && modelConfig.model.includes("preview")) {
requestPayload["max_tokens"] = Math.max(modelConfig.max_tokens, 4000);
requestPayload = {
messages,
stream: options.config.stream,
model: modelConfig.model,
temperature: modelConfig.temperature,
presence_penalty: modelConfig.presence_penalty,
frequency_penalty: modelConfig.frequency_penalty,
top_p: modelConfig.top_p,
// max_tokens: Math.max(modelConfig.max_tokens, 1024),
// Please do not ask me why not send max_tokens, no reason, this param is just shit, I dont want to explain anymore.
};

// add max_tokens to vision model
if (visionModel && modelConfig.model.includes("preview")) {
requestPayload["max_tokens"] = Math.max(modelConfig.max_tokens, 4000);
}
}

console.log("[Request] openai payload: ", requestPayload);

const shouldStream = !!options.config.stream;
const shouldStream = !isDalle3 && !!options.config.stream;
const controller = new AbortController();
options.onController?.(controller);

Expand All @@ -168,13 +219,15 @@ export class ChatGPTApi implements LLMApi {
model?.provider?.providerName === ServiceProvider.Azure,
);
chatPath = this.path(
Azure.ChatPath(
(isDalle3 ? Azure.ImagePath : Azure.ChatPath)(
(model?.displayName ?? model?.name) as string,
useCustomConfig ? useAccessStore.getState().azureApiVersion : "",
),
);
} else {
chatPath = this.path(OpenaiPath.ChatPath);
chatPath = this.path(
isDalle3 ? OpenaiPath.ImagePath : OpenaiPath.ChatPath,
);
}
const chatPayload = {
method: "POST",
Expand All @@ -186,7 +239,7 @@ export class ChatGPTApi implements LLMApi {
// make a fetch request
const requestTimeoutId = setTimeout(
() => controller.abort(),
REQUEST_TIMEOUT_MS,
isDalle3 ? REQUEST_TIMEOUT_MS * 2 : REQUEST_TIMEOUT_MS, // dalle3 using b64_json is slow.
);

if (shouldStream) {
Expand Down Expand Up @@ -317,7 +370,7 @@ export class ChatGPTApi implements LLMApi {
clearTimeout(requestTimeoutId);

const resJson = await res.json();
const message = this.extractMessage(resJson);
const message = await this.extractMessage(resJson);
options.onFinish(message);
}
} catch (e) {
Expand Down
35 changes: 35 additions & 0 deletions app/components/chat.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ import AutoIcon from "../icons/auto.svg";
import BottomIcon from "../icons/bottom.svg";
import StopIcon from "../icons/pause.svg";
import RobotIcon from "../icons/robot.svg";
import SizeIcon from "../icons/size.svg";
import PluginIcon from "../icons/plugin.svg";

import {
Expand All @@ -60,13 +61,15 @@ import {
getMessageTextContent,
getMessageImages,
isVisionModel,
isDalle3,
} from "../utils";

import { uploadImage as uploadImageRemote } from "@/app/utils/chat";

import dynamic from "next/dynamic";

import { ChatControllerPool } from "../client/controller";
import { DalleSize } from "../typing";
import { Prompt, usePromptStore } from "../store/prompt";
import Locale from "../locales";

Expand Down Expand Up @@ -481,6 +484,11 @@ export function ChatActions(props: {
const [showPluginSelector, setShowPluginSelector] = useState(false);
const [showUploadImage, setShowUploadImage] = useState(false);

const [showSizeSelector, setShowSizeSelector] = useState(false);
const dalle3Sizes: DalleSize[] = ["1024x1024", "1792x1024", "1024x1792"];
const currentSize =
chatStore.currentSession().mask.modelConfig?.size ?? "1024x1024";

useEffect(() => {
const show = isVisionModel(currentModel);
setShowUploadImage(show);
Expand Down Expand Up @@ -624,6 +632,33 @@ export function ChatActions(props: {
/>
)}

{isDalle3(currentModel) && (
<ChatAction
onClick={() => setShowSizeSelector(true)}
text={currentSize}
icon={<SizeIcon />}
/>
)}

{showSizeSelector && (
<Selector
defaultSelectedValue={currentSize}
items={dalle3Sizes.map((m) => ({
title: m,
value: m,
}))}
onClose={() => setShowSizeSelector(false)}
onSelection={(s) => {
if (s.length === 0) return;
const size = s[0];
chatStore.updateCurrentSession((session) => {
session.mask.modelConfig.size = size;
});
showToast(size);
}}
/>
)}

<ChatAction
onClick={() => setShowPluginSelector(true)}
text={Locale.Plugin.Name}
Expand Down
7 changes: 6 additions & 1 deletion app/constant.ts
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ export const Anthropic = {

export const OpenaiPath = {
ChatPath: "v1/chat/completions",
ImagePath: "v1/images/generations",
UsagePath: "dashboard/billing/usage",
SubsPath: "dashboard/billing/subscription",
ListModelPath: "v1/models",
Expand All @@ -154,7 +155,10 @@ export const OpenaiPath = {
export const Azure = {
ChatPath: (deployName: string, apiVersion: string) =>
`deployments/${deployName}/chat/completions?api-version=${apiVersion}`,
ExampleEndpoint: "https://{resource-url}/openai/deployments/{deploy-id}",
// https://<your_resource_name>.openai.azure.com/openai/deployments/<your_deployment_name>/images/generations?api-version=<api_version>
ImagePath: (deployName: string, apiVersion: string) =>
`deployments/${deployName}/images/generations?api-version=${apiVersion}`,
ExampleEndpoint: "https://{resource-url}/openai",
};

export const Google = {
Expand Down Expand Up @@ -256,6 +260,7 @@ const openaiModels = [
"gpt-4-vision-preview",
"gpt-4-turbo-2024-04-09",
"gpt-4-1106-preview",
"dall-e-3",
];

const googleModels = [
Expand Down
1 change: 1 addition & 0 deletions app/icons/size.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
5 changes: 5 additions & 0 deletions app/store/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import { nanoid } from "nanoid";
import { createPersistStore } from "../utils/store";
import { collectModelsWithDefaultModel } from "../utils/model";
import { useAccessStore } from "./access";
import { isDalle3 } from "../utils";

export type ChatMessage = RequestMessage & {
date: string;
Expand Down Expand Up @@ -541,6 +542,10 @@ export const useChatStore = createPersistStore(
const config = useAppConfig.getState();
const session = get().currentSession();
const modelConfig = session.mask.modelConfig;
// skip summarize when using dalle3?
if (isDalle3(modelConfig.model)) {
return;
}

const api: ClientApi = getClientApi(modelConfig.providerName);

Expand Down
2 changes: 2 additions & 0 deletions app/store/config.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { LLMModel } from "../client/api";
import { DalleSize } from "../typing";
import { getClientConfig } from "../config/client";
import {
DEFAULT_INPUT_TEMPLATE,
Expand Down Expand Up @@ -60,6 +61,7 @@ export const DEFAULT_CONFIG = {
compressMessageLengthThreshold: 1000,
enableInjectSystemPrompts: true,
template: config?.template ?? DEFAULT_INPUT_TEMPLATE,
size: "1024x1024" as DalleSize,
},
};

Expand Down
2 changes: 2 additions & 0 deletions app/typing.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,5 @@ export interface RequestMessage {
role: MessageRole;
content: string;
}

export type DalleSize = "1024x1024" | "1792x1024" | "1024x1792";
4 changes: 4 additions & 0 deletions app/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -265,3 +265,7 @@ export function isVisionModel(model: string) {
visionKeywords.some((keyword) => model.includes(keyword)) || isGpt4Turbo
);
}

export function isDalle3(model: string) {
return "dall-e-3" === model;
}
Loading