Skip to content

Commit

Permalink
feat: allow quick switching of models even if you don't know the deta…
Browse files Browse the repository at this point in the history
…iled name.

feat: now bring back the audio dialog function and select either transcription or dialog via AUDIO_HANDLE_TYPE

perf: optimize the /set command so that even if a shortcut key is not set, you can still modify the KEY environment variable using -KEY.

Switch models by querying similarity through //c //v. You can modify the variable RERANK_MODELS to customize the base data. Process the base data through the default JINA RERANK MODEL. The default rerank service provider is jina, with the model being jina-colbert-v2. This processes the base data, currently also supporting obtaining embeddings data through the openai embedding model, and then indirectly obtaining values by finding the maximum cosine similarity value. You need to enable ENABLE_INTELLIGENT_MODEL to use this feature.
  • Loading branch information
adolphnov committed Nov 21, 2024
1 parent a806205 commit 8896fa8
Show file tree
Hide file tree
Showing 11 changed files with 178 additions and 75 deletions.
2 changes: 1 addition & 1 deletion src/agent/xai.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import type { CoreUserMessage } from 'ai';
import type { AgentUserConfig } from '../config/env';
import type { ChatAgent, ChatStreamTextHandler, LLMChatParams, LLMChatRequestParams, ResponseMessage } from './types';
import { createXai } from '@ai-sdk/xai';
import { warpLLMParams } from '.';
import { requestChatCompletionsV2 } from './request';
import { CoreUserMessage } from 'ai';

export class XAI implements ChatAgent {
readonly name = 'xai';
Expand Down
11 changes: 11 additions & 0 deletions src/config/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,8 @@ export class EnvironmentConfig {
STORE_MEDIA_MESSAGE: boolean = false;
// If true, will store text chunk when message separated to multiple chunks
STORE_TEXT_CHUNK_MESSAGE: boolean = false;
// Audio handle type, 'transcribe' or 'chat'
AUDIO_HANDLE_TYPE = 'transcribe';
}

// -- 通用配置 --
Expand Down Expand Up @@ -200,6 +202,7 @@ export class OpenAIConfig {
* @deprecated
*/
OPENAI_NEED_TRANSFORM_MODEL: string[] = ['o1-mini-all', 'o1-mini-preview-all'];
OPENAI_EMBEDDING_MODEL = 'text-embedding-3-small';
}

// -- DALLE 配置 --
Expand Down Expand Up @@ -370,4 +373,12 @@ export class ExtraUserConfig {
MAX_STEPS = 3;
// chat agent max retries
MAX_RETRIES = 0;
// Rerank Agent, jina or openai(calculate the cosine similarity using embedding models to get the result)
RERANK_AGENT = 'jina';
// Jina Rerank Model
JINA_RERANK_MODEL = 'jina-colbert-v2';
// Rerank Models
RERANK_MODELS: string[] = ['gpt-4o-mini', 'gpt-4o-2024-05-13', 'gpt-4o-2024-08-06', 'chatgpt-4o-latest', 'o1-mini', 'o1-preview', 'claude-3-5-sonnet-20240620', 'claude-3-5-sonnet-20241012', 'gemini-1.5-flash-002', 'gemini-1.5-pro-002', 'gemini-1.5-flash-latest', 'gemini-1.5-pro-latest', 'gemini-exp-1114', 'grok-beta', 'grok-vision-beta', 'claude-3-5-haiku-20241012'];
// Whether to enable intelligent model processing
ENABLE_INTELLIGENT_MODEL = false;
}
2 changes: 1 addition & 1 deletion src/telegram/command/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ export async function handleCommandMessage(message: Telegram.Message, context: W
for (const cmd of SYSTEM_COMMANDS) {
if (text === cmd.command || text.startsWith(`${cmd.command} `)) {
log.info(`[SYSTEM COMMAND] handle system command: ${cmd.command}`);
return await handleSystemCommand(message, text, cmd, context);
return handleSystemCommand(message, text, cmd, context);
}
}
return null;
Expand Down
14 changes: 7 additions & 7 deletions src/telegram/command/system.ts
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ export class SetCommandHandler implements CommandHandler {
try {
if (!subcommand) {
const detailSet = ENV.I18N.command?.detail?.set || 'Have no detailed information in the language';
return sender.sendRichText(`\`\`\`plaintext\n${detailSet}\n\`\`\``, 'MarkdownV2');
return sender.sendRichText(`<pre>${detailSet}</pre>`, 'HTML');
}

const { keys, values } = this.parseMappings(context);
Expand Down Expand Up @@ -477,7 +477,7 @@ export class SetCommandHandler implements CommandHandler {
context: WorkerContext,
sender: MessageSender,
): Promise<string | Response> {
let key = keys[flag];
let key = keys[flag] || (Object.values(keys).includes(flag.slice(1)) ? flag.slice(1) : null);
let mappedValue = values[value] ?? value;

if (!key) {
Expand All @@ -499,11 +499,11 @@ export class SetCommandHandler implements CommandHandler {
? `${context.USER_CONFIG.AI_PROVIDER.toUpperCase()}_${key}`
: key;
break;
case 'CURRENT_MODE':
if (!Object.keys(context.USER_CONFIG.MODES).includes(value)) {
return sender.sendPlainText(`mode ${value} not found. Support modes: ${Object.keys(context.USER_CONFIG.MODES).join(', ')}`);
}
break;
// case 'CURRENT_MODE':
// if (!Object.keys(context.USER_CONFIG.MODES).includes(value)) {
// return sender.sendPlainText(`mode ${value} not found. Support modes: ${Object.keys(context.USER_CONFIG.MODES).join(', ')}`);
// }
// break;
case 'USE_TOOLS':
if (value === 'on') {
mappedValue = Object.keys(tools);
Expand Down
18 changes: 11 additions & 7 deletions src/telegram/handler/chat.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/* eslint-disable unused-imports/no-unused-vars */
import type { FilePart, ToolResultPart } from 'ai';
import type { FilePart, TextPart, ToolResultPart } from 'ai';
import type * as Telegram from 'telegram-bot-api-types';
import type { ChatStreamTextHandler, HistoryModifier, ImageResult, LLMChatRequestParams } from '../../agent/types';
import type { WorkerContext } from '../../config/context';
Expand All @@ -10,7 +10,7 @@ import type { MessageHandler } from './types';
import { loadAudioLLM, loadChatLLM, loadImageGen } from '../../agent';
import { loadHistory, requestCompletionsFromLLM } from '../../agent/chat';
import { ENV } from '../../config/env';
import { getLog, logSingleton } from '../../log/logDecortor';
import { clearLog, getLog, logSingleton } from '../../log/logDecortor';
import { log } from '../../log/logger';
import { sendToolResult } from '../../tools';
import { imageToBase64String, renderBase64DataURI } from '../../utils/image';
Expand Down Expand Up @@ -92,7 +92,7 @@ export class ChatHandler implements MessageHandler<WorkerContext> {
console.error('Error:', e);
const sender = context.MIDDEL_CONTEXT.sender ?? MessageSender.from(context.SHARE_CONTEXT.botToken, message);
const filtered = (e as Error).message.replace(context.SHARE_CONTEXT.botToken, '[REDACTED]');
return sender.sendPlainText(`Error: ${filtered.substring(0, 4000)}`);
return sender.sendRichText(`<pre>Error:${filtered.substring(0, 4000)}</pre>`, 'HTML');
}
};

Expand All @@ -116,8 +116,6 @@ export class ChatHandler implements MessageHandler<WorkerContext> {
};

if (type !== 'text' && id) {
// const fileIds = await getStoreMediaIds(context.SHARE_CONTEXT, context.MIDDEL_CONTEXT.originalMessageInfo.media_group_id);
// id.push(...fileIds.filter(i => !id.includes(i)));
const api = createTelegramBotAPI(context.SHARE_CONTEXT.botToken);
const files = await Promise.all(id.map(i => api.getFileWithReturns({ file_id: i })));
const paths = files.map(f => f.result.file_path).filter(Boolean) as string[];
Expand Down Expand Up @@ -340,11 +338,17 @@ async function handleAudioToText(
if (!agent) {
return sender.sendPlainText('ERROR: Audio agent not found');
}
const url = (params.content as FilePart[])[0].data as string;
const url = (params.content as FilePart[]).at(-1)?.data as string;
const audio = await fetch(url).then(b => b.blob());
const result = await agent.request(audio, context.USER_CONFIG);
context.MIDDEL_CONTEXT.history.push({ role: 'user', content: result.text || '' });
return sender.sendRichText(`${getLog(context.USER_CONFIG)}\n> \`${result.text}\``, 'MarkdownV2', 'chat');
await sender.sendRichText(`${getLog(context.USER_CONFIG, false, false)}\n> \n${result.text}`, 'MarkdownV2', 'chat');
if (ENV.AUDIO_HANDLE_TYPE === 'chat' && result.text) {
clearLog(context.USER_CONFIG);
const otherText = (params.content as TextPart[]).filter(c => c.type === 'text').map(c => c.text).join('\n').trim();
return chatWithLLM(message, { role: 'user', content: `[AUDIO]: ${result.text}\n${otherText}` }, context, null);
}
return new Response('audio handle done');
}

export async function sendImages(img: ImageResult, SEND_IMAGE_AS_FILE: boolean, sender: MessageSender, config: AgentUserConfig) {
Expand Down
28 changes: 28 additions & 0 deletions src/telegram/handler/handlers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import { WorkerContext } from '../../config/context';
import { ENV } from '../../config/env';
import { tagMessageIds } from '../../log/logDecortor';
import { log } from '../../log/logger';
import { Rerank } from '../../utils/data_calculation/rerank';
import { createTelegramBotAPI } from '../api';
import { handleCommandMessage } from '../command';
import { loadChatRoleWithContext } from '../command/auth';
Expand Down Expand Up @@ -429,3 +430,30 @@ export class CheckForwarding implements MessageHandler<WorkerContext> {
return null;
};
}

export class IntelligentModelProcess implements MessageHandler<WorkerContext> {
handle = async (message: Telegram.Message, context: WorkerContext): Promise<Response | null> => {
if (!context.USER_CONFIG.ENABLE_INTELLIGENT_MODEL) {
return null;
}
const regex = /^\s*\/\/(c|v)\s*(\S+)/;
const text = (message.text || message.caption || '').trim().match(regex);
if (text && text[1] && text[2]) {
const rerank = new Rerank();
try {
const similarityModel = (await rerank.rank(context.USER_CONFIG, [text[2], ...context.USER_CONFIG.RERANK_MODELS], 1))[0].name;
const mode = text[1];
const textReplace = `/set ${mode === 'c' ? '-CHAT_MODEL' : `-VISION_MODEL`} ${similarityModel} `;
if (message.text) {
message.text = textReplace + message.text.slice(text[0].length).trim();
} else if (message.caption) {
message.caption = textReplace + message.caption.slice(text[0].length).trim();
}
} catch (error) {
log.error(`[INTELLIGENT MODEL PROCESS] Rerank error: ${error}`);
return null;
}
}
return null;
};
}
3 changes: 3 additions & 0 deletions src/telegram/handler/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import {
HandlerCallbackQuery,
HandlerInlineQuery,
InitUserConfig,
IntelligentModelProcess,
MessageFilter,
OldMessageFilter,
SaveLastMessage,
Expand Down Expand Up @@ -66,6 +67,8 @@ async function handleMessage(token: string, message: Telegram.Message, isForward
new SaveLastMessage(),
// 初始化用户配置
new InitUserConfig(),
// 动态模型处理
new IntelligentModelProcess(),
// 处理命令消息
new CommandHandler(),
// 检查是否是转发消息
Expand Down
22 changes: 22 additions & 0 deletions src/utils/data_calculation/classifier.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import type { AgentUserConfig } from '../../config/env';

export class JinaClassifier {
readonly model = 'jina-embeddings-v3';
readonly api = 'https://api.jina.ai/v1/classify';

readonly request = async (context: AgentUserConfig, data: string[], labels: string[]) => {
const body = {
model: this.model,
input: data,
labels,
};
return fetch(this.api, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
'Authorization': `Bearer ${context.JINA_API_KEY}`,
},
body: JSON.stringify(body),
}).then(res => res.json()).then(res => res.data.map((item: any) => ({ label: item.prediction, value: data[item.index] })));
};
}
47 changes: 47 additions & 0 deletions src/utils/data_calculation/embedding.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import type { AgentUserConfig } from '../../config/env';
import { createOpenAI } from '@ai-sdk/openai';
import { embedMany } from 'ai';
import { OpenAIBase } from '../../agent/openai';

export class OpenaiEmbedding extends OpenAIBase {
readonly request = async (context: AgentUserConfig, data: string[]) => {
const { embeddings, values } = await embedMany({
model: createOpenAI({
baseURL: context.OPENAI_API_BASE,
apiKey: this.apikey(context),
}).embedding(context.OPENAI_EMBEDDING_MODEL),
values: data,
maxRetries: 0,
});
return values.map((value, i) => ({ embed: embeddings[i], value }));
};
}

export class JinaEmbedding {
readonly task: 'retrieval.query' | 'retrieval.passage' | 'separation' | 'classification' | 'text-matching' | undefined;
readonly model = 'jina-embeddings-v3';

constructor(task?: 'retrieval.query' | 'retrieval.passage' | 'separation' | 'classification' | 'text-matching') {
this.task = task;
}

readonly request = async (context: AgentUserConfig, data: string[]) => {
const url = 'https://api.jina.ai/v1/embeddings';
const body = {
model: this.model,
task: this.task,
dimensions: 1024,
late_chunking: false,
embedding_type: 'float',
input: data,
};
return fetch(url, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
'Authorization': `Bearer ${context.JINA_API_KEY}`,
},
body: JSON.stringify(body),
}).then(res => res.json()).then(res => res.data.map((item: any) => ({ embed: item.embedding, value: data[item.index] })));
};
}
47 changes: 47 additions & 0 deletions src/utils/data_calculation/rerank.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import type { AgentUserConfig } from '../../config/env';
import { cosineSimilarity } from 'ai';
import { OpenaiEmbedding } from './embedding';

export class Rerank {
readonly rank = (context: AgentUserConfig, data: string[], topN: number = 1) => {
switch (context.RERANK_AGENT) {
case 'jina':
return this.jina(context, data, topN);
case 'openai':
return this.openai(context, data, topN);
default:
throw new Error('Invalid RERANK_AGENT');
}
};

readonly jina = async (context: AgentUserConfig, data: string[], topN: number) => {
const url = 'https://api.jina.ai/v1/rerank';
const result = await fetch(url, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
'Authorization': `Bearer ${context.JINA_API_KEY}`,
},
body: JSON.stringify({
model: context.JINA_RERANK_MODEL,
query: data[0],
documents: data.slice(1),
top_n: topN,
}),
}).then(res => res.json());
if (!result.results) {
throw new Error(`No results found. details: ${JSON.stringify(result)}`);
}
return result.results.map((item: any) => ({ similar: item.relevance_score, name: item.document.text }));
};

readonly openai = async (context: any, data: string[], topN: number) => {
const embeddings = await new OpenaiEmbedding().request(context, data);
const inputEmbeddings = embeddings[0].embed;
return embeddings.slice(1)
.map(({ embed, value }) => ({ similar: cosineSimilarity(inputEmbeddings, embed), value }))
.sort((a, b) => b.similar - a.similar)
.slice(0, topN)
.map(i => i.value);
};
}
59 changes: 0 additions & 59 deletions src/utils/embedding/index.ts

This file was deleted.

0 comments on commit 8896fa8

Please sign in to comment.