Skip to content

Commit

Permalink
feat: allow google/vertex to send files/images/audio as links text.
Browse files Browse the repository at this point in the history
chore: add empty message check for private chat
perf: move message checking to sendHandler
  • Loading branch information
adolphnov committed Jan 1, 2025
1 parent 4fcb40e commit 9991ba6
Show file tree
Hide file tree
Showing 6 changed files with 78 additions and 17 deletions.
62 changes: 60 additions & 2 deletions src/agent/google.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import type { CoreUserMessage } from 'ai';
import type { CoreUserMessage, FilePart, ImagePart, UserContent } from 'ai';
import type { AgentUserConfig } from '../config/env';
import type { ChatAgent, ChatStreamTextHandler, LLMChatParams, LLMChatRequestParams, ResponseMessage } from './types';
import { createLlmModel, warpLLMParams } from '.';
Expand All @@ -24,11 +24,69 @@ export class Google implements ChatAgent {
};

readonly request = async (params: LLMChatParams, context: AgentUserConfig, onStream: ChatStreamTextHandler | null): Promise<{ messages: ResponseMessage[]; content: string }> => {
const userMessage = params.messages.at(-1) as CoreUserMessage;
const userMessage = handleUrl(params.messages.at(-1) as CoreUserMessage);
const languageModelV1 = await createLlmModel(this.model(context, userMessage), context);
return requestChatCompletionsV2(await warpLLMParams({
model: languageModelV1,
messages: params.messages,
}, context), onStream);
};
}

export function handleUrl(messages: CoreUserMessage): CoreUserMessage {
if (typeof messages.content === 'string') {
const { data, text } = extractUrls(messages.content);
if (data.length > 0) {
const newMessage: UserContent = [];
newMessage.push({
type: 'text',
text,
});
data.forEach(i => newMessage.push({
type: i.type as 'image' | 'file',
[i.type === 'image' ? 'url' : 'data']: i.url,
mimeType: i.mimeType,
} as unknown as FilePart | ImagePart));
messages.content = newMessage;
}
}
return messages;
}

function extractUrls(str: string): { data: { type: string; url: string; mimeType: string }[]; text: string } {
const urlRegex = /(https?:\/\/\S+)/g;
const matches = str.match(urlRegex) || [];
const supportTypes = {
pdf: 'application/pdf',
mp3: 'audio/mpeg',
aac: 'audio/aac',
flac: 'audio/flac',
ogg: 'audio/ogg',
wav: 'audio/wav',
mp4: 'video/mp4',
jpg: 'image/jpeg',
jpeg: 'image/jpeg',
png: 'image/png',
gif: 'image/gif',
js: 'text/javascript',
python: 'text/x-python',
html: 'text/html',
css: 'text/css',
xml: 'application/xml',
csv: 'text/csv',
rtf: 'text/rtf',
txt: 'text/plain',
md: 'text/markdown',
};
return {
data: matches.map((i) => {
const type = i.split('.').pop()?.replace(/\?.*$/, '') as keyof typeof supportTypes;
return {
mimeType: (type && supportTypes[type]) || 'text/html',
url: i,
type: supportTypes[type]?.startsWith('image') ? 'image' : 'file',
};
}),
text: str.replace(urlRegex, '').trim(),
};
}
8 changes: 5 additions & 3 deletions src/agent/model_middleware.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,24 @@ export function AIMiddleware({ config, activeTools, onStream, toolChoice, messag
let sendToolCall = false;
let step = 0;
let rawSystemPrompt: string | undefined;
const isThinking = chatModel.includes('thinking');
let isThinking = false;
let thinkingEnd = false;
let isThinkingStart = true;
return {
wrapGenerate: async ({ doGenerate, params, model }) => {
log.info('doGenerate called');
warpModel(model, config, activeTools, (params.mode as any).toolChoice, chatModel);
log.info(`modelId: ${model.modelId}`);
isThinking = model.modelId.includes('thinking');
recordModelLog(config, model, activeTools, (params.mode as any).toolChoice);
const result = await doGenerate();
log.debug(`doGenerate result: ${JSON.stringify(result)}`);
return result;
},

wrapStream: async ({ doStream, params, model }) => {
log.info('doStream called');
warpModel(model, config, activeTools, (params.mode as any).toolChoice, chatModel);
log.info(`modelId: ${model.modelId}`);
isThinking = model.modelId.includes('thinking');
recordModelLog(config, model, activeTools, (params.mode as any).toolChoice);
return doStream();
},
Expand Down
10 changes: 1 addition & 9 deletions src/agent/request.ts
Original file line number Diff line number Diff line change
Expand Up @@ -142,9 +142,6 @@ export async function streamHandler(stream: AsyncIterable<any>, contentExtractor
let lastChunk = '';
const maxLength = 10_000;

const immediatePromise = Promise.resolve('[PROMISE DONE]');
let sendPromise: Promise<any> | null = null;

try {
for await (const part of stream) {
const textPart = contentExtractor(part);
Expand All @@ -160,14 +157,9 @@ export async function streamHandler(stream: AsyncIterable<any>, contentExtractor
lastChunk = textPart;

if (lastChunk && lengthDelta > updateStep) {
// 已发送过消息且消息未发送完成
if (sendPromise && (await Promise.race([sendPromise, immediatePromise]) === '[PROMISE DONE]')) {
continue;
}

lengthDelta = 0;
updateStep = Math.min(updateStep + 40, maxLength);
sendPromise = onStream.send(`${contentFull.trimEnd()}●`);
onStream.send(`${contentFull.trimEnd()}●`);
}
}
contentFull += lastChunk;
Expand Down
3 changes: 2 additions & 1 deletion src/agent/vertex.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import { createVertex } from '@ai-sdk/google-vertex';
import { experimental_generateImage as generateImage } from 'ai';
import { createLlmModel, warpLLMParams } from '.';
import { Log } from '../log/logDecortor';
import { handleUrl } from './google';
import { requestChatCompletionsV2 } from './request';

class VertexBase {
Expand All @@ -29,7 +30,7 @@ export class Vertex extends VertexBase implements ChatAgent {
readonly modelKey = 'VERTEX_CHAT_MODEL';

readonly request = async (params: LLMChatParams, context: AgentUserConfig, onStream: ChatStreamTextHandler | null): Promise<{ messages: ResponseMessage[]; content: string }> => {
const userMessage = params.messages.at(-1) as CoreUserMessage;
const userMessage = handleUrl(params.messages.at(-1) as CoreUserMessage);
const languageModelV1 = await createLlmModel(this.model(context, userMessage), context);
return requestChatCompletionsV2(await warpLLMParams({
model: languageModelV1,
Expand Down
8 changes: 6 additions & 2 deletions src/telegram/handler/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,8 @@ export function OnStreamHander(sender: MessageSender | ChosenInlineSender, conte
};
};

const immediatePromise = Promise.resolve('[PROMISE DONE]');

const streamSender = {
send: null as ((text: string, isEnd: boolean) => Promise<any>) | null,
end: null as ((text: string) => Promise<any>) | null,
Expand All @@ -261,8 +263,10 @@ export function OnStreamHander(sender: MessageSender | ChosenInlineSender, conte
log.info(`Need await: ${(nextEnableTime || 0) - Date.now()}ms`);
return;
}

await sentPromise;
// 防止最后可能存在两个sendPromise
if (sentPromise && (await Promise.race([sentPromise, immediatePromise]) === '[PROMISE DONE]')) {
return;
}

// 设置最小流间隔
if (sendInterval > 0) {
Expand Down
4 changes: 4 additions & 0 deletions src/telegram/handler/group.ts
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,10 @@ export class GroupMention implements MessageHandler {
// 非群组消息不作判断,交给下一个中间件处理
if (!isTelegramChatTypeGroup(message.chat.type)) {
this.mergeMessage(false, message);
const noneMessage = await this.noneMessage(message, context);
if (noneMessage instanceof Response) {
return noneMessage;
}
return this.furtherChecker(message, context);
}

Expand Down

0 comments on commit 9991ba6

Please sign in to comment.