From 8896fa82421bb775dc9808a85fac9624b2e77a33 Mon Sep 17 00:00:00 2001 From: adolphzhang Date: Fri, 22 Nov 2024 01:15:47 +0800 Subject: [PATCH] feat: allow quick switching of models even if you don't know the detailed name. feat: now bring back the audio dialog function and select either transcription or dialog via AUDIO_HANDLE_TYPE perf: optimize the /set command so that even if a shortcut key is not set, you can still modify the KEY environment variable using -KEY. Switch models by querying similarity through //c //v. You can modify the variable RERANK_MODELS to customize the base data. Process the base data through the default JINA RERANK MODEL. The default rerank service provider is jina, with the model being jina-colbert-v2. This processes the base data, currently also supporting obtaining embeddings data through the openai embedding model, and then indirectly obtaining values by finding the maximum cosine similarity value. You need to enable ENABLE_INTELLIGENT_MODEL to use this feature. --- src/agent/xai.ts | 2 +- src/config/config.ts | 11 +++++ src/telegram/command/index.ts | 2 +- src/telegram/command/system.ts | 14 +++--- src/telegram/handler/chat.ts | 18 +++++--- src/telegram/handler/handlers.ts | 28 +++++++++++ src/telegram/handler/index.ts | 3 ++ src/utils/data_calculation/classifier.ts | 22 +++++++++ src/utils/data_calculation/embedding.ts | 47 +++++++++++++++++++ src/utils/data_calculation/rerank.ts | 47 +++++++++++++++++++ src/utils/embedding/index.ts | 59 ------------------------ 11 files changed, 178 insertions(+), 75 deletions(-) create mode 100644 src/utils/data_calculation/classifier.ts create mode 100644 src/utils/data_calculation/embedding.ts create mode 100644 src/utils/data_calculation/rerank.ts delete mode 100644 src/utils/embedding/index.ts diff --git a/src/agent/xai.ts b/src/agent/xai.ts index 801df7cb..156375bb 100644 --- a/src/agent/xai.ts +++ b/src/agent/xai.ts @@ -1,9 +1,9 @@ +import type { CoreUserMessage } from 'ai'; import type { AgentUserConfig } from '../config/env'; import type { ChatAgent, ChatStreamTextHandler, LLMChatParams, LLMChatRequestParams, ResponseMessage } from './types'; import { createXai } from '@ai-sdk/xai'; import { warpLLMParams } from '.'; import { requestChatCompletionsV2 } from './request'; -import { CoreUserMessage } from 'ai'; export class XAI implements ChatAgent { readonly name = 'xai'; diff --git a/src/config/config.ts b/src/config/config.ts index 14e85a33..ea1c4f70 100644 --- a/src/config/config.ts +++ b/src/config/config.ts @@ -167,6 +167,8 @@ export class EnvironmentConfig { STORE_MEDIA_MESSAGE: boolean = false; // If true, will store text chunk when message separated to multiple chunks STORE_TEXT_CHUNK_MESSAGE: boolean = false; + // Audio handle type, 'transcribe' or 'chat' + AUDIO_HANDLE_TYPE = 'transcribe'; } // -- 通用配置 -- @@ -200,6 +202,7 @@ export class OpenAIConfig { * @deprecated */ OPENAI_NEED_TRANSFORM_MODEL: string[] = ['o1-mini-all', 'o1-mini-preview-all']; + OPENAI_EMBEDDING_MODEL = 'text-embedding-3-small'; } // -- DALLE 配置 -- @@ -370,4 +373,12 @@ export class ExtraUserConfig { MAX_STEPS = 3; // chat agent max retries MAX_RETRIES = 0; + // Rerank Agent, jina or openai(calculate the cosine similarity using embedding models to get the result) + RERANK_AGENT = 'jina'; + // Jina Rerank Model + JINA_RERANK_MODEL = 'jina-colbert-v2'; + // Rerank Models + RERANK_MODELS: string[] = ['gpt-4o-mini', 'gpt-4o-2024-05-13', 'gpt-4o-2024-08-06', 'chatgpt-4o-latest', 'o1-mini', 'o1-preview', 'claude-3-5-sonnet-20240620', 'claude-3-5-sonnet-20241012', 'gemini-1.5-flash-002', 'gemini-1.5-pro-002', 'gemini-1.5-flash-latest', 'gemini-1.5-pro-latest', 'gemini-exp-1114', 'grok-beta', 'grok-vision-beta', 'claude-3-5-haiku-20241012']; + // Whether to enable intelligent model processing + ENABLE_INTELLIGENT_MODEL = false; } diff --git a/src/telegram/command/index.ts b/src/telegram/command/index.ts index f994fb4f..f0f71b17 100644 --- a/src/telegram/command/index.ts +++ b/src/telegram/command/index.ts @@ -153,7 +153,7 @@ export async function handleCommandMessage(message: Telegram.Message, context: W for (const cmd of SYSTEM_COMMANDS) { if (text === cmd.command || text.startsWith(`${cmd.command} `)) { log.info(`[SYSTEM COMMAND] handle system command: ${cmd.command}`); - return await handleSystemCommand(message, text, cmd, context); + return handleSystemCommand(message, text, cmd, context); } } return null; diff --git a/src/telegram/command/system.ts b/src/telegram/command/system.ts index a0a8a4c1..9e3e5755 100644 --- a/src/telegram/command/system.ts +++ b/src/telegram/command/system.ts @@ -377,7 +377,7 @@ export class SetCommandHandler implements CommandHandler { try { if (!subcommand) { const detailSet = ENV.I18N.command?.detail?.set || 'Have no detailed information in the language'; - return sender.sendRichText(`\`\`\`plaintext\n${detailSet}\n\`\`\``, 'MarkdownV2'); + return sender.sendRichText(`
${detailSet}
`, 'HTML'); } const { keys, values } = this.parseMappings(context); @@ -477,7 +477,7 @@ export class SetCommandHandler implements CommandHandler { context: WorkerContext, sender: MessageSender, ): Promise { - let key = keys[flag]; + let key = keys[flag] || (Object.values(keys).includes(flag.slice(1)) ? flag.slice(1) : null); let mappedValue = values[value] ?? value; if (!key) { @@ -499,11 +499,11 @@ export class SetCommandHandler implements CommandHandler { ? `${context.USER_CONFIG.AI_PROVIDER.toUpperCase()}_${key}` : key; break; - case 'CURRENT_MODE': - if (!Object.keys(context.USER_CONFIG.MODES).includes(value)) { - return sender.sendPlainText(`mode ${value} not found. Support modes: ${Object.keys(context.USER_CONFIG.MODES).join(', ')}`); - } - break; + // case 'CURRENT_MODE': + // if (!Object.keys(context.USER_CONFIG.MODES).includes(value)) { + // return sender.sendPlainText(`mode ${value} not found. Support modes: ${Object.keys(context.USER_CONFIG.MODES).join(', ')}`); + // } + // break; case 'USE_TOOLS': if (value === 'on') { mappedValue = Object.keys(tools); diff --git a/src/telegram/handler/chat.ts b/src/telegram/handler/chat.ts index 4cb03586..97f918ca 100644 --- a/src/telegram/handler/chat.ts +++ b/src/telegram/handler/chat.ts @@ -1,5 +1,5 @@ /* eslint-disable unused-imports/no-unused-vars */ -import type { FilePart, ToolResultPart } from 'ai'; +import type { FilePart, TextPart, ToolResultPart } from 'ai'; import type * as Telegram from 'telegram-bot-api-types'; import type { ChatStreamTextHandler, HistoryModifier, ImageResult, LLMChatRequestParams } from '../../agent/types'; import type { WorkerContext } from '../../config/context'; @@ -10,7 +10,7 @@ import type { MessageHandler } from './types'; import { loadAudioLLM, loadChatLLM, loadImageGen } from '../../agent'; import { loadHistory, requestCompletionsFromLLM } from '../../agent/chat'; import { ENV } from '../../config/env'; -import { getLog, logSingleton } from '../../log/logDecortor'; +import { clearLog, getLog, logSingleton } from '../../log/logDecortor'; import { log } from '../../log/logger'; import { sendToolResult } from '../../tools'; import { imageToBase64String, renderBase64DataURI } from '../../utils/image'; @@ -92,7 +92,7 @@ export class ChatHandler implements MessageHandler { console.error('Error:', e); const sender = context.MIDDEL_CONTEXT.sender ?? MessageSender.from(context.SHARE_CONTEXT.botToken, message); const filtered = (e as Error).message.replace(context.SHARE_CONTEXT.botToken, '[REDACTED]'); - return sender.sendPlainText(`Error: ${filtered.substring(0, 4000)}`); + return sender.sendRichText(`
Error:${filtered.substring(0, 4000)}
`, 'HTML'); } }; @@ -116,8 +116,6 @@ export class ChatHandler implements MessageHandler { }; if (type !== 'text' && id) { - // const fileIds = await getStoreMediaIds(context.SHARE_CONTEXT, context.MIDDEL_CONTEXT.originalMessageInfo.media_group_id); - // id.push(...fileIds.filter(i => !id.includes(i))); const api = createTelegramBotAPI(context.SHARE_CONTEXT.botToken); const files = await Promise.all(id.map(i => api.getFileWithReturns({ file_id: i }))); const paths = files.map(f => f.result.file_path).filter(Boolean) as string[]; @@ -340,11 +338,17 @@ async function handleAudioToText( if (!agent) { return sender.sendPlainText('ERROR: Audio agent not found'); } - const url = (params.content as FilePart[])[0].data as string; + const url = (params.content as FilePart[]).at(-1)?.data as string; const audio = await fetch(url).then(b => b.blob()); const result = await agent.request(audio, context.USER_CONFIG); context.MIDDEL_CONTEXT.history.push({ role: 'user', content: result.text || '' }); - return sender.sendRichText(`${getLog(context.USER_CONFIG)}\n> \`${result.text}\``, 'MarkdownV2', 'chat'); + await sender.sendRichText(`${getLog(context.USER_CONFIG, false, false)}\n> \n${result.text}`, 'MarkdownV2', 'chat'); + if (ENV.AUDIO_HANDLE_TYPE === 'chat' && result.text) { + clearLog(context.USER_CONFIG); + const otherText = (params.content as TextPart[]).filter(c => c.type === 'text').map(c => c.text).join('\n').trim(); + return chatWithLLM(message, { role: 'user', content: `[AUDIO]: ${result.text}\n${otherText}` }, context, null); + } + return new Response('audio handle done'); } export async function sendImages(img: ImageResult, SEND_IMAGE_AS_FILE: boolean, sender: MessageSender, config: AgentUserConfig) { diff --git a/src/telegram/handler/handlers.ts b/src/telegram/handler/handlers.ts index 34d33e57..bd45cec7 100644 --- a/src/telegram/handler/handlers.ts +++ b/src/telegram/handler/handlers.ts @@ -8,6 +8,7 @@ import { WorkerContext } from '../../config/context'; import { ENV } from '../../config/env'; import { tagMessageIds } from '../../log/logDecortor'; import { log } from '../../log/logger'; +import { Rerank } from '../../utils/data_calculation/rerank'; import { createTelegramBotAPI } from '../api'; import { handleCommandMessage } from '../command'; import { loadChatRoleWithContext } from '../command/auth'; @@ -429,3 +430,30 @@ export class CheckForwarding implements MessageHandler { return null; }; } + +export class IntelligentModelProcess implements MessageHandler { + handle = async (message: Telegram.Message, context: WorkerContext): Promise => { + if (!context.USER_CONFIG.ENABLE_INTELLIGENT_MODEL) { + return null; + } + const regex = /^\s*\/\/(c|v)\s*(\S+)/; + const text = (message.text || message.caption || '').trim().match(regex); + if (text && text[1] && text[2]) { + const rerank = new Rerank(); + try { + const similarityModel = (await rerank.rank(context.USER_CONFIG, [text[2], ...context.USER_CONFIG.RERANK_MODELS], 1))[0].name; + const mode = text[1]; + const textReplace = `/set ${mode === 'c' ? '-CHAT_MODEL' : `-VISION_MODEL`} ${similarityModel} `; + if (message.text) { + message.text = textReplace + message.text.slice(text[0].length).trim(); + } else if (message.caption) { + message.caption = textReplace + message.caption.slice(text[0].length).trim(); + } + } catch (error) { + log.error(`[INTELLIGENT MODEL PROCESS] Rerank error: ${error}`); + return null; + } + } + return null; + }; +} diff --git a/src/telegram/handler/index.ts b/src/telegram/handler/index.ts index 74001540..5d3701ce 100644 --- a/src/telegram/handler/index.ts +++ b/src/telegram/handler/index.ts @@ -14,6 +14,7 @@ import { HandlerCallbackQuery, HandlerInlineQuery, InitUserConfig, + IntelligentModelProcess, MessageFilter, OldMessageFilter, SaveLastMessage, @@ -66,6 +67,8 @@ async function handleMessage(token: string, message: Telegram.Message, isForward new SaveLastMessage(), // 初始化用户配置 new InitUserConfig(), + // 动态模型处理 + new IntelligentModelProcess(), // 处理命令消息 new CommandHandler(), // 检查是否是转发消息 diff --git a/src/utils/data_calculation/classifier.ts b/src/utils/data_calculation/classifier.ts new file mode 100644 index 00000000..b428ab18 --- /dev/null +++ b/src/utils/data_calculation/classifier.ts @@ -0,0 +1,22 @@ +import type { AgentUserConfig } from '../../config/env'; + +export class JinaClassifier { + readonly model = 'jina-embeddings-v3'; + readonly api = 'https://api.jina.ai/v1/classify'; + + readonly request = async (context: AgentUserConfig, data: string[], labels: string[]) => { + const body = { + model: this.model, + input: data, + labels, + }; + return fetch(this.api, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + 'Authorization': `Bearer ${context.JINA_API_KEY}`, + }, + body: JSON.stringify(body), + }).then(res => res.json()).then(res => res.data.map((item: any) => ({ label: item.prediction, value: data[item.index] }))); + }; +} diff --git a/src/utils/data_calculation/embedding.ts b/src/utils/data_calculation/embedding.ts new file mode 100644 index 00000000..cdb7c92c --- /dev/null +++ b/src/utils/data_calculation/embedding.ts @@ -0,0 +1,47 @@ +import type { AgentUserConfig } from '../../config/env'; +import { createOpenAI } from '@ai-sdk/openai'; +import { embedMany } from 'ai'; +import { OpenAIBase } from '../../agent/openai'; + +export class OpenaiEmbedding extends OpenAIBase { + readonly request = async (context: AgentUserConfig, data: string[]) => { + const { embeddings, values } = await embedMany({ + model: createOpenAI({ + baseURL: context.OPENAI_API_BASE, + apiKey: this.apikey(context), + }).embedding(context.OPENAI_EMBEDDING_MODEL), + values: data, + maxRetries: 0, + }); + return values.map((value, i) => ({ embed: embeddings[i], value })); + }; +} + +export class JinaEmbedding { + readonly task: 'retrieval.query' | 'retrieval.passage' | 'separation' | 'classification' | 'text-matching' | undefined; + readonly model = 'jina-embeddings-v3'; + + constructor(task?: 'retrieval.query' | 'retrieval.passage' | 'separation' | 'classification' | 'text-matching') { + this.task = task; + } + + readonly request = async (context: AgentUserConfig, data: string[]) => { + const url = 'https://api.jina.ai/v1/embeddings'; + const body = { + model: this.model, + task: this.task, + dimensions: 1024, + late_chunking: false, + embedding_type: 'float', + input: data, + }; + return fetch(url, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + 'Authorization': `Bearer ${context.JINA_API_KEY}`, + }, + body: JSON.stringify(body), + }).then(res => res.json()).then(res => res.data.map((item: any) => ({ embed: item.embedding, value: data[item.index] }))); + }; +} diff --git a/src/utils/data_calculation/rerank.ts b/src/utils/data_calculation/rerank.ts new file mode 100644 index 00000000..bcb2037d --- /dev/null +++ b/src/utils/data_calculation/rerank.ts @@ -0,0 +1,47 @@ +import type { AgentUserConfig } from '../../config/env'; +import { cosineSimilarity } from 'ai'; +import { OpenaiEmbedding } from './embedding'; + +export class Rerank { + readonly rank = (context: AgentUserConfig, data: string[], topN: number = 1) => { + switch (context.RERANK_AGENT) { + case 'jina': + return this.jina(context, data, topN); + case 'openai': + return this.openai(context, data, topN); + default: + throw new Error('Invalid RERANK_AGENT'); + } + }; + + readonly jina = async (context: AgentUserConfig, data: string[], topN: number) => { + const url = 'https://api.jina.ai/v1/rerank'; + const result = await fetch(url, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + 'Authorization': `Bearer ${context.JINA_API_KEY}`, + }, + body: JSON.stringify({ + model: context.JINA_RERANK_MODEL, + query: data[0], + documents: data.slice(1), + top_n: topN, + }), + }).then(res => res.json()); + if (!result.results) { + throw new Error(`No results found. details: ${JSON.stringify(result)}`); + } + return result.results.map((item: any) => ({ similar: item.relevance_score, name: item.document.text })); + }; + + readonly openai = async (context: any, data: string[], topN: number) => { + const embeddings = await new OpenaiEmbedding().request(context, data); + const inputEmbeddings = embeddings[0].embed; + return embeddings.slice(1) + .map(({ embed, value }) => ({ similar: cosineSimilarity(inputEmbeddings, embed), value })) + .sort((a, b) => b.similar - a.similar) + .slice(0, topN) + .map(i => i.value); + }; +} diff --git a/src/utils/embedding/index.ts b/src/utils/embedding/index.ts deleted file mode 100644 index d1dd06f6..00000000 --- a/src/utils/embedding/index.ts +++ /dev/null @@ -1,59 +0,0 @@ -import type { AgentUserConfig } from '../../config/env'; -import { createOpenAI } from '@ai-sdk/openai'; -import { cosineSimilarity, embedMany } from 'ai'; -import { OpenAIBase } from '../../agent/openai'; - -export class OpenaiEmbedding extends OpenAIBase { - readonly request = async (context: AgentUserConfig, data: string[]) => { - const { embeddings, values } = await embedMany({ - model: createOpenAI({ - baseURL: context.OPENAI_API_BASE, - apiKey: this.apikey(context), - }).embedding('text-embedding-3-small'), - values: data, - maxRetries: 0, - }); - return values.map((value, i) => ({ embed: embeddings[i], value })); - }; -} - -export class Rerank { - jina_base_url = 'https://api.jina.ai/v1/rerank'; - rank = (context: any, data: string[], topN: number = 1) => { - switch (context.RERANK_AGENT) { - case 'jina': - return this.jina(context, data, topN); - case 'openai': - return this.openai(context, data, topN); - default: - throw new Error('Invalid RERANK_AGENT'); - } - }; - - private jina = (context: any, data: string[], topN: number) => { - return fetch(this.jina_base_url, { - method: 'POST', - headers: { - 'Content-Type': 'application/json', - 'Authorization': `Bearer ${context.JINA_API_KEY}`, - }, - body: JSON.stringify({ - model: 'jina-reranker-v1-tiny-en', - query: data[0], - documents: data.slice(1), - top_n: topN, - }), - }).then(res => res.json()).then(res => res.results - .map((item: any) => ({ similar: item.relevance_score, name: item.document.text }))); - }; - - private openai = async (context: any, data: string[], topN: number) => { - const embeddings = await new OpenaiEmbedding().request(context, data); - const inputEmbeddings = embeddings[0].embed; - return embeddings.slice(1) - .map(({ embed, value }) => ({ similar: cosineSimilarity(inputEmbeddings, embed), value })) - .sort((a, b) => b.similar - a.similar) - .slice(0, topN) - .map(i => i.value); - }; -}