From 4a10ee8968239f0f52ca36d39c5d9366982ac283 Mon Sep 17 00:00:00 2001 From: Pavel Jbanov Date: Wed, 19 Feb 2025 08:14:32 -0600 Subject: [PATCH] feat(js/plugins): added `experimental_debugTraces` option to googleai and vertexai plugins (#2023) --- js/plugins/googleai/src/gemini.ts | 150 +++++++++++++-------- js/plugins/googleai/src/index.ts | 50 +++---- js/plugins/vertexai/src/common/types.ts | 2 + js/plugins/vertexai/src/gemini.ts | 165 ++++++++++++++--------- js/plugins/vertexai/src/index.ts | 29 ++-- js/testapps/context-caching/src/index.ts | 11 +- js/testapps/flow-simple-ai/src/index.ts | 14 +- 7 files changed, 261 insertions(+), 160 deletions(-) diff --git a/js/plugins/googleai/src/gemini.ts b/js/plugins/googleai/src/gemini.ts index 68179a765..047820d6c 100644 --- a/js/plugins/googleai/src/gemini.ts +++ b/js/plugins/googleai/src/gemini.ts @@ -62,6 +62,7 @@ import { downloadRequestMedia, simulateSystemPrompt, } from 'genkit/model/middleware'; +import { runInNewSpan } from 'genkit/tracing'; import { getApiKeyFromEnvVar } from './common'; import { handleCacheIfNeeded } from './context-caching'; import { extractCacheConfig } from './context-caching/utils'; @@ -633,15 +634,25 @@ export function cleanSchema(schema: JSONSchema): JSONSchema { /** * Defines a new GoogleAI model. */ -export function defineGoogleAIModel( - ai: Genkit, - name: string, - apiKey?: string, - apiVersion?: string, - baseUrl?: string, - info?: ModelInfo, - defaultConfig?: GeminiConfig -): ModelAction { +export function defineGoogleAIModel({ + ai, + name, + apiKey, + apiVersion, + baseUrl, + info, + defaultConfig, + debugTraces, +}: { + ai: Genkit; + name: string; + apiKey?: string; + apiVersion?: string; + baseUrl?: string; + info?: ModelInfo; + defaultConfig?: GeminiConfig; + debugTraces?: boolean; +}): ModelAction { if (!apiKey) { apiKey = getApiKeyFromEnvVar(); } @@ -832,54 +843,83 @@ export function defineGoogleAIModel( ); } - if (sendChunk) { - const result = await genModel - .startChat(updatedChatRequest) - .sendMessageStream(msg.parts, options); - for await (const item of result.stream) { - (item as GenerateContentResponse).candidates?.forEach((candidate) => { - const c = fromJSONModeScopedGeminiCandidate(candidate); - sendChunk({ - index: c.index, - content: c.message.content, + const callGemini = async () => { + if (sendChunk) { + const result = await genModel + .startChat(updatedChatRequest) + .sendMessageStream(msg.parts, options); + for await (const item of result.stream) { + (item as GenerateContentResponse).candidates?.forEach( + (candidate) => { + const c = fromJSONModeScopedGeminiCandidate(candidate); + sendChunk({ + index: c.index, + content: c.message.content, + }); + } + ); + } + const response = await result.response; + const candidates = response.candidates || []; + if (response.candidates?.['undefined']) { + candidates.push(response.candidates['undefined']); + } + if (!candidates.length) { + throw new GenkitError({ + status: 'FAILED_PRECONDITION', + message: 'No valid candidates returned.', }); - }); + } + return { + candidates: candidates.map(fromJSONModeScopedGeminiCandidate) || [], + custom: response, + }; + } else { + const result = await genModel + .startChat(updatedChatRequest) + .sendMessage(msg.parts, options); + if (!result.response.candidates?.length) + throw new Error('No valid candidates returned.'); + const responseCandidates = + result.response.candidates.map(fromJSONModeScopedGeminiCandidate) || + []; + return { + candidates: responseCandidates, + custom: result.response, + usage: { + ...getBasicUsageStats(request.messages, responseCandidates), + inputTokens: result.response.usageMetadata?.promptTokenCount, + outputTokens: result.response.usageMetadata?.candidatesTokenCount, + totalTokens: result.response.usageMetadata?.totalTokenCount, + }, + }; } - const response = await result.response; - const candidates = response.candidates || []; - if (response.candidates?.['undefined']) { - candidates.push(response.candidates['undefined']); - } - if (!candidates.length) { - throw new GenkitError({ - status: 'FAILED_PRECONDITION', - message: 'No valid candidates returned.', - }); - } - return { - candidates: candidates.map(fromJSONModeScopedGeminiCandidate) || [], - custom: response, - }; - } else { - const result = await genModel - .startChat(updatedChatRequest) - .sendMessage(msg.parts, options); - if (!result.response.candidates?.length) - throw new Error('No valid candidates returned.'); - const responseCandidates = - result.response.candidates.map(fromJSONModeScopedGeminiCandidate) || - []; - return { - candidates: responseCandidates, - custom: result.response, - usage: { - ...getBasicUsageStats(request.messages, responseCandidates), - inputTokens: result.response.usageMetadata?.promptTokenCount, - outputTokens: result.response.usageMetadata?.candidatesTokenCount, - totalTokens: result.response.usageMetadata?.totalTokenCount, - }, - }; - } + }; + // If debugTraces is enable, we wrap the actual model call with a span, add raw + // API params as for input. + return debugTraces + ? await runInNewSpan( + ai.registry, + { + metadata: { + name: sendChunk ? 'sendMessageStream' : 'sendMessage', + }, + }, + async (metadata) => { + metadata.input = { + sdk: '@google/generative-ai', + cache: cache, + model: genModel.model, + chatOptions: updatedChatRequest, + parts: msg.parts, + options, + }; + const response = await callGemini(); + metadata.output = response.custom; + return response; + } + ) + : await callGemini(); } ); } diff --git a/js/plugins/googleai/src/index.ts b/js/plugins/googleai/src/index.ts index d4ea6753d..4a1264f90 100644 --- a/js/plugins/googleai/src/index.ts +++ b/js/plugins/googleai/src/index.ts @@ -59,6 +59,7 @@ export interface PluginOptions { | ModelReference | string )[]; + experimental_debugTraces?: boolean; } /** @@ -78,33 +79,36 @@ export function googleAI(options?: PluginOptions): GenkitPlugin { if (apiVersions.includes('v1beta')) { Object.keys(SUPPORTED_V15_MODELS).forEach((name) => - defineGoogleAIModel( + defineGoogleAIModel({ ai, name, - options?.apiKey, - 'v1beta', - options?.baseUrl - ) + apiKey: options?.apiKey, + apiVersion: 'v1beta', + baseUrl: options?.baseUrl, + debugTraces: options?.experimental_debugTraces, + }) ); } if (apiVersions.includes('v1')) { Object.keys(SUPPORTED_V1_MODELS).forEach((name) => - defineGoogleAIModel( + defineGoogleAIModel({ ai, name, - options?.apiKey, - undefined, - options?.baseUrl - ) + apiKey: options?.apiKey, + apiVersion: undefined, + baseUrl: options?.baseUrl, + debugTraces: options?.experimental_debugTraces, + }) ); Object.keys(SUPPORTED_V15_MODELS).forEach((name) => - defineGoogleAIModel( + defineGoogleAIModel({ ai, name, - options?.apiKey, - undefined, - options?.baseUrl - ) + apiKey: options?.apiKey, + apiVersion: undefined, + baseUrl: options?.baseUrl, + debugTraces: options?.experimental_debugTraces, + }) ); Object.keys(EMBEDDER_MODELS).forEach((name) => defineGoogleAIEmbedder(ai, name, { apiKey: options?.apiKey }) @@ -120,17 +124,17 @@ export function googleAI(options?: PluginOptions): GenkitPlugin { modelOrRef.name.split('/')[1]; const modelRef = typeof modelOrRef === 'string' ? gemini(modelOrRef) : modelOrRef; - defineGoogleAIModel( + defineGoogleAIModel({ ai, - modelName, - options?.apiKey, - undefined, - options?.baseUrl, - { + name: modelName, + apiKey: options?.apiKey, + baseUrl: options?.baseUrl, + info: { ...modelRef.info, label: `Google AI - ${modelName}`, - } - ); + }, + debugTraces: options?.experimental_debugTraces, + }); } } }); diff --git a/js/plugins/vertexai/src/common/types.ts b/js/plugins/vertexai/src/common/types.ts index 642f5980a..70af9dc9f 100644 --- a/js/plugins/vertexai/src/common/types.ts +++ b/js/plugins/vertexai/src/common/types.ts @@ -26,6 +26,8 @@ export interface CommonPluginOptions { location: string; /** Provide custom authentication configuration for connecting to Vertex AI. */ googleAuth?: GoogleAuthOptions; + /** Enables additional debug traces (e.g. raw model API call details). */ + experimental_debugTraces?: boolean; } /** Combined plugin options, extending common options with subplugin-specific options */ diff --git a/js/plugins/vertexai/src/gemini.ts b/js/plugins/vertexai/src/gemini.ts index 00a605491..46b721156 100644 --- a/js/plugins/vertexai/src/gemini.ts +++ b/js/plugins/vertexai/src/gemini.ts @@ -58,6 +58,7 @@ import { downloadRequestMedia, simulateSystemPrompt, } from 'genkit/model/middleware'; +import { runInNewSpan } from 'genkit/tracing'; import { GoogleAuth } from 'google-auth-library'; import { PluginOptions } from './common/types.js'; import { handleCacheIfNeeded } from './context-caching/index.js'; @@ -742,36 +743,47 @@ export function defineGeminiKnownModel( vertexClientFactory: ( request: GenerateRequest ) => VertexAI, - options: PluginOptions + options: PluginOptions, + debugTraces?: boolean ): ModelAction { const modelName = `vertexai/${name}`; const model: ModelReference = SUPPORTED_GEMINI_MODELS[name]; if (!model) throw new Error(`Unsupported model: ${name}`); - return defineGeminiModel( + return defineGeminiModel({ ai, modelName, - name, - model?.info, + version: name, + modelInfo: model?.info, vertexClientFactory, - options - ); + options, + debugTraces, + }); } /** * Define a Vertex AI Gemini model. */ -export function defineGeminiModel( - ai: Genkit, - modelName: string, - version: string, - modelInfo: ModelInfo | undefined, +export function defineGeminiModel({ + ai, + modelName, + version, + modelInfo, + vertexClientFactory, + options, + debugTraces, +}: { + ai: Genkit; + modelName: string; + version: string; + modelInfo: ModelInfo | undefined; vertexClientFactory: ( request: GenerateRequest - ) => VertexAI, - options: PluginOptions -): ModelAction { + ) => VertexAI; + options: PluginOptions; + debugTraces?: boolean; +}): ModelAction { const middlewares: ModelMiddleware[] = []; if (SUPPORTED_V1_MODELS[version]) { middlewares.push(simulateSystemPrompt()); @@ -788,7 +800,7 @@ export function defineGeminiModel( configSchema: GeminiConfigSchema, use: middlewares, }, - async (request, streamingCallback) => { + async (request, sendChunk) => { const vertex = vertexClientFactory(request); // Make a copy of messages to avoid side-effects @@ -926,57 +938,86 @@ export function defineGeminiModel( ); } - // Handle streaming and non-streaming responses - if (streamingCallback) { - const result = await genModel - .startChat(updatedChatRequest) - .sendMessageStream(msg.parts); - - for await (const item of result.stream) { - (item as GenerateContentResponse).candidates?.forEach((candidate) => { - const c = fromGeminiCandidate(candidate, jsonMode); - streamingCallback({ - index: c.index, - content: c.message.content, - }); - }); - } + const callGemini = async () => { + // Handle streaming and non-streaming responses + if (sendChunk) { + const result = await genModel + .startChat(updatedChatRequest) + .sendMessageStream(msg.parts); + + for await (const item of result.stream) { + (item as GenerateContentResponse).candidates?.forEach( + (candidate) => { + const c = fromGeminiCandidate(candidate, jsonMode); + sendChunk({ + index: c.index, + content: c.message.content, + }); + } + ); + } - const response = await result.response; - if (!response.candidates?.length) { - throw new Error('No valid candidates returned.'); - } + const response = await result.response; + if (!response.candidates?.length) { + throw new Error('No valid candidates returned.'); + } - return { - candidates: response.candidates.map((c) => - fromGeminiCandidate(c, jsonMode) - ), - custom: response, - }; - } else { - const result = await genModel - .startChat(updatedChatRequest) - .sendMessage(msg.parts); + return { + candidates: response.candidates.map((c) => + fromGeminiCandidate(c, jsonMode) + ), + custom: response, + }; + } else { + const result = await genModel + .startChat(updatedChatRequest) + .sendMessage(msg.parts); + + if (!result?.response.candidates?.length) { + throw new Error('No valid candidates returned.'); + } - if (!result?.response.candidates?.length) { - throw new Error('No valid candidates returned.'); + const responseCandidates = result.response.candidates.map((c) => + fromGeminiCandidate(c, jsonMode) + ); + + return { + candidates: responseCandidates, + custom: result.response, + usage: { + ...getBasicUsageStats(request.messages, responseCandidates), + inputTokens: result.response.usageMetadata?.promptTokenCount, + outputTokens: result.response.usageMetadata?.candidatesTokenCount, + totalTokens: result.response.usageMetadata?.totalTokenCount, + }, + }; } - - const responseCandidates = result.response.candidates.map((c) => - fromGeminiCandidate(c, jsonMode) - ); - - return { - candidates: responseCandidates, - custom: result.response, - usage: { - ...getBasicUsageStats(request.messages, responseCandidates), - inputTokens: result.response.usageMetadata?.promptTokenCount, - outputTokens: result.response.usageMetadata?.candidatesTokenCount, - totalTokens: result.response.usageMetadata?.totalTokenCount, - }, - }; - } + }; + // If debugTraces is enable, we wrap the actual model call with a span, add raw + // API params as for input. + return debugTraces + ? await runInNewSpan( + ai.registry, + { + metadata: { + name: sendChunk ? 'sendMessageStream' : 'sendMessage', + }, + }, + async (metadata) => { + metadata.input = { + sdk: '@google-cloud/vertexai', + cache: cache, + model: genModel.getModelName(), + chatOptions: updatedChatRequest, + parts: msg.parts, + options, + }; + const response = await callGemini(); + metadata.output = response.custom; + return response; + } + ) + : await callGemini(); } ); } diff --git a/js/plugins/vertexai/src/index.ts b/js/plugins/vertexai/src/index.ts index bfd1c74c9..eda1df60a 100644 --- a/js/plugins/vertexai/src/index.ts +++ b/js/plugins/vertexai/src/index.ts @@ -87,10 +87,16 @@ export function vertexAI(options?: PluginOptions): GenkitPlugin { imagenModel(ai, name, authClient, { projectId, location }) ); Object.keys(SUPPORTED_GEMINI_MODELS).map((name) => - defineGeminiKnownModel(ai, name, vertexClientFactory, { - projectId, - location, - }) + defineGeminiKnownModel( + ai, + name, + vertexClientFactory, + { + projectId, + location, + }, + options?.experimental_debugTraces + ) ); if (options?.models) { for (const modelOrRef of options?.models) { @@ -101,17 +107,18 @@ export function vertexAI(options?: PluginOptions): GenkitPlugin { modelOrRef.name.split('/')[1]; const modelRef = typeof modelOrRef === 'string' ? gemini(modelOrRef) : modelOrRef; - defineGeminiModel( + defineGeminiModel({ ai, - modelRef.name, - modelName, - modelRef.info, + modelName: modelRef.name, + version: modelName, + modelInfo: modelRef.info, vertexClientFactory, - { + options: { projectId, location, - } - ); + }, + debugTraces: options.experimental_debugTraces, + }); } } diff --git a/js/testapps/context-caching/src/index.ts b/js/testapps/context-caching/src/index.ts index af5ab6f6f..8c54c378a 100644 --- a/js/testapps/context-caching/src/index.ts +++ b/js/testapps/context-caching/src/index.ts @@ -24,7 +24,10 @@ import { genkit, z } from 'genkit'; // Import Genkit framework and Zod for schem import { logger } from 'genkit/logging'; // Import logging utility from Genkit const ai = genkit({ - plugins: [vertexAI(), googleAI()], // Initialize Genkit with the Google AI plugin + plugins: [ + vertexAI({ experimental_debugTraces: true, location: 'us-central1' }), + googleAI({ experimental_debugTraces: true }), + ], // Initialize Genkit with the Google AI plugin }); logger.setLogLevel('debug'); // Set the logging level to debug for detailed output @@ -38,7 +41,7 @@ export const lotrFlowVertex = ai.defineFlow( }), outputSchema: z.string(), // Define the expected output as a string }, - async ({ query, textFilePath }) => { + async ({ query, textFilePath }, { sendChunk }) => { const defaultQuery = 'What is the text i provided you with?'; // Default query to use if none is provided // Read the content from the file if the path is provided @@ -69,6 +72,7 @@ export const lotrFlowVertex = ai.defineFlow( }, model: gemini15Flash, // Specify the model (gemini15Flash) to use for generation prompt: query || defaultQuery, // Use the provided query or fall back to the default query + onChunk: sendChunk, }); return llmResponse.text; // Return the generated text from the model @@ -84,7 +88,7 @@ export const lotrFlowGoogleAI = ai.defineFlow( }), outputSchema: z.string(), // Define the expected output as a string }, - async ({ query, textFilePath }) => { + async ({ query, textFilePath }, { sendChunk }) => { const defaultQuery = 'What is the text i provided you with?'; // Default query to use if none is provided // Read the content from the file if the path is provided @@ -115,6 +119,7 @@ export const lotrFlowGoogleAI = ai.defineFlow( }, model: gemini15FlashGoogleAI, // Specify the model (gemini15Flash) to use for generation prompt: query || defaultQuery, // Use the provided query or fall back to the default query + onChunk: sendChunk, }); return llmResponse.text; // Return the generated text from the model diff --git a/js/testapps/flow-simple-ai/src/index.ts b/js/testapps/flow-simple-ai/src/index.ts index 38bc0f153..ab550425f 100644 --- a/js/testapps/flow-simple-ai/src/index.ts +++ b/js/testapps/flow-simple-ai/src/index.ts @@ -30,7 +30,7 @@ import { GoogleAIFileManager } from '@google/generative-ai/server'; import { AlwaysOnSampler } from '@opentelemetry/sdk-trace-base'; import { initializeApp } from 'firebase-admin/app'; import { getFirestore } from 'firebase-admin/firestore'; -import { GenerateResponseData, MessageSchema, genkit, z } from 'genkit'; +import { GenerateResponseData, MessageSchema, genkit, z } from 'genkit/beta'; import { logger } from 'genkit/logging'; import { ModelMiddleware, simulateConstrainedGeneration } from 'genkit/model'; import { PluginProvider } from 'genkit/plugin'; @@ -56,7 +56,10 @@ enableGoogleCloudTelemetry({ }); const ai = genkit({ - plugins: [googleAI(), vertexAI()], + plugins: [ + googleAI({ experimental_debugTraces: true }), + vertexAI({ location: 'us-central1', experimental_debugTraces: true }), + ], }); const math: PluginProvider = { @@ -452,8 +455,8 @@ const exitTool = ai.defineTool( }), description: 'call this tool when you have the final answer', }, - async (input) => { - throw new Error(`Answer: ${input.answer}`); + async (input, { interrupt }) => { + interrupt(); } ); @@ -461,7 +464,6 @@ export const forcedToolCaller = ai.defineFlow( { name: 'forcedToolCaller', inputSchema: z.number(), - outputSchema: z.string(), streamSchema: z.any(), }, async (input, { sendChunk }) => { @@ -479,7 +481,7 @@ export const forcedToolCaller = ai.defineFlow( sendChunk(chunk); } - return (await response).text; + return await response; } );