From 21ffd9b088c54eefa7fb5753cf4cda172ff628e9 Mon Sep 17 00:00:00 2001 From: Pavel Jbanov Date: Mon, 27 Jan 2025 08:18:17 -0500 Subject: [PATCH] feat(js): propagate context to sub actions, expose context in prompts (#1663) --- docs/auth.md | 6 +- js/ai/src/generate.ts | 22 +- js/ai/src/prompt.ts | 66 ++++- js/ai/src/tool.ts | 8 +- js/ai/tests/prompt/prompt_test.ts | 300 +++++++++++++++++++- js/core/src/action.ts | 24 +- js/core/src/context.ts | 42 ++- js/core/src/flow.ts | 8 +- js/core/src/index.ts | 2 +- js/core/tests/action_test.ts | 23 ++ js/core/tests/flow_test.ts | 16 +- js/genkit/src/genkit.ts | 11 + js/genkit/src/index.ts | 2 +- js/genkit/tests/generate_test.ts | 59 ++++ js/plugins/express/src/index.ts | 8 +- js/plugins/express/tests/express_test.ts | 4 +- js/plugins/firebase/tests/functions_test.ts | 4 +- js/testapps/express/src/index.ts | 2 +- 18 files changed, 521 insertions(+), 86 deletions(-) diff --git a/docs/auth.md b/docs/auth.md index 000c34785..aedeb4f8d 100644 --- a/docs/auth.md +++ b/docs/auth.md @@ -73,7 +73,7 @@ When running with the Genkit Development UI, you can pass the Auth object by entering JSON in the "Auth JSON" tab: `{"uid": "abc-def"}`. You can also retrieve the auth context for the flow at any time within the flow -by calling `getFlowAuth()`, including in functions invoked by the flow: +by calling `ai.currentContext()`, including in functions invoked by the flow: ```ts import { genkit, z } from 'genkit'; @@ -81,7 +81,7 @@ import { genkit, z } from 'genkit'; const ai = genkit({ ... });; async function readDatabase(uid: string) { - const auth = ai.getAuthContext(); + const auth = ai.currentContext()?.auth; if (auth?.admin) { // Do something special if the user is an admin } else { @@ -153,7 +153,7 @@ export const selfSummaryFlow = onFlow( When using the Firebase Auth plugin, `user` will be returned as a [DecodedIdToken](https://firebase.google.com/docs/reference/admin/node/firebase-admin.auth.decodedidtoken). -You can always retrieve this object at any time via `getFlowAuth()` as noted +You can always retrieve this object at any time via `ai.currentContext()` as noted above. When running this flow during development, you would pass the user object in the same way: diff --git a/js/ai/src/generate.ts b/js/ai/src/generate.ts index 1081f9a07..8579142c3 100755 --- a/js/ai/src/generate.ts +++ b/js/ai/src/generate.ts @@ -16,8 +16,10 @@ import { Action, + ActionContext, GenkitError, StreamingCallback, + runWithContext, runWithStreamingCallback, sentinelNoopStreamingCallback, z, @@ -113,8 +115,8 @@ export interface GenerateOptions< * const interrupt = response.interrupts[0]; * * const resumedResponse = await ai.generate({ - * messages: response.messages, - * resume: myInterrupt.reply(interrupt, {note: "this is the reply data"}), + * messages: response.messages, + * resume: myInterrupt.reply(interrupt, {note: "this is the reply data"}), * }); * ``` */ @@ -133,6 +135,8 @@ export interface GenerateOptions< streamingCallback?: StreamingCallback; /** Middleware to be used with this model call. */ use?: ModelMiddleware[]; + /** Additional context (data, like e.g. auth) to be passed down to tools, prompts and other sub actions. */ + context?: ActionContext; } function applyResumeOption( @@ -376,10 +380,16 @@ export async function generate< registry, stripNoop(resolvedOptions.onChunk ?? resolvedOptions.streamingCallback), async () => { - const response = await generateHelper(registry, { - rawRequest: params, - middleware: resolvedOptions.use, - }); + const generateFn = () => + generateHelper(registry, { + rawRequest: params, + middleware: resolvedOptions.use, + }); + const response = await runWithContext( + registry, + resolvedOptions.context, + generateFn + ); const request = await toGenerateRequest(registry, { ...resolvedOptions, tools, diff --git a/js/ai/src/prompt.ts b/js/ai/src/prompt.ts index fb8bddd7c..abfe82e6d 100644 --- a/js/ai/src/prompt.ts +++ b/js/ai/src/prompt.ts @@ -17,8 +17,10 @@ import { Action, ActionAsyncParams, + ActionContext, defineActionAsync, GenkitError, + getContext, JSONSchema7, stripUndefinedProps, z, @@ -117,6 +119,7 @@ export interface PromptConfig< tools?: ToolArgument[]; toolChoice?: ToolChoice; use?: ModelMiddleware[]; + context?: ActionContext; } /** @@ -179,6 +182,7 @@ export type PartsResolver = ( input: I, options: { state?: S; + context: ActionContext; } ) => Part[] | Promise; @@ -187,12 +191,14 @@ export type MessagesResolver = ( options: { history?: MessageData[]; state?: S; + context: ActionContext; } ) => MessageData[] | Promise; export type DocsResolver = ( input: I, options: { + context: ActionContext; state?: S; } ) => DocumentData[] | Promise; @@ -250,7 +256,8 @@ function definePromptAsync< input, messages, resolvedOptions, - promptCache + promptCache, + renderOptions ); await renderMessages( registry, @@ -267,13 +274,15 @@ function definePromptAsync< input, messages, resolvedOptions, - promptCache + promptCache, + renderOptions ); let docs: DocumentData[] | undefined; if (typeof resolvedOptions.docs === 'function') { docs = await resolvedOptions.docs(input, { state: session?.state, + context: renderOptions?.context || getContext(registry) || {}, }); } else { docs = resolvedOptions.docs; @@ -287,6 +296,7 @@ function definePromptAsync< tools: resolvedOptions.tools, returnToolRequests: resolvedOptions.returnToolRequests, toolChoice: resolvedOptions.toolChoice, + context: resolvedOptions.context, output: resolvedOptions.output, use: resolvedOptions.use, ...stripUndefinedProps(renderOptions), @@ -442,13 +452,17 @@ async function renderSystemPrompt< input: z.infer, messages: MessageData[], options: PromptConfig, - promptCache: PromptCache + promptCache: PromptCache, + renderOptions: PromptGenerateOptions | undefined ) { if (typeof options.system === 'function') { messages.push({ role: 'system', content: normalizeParts( - await options.system(input, { state: session?.state }) + await options.system(input, { + state: session?.state, + context: renderOptions?.context || getContext(registry) || {}, + }) ), }); } else if (typeof options.system === 'string') { @@ -458,7 +472,14 @@ async function renderSystemPrompt< } messages.push({ role: 'system', - content: await renderDotpromptToParts(promptCache.system, input, session), + content: await renderDotpromptToParts( + registry, + promptCache.system, + input, + session, + options, + renderOptions + ), }); } else if (options.system) { messages.push({ @@ -486,6 +507,7 @@ async function renderMessages< messages.push( ...(await options.messages(input, { state: session?.state, + context: renderOptions?.context || getContext(registry) || {}, history: renderOptions?.messages, })) ); @@ -498,7 +520,10 @@ async function renderMessages< } const rendered = await promptCache.messages({ input, - context: { state: session?.state }, + context: { + ...(renderOptions?.context || getContext(registry)), + state: session?.state, + }, messages: renderOptions?.messages?.map((m) => Message.parseData(m) ) as DpMessage[], @@ -528,13 +553,17 @@ async function renderUserPrompt< input: z.infer, messages: MessageData[], options: PromptConfig, - promptCache: PromptCache + promptCache: PromptCache, + renderOptions: PromptGenerateOptions | undefined ) { if (typeof options.prompt === 'function') { messages.push({ role: 'user', content: normalizeParts( - await options.prompt(input, { state: session?.state }) + await options.prompt(input, { + state: session?.state, + context: renderOptions?.context || getContext(registry) || {}, + }) ), }); } else if (typeof options.prompt === 'string') { @@ -545,9 +574,12 @@ async function renderUserPrompt< messages.push({ role: 'user', content: await renderDotpromptToParts( + registry, promptCache.userPrompt, input, - session + session, + options, + renderOptions ), }); } else if (options.prompt) { @@ -585,14 +617,24 @@ function normalizeParts(parts: string | Part | Part[]): Part[] { return [parts as Part]; } -async function renderDotpromptToParts( +async function renderDotpromptToParts< + I extends z.ZodTypeAny = z.ZodTypeAny, + O extends z.ZodTypeAny = z.ZodTypeAny, + CustomOptions extends z.ZodTypeAny = z.ZodTypeAny, +>( + registry: Registry, promptFn: PromptFunction, input: any, - session?: Session + session: Session | undefined, + options: PromptConfig, + renderOptions: PromptGenerateOptions | undefined ): Promise { const renderred = await promptFn({ input, - context: { state: session?.state }, + context: { + ...(renderOptions?.context || getContext(registry)), + state: session?.state, + }, }); if (renderred.messages.length !== 1) { throw new Error('parts tempate must produce only one message'); diff --git a/js/ai/src/tool.ts b/js/ai/src/tool.ts index d0d416ecd..da9133560 100644 --- a/js/ai/src/tool.ts +++ b/js/ai/src/tool.ts @@ -16,6 +16,7 @@ import { Action, + ActionContext, defineAction, JSONSchema7, stripUndefinedProps, @@ -188,6 +189,8 @@ export interface ToolFnOptions { * getting interrupted (immediately) and tool request returned to the upstream caller. */ interrupt: (metadata?: Record) => never; + + context: ActionContext; } export type ToolFn = ( @@ -212,9 +215,12 @@ export function defineTool( actionType: 'tool', metadata: { ...(config.metadata || {}), type: 'tool' }, }, - (i) => + (i, { context }) => fn(i, { interrupt: interruptTool, + context: { + ...context, + }, }) ); (a as ToolAction).reply = (interrupt, replyData, options) => { diff --git a/js/ai/tests/prompt/prompt_test.ts b/js/ai/tests/prompt/prompt_test.ts index eaa1d975a..f2db66e3a 100644 --- a/js/ai/tests/prompt/prompt_test.ts +++ b/js/ai/tests/prompt/prompt_test.ts @@ -14,20 +14,21 @@ * limitations under the License. */ -import { z } from '@genkit-ai/core'; +import { ActionContext, runWithContext, z } from '@genkit-ai/core'; import { Registry } from '@genkit-ai/core/registry'; import assert from 'node:assert'; import { beforeEach, describe, it } from 'node:test'; +import { toJsonSchema } from '../../../core/src/schema'; import { Document } from '../../lib/document'; import { GenerateOptions } from '../../lib/index'; import { Session } from '../../lib/session'; -import { ModelAction, defineModel } from '../../src/model.ts'; +import { ModelAction, defineModel } from '../../src/model'; import { PromptConfig, PromptGenerateOptions, definePrompt, -} from '../../src/prompt.ts'; -import { defineTool } from '../../src/tool.ts'; +} from '../../src/prompt'; +import { defineTool } from '../../src/tool'; describe('prompt', () => { let registry; @@ -55,6 +56,7 @@ describe('prompt', () => { wantRendered?: GenerateOptions; state?: any; only?: boolean; + context?: ActionContext; }[] = [ { name: 'renders user prompt', @@ -79,6 +81,32 @@ describe('prompt', () => { model: 'echoModel', }, }, + { + name: 'renders user prompt with context', + prompt: { + model: 'echoModel', + name: 'prompt1', + config: { banana: 'ripe' }, + input: { schema: z.object({ name: z.string() }) }, + prompt: 'hello {{name}} ({{@state.name}}, {{@auth.email}})', + }, + input: { name: 'foo' }, + state: { name: 'bar' }, + inputOptions: { config: { temperature: 11 } }, + context: { auth: { email: 'a@b.c' } }, + wantTextOutput: + 'Echo: hello foo (bar, a@b.c); config: {"banana":"ripe","temperature":11}', + wantRendered: { + config: { + banana: 'ripe', + temperature: 11, + }, + messages: [ + { content: [{ text: 'hello foo (bar, a@b.c)' }], role: 'user' }, + ], + model: 'echoModel', + }, + }, { name: 'renders user prompt with explicit messages override', prompt: { @@ -229,6 +257,69 @@ describe('prompt', () => { model: 'echoModel', }, }, + { + name: 'renders user prompt from function with context', + prompt: { + model: 'echoModel', + name: 'prompt1', + config: { banana: 'ripe' }, + input: { schema: z.object({ name: z.string() }) }, + prompt: async (input, { state, context }) => { + return `hello ${input.name} (${state.name}, ${context.auth?.email})`; + }, + }, + input: { name: 'foo' }, + state: { name: 'bar' }, + context: { auth: { email: 'a@b.c' } }, + inputOptions: { config: { temperature: 11 } }, + wantTextOutput: + 'Echo: hello foo (bar, a@b.c); config: {"banana":"ripe","temperature":11}', + wantRendered: { + config: { + banana: 'ripe', + temperature: 11, + }, + messages: [ + { content: [{ text: 'hello foo (bar, a@b.c)' }], role: 'user' }, + ], + model: 'echoModel', + }, + }, + { + name: 'renders user prompt from function with context as render option', + prompt: { + model: 'echoModel', + name: 'prompt1', + config: { banana: 'ripe' }, + input: { schema: z.object({ name: z.string() }) }, + prompt: async (input, { state, context }) => { + return `hello ${input.name} (${state.name}, ${context.auth?.email})`; + }, + }, + input: { name: 'foo' }, + state: { name: 'bar' }, + inputOptions: { + config: { temperature: 11 }, + context: { auth: { email: 'a@b.c' } }, + }, + wantTextOutput: + 'Echo: hello foo (bar, a@b.c); config: {"banana":"ripe","temperature":11}', + wantRendered: { + config: { + banana: 'ripe', + temperature: 11, + }, + context: { + auth: { + email: 'a@b.c', + }, + }, + messages: [ + { content: [{ text: 'hello foo (bar, a@b.c)' }], role: 'user' }, + ], + model: 'echoModel', + }, + }, { name: 'renders system prompt', prompt: { @@ -252,6 +343,61 @@ describe('prompt', () => { model: 'echoModel', }, }, + { + name: 'renders system prompt with context', + prompt: { + model: 'echoModel', + name: 'prompt1', + config: { banana: 'ripe' }, + input: { schema: z.object({ name: z.string() }) }, + system: 'hello {{name}} ({{@state.name}}, {{@auth.email}})', + }, + input: { name: 'foo' }, + state: { name: 'bar' }, + context: { auth: { email: 'a@b.c' } }, + inputOptions: { config: { temperature: 11 } }, + wantTextOutput: + 'Echo: system: hello foo (bar, a@b.c); config: {"banana":"ripe","temperature":11}', + wantRendered: { + config: { + banana: 'ripe', + temperature: 11, + }, + messages: [ + { content: [{ text: 'hello foo (bar, a@b.c)' }], role: 'system' }, + ], + model: 'echoModel', + }, + }, + { + name: 'renders system prompt with context as render option', + prompt: { + model: 'echoModel', + name: 'prompt1', + config: { banana: 'ripe' }, + input: { schema: z.object({ name: z.string() }) }, + system: 'hello {{name}} ({{@state.name}}, {{@auth.email}})', + }, + input: { name: 'foo' }, + state: { name: 'bar' }, + inputOptions: { + config: { temperature: 11 }, + context: { auth: { email: 'a@b.c' } }, + }, + wantTextOutput: + 'Echo: system: hello foo (bar, a@b.c); config: {"banana":"ripe","temperature":11}', + wantRendered: { + config: { + banana: 'ripe', + temperature: 11, + }, + context: { auth: { email: 'a@b.c' } }, + messages: [ + { content: [{ text: 'hello foo (bar, a@b.c)' }], role: 'system' }, + ], + model: 'echoModel', + }, + }, { name: 'renders system prompt from a function', prompt: { @@ -277,6 +423,65 @@ describe('prompt', () => { model: 'echoModel', }, }, + { + name: 'renders system prompt from a function with context', + prompt: { + model: 'echoModel', + name: 'prompt1', + config: { banana: 'ripe' }, + input: { schema: z.object({ name: z.string() }) }, + system: async (input, { state, context }) => { + return `hello ${input.name} (${state.name}, ${context.auth?.email})`; + }, + }, + input: { name: 'foo' }, + state: { name: 'bar' }, + context: { auth: { email: 'a@b.c' } }, + inputOptions: { config: { temperature: 11 } }, + wantTextOutput: + 'Echo: system: hello foo (bar, a@b.c); config: {"banana":"ripe","temperature":11}', + wantRendered: { + config: { + banana: 'ripe', + temperature: 11, + }, + messages: [ + { content: [{ text: 'hello foo (bar, a@b.c)' }], role: 'system' }, + ], + model: 'echoModel', + }, + }, + { + name: 'renders system prompt from a function with context as render option', + prompt: { + model: 'echoModel', + name: 'prompt1', + config: { banana: 'ripe' }, + input: { schema: z.object({ name: z.string() }) }, + system: async (input, { state, context }) => { + return `hello ${input.name} (${state.name}, ${context.auth?.email})`; + }, + }, + input: { name: 'foo' }, + state: { name: 'bar' }, + inputOptions: { + config: { temperature: 11 }, + context: { auth: { email: 'a@b.c' } }, + }, + wantTextOutput: + 'Echo: system: hello foo (bar, a@b.c); config: {"banana":"ripe","temperature":11}', + wantRendered: { + config: { + banana: 'ripe', + temperature: 11, + }, + context: { auth: { email: 'a@b.c' } }, + messages: [ + { content: [{ text: 'hello foo (bar, a@b.c)' }], role: 'system' }, + ], + model: 'echoModel', + }, + }, { name: 'renders messages from template', prompt: { @@ -304,6 +509,34 @@ describe('prompt', () => { model: 'echoModel', }, }, + { + name: 'renders messages from template with context', + prompt: { + model: 'echoModel', + name: 'prompt1', + config: { banana: 'ripe' }, + input: { schema: z.object({ name: z.string() }) }, + messages: + '{{role "system"}}system {{name}}{{role "user"}}user {{@state.name}}, {{@auth.email}}', + }, + input: { name: 'foo' }, + state: { name: 'bar' }, + context: { auth: { email: 'a@b.c' } }, + inputOptions: { config: { temperature: 11 } }, + wantTextOutput: + 'Echo: system: system foo,user bar, a@b.c; config: {"banana":"ripe","temperature":11}', + wantRendered: { + config: { + banana: 'ripe', + temperature: 11, + }, + messages: [ + { role: 'system', content: [{ text: 'system foo' }] }, + { role: 'user', content: [{ text: 'user bar, a@b.c' }] }, + ], + model: 'echoModel', + }, + }, { name: 'renders messages', prompt: { @@ -362,6 +595,39 @@ describe('prompt', () => { model: 'echoModel', }, }, + { + name: 'renders messages from function with context', + prompt: { + model: 'echoModel', + name: 'prompt1', + config: { banana: 'ripe' }, + input: { schema: z.object({ name: z.string() }) }, + messages: async (input, { state, context }) => [ + { role: 'system', content: [{ text: `system ${input.name}` }] }, + { + role: 'user', + content: [{ text: `user ${state.name}, ${context.auth?.email}` }], + }, + ], + }, + input: { name: 'foo' }, + state: { name: 'bar' }, + context: { auth: { email: 'a@b.c' } }, + inputOptions: { config: { temperature: 11 } }, + wantTextOutput: + 'Echo: system: system foo,user bar, a@b.c; config: {"banana":"ripe","temperature":11}', + wantRendered: { + config: { + banana: 'ripe', + temperature: 11, + }, + messages: [ + { role: 'system', content: [{ text: 'system foo' }] }, + { role: 'user', content: [{ text: 'user bar, a@b.c' }] }, + ], + model: 'echoModel', + }, + }, { name: 'renders system, message and prompt in the same order', prompt: { @@ -504,23 +770,30 @@ describe('prompt', () => { } const p = definePrompt(registry, test.prompt); - const { text } = await (session - ? session.run(() => p(test.input, test.inputOptions)) - : p(test.input, test.inputOptions)); + const sessionFn = () => + session + ? session.run(() => p(test.input, test.inputOptions)) + : p(test.input, test.inputOptions); + + const { text } = await runWithContext(registry, test.context, sessionFn); assert.strictEqual(text, test.wantTextOutput); + + const sessionRenderFn = () => + session + ? session.run(() => p.render(test.input, test.inputOptions)) + : p.render(test.input, test.inputOptions); + assert.deepStrictEqual( stripUndefined( - await (session - ? session.run(() => p.render(test.input, test.inputOptions)) - : p.render(test.input, test.inputOptions)) + await runWithContext(registry, test.context, sessionRenderFn) ), test.wantRendered ); }); } - it.skip('respects output schema in the definition', async () => { + it('respects output schema in the definition', async () => { const schema1 = z.object({ puppyName: z.string({ description: 'A cute name for a puppy' }), }); @@ -543,7 +816,10 @@ describe('prompt', () => { const generateRequest = await prompt1.render('poodle', { model: 'geminiPro', }); - assert.equal(generateRequest.output?.schema, schema1); + assert.deepStrictEqual( + toJsonSchema({ schema: generateRequest.output?.schema }), + toJsonSchema({ schema: schema1 }) + ); }); }); diff --git a/js/core/src/action.ts b/js/core/src/action.ts index 116ef9ebf..30e1ac7dd 100644 --- a/js/core/src/action.ts +++ b/js/core/src/action.ts @@ -16,7 +16,7 @@ import { JSONSchema7 } from 'json-schema'; import * as z from 'zod'; -import { getContext } from './context.js'; +import { ActionContext, getContext, runWithContext } from './context.js'; import { ActionType, Registry } from './registry.js'; import { parseSchema } from './schema.js'; import { @@ -70,7 +70,7 @@ export interface ActionRunOptions { /** * Additional runtime context data (ex. auth context data). */ - context?: any; + context?: ActionContext; /** * Additional span attributes to apply to OT spans. @@ -90,7 +90,7 @@ export interface ActionFnArg { /** * Additional runtime context data (ex. auth context data). */ - context?: any; + context?: ActionContext; } /** @@ -310,11 +310,19 @@ export function action< metadata.input = input; try { - const output = await fn(input, { - // Context can either be explicitly set, or inherited from the parent action. - context: options?.context ?? getContext(registry), - sendChunk: options?.onChunk ?? sentinelNoopStreamingCallback, - }); + const actionFn = () => + fn(input, { + // Context can either be explicitly set, or inherited from the parent action. + context: options?.context ?? getContext(registry), + sendChunk: options?.onChunk ?? sentinelNoopStreamingCallback, + }); + // if context is explicitly passed in, we run action with the provided context, + // otherwise we let upstream context carry through. + const output = await runWithContext( + registry, + options?.context, + actionFn + ); metadata.output = JSON.stringify(output); return output; diff --git a/js/core/src/context.ts b/js/core/src/context.ts index 50efb988e..bc68b3f2e 100644 --- a/js/core/src/context.ts +++ b/js/core/src/context.ts @@ -14,48 +14,44 @@ * limitations under the License. */ -import { AsyncLocalStorage } from 'node:async_hooks'; import { runInActionRuntimeContext } from './action.js'; import { HasRegistry, Registry } from './registry.js'; const contextAlsKey = 'core.auth.context'; -const legacyContextAsyncLocalStorage = new AsyncLocalStorage(); + +export interface ActionContext { + /** Information about the currently authenticated user if provided. */ + auth?: Record; + [additionalContext: string]: any; +} /** * Execute the provided function in the runtime context. Call {@link getFlowContext()} anywhere - * within the async call stack to retrieve the context. + * within the async call stack to retrieve the context. If context object is undefined, this function + * is a no op passthrough, the function will be invoked as is. */ export function runWithContext( registry: Registry, - context: any, + context: ActionContext | undefined, fn: () => R -) { - return legacyContextAsyncLocalStorage.run(context, () => - registry.asyncStore.run(contextAlsKey, context, () => - runInActionRuntimeContext(registry, fn) - ) - ); -} - -/** - * Gets the auth object from the current context. - * - * @deprecated use {@link getFlowContext} - */ -export function getFlowAuth(registry?: Registry | HasRegistry): any { - if (!registry) { - return legacyContextAsyncLocalStorage.getStore(); +): R { + if (context === undefined) { + return fn(); } - return getContext(registry); + return registry.asyncStore.run(contextAlsKey, context, () => + runInActionRuntimeContext(registry, fn) + ); } /** * Gets the runtime context of the current flow. */ -export function getContext(registry: Registry | HasRegistry): any { +export function getContext( + registry: Registry | HasRegistry +): ActionContext | undefined { if ((registry as HasRegistry).registry) { registry = (registry as HasRegistry).registry; } registry = registry as Registry; - return registry.asyncStore.getStore(contextAlsKey); + return registry.asyncStore.getStore(contextAlsKey); } diff --git a/js/core/src/flow.ts b/js/core/src/flow.ts index 7c416f6ec..86dab3b9a 100644 --- a/js/core/src/flow.ts +++ b/js/core/src/flow.ts @@ -17,7 +17,7 @@ import { AsyncLocalStorage } from 'node:async_hooks'; import { z } from 'zod'; import { Action, defineAction, StreamingCallback } from './action.js'; -import { runWithContext } from './context.js'; +import { ActionContext } from './context.js'; import { HasRegistry, Registry } from './registry.js'; import { runInNewSpan, SPAN_TYPE_ATTR } from './tracing.js'; @@ -64,7 +64,7 @@ export interface FlowSideChannel { /** * Additional runtime context data (ex. auth context data). */ - context?: any; + context?: ActionContext; } /** @@ -125,9 +125,7 @@ function defineFlowAction< const ctx = sendChunk; (ctx as FlowSideChannel>).sendChunk = sendChunk; (ctx as FlowSideChannel>).context = context; - return runWithContext(registry, context, () => - fn(input, ctx as FlowSideChannel>) - ); + return fn(input, ctx as FlowSideChannel>); }); } ); diff --git a/js/core/src/index.ts b/js/core/src/index.ts index b5b9000cc..1039a0cc5 100644 --- a/js/core/src/index.ts +++ b/js/core/src/index.ts @@ -29,7 +29,7 @@ export const GENKIT_REFLECTION_API_SPEC_VERSION = 1; export { z } from 'zod'; export * from './action.js'; -export { getFlowAuth } from './context.js'; +export { getContext, runWithContext, type ActionContext } from './context.js'; export { GenkitError } from './error.js'; export { defineFlow, diff --git a/js/core/tests/action_test.ts b/js/core/tests/action_test.ts index 0b4391b15..76f5c1d84 100644 --- a/js/core/tests/action_test.ts +++ b/js/core/tests/action_test.ts @@ -144,4 +144,27 @@ describe('action', () => { { count: 3 }, ]); }); + + it('should inherit context from parent action invocation', async () => { + const child = defineAction( + registry, + { name: 'child', actionType: 'custom' }, + async (_, { context }) => { + return `hi ${context.auth.email}`; + } + ); + const parent = defineAction( + registry, + { name: 'parent', actionType: 'custom' }, + async () => { + return child(); + } + ); + + const response = await parent(undefined, { + context: { auth: { email: 'a@b.c' } }, + }); + + assert.strictEqual(response, 'hi a@b.c'); + }); }); diff --git a/js/core/tests/flow_test.ts b/js/core/tests/flow_test.ts index aaeb64ef3..3decd65d7 100644 --- a/js/core/tests/flow_test.ts +++ b/js/core/tests/flow_test.ts @@ -18,7 +18,7 @@ import { SimpleSpanProcessor } from '@opentelemetry/sdk-trace-base'; import * as assert from 'assert'; import { beforeEach, describe, it } from 'node:test'; import { defineFlow, run } from '../src/flow.js'; -import { defineAction, getFlowAuth, z } from '../src/index.js'; +import { defineAction, getContext, z } from '../src/index.js'; import { Registry } from '../src/registry.js'; import { enableTelemetry } from '../src/tracing.js'; import { TestSpanExporter } from './utils.js'; @@ -114,7 +114,13 @@ describe('flow', () => { }); }); - describe('getFlowAuth', () => { + describe('getContext', () => { + let registry: Registry; + + beforeEach(() => { + registry = new Registry(); + }); + it('should run the flow', async () => { const testFlow = defineFlow( registry, @@ -124,7 +130,7 @@ describe('flow', () => { outputSchema: z.string(), }, async (input) => { - return `bar ${input} ${JSON.stringify(getFlowAuth())}`; + return `bar ${input} ${JSON.stringify(getContext(registry))}`; } ); @@ -150,7 +156,7 @@ describe('flow', () => { streamingCallback({ count: i }); } } - return `bar ${input} ${!!streamingCallback} ${JSON.stringify(getFlowAuth())}`; + return `bar ${input} ${!!streamingCallback} ${JSON.stringify(getContext(registry))}`; } ); @@ -178,7 +184,7 @@ describe('flow', () => { outputSchema: z.string(), }, async (input) => { - return `bar ${input} ${JSON.stringify(getFlowAuth())}`; + return `bar ${input} ${JSON.stringify(getContext(registry))}`; } ); diff --git a/js/genkit/src/genkit.ts b/js/genkit/src/genkit.ts index 0e2435a27..48afd4bec 100644 --- a/js/genkit/src/genkit.ts +++ b/js/genkit/src/genkit.ts @@ -100,6 +100,7 @@ import { import { ToolFn } from '@genkit-ai/ai/tool'; import { Action, + ActionContext, FlowConfig, FlowFn, JSONSchema, @@ -108,6 +109,7 @@ import { defineFlow, defineJsonSchema, defineSchema, + getContext, isDevEnv, run, z, @@ -785,6 +787,15 @@ export class Genkit implements HasRegistry { return run(name, funcOrInput, this.registry); } + /** + * Returns current action (or flow) invocation context. Can be used to access things like auth + * data set by HTTP server frameworks. If invoked outside of an action (e.g. flow or tool) will + * return `undefined`. + */ + currentContext(): ActionContext | undefined { + return getContext(this); + } + /** * Configures the Genkit instance. */ diff --git a/js/genkit/src/index.ts b/js/genkit/src/index.ts index 7dbb1f9cd..204b60a59 100644 --- a/js/genkit/src/index.ts +++ b/js/genkit/src/index.ts @@ -121,12 +121,12 @@ export { defineJsonSchema, defineSchema, getCurrentEnv, - getFlowAuth, getStreamingCallback, isDevEnv, runWithStreamingCallback, z, type Action, + type ActionContext, type ActionMetadata, type Flow, type FlowConfig, diff --git a/js/genkit/tests/generate_test.ts b/js/genkit/tests/generate_test.ts index 45e3abe55..bc44ca99b 100644 --- a/js/genkit/tests/generate_test.ts +++ b/js/genkit/tests/generate_test.ts @@ -435,6 +435,65 @@ describe('generate', () => { }); }); + it('should propagate context to the tool', async () => { + const schema = z.object({ + foo: z.string(), + }); + + ai.defineTool( + { + name: 'testTool', + description: 'description', + inputSchema: schema, + outputSchema: schema, + }, + async (_, { context }) => { + return { + foo: `bar ${context.auth.email}`, + }; + } + ); + + // first response be tools call, the subsequent just text response from agent b. + let reqCounter = 0; + pm.handleResponse = async (req, sc) => { + return { + message: { + role: 'model', + content: [ + reqCounter++ === 0 + ? { + toolRequest: { + name: 'testTool', + input: { foo: 'fromTool' }, + ref: 'ref123', + }, + } + : { + text: req.messages + .splice(-1) + .map((m) => + m.content + .map( + (c) => + c.text || JSON.stringify(c.toolResponse?.output) + ) + .join() + ) + .join(), + }, + ], + }, + }; + }; + const { text } = await ai.generate({ + prompt: 'call the tool', + tools: ['testTool'], + context: { auth: { email: 'a@b.c' } }, + }); + assert.strictEqual(text, '{"foo":"bar a@b.c"}'); + }); + it('streams the tool responses', async () => { ai.defineTool( { name: 'testTool', description: 'description' }, diff --git a/js/plugins/express/src/index.ts b/js/plugins/express/src/index.ts index cdc6ada22..d30e26b04 100644 --- a/js/plugins/express/src/index.ts +++ b/js/plugins/express/src/index.ts @@ -34,7 +34,7 @@ export interface AuthPolicyContext< > { action?: Action; input: z.infer; - auth: any | undefined; + auth?: Record; request: RequestWithAuth; } @@ -56,7 +56,7 @@ export interface AuthPolicy< * the flow context. */ export interface RequestWithAuth extends express.Request { - auth?: unknown; + auth?: Record; } /** @@ -116,7 +116,7 @@ export function expressHandler< () => action.run(input, { onChunk, - context: auth, + context: { auth }, }) ); response.write( @@ -140,7 +140,7 @@ export function expressHandler< } } else { try { - const result = await action.run(input, { context: auth }); + const result = await action.run(input, { context: { auth } }); response.setHeader('x-genkit-trace-id', result.telemetry.traceId); response.setHeader('x-genkit-span-id', result.telemetry.spanId); // Responses for non-streaming flows are passed back with the flow result stored in a field called "result." diff --git a/js/plugins/express/tests/express_test.ts b/js/plugins/express/tests/express_test.ts index ad6a9a958..d4c0917fb 100644 --- a/js/plugins/express/tests/express_test.ts +++ b/js/plugins/express/tests/express_test.ts @@ -82,7 +82,7 @@ describe('expressHandler', async () => { inputSchema: z.object({ question: z.string() }), }, async (input, { context }) => { - return `${input.question} - ${JSON.stringify(context)}`; + return `${input.question} - ${JSON.stringify(context.auth)}`; } ); @@ -382,7 +382,7 @@ describe('startFlowServer', async () => { inputSchema: z.object({ question: z.string() }), }, async (input, { context }) => { - return `${input.question} - ${JSON.stringify(context)}`; + return `${input.question} - ${JSON.stringify(context.auth)}`; } ); diff --git a/js/plugins/firebase/tests/functions_test.ts b/js/plugins/firebase/tests/functions_test.ts index 365daec3a..7e50635fa 100644 --- a/js/plugins/firebase/tests/functions_test.ts +++ b/js/plugins/firebase/tests/functions_test.ts @@ -81,7 +81,7 @@ describe('function', () => { authPolicy: authPolicy, }, async (input, { context }) => { - return `hi ${input} - ${JSON.stringify(context)}`; + return `hi ${input} - ${JSON.stringify(context?.auth)}`; } ); @@ -96,7 +96,7 @@ describe('function', () => { sendChunk({ chubk: 2 }); sendChunk({ chubk: 3 }); - return `hi ${input} - ${JSON.stringify(context)}`; + return `hi ${input} - ${JSON.stringify(context?.auth)}`; } ); const app = express(); diff --git a/js/testapps/express/src/index.ts b/js/testapps/express/src/index.ts index 16a3757ed..8e6cc0019 100644 --- a/js/testapps/express/src/index.ts +++ b/js/testapps/express/src/index.ts @@ -79,7 +79,7 @@ app.use(express.json()); const authPolicies: Record = { jokeFlow: ({ auth }) => { - if (auth.username != 'Ali Baba') { + if (auth?.username != 'Ali Baba') { throw new Error('unauthorized: ' + JSON.stringify(auth)); } },