Skip to content

Commit

Permalink
feat: add openai service providers tts asr embedding rerank
Browse files Browse the repository at this point in the history
chore: optimize intelligent model prompts, set command prompts
  • Loading branch information
adolphnov committed Dec 1, 2024
1 parent 211d1b5 commit 2769352
Show file tree
Hide file tree
Showing 13 changed files with 383 additions and 169 deletions.
16 changes: 8 additions & 8 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@
"@ai-sdk/google-vertex": "^1.0.4",
"@ai-sdk/mistral": "^1.0.3",
"@ai-sdk/openai": "^1.0.5",
"@ai-sdk/xai": "^1.0.3",
"ai": "^4.0.6",
"@ai-sdk/xai": "^1.0.4",
"ai": "^4.0.9",
"base64-stream": "^1.0.0",
"cloudflare-worker-adapter": "^1.3.4",
"fluent-ffmpeg": "^2.1.3",
Expand All @@ -62,22 +62,22 @@
"@ai-sdk/google-vertex": "^1.0.4",
"@ai-sdk/mistral": "^1.0.3",
"@ai-sdk/openai": "^1.0.5",
"@antfu/eslint-config": "^3.11.0",
"@cloudflare/workers-types": "^4.20241112.0",
"@antfu/eslint-config": "^3.11.2",
"@cloudflare/workers-types": "^4.20241127.0",
"@google-cloud/vertexai": "^1.9.0",
"@rollup/plugin-node-resolve": "^15.3.0",
"@types/base64-stream": "^1.0.5",
"@types/fluent-ffmpeg": "^2.1.27",
"@types/node": "^22.10.0",
"@types/node": "^22.10.1",
"@types/node-cron": "^3.0.11",
"@types/react": "^18.3.12",
"@types/react-dom": "^18.3.1",
"@types/ws": "^8.5.13",
"@vercel/node": "^3.2.27",
"ai": "^4.0.6",
"ai": "^4.0.9",
"base64-stream": "^1.0.0",
"eslint": "^9.15.0",
"eslint-plugin-format": "^0.1.2",
"eslint": "^9.16.0",
"eslint-plugin-format": "^0.1.3",
"fluent-ffmpeg": "^2.1.3",
"gts": "^6.0.2",
"openai": "^4.73.1",
Expand Down
4 changes: 2 additions & 2 deletions scripts/plugins/docker/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ const packageJson = `
"@ai-sdk/google-vertex": "^1.0.4",
"@ai-sdk/mistral": "^1.0.3",
"@ai-sdk/openai": "^1.0.5",
"@ai-sdk/xai": "^1.0.3",
"ai": "^4.0.6",
"@ai-sdk/xai": "^1.0.4",
"ai": "^4.0.9",
"base64-stream": "^1.0.0",
"cloudflare-worker-adapter": "^1.3.4",
"fluent-ffmpeg": "^2.1.3",
Expand Down
36 changes: 22 additions & 14 deletions src/agent/index.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/* eslint-disable no-case-declarations */
import type { CoreMessage, CoreUserMessage, LanguageModelV1 } from 'ai';
import type { AudioAgent, ChatAgent, ImageAgent } from './types';
import type { ASRAgent, ChatAgent, ImageAgent, TTSAgent } from './types';
import { createAnthropic } from '@ai-sdk/anthropic';
import { createCohere } from '@ai-sdk/cohere';
import { createGoogleGenerativeAI } from '@ai-sdk/google';
Expand All @@ -15,8 +15,8 @@ import { AzureChatAI, AzureImageAI } from './azure';
import { Cohere } from './cohere';
import { Google } from './google';
import { Mistral } from './mistralai';
import { Dalle, OpenAI, Transcription } from './openai';
import { OpenAILike, OpenAILikeImage } from './openailike';
import { Dalle, OpenAI, OpenAIASR, OpenAITTS } from './openai';
import { OpenAILike, OpenAILikeASR, OpenAILikeImage, OpenAILikeTTS } from './openailike';
import { Vertex } from './vertex';
import { WorkersChat, WorkersImage } from './workersai';
import { XAI } from './xai';
Expand Down Expand Up @@ -73,20 +73,28 @@ export function loadImageGen(context: AgentUserConfig): ImageAgent | null {
return null;
}

const AUDIO_AGENTS: AudioAgent[] = [
// 当前仅实现OpenAI音频处理
new Transcription(),
const ASR_AGENTS: ASRAgent[] = [
new OpenAIASR(),
new OpenAILikeASR(),
];

export function loadAudioLLM(context: AgentUserConfig) {
for (const llm of AUDIO_AGENTS) {
if (llm.name === context.AI_PROVIDER) {
export function loadASRLLM(context: AgentUserConfig) {
for (const llm of ASR_AGENTS) {
if (llm.name === context.AI_ASR_PROVIDER) {
return llm;
}
}
// 找不到指定的AI,使用第一个可用的AI
for (const llm of AUDIO_AGENTS) {
if (llm.enable(context)) {
return null;
}

const TTS_AGENTS: TTSAgent[] = [
new OpenAITTS(),
new OpenAILikeTTS(),
];

export function loadTTSLLM(context: AgentUserConfig) {
for (const llm of TTS_AGENTS) {
if (llm.name === context.AI_TTS_PROVIDER) {
return llm;
}
}
Expand Down Expand Up @@ -228,8 +236,8 @@ export async function createLlmModel(model: string, context: AgentUserConfig) {
default:
return createOpenAI({
name: 'olike',
baseURL: context.OPENAILIKE_API_BASE || undefined,
apiKey: context.OPENAILIKE_API_KEY || undefined,
baseURL: context.OAILIKE_API_BASE || undefined,
apiKey: context.OAILIKE_API_KEY || undefined,
}).languageModel(model_id, undefined);
}
// if (model.includes(':')) {
Expand Down
31 changes: 14 additions & 17 deletions src/agent/openai.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import type { CoreMessage, CoreUserMessage } from 'ai';
import type { AudioAgent, ChatAgent, ChatStreamTextHandler, ImageAgent, ImageResult, LLMChatParams, LLMChatRequestParams, ResponseMessage } from './types';
import type { ASRAgent, ChatAgent, ChatStreamTextHandler, ImageAgent, ImageResult, LLMChatParams, LLMChatRequestParams, ResponseMessage, TTSAgent } from './types';
import { createOpenAI } from '@ai-sdk/openai';
import { warpLLMParams } from '.';
import { type AgentUserConfig, ENV } from '../config/env';
Expand All @@ -10,6 +10,10 @@ import { requestChatCompletionsV2 } from './request';

export class OpenAIBase {
readonly name = 'openai';
readonly enable = (context: AgentUserConfig): boolean => {
return context.OPENAI_API_KEY.length > 0;
};

apikey = (context: AgentUserConfig): string => {
const length = context.OPENAI_API_KEY.length;
return context.OPENAI_API_KEY[Math.floor(Math.random() * length)];
Expand All @@ -20,10 +24,6 @@ export class OpenAI extends OpenAIBase implements ChatAgent {
readonly modelKey = 'OPENAI_CHAT_MODEL';
static readonly transformModelPerfix = 'TRANSFROM-';

readonly enable = (context: AgentUserConfig): boolean => {
return context.OPENAI_API_KEY.length > 0;
};

readonly model = (ctx: AgentUserConfig, params?: LLMChatRequestParams): string => {
const msgType = Array.isArray(params?.content) ? params.content.at(-1)?.type : 'text';
switch (msgType) {
Expand Down Expand Up @@ -99,10 +99,6 @@ export class OpenAI extends OpenAIBase implements ChatAgent {
export class Dalle extends OpenAIBase implements ImageAgent {
readonly modelKey = 'DALL_E_MODEL';

enable = (context: AgentUserConfig): boolean => {
return context.OPENAI_API_KEY.length > 0;
};

model = (ctx: AgentUserConfig): string => {
return ctx.DALL_E_MODEL;
};
Expand Down Expand Up @@ -143,13 +139,9 @@ export class Dalle extends OpenAIBase implements ImageAgent {
};
}

export class Transcription extends OpenAIBase implements AudioAgent {
export class OpenAIASR extends OpenAIBase implements ASRAgent {
readonly modelKey = 'OPENAI_STT_MODEL';

enable = (context: AgentUserConfig): boolean => {
return context.OPENAI_API_KEY.length > 0;
};

model = (ctx: AgentUserConfig): string => {
return ctx.OPENAI_STT_MODEL;
};
Expand All @@ -163,7 +155,7 @@ export class Transcription extends OpenAIBase implements AudioAgent {
};
const formData = new FormData();
formData.append('file', audio, 'audio.ogg');
formData.append('model', this.model(context));
formData.append('model', context.OPENAI_STT_MODEL);
if (context.OPENAI_STT_EXTRA_PARAMS) {
Object.entries(context.OPENAI_STT_EXTRA_PARAMS as string).forEach(([k, v]) => {
formData.append(k, v);
Expand All @@ -189,9 +181,14 @@ export class Transcription extends OpenAIBase implements AudioAgent {
};
}

export class ASR extends OpenAIBase {
export class OpenAITTS extends OpenAIBase implements TTSAgent {
readonly modelKey = 'OPENAI_TTS_MODEL';
hander = (text: string, context: AgentUserConfig): Promise<Blob> => {

model = (ctx: AgentUserConfig): string => {
return ctx.OPENAI_TTS_MODEL;
};

request = (text: string, context: AgentUserConfig): Promise<Blob> => {
const url = `${context.OPENAI_API_BASE}/audio/speech`;
const headers = {
'Authorization': `Bearer ${this.apikey(context)}`,
Expand Down
106 changes: 91 additions & 15 deletions src/agent/openailike.ts
Original file line number Diff line number Diff line change
@@ -1,33 +1,35 @@
import type { AgentUserConfig } from '../config/env';
import type { ChatAgent, ChatStreamTextHandler, ImageAgent, ImageResult, LLMChatParams, ResponseMessage } from './types';
import type { ASRAgent, ChatAgent, ChatStreamTextHandler, ImageAgent, ImageResult, LLMChatParams, ResponseMessage } from './types';
import { createOpenAI } from '@ai-sdk/openai';
import { Log } from '../log/logDecortor';
import { log } from '../log/logger';
import { requestText2Image } from './chat';
import { requestChatCompletionsV2 } from './request';

class OpenAILikeBase {
export class OpenAILikeBase {
readonly name = 'olike';

readonly enable = (context: AgentUserConfig): boolean => {
return !!context.OPENAILIKE_API_KEY;
return !!context.OAILIKE_API_KEY;
};
}

export class OpenAILike extends OpenAILikeBase implements ChatAgent {
readonly modelKey = 'OPENAILIKE_CHAT_MODEL';
readonly modelKey = 'OAILIKE_CHAT_MODEL';

readonly enable = (context: AgentUserConfig): boolean => {
return !!context.OPENAILIKE_API_KEY;
return !!context.OAILIKE_API_KEY;
};

readonly model = (ctx: AgentUserConfig): string => {
return ctx.OPENAILIKE_CHAT_MODEL;
return ctx.OAILIKE_CHAT_MODEL;
};

readonly request = async (params: LLMChatParams, context: AgentUserConfig, onStream: ChatStreamTextHandler | null): Promise<{ messages: ResponseMessage[]; content: string }> => {
const provider = createOpenAI({
name: 'openaiLike',
baseURL: context.OPENAILIKE_API_BASE || undefined,
apiKey: context.OPENAILIKE_API_KEY || undefined,
baseURL: context.OAILIKE_API_BASE || undefined,
apiKey: context.OAILIKE_API_KEY || undefined,
});
const languageModelV1 = provider.languageModel(this.model(context), undefined);
return requestChatCompletionsV2({
Expand All @@ -39,25 +41,25 @@ export class OpenAILike extends OpenAILikeBase implements ChatAgent {
}

export class OpenAILikeImage extends OpenAILikeBase implements ImageAgent {
readonly modelKey = 'OPENAILIKE_IMAGE_MODEL';
readonly modelKey = 'OAILIKE_IMAGE_MODEL';

model = (ctx: AgentUserConfig): string => {
return ctx.OPENAILIKE_IMAGE_MODEL;
return ctx.OAILIKE_IMAGE_MODEL;
};

request = async (prompt: string, context: AgentUserConfig): Promise<ImageResult> => {
const url = `${context.OPENAILIKE_API_BASE}/image/generations`;
const url = `${context.OAILIKE_API_BASE}/image/generations`;
const header = {
'Content-Type': 'application/json',
'Authorization': `Bearer ${context.OPENAILIKE_API_KEY}`,
'Authorization': `Bearer ${context.OAILIKE_API_KEY}`,
};
const body: any = {
prompt,
image_size: context.OPENAILIKE_IMAGE_SIZE,
model: context.OPENAILIKE_IMAGE_MODEL,
image_size: context.OAILIKE_IMAGE_SIZE,
model: context.OAILIKE_IMAGE_MODEL,
// num_inference_steps: 10,
batch_size: 4,
...context.OPENAILIKE_EXTRA_PARAMS,
...context.OAILIKE_EXTRA_PARAMS,
};
return requestText2Image(url, header, body, this.render);
};
Expand All @@ -72,3 +74,77 @@ export class OpenAILikeImage extends OpenAILikeBase implements ImageAgent {
return { type: 'image', url: (await resp?.images)?.map((i: { url: any }) => i?.url) };
};
}

export class OpenAILikeASR extends OpenAILikeBase implements ASRAgent {
readonly modelKey = 'OLIKE_STT_MODEL';

model = (ctx: AgentUserConfig): string => {
return ctx.OLIKE_STT_MODEL;
};

@Log
request = async (audio: Blob, context: AgentUserConfig): Promise<string> => {
const url = `${context.OAILIKE_API_BASE}/audio/transcriptions`;
const header = {
Authorization: `Bearer ${context.OAILIKE_API_KEY}`,
Accept: 'application/json',
};
const formData = new FormData();
formData.append('file', audio, 'audio.mp3');
formData.append('model', context.OLIKE_STT_MODEL);
if (context.OAILIKE_STT_EXTRA_PARAMS) {
Object.entries(context.OAILIKE_STT_EXTRA_PARAMS as string).forEach(([k, v]) => {
formData.append(k, v);
});
}
formData.append('response_format', 'json');
const resp = await fetch(url, {
method: 'POST',
headers: header,
body: formData,
redirect: 'follow',
}).then(r => r.json());

if (resp.error?.message) {
throw new Error(resp.error.message);
}
if (resp.text === undefined) {
console.error(JSON.stringify(resp));
throw new Error(JSON.stringify(resp));
}
log.info(`Transcription: ${resp.text}`);
return resp.text;
};
}

export class OpenAILikeTTS extends OpenAILikeBase {
readonly modelKey = 'OAILIKE_TTS_MODEL';

model = (ctx: AgentUserConfig): string => {
return ctx.OAILIKE_TTS_MODEL;
};

readonly request = async (text: string, context: AgentUserConfig): Promise<Blob> => {
const url = `${context.OAILIKE_API_BASE}/audio/speech`;
const headers = {
'Authorization': `Bearer ${context.OAILIKE_API_KEY}`,
'Content-Type': 'application/json',
};
const resp = await fetch(url, {
method: 'POST',
headers,
body: JSON.stringify({
model: context.OAILIKE_TTS_MODEL,
input: text,
voice: context.OAILIKE_TTS_VOICE,
response_format: 'opus',
speed: 1,
stream: false,
}),
});
if (!resp.ok) {
throw new Error(await resp.text());
}
return resp.blob();
};
}
15 changes: 12 additions & 3 deletions src/agent/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -78,13 +78,22 @@ export interface ImageResult extends UnionData {
caption?: string;
}

export type AudioAgentRequest = (audio: Blob, context: AgentUserConfig) => Promise<string>;
export type ASRAgentRequest = (audio: Blob, context: AgentUserConfig) => Promise<string>;

export interface AudioAgent {
export interface ASRAgent {
name: string | string[];
modelKey: string;
enable: (context: AgentUserConfig) => boolean;
request: AudioAgentRequest;
request: ASRAgentRequest;
model: (ctx: AgentUserConfig) => string;
}
export type TTSAgentRequest = (text: string, context: AgentUserConfig) => Promise<Blob>;

export interface TTSAgent {
name: string;
modelKey: string;
enable: (context: AgentUserConfig) => boolean;
request: TTSAgentRequest;
model: (ctx: AgentUserConfig) => string;
}

Expand Down
Loading

0 comments on commit 2769352

Please sign in to comment.