Skip to content

Commit

Permalink
feat: add google embedding
Browse files Browse the repository at this point in the history
perf: simplify rerank logic
  • Loading branch information
adolphnov committed Dec 21, 2024
1 parent 5f9d9f2 commit 6019946
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 27 deletions.
11 changes: 6 additions & 5 deletions src/config/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,8 @@ export class GeminiConfig {
GOOGLE_CHAT_MODEL = 'gemini-1.5-flash';
// Google Gemini Vision Model
GOOGLE_VISION_MODEL = 'gemini-1.5-flash';
// Google Embedding Model
GOOGLE_EMBEDDING_MODEL = 'text-embedding-004';
}

// -- Mistral 配置 --
Expand Down Expand Up @@ -313,8 +315,6 @@ export class OpenAILikeConfig {
OAILIKE_EMBEDDING_MODEL = 'BAAI/bge-m3';
// oailike rerank model
OAILIKE_RERANK_MODEL = 'BAAI/bge-reranker-v2-m3';
// oailike rerank type, v1 means use embedding model, v2 means use rerank model to rerank
OAILIKE_RERANK_TYPE = 'v2';
// oailike asr model
OAILIKE_STT_MODEL = 'FunAudioLLM/SenseVoiceSmall';
// oailike tts model
Expand Down Expand Up @@ -353,7 +353,7 @@ export class DefineKeys {
}

export class ExtraUserConfig {
MAPPING_KEY = '-p:SYSTEM_INIT_MESSAGE|-n:MAX_HISTORY_LENGTH|-a:AI_PROVIDER|-ai:AI_IMAGE_PROVIDER|-m:CHAT_MODEL|-md:CURRENT_MODE|-v:VISION_MODEL|-t:OPENAI_TTS_MODEL|-ex:OPENAI_API_EXTRA_PARAMS|-mk:MAPPING_KEY|-mv:MAPPING_VALUE|-asap:FUNCTION_REPLY_ASAP|-tm:TOOL_MODEL|-tool:USE_TOOLS|-oli:IMAGE_MODEL|-th:TEXT_HANDLE_TYPE|-to:TEXT_OUTPUT|-ah:AUDIO_HANDLE_TYPE|-ao:AUDIO_OUTPUT|-act:AUDIO_CONTAINS_TEXT|-as:AI_ASR_PROVIDER|-at:AI_TTS_PROVIDER';
MAPPING_KEY = '-p:SYSTEM_INIT_MESSAGE|-n:MAX_HISTORY_LENGTH|-a:AI_PROVIDER|-ai:AI_IMAGE_PROVIDER|-m:CHAT_MODEL|-md:CURRENT_MODE|-v:VISION_MODEL|-t:OPENAI_TTS_MODEL|-ex:OPENAI_API_EXTRA_PARAMS|-mk:MAPPING_KEY|-mv:MAPPING_VALUE|-asap:FUNCTION_REPLY_ASAP|-tm:TOOL_MODEL|-tool:USE_TOOLS|-oli:IMAGE_MODEL|-th:TEXT_HANDLE_TYPE|-to:TEXT_OUTPUT|-ah:AUDIO_HANDLE_TYPE|-ao:AUDIO_OUTPUT|-act:AUDIO_CONTAINS_TEXT|-as:AI_ASR_PROVIDER|-at:AI_TTS_PROVIDER|-ra:RERANK_AGENT';
// /set command mapping value, separated by |, : separates multiple relationships
MAPPING_VALUE = '';
// MAPPING_VALUE = "cson:claude-3-5-sonnet-20240620|haiku:claude-3-haiku-20240307|g4m:gpt-4o-mini|g4:gpt-4o|rp+:command-r-plus";
Expand Down Expand Up @@ -400,8 +400,9 @@ export class ExtraUserConfig {
MAX_STEPS = 3;
// chat agent max retries
MAX_RETRIES = 0;
// Rerank Agent, jina or openai or oailike (calculate the cosine similarity using embedding models to get the result)
RERANK_AGENT = 'oailike';
// Rerank Agent, jina or openai or oailikeV1 or oailikeV2 or google
// oailikeV1 means use embedding model, oailikeV2 means use rerank model to rerank
RERANK_AGENT = 'google';
// Jina Rerank Model
JINA_RERANK_MODEL = 'jina-colbert-v2';
// Rerank Models
Expand Down
15 changes: 15 additions & 0 deletions src/utils/data_calculation/embedding.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import type { AgentUserConfig } from '../../config/env';
import { createGoogleGenerativeAI } from '@ai-sdk/google';
import { createOpenAI } from '@ai-sdk/openai';
import { embedMany } from 'ai';
import { OpenAIBase } from '../../agent/openai';
Expand Down Expand Up @@ -59,3 +60,17 @@ export class OpenAILikeEmbedding extends OpenAILikeBase {
return values.map((value, i) => ({ embed: embeddings[i], value }));
};
}

export class GoogleEmbedding {
readonly request = async (data: string[], context: AgentUserConfig) => {
const { embeddings, values } = await embedMany({
model: createGoogleGenerativeAI({
baseURL: context.GOOGLE_API_BASE,
apiKey: context.GOOGLE_API_KEY || undefined,
}).textEmbeddingModel(context.GOOGLE_EMBEDDING_MODEL),
values: data,
maxRetries: 0,
});
return values.map((value, i) => ({ embed: embeddings[i], value }));
};
}
44 changes: 22 additions & 22 deletions src/utils/data_calculation/rerank.ts
Original file line number Diff line number Diff line change
@@ -1,26 +1,44 @@
import type { AgentUserConfig } from '../../config/env';
import { cosineSimilarity } from 'ai';
import { OpenaiEmbedding, OpenAILikeEmbedding } from './embedding';
import { GoogleEmbedding, OpenaiEmbedding, OpenAILikeEmbedding } from './embedding';

interface RerankResult {
similar: number;
value: string;
}

const generalRerankAgent = {
openai: new OpenaiEmbedding(),
oailikeV1: new OpenAILikeEmbedding(),
oailikeV2: new OpenAILikeEmbedding(),
google: new GoogleEmbedding(),
};

export class Rerank {
readonly rank = async (context: AgentUserConfig, data: string[], topN: number = 1): Promise<RerankResult[]> => {
switch (context.RERANK_AGENT) {
case 'jina':
return this.jina(context, data, topN);
case 'openai':
return this.openai(context, data, topN);
case 'oailike':
return context.OAILIKE_RERANK_TYPE === 'v1' ? this.oailikeV1(context, data, topN) : this.oailikeV2(context, data, topN);
case 'oailikeV1':
case 'google':
return this.generalRerankAgent(context, data, topN);
case 'oailikeV2':
return this.oailikeV2(context, data, topN);
default:
throw new Error('Invalid RERANK_AGENT');
}
};

readonly generalRerankAgent = async (context: AgentUserConfig, data: string[], topN: number): Promise<RerankResult[]> => {
const embeddings = await generalRerankAgent[context.RERANK_AGENT as keyof typeof generalRerankAgent].request(data, context);
const inputEmbeddings = embeddings[0].embed;
return embeddings.slice(1)
.map(({ embed, value }) => ({ similar: cosineSimilarity(inputEmbeddings, embed), value }))
.sort((a, b) => b.similar - a.similar)
.slice(0, topN);
};

readonly jina = async (context: AgentUserConfig, data: string[], topN: number): Promise<RerankResult[]> => {
const url = 'https://api.jina.ai/v1/rerank';
const result = await fetch(url, {
Expand All @@ -42,24 +60,6 @@ export class Rerank {
return result.results.map((item: any) => ({ similar: item.relevance_score, value: item.document.text }));
};

readonly openai = async (context: AgentUserConfig, data: string[], topN: number): Promise<RerankResult[]> => {
const embeddings = await new OpenaiEmbedding().request(data, context);
const inputEmbeddings = embeddings[0].embed;
return embeddings.slice(1)
.map(({ embed, value }) => ({ similar: cosineSimilarity(inputEmbeddings, embed), value }))
.sort((a, b) => b.similar - a.similar)
.slice(0, topN);
};

readonly oailikeV1 = async (context: AgentUserConfig, data: string[], topN: number): Promise<RerankResult[]> => {
const embeddings = await new OpenAILikeEmbedding().request(data, context);
const inputEmbeddings = embeddings[0].embed;
return embeddings.slice(1)
.map(({ embed, value }) => ({ similar: cosineSimilarity(inputEmbeddings, embed), value }))
.sort((a, b) => b.similar - a.similar)
.slice(0, topN);
};

readonly oailikeV2 = async (context: AgentUserConfig, data: string[], topN: number): Promise<RerankResult[]> => {
const url = `${context.OAILIKE_API_BASE}/rerank`;
const result = await fetch(url, {
Expand Down

0 comments on commit 6019946

Please sign in to comment.