Skip to content

Commit

Permalink
feat(js): propagate context to sub actions, expose context in prompts (
Browse files Browse the repository at this point in the history
pavelgj authored Jan 27, 2025
1 parent 7c65ece commit 21ffd9b
Showing 18 changed files with 521 additions and 86 deletions.
6 changes: 3 additions & 3 deletions docs/auth.md
Original file line number Diff line number Diff line change
@@ -73,15 +73,15 @@ When running with the Genkit Development UI, you can pass the Auth object by
entering JSON in the "Auth JSON" tab: `{"uid": "abc-def"}`.

You can also retrieve the auth context for the flow at any time within the flow
by calling `getFlowAuth()`, including in functions invoked by the flow:
by calling `ai.currentContext()`, including in functions invoked by the flow:

```ts
import { genkit, z } from 'genkit';

const ai = genkit({ ... });;

async function readDatabase(uid: string) {
const auth = ai.getAuthContext();
const auth = ai.currentContext()?.auth;
if (auth?.admin) {
// Do something special if the user is an admin
} else {
@@ -153,7 +153,7 @@ export const selfSummaryFlow = onFlow(

When using the Firebase Auth plugin, `user` will be returned as a
[DecodedIdToken](https://firebase.google.com/docs/reference/admin/node/firebase-admin.auth.decodedidtoken).
You can always retrieve this object at any time via `getFlowAuth()` as noted
You can always retrieve this object at any time via `ai.currentContext()` as noted
above. When running this flow during development, you would pass the user object
in the same way:

22 changes: 16 additions & 6 deletions js/ai/src/generate.ts
Original file line number Diff line number Diff line change
@@ -16,8 +16,10 @@

import {
Action,
ActionContext,
GenkitError,
StreamingCallback,
runWithContext,
runWithStreamingCallback,
sentinelNoopStreamingCallback,
z,
@@ -113,8 +115,8 @@ export interface GenerateOptions<
* const interrupt = response.interrupts[0];
*
* const resumedResponse = await ai.generate({
* messages: response.messages,
* resume: myInterrupt.reply(interrupt, {note: "this is the reply data"}),
* messages: response.messages,
* resume: myInterrupt.reply(interrupt, {note: "this is the reply data"}),
* });
* ```
*/
@@ -133,6 +135,8 @@ export interface GenerateOptions<
streamingCallback?: StreamingCallback<GenerateResponseChunk>;
/** Middleware to be used with this model call. */
use?: ModelMiddleware[];
/** Additional context (data, like e.g. auth) to be passed down to tools, prompts and other sub actions. */
context?: ActionContext;
}

function applyResumeOption(
@@ -376,10 +380,16 @@ export async function generate<
registry,
stripNoop(resolvedOptions.onChunk ?? resolvedOptions.streamingCallback),
async () => {
const response = await generateHelper(registry, {
rawRequest: params,
middleware: resolvedOptions.use,
});
const generateFn = () =>
generateHelper(registry, {
rawRequest: params,
middleware: resolvedOptions.use,
});
const response = await runWithContext(
registry,
resolvedOptions.context,
generateFn
);
const request = await toGenerateRequest(registry, {
...resolvedOptions,
tools,
66 changes: 54 additions & 12 deletions js/ai/src/prompt.ts
Original file line number Diff line number Diff line change
@@ -17,8 +17,10 @@
import {
Action,
ActionAsyncParams,
ActionContext,
defineActionAsync,
GenkitError,
getContext,
JSONSchema7,
stripUndefinedProps,
z,
@@ -117,6 +119,7 @@ export interface PromptConfig<
tools?: ToolArgument[];
toolChoice?: ToolChoice;
use?: ModelMiddleware[];
context?: ActionContext;
}

/**
@@ -179,6 +182,7 @@ export type PartsResolver<I, S = any> = (
input: I,
options: {
state?: S;
context: ActionContext;
}
) => Part[] | Promise<string | Part | Part[]>;

@@ -187,12 +191,14 @@ export type MessagesResolver<I, S = any> = (
options: {
history?: MessageData[];
state?: S;
context: ActionContext;
}
) => MessageData[] | Promise<MessageData[]>;

export type DocsResolver<I, S = any> = (
input: I,
options: {
context: ActionContext;
state?: S;
}
) => DocumentData[] | Promise<DocumentData[]>;
@@ -250,7 +256,8 @@ function definePromptAsync<
input,
messages,
resolvedOptions,
promptCache
promptCache,
renderOptions
);
await renderMessages(
registry,
@@ -267,13 +274,15 @@ function definePromptAsync<
input,
messages,
resolvedOptions,
promptCache
promptCache,
renderOptions
);

let docs: DocumentData[] | undefined;
if (typeof resolvedOptions.docs === 'function') {
docs = await resolvedOptions.docs(input, {
state: session?.state,
context: renderOptions?.context || getContext(registry) || {},
});
} else {
docs = resolvedOptions.docs;
@@ -287,6 +296,7 @@ function definePromptAsync<
tools: resolvedOptions.tools,
returnToolRequests: resolvedOptions.returnToolRequests,
toolChoice: resolvedOptions.toolChoice,
context: resolvedOptions.context,
output: resolvedOptions.output,
use: resolvedOptions.use,
...stripUndefinedProps(renderOptions),
@@ -442,13 +452,17 @@ async function renderSystemPrompt<
input: z.infer<I>,
messages: MessageData[],
options: PromptConfig<I, O, CustomOptions>,
promptCache: PromptCache
promptCache: PromptCache,
renderOptions: PromptGenerateOptions<O, CustomOptions> | undefined
) {
if (typeof options.system === 'function') {
messages.push({
role: 'system',
content: normalizeParts(
await options.system(input, { state: session?.state })
await options.system(input, {
state: session?.state,
context: renderOptions?.context || getContext(registry) || {},
})
),
});
} else if (typeof options.system === 'string') {
@@ -458,7 +472,14 @@ async function renderSystemPrompt<
}
messages.push({
role: 'system',
content: await renderDotpromptToParts(promptCache.system, input, session),
content: await renderDotpromptToParts(
registry,
promptCache.system,
input,
session,
options,
renderOptions
),
});
} else if (options.system) {
messages.push({
@@ -486,6 +507,7 @@ async function renderMessages<
messages.push(
...(await options.messages(input, {
state: session?.state,
context: renderOptions?.context || getContext(registry) || {},
history: renderOptions?.messages,
}))
);
@@ -498,7 +520,10 @@ async function renderMessages<
}
const rendered = await promptCache.messages({
input,
context: { state: session?.state },
context: {
...(renderOptions?.context || getContext(registry)),
state: session?.state,
},
messages: renderOptions?.messages?.map((m) =>
Message.parseData(m)
) as DpMessage[],
@@ -528,13 +553,17 @@ async function renderUserPrompt<
input: z.infer<I>,
messages: MessageData[],
options: PromptConfig<I, O, CustomOptions>,
promptCache: PromptCache
promptCache: PromptCache,
renderOptions: PromptGenerateOptions<O, CustomOptions> | undefined
) {
if (typeof options.prompt === 'function') {
messages.push({
role: 'user',
content: normalizeParts(
await options.prompt(input, { state: session?.state })
await options.prompt(input, {
state: session?.state,
context: renderOptions?.context || getContext(registry) || {},
})
),
});
} else if (typeof options.prompt === 'string') {
@@ -545,9 +574,12 @@ async function renderUserPrompt<
messages.push({
role: 'user',
content: await renderDotpromptToParts(
registry,
promptCache.userPrompt,
input,
session
session,
options,
renderOptions
),
});
} else if (options.prompt) {
@@ -585,14 +617,24 @@ function normalizeParts(parts: string | Part | Part[]): Part[] {
return [parts as Part];
}

async function renderDotpromptToParts(
async function renderDotpromptToParts<
I extends z.ZodTypeAny = z.ZodTypeAny,
O extends z.ZodTypeAny = z.ZodTypeAny,
CustomOptions extends z.ZodTypeAny = z.ZodTypeAny,
>(
registry: Registry,
promptFn: PromptFunction,
input: any,
session?: Session
session: Session | undefined,
options: PromptConfig<I, O, CustomOptions>,
renderOptions: PromptGenerateOptions<O, CustomOptions> | undefined
): Promise<Part[]> {
const renderred = await promptFn({
input,
context: { state: session?.state },
context: {
...(renderOptions?.context || getContext(registry)),
state: session?.state,
},
});
if (renderred.messages.length !== 1) {
throw new Error('parts tempate must produce only one message');
8 changes: 7 additions & 1 deletion js/ai/src/tool.ts
Original file line number Diff line number Diff line change
@@ -16,6 +16,7 @@

import {
Action,
ActionContext,
defineAction,
JSONSchema7,
stripUndefinedProps,
@@ -188,6 +189,8 @@ export interface ToolFnOptions {
* getting interrupted (immediately) and tool request returned to the upstream caller.
*/
interrupt: (metadata?: Record<string, any>) => never;

context: ActionContext;
}

export type ToolFn<I extends z.ZodTypeAny, O extends z.ZodTypeAny> = (
@@ -212,9 +215,12 @@ export function defineTool<I extends z.ZodTypeAny, O extends z.ZodTypeAny>(
actionType: 'tool',
metadata: { ...(config.metadata || {}), type: 'tool' },
},
(i) =>
(i, { context }) =>
fn(i, {
interrupt: interruptTool,
context: {
...context,
},
})
);
(a as ToolAction<I, O>).reply = (interrupt, replyData, options) => {
300 changes: 288 additions & 12 deletions js/ai/tests/prompt/prompt_test.ts
Original file line number Diff line number Diff line change
@@ -14,20 +14,21 @@
* limitations under the License.
*/

import { z } from '@genkit-ai/core';
import { ActionContext, runWithContext, z } from '@genkit-ai/core';
import { Registry } from '@genkit-ai/core/registry';
import assert from 'node:assert';
import { beforeEach, describe, it } from 'node:test';
import { toJsonSchema } from '../../../core/src/schema';
import { Document } from '../../lib/document';
import { GenerateOptions } from '../../lib/index';
import { Session } from '../../lib/session';
import { ModelAction, defineModel } from '../../src/model.ts';
import { ModelAction, defineModel } from '../../src/model';
import {
PromptConfig,
PromptGenerateOptions,
definePrompt,
} from '../../src/prompt.ts';
import { defineTool } from '../../src/tool.ts';
} from '../../src/prompt';
import { defineTool } from '../../src/tool';

describe('prompt', () => {
let registry;
@@ -55,6 +56,7 @@ describe('prompt', () => {
wantRendered?: GenerateOptions;
state?: any;
only?: boolean;
context?: ActionContext;
}[] = [
{
name: 'renders user prompt',
@@ -79,6 +81,32 @@ describe('prompt', () => {
model: 'echoModel',
},
},
{
name: 'renders user prompt with context',
prompt: {
model: 'echoModel',
name: 'prompt1',
config: { banana: 'ripe' },
input: { schema: z.object({ name: z.string() }) },
prompt: 'hello {{name}} ({{@state.name}}, {{@auth.email}})',
},
input: { name: 'foo' },
state: { name: 'bar' },
inputOptions: { config: { temperature: 11 } },
context: { auth: { email: 'a@b.c' } },
wantTextOutput:
'Echo: hello foo (bar, a@b.c); config: {"banana":"ripe","temperature":11}',
wantRendered: {
config: {
banana: 'ripe',
temperature: 11,
},
messages: [
{ content: [{ text: 'hello foo (bar, a@b.c)' }], role: 'user' },
],
model: 'echoModel',
},
},
{
name: 'renders user prompt with explicit messages override',
prompt: {
@@ -229,6 +257,69 @@ describe('prompt', () => {
model: 'echoModel',
},
},
{
name: 'renders user prompt from function with context',
prompt: {
model: 'echoModel',
name: 'prompt1',
config: { banana: 'ripe' },
input: { schema: z.object({ name: z.string() }) },
prompt: async (input, { state, context }) => {
return `hello ${input.name} (${state.name}, ${context.auth?.email})`;
},
},
input: { name: 'foo' },
state: { name: 'bar' },
context: { auth: { email: 'a@b.c' } },
inputOptions: { config: { temperature: 11 } },
wantTextOutput:
'Echo: hello foo (bar, a@b.c); config: {"banana":"ripe","temperature":11}',
wantRendered: {
config: {
banana: 'ripe',
temperature: 11,
},
messages: [
{ content: [{ text: 'hello foo (bar, a@b.c)' }], role: 'user' },
],
model: 'echoModel',
},
},
{
name: 'renders user prompt from function with context as render option',
prompt: {
model: 'echoModel',
name: 'prompt1',
config: { banana: 'ripe' },
input: { schema: z.object({ name: z.string() }) },
prompt: async (input, { state, context }) => {
return `hello ${input.name} (${state.name}, ${context.auth?.email})`;
},
},
input: { name: 'foo' },
state: { name: 'bar' },
inputOptions: {
config: { temperature: 11 },
context: { auth: { email: 'a@b.c' } },
},
wantTextOutput:
'Echo: hello foo (bar, a@b.c); config: {"banana":"ripe","temperature":11}',
wantRendered: {
config: {
banana: 'ripe',
temperature: 11,
},
context: {
auth: {
email: 'a@b.c',
},
},
messages: [
{ content: [{ text: 'hello foo (bar, a@b.c)' }], role: 'user' },
],
model: 'echoModel',
},
},
{
name: 'renders system prompt',
prompt: {
@@ -252,6 +343,61 @@ describe('prompt', () => {
model: 'echoModel',
},
},
{
name: 'renders system prompt with context',
prompt: {
model: 'echoModel',
name: 'prompt1',
config: { banana: 'ripe' },
input: { schema: z.object({ name: z.string() }) },
system: 'hello {{name}} ({{@state.name}}, {{@auth.email}})',
},
input: { name: 'foo' },
state: { name: 'bar' },
context: { auth: { email: 'a@b.c' } },
inputOptions: { config: { temperature: 11 } },
wantTextOutput:
'Echo: system: hello foo (bar, a@b.c); config: {"banana":"ripe","temperature":11}',
wantRendered: {
config: {
banana: 'ripe',
temperature: 11,
},
messages: [
{ content: [{ text: 'hello foo (bar, a@b.c)' }], role: 'system' },
],
model: 'echoModel',
},
},
{
name: 'renders system prompt with context as render option',
prompt: {
model: 'echoModel',
name: 'prompt1',
config: { banana: 'ripe' },
input: { schema: z.object({ name: z.string() }) },
system: 'hello {{name}} ({{@state.name}}, {{@auth.email}})',
},
input: { name: 'foo' },
state: { name: 'bar' },
inputOptions: {
config: { temperature: 11 },
context: { auth: { email: 'a@b.c' } },
},
wantTextOutput:
'Echo: system: hello foo (bar, a@b.c); config: {"banana":"ripe","temperature":11}',
wantRendered: {
config: {
banana: 'ripe',
temperature: 11,
},
context: { auth: { email: 'a@b.c' } },
messages: [
{ content: [{ text: 'hello foo (bar, a@b.c)' }], role: 'system' },
],
model: 'echoModel',
},
},
{
name: 'renders system prompt from a function',
prompt: {
@@ -277,6 +423,65 @@ describe('prompt', () => {
model: 'echoModel',
},
},
{
name: 'renders system prompt from a function with context',
prompt: {
model: 'echoModel',
name: 'prompt1',
config: { banana: 'ripe' },
input: { schema: z.object({ name: z.string() }) },
system: async (input, { state, context }) => {
return `hello ${input.name} (${state.name}, ${context.auth?.email})`;
},
},
input: { name: 'foo' },
state: { name: 'bar' },
context: { auth: { email: 'a@b.c' } },
inputOptions: { config: { temperature: 11 } },
wantTextOutput:
'Echo: system: hello foo (bar, a@b.c); config: {"banana":"ripe","temperature":11}',
wantRendered: {
config: {
banana: 'ripe',
temperature: 11,
},
messages: [
{ content: [{ text: 'hello foo (bar, a@b.c)' }], role: 'system' },
],
model: 'echoModel',
},
},
{
name: 'renders system prompt from a function with context as render option',
prompt: {
model: 'echoModel',
name: 'prompt1',
config: { banana: 'ripe' },
input: { schema: z.object({ name: z.string() }) },
system: async (input, { state, context }) => {
return `hello ${input.name} (${state.name}, ${context.auth?.email})`;
},
},
input: { name: 'foo' },
state: { name: 'bar' },
inputOptions: {
config: { temperature: 11 },
context: { auth: { email: 'a@b.c' } },
},
wantTextOutput:
'Echo: system: hello foo (bar, a@b.c); config: {"banana":"ripe","temperature":11}',
wantRendered: {
config: {
banana: 'ripe',
temperature: 11,
},
context: { auth: { email: 'a@b.c' } },
messages: [
{ content: [{ text: 'hello foo (bar, a@b.c)' }], role: 'system' },
],
model: 'echoModel',
},
},
{
name: 'renders messages from template',
prompt: {
@@ -304,6 +509,34 @@ describe('prompt', () => {
model: 'echoModel',
},
},
{
name: 'renders messages from template with context',
prompt: {
model: 'echoModel',
name: 'prompt1',
config: { banana: 'ripe' },
input: { schema: z.object({ name: z.string() }) },
messages:
'{{role "system"}}system {{name}}{{role "user"}}user {{@state.name}}, {{@auth.email}}',
},
input: { name: 'foo' },
state: { name: 'bar' },
context: { auth: { email: 'a@b.c' } },
inputOptions: { config: { temperature: 11 } },
wantTextOutput:
'Echo: system: system foo,user bar, a@b.c; config: {"banana":"ripe","temperature":11}',
wantRendered: {
config: {
banana: 'ripe',
temperature: 11,
},
messages: [
{ role: 'system', content: [{ text: 'system foo' }] },
{ role: 'user', content: [{ text: 'user bar, a@b.c' }] },
],
model: 'echoModel',
},
},
{
name: 'renders messages',
prompt: {
@@ -362,6 +595,39 @@ describe('prompt', () => {
model: 'echoModel',
},
},
{
name: 'renders messages from function with context',
prompt: {
model: 'echoModel',
name: 'prompt1',
config: { banana: 'ripe' },
input: { schema: z.object({ name: z.string() }) },
messages: async (input, { state, context }) => [
{ role: 'system', content: [{ text: `system ${input.name}` }] },
{
role: 'user',
content: [{ text: `user ${state.name}, ${context.auth?.email}` }],
},
],
},
input: { name: 'foo' },
state: { name: 'bar' },
context: { auth: { email: 'a@b.c' } },
inputOptions: { config: { temperature: 11 } },
wantTextOutput:
'Echo: system: system foo,user bar, a@b.c; config: {"banana":"ripe","temperature":11}',
wantRendered: {
config: {
banana: 'ripe',
temperature: 11,
},
messages: [
{ role: 'system', content: [{ text: 'system foo' }] },
{ role: 'user', content: [{ text: 'user bar, a@b.c' }] },
],
model: 'echoModel',
},
},
{
name: 'renders system, message and prompt in the same order',
prompt: {
@@ -504,23 +770,30 @@ describe('prompt', () => {
}
const p = definePrompt(registry, test.prompt);

const { text } = await (session
? session.run(() => p(test.input, test.inputOptions))
: p(test.input, test.inputOptions));
const sessionFn = () =>
session
? session.run(() => p(test.input, test.inputOptions))
: p(test.input, test.inputOptions);

const { text } = await runWithContext(registry, test.context, sessionFn);

assert.strictEqual(text, test.wantTextOutput);

const sessionRenderFn = () =>
session
? session.run(() => p.render(test.input, test.inputOptions))
: p.render(test.input, test.inputOptions);

assert.deepStrictEqual(
stripUndefined(
await (session
? session.run(() => p.render(test.input, test.inputOptions))
: p.render(test.input, test.inputOptions))
await runWithContext(registry, test.context, sessionRenderFn)
),
test.wantRendered
);
});
}

it.skip('respects output schema in the definition', async () => {
it('respects output schema in the definition', async () => {
const schema1 = z.object({
puppyName: z.string({ description: 'A cute name for a puppy' }),
});
@@ -543,7 +816,10 @@ describe('prompt', () => {
const generateRequest = await prompt1.render('poodle', {
model: 'geminiPro',
});
assert.equal(generateRequest.output?.schema, schema1);
assert.deepStrictEqual(
toJsonSchema({ schema: generateRequest.output?.schema }),
toJsonSchema({ schema: schema1 })
);
});
});

24 changes: 16 additions & 8 deletions js/core/src/action.ts
Original file line number Diff line number Diff line change
@@ -16,7 +16,7 @@

import { JSONSchema7 } from 'json-schema';
import * as z from 'zod';
import { getContext } from './context.js';
import { ActionContext, getContext, runWithContext } from './context.js';
import { ActionType, Registry } from './registry.js';
import { parseSchema } from './schema.js';
import {
@@ -70,7 +70,7 @@ export interface ActionRunOptions<S> {
/**
* Additional runtime context data (ex. auth context data).
*/
context?: any;
context?: ActionContext;

/**
* Additional span attributes to apply to OT spans.
@@ -90,7 +90,7 @@ export interface ActionFnArg<S> {
/**
* Additional runtime context data (ex. auth context data).
*/
context?: any;
context?: ActionContext;
}

/**
@@ -310,11 +310,19 @@ export function action<
metadata.input = input;

try {
const output = await fn(input, {
// Context can either be explicitly set, or inherited from the parent action.
context: options?.context ?? getContext(registry),
sendChunk: options?.onChunk ?? sentinelNoopStreamingCallback,
});
const actionFn = () =>
fn(input, {
// Context can either be explicitly set, or inherited from the parent action.
context: options?.context ?? getContext(registry),
sendChunk: options?.onChunk ?? sentinelNoopStreamingCallback,
});
// if context is explicitly passed in, we run action with the provided context,
// otherwise we let upstream context carry through.
const output = await runWithContext(
registry,
options?.context,
actionFn
);

metadata.output = JSON.stringify(output);
return output;
42 changes: 19 additions & 23 deletions js/core/src/context.ts
Original file line number Diff line number Diff line change
@@ -14,48 +14,44 @@
* limitations under the License.
*/

import { AsyncLocalStorage } from 'node:async_hooks';
import { runInActionRuntimeContext } from './action.js';
import { HasRegistry, Registry } from './registry.js';

const contextAlsKey = 'core.auth.context';
const legacyContextAsyncLocalStorage = new AsyncLocalStorage<any>();

export interface ActionContext {
/** Information about the currently authenticated user if provided. */
auth?: Record<string, any>;
[additionalContext: string]: any;
}

/**
* Execute the provided function in the runtime context. Call {@link getFlowContext()} anywhere
* within the async call stack to retrieve the context.
* within the async call stack to retrieve the context. If context object is undefined, this function
* is a no op passthrough, the function will be invoked as is.
*/
export function runWithContext<R>(
registry: Registry,
context: any,
context: ActionContext | undefined,
fn: () => R
) {
return legacyContextAsyncLocalStorage.run(context, () =>
registry.asyncStore.run(contextAlsKey, context, () =>
runInActionRuntimeContext(registry, fn)
)
);
}

/**
* Gets the auth object from the current context.
*
* @deprecated use {@link getFlowContext}
*/
export function getFlowAuth(registry?: Registry | HasRegistry): any {
if (!registry) {
return legacyContextAsyncLocalStorage.getStore();
): R {
if (context === undefined) {
return fn();
}
return getContext(registry);
return registry.asyncStore.run(contextAlsKey, context, () =>
runInActionRuntimeContext(registry, fn)
);
}

/**
* Gets the runtime context of the current flow.
*/
export function getContext(registry: Registry | HasRegistry): any {
export function getContext(
registry: Registry | HasRegistry
): ActionContext | undefined {
if ((registry as HasRegistry).registry) {
registry = (registry as HasRegistry).registry;
}
registry = registry as Registry;
return registry.asyncStore.getStore(contextAlsKey);
return registry.asyncStore.getStore<ActionContext>(contextAlsKey);
}
8 changes: 3 additions & 5 deletions js/core/src/flow.ts
Original file line number Diff line number Diff line change
@@ -17,7 +17,7 @@
import { AsyncLocalStorage } from 'node:async_hooks';
import { z } from 'zod';
import { Action, defineAction, StreamingCallback } from './action.js';
import { runWithContext } from './context.js';
import { ActionContext } from './context.js';
import { HasRegistry, Registry } from './registry.js';
import { runInNewSpan, SPAN_TYPE_ATTR } from './tracing.js';

@@ -64,7 +64,7 @@ export interface FlowSideChannel<S> {
/**
* Additional runtime context data (ex. auth context data).
*/
context?: any;
context?: ActionContext;
}

/**
@@ -125,9 +125,7 @@ function defineFlowAction<
const ctx = sendChunk;
(ctx as FlowSideChannel<z.infer<S>>).sendChunk = sendChunk;
(ctx as FlowSideChannel<z.infer<S>>).context = context;
return runWithContext(registry, context, () =>
fn(input, ctx as FlowSideChannel<z.infer<S>>)
);
return fn(input, ctx as FlowSideChannel<z.infer<S>>);
});
}
);
2 changes: 1 addition & 1 deletion js/core/src/index.ts
Original file line number Diff line number Diff line change
@@ -29,7 +29,7 @@ export const GENKIT_REFLECTION_API_SPEC_VERSION = 1;

export { z } from 'zod';
export * from './action.js';
export { getFlowAuth } from './context.js';
export { getContext, runWithContext, type ActionContext } from './context.js';
export { GenkitError } from './error.js';
export {
defineFlow,
23 changes: 23 additions & 0 deletions js/core/tests/action_test.ts
Original file line number Diff line number Diff line change
@@ -144,4 +144,27 @@ describe('action', () => {
{ count: 3 },
]);
});

it('should inherit context from parent action invocation', async () => {
const child = defineAction(
registry,
{ name: 'child', actionType: 'custom' },
async (_, { context }) => {
return `hi ${context.auth.email}`;
}
);
const parent = defineAction(
registry,
{ name: 'parent', actionType: 'custom' },
async () => {
return child();
}
);

const response = await parent(undefined, {
context: { auth: { email: 'a@b.c' } },
});

assert.strictEqual(response, 'hi a@b.c');
});
});
16 changes: 11 additions & 5 deletions js/core/tests/flow_test.ts
Original file line number Diff line number Diff line change
@@ -18,7 +18,7 @@ import { SimpleSpanProcessor } from '@opentelemetry/sdk-trace-base';
import * as assert from 'assert';
import { beforeEach, describe, it } from 'node:test';
import { defineFlow, run } from '../src/flow.js';
import { defineAction, getFlowAuth, z } from '../src/index.js';
import { defineAction, getContext, z } from '../src/index.js';
import { Registry } from '../src/registry.js';
import { enableTelemetry } from '../src/tracing.js';
import { TestSpanExporter } from './utils.js';
@@ -114,7 +114,13 @@ describe('flow', () => {
});
});

describe('getFlowAuth', () => {
describe('getContext', () => {
let registry: Registry;

beforeEach(() => {
registry = new Registry();
});

it('should run the flow', async () => {
const testFlow = defineFlow(
registry,
@@ -124,7 +130,7 @@ describe('flow', () => {
outputSchema: z.string(),
},
async (input) => {
return `bar ${input} ${JSON.stringify(getFlowAuth())}`;
return `bar ${input} ${JSON.stringify(getContext(registry))}`;
}
);

@@ -150,7 +156,7 @@ describe('flow', () => {
streamingCallback({ count: i });
}
}
return `bar ${input} ${!!streamingCallback} ${JSON.stringify(getFlowAuth())}`;
return `bar ${input} ${!!streamingCallback} ${JSON.stringify(getContext(registry))}`;
}
);

@@ -178,7 +184,7 @@ describe('flow', () => {
outputSchema: z.string(),
},
async (input) => {
return `bar ${input} ${JSON.stringify(getFlowAuth())}`;
return `bar ${input} ${JSON.stringify(getContext(registry))}`;
}
);

11 changes: 11 additions & 0 deletions js/genkit/src/genkit.ts
Original file line number Diff line number Diff line change
@@ -100,6 +100,7 @@ import {
import { ToolFn } from '@genkit-ai/ai/tool';
import {
Action,
ActionContext,
FlowConfig,
FlowFn,
JSONSchema,
@@ -108,6 +109,7 @@ import {
defineFlow,
defineJsonSchema,
defineSchema,
getContext,
isDevEnv,
run,
z,
@@ -785,6 +787,15 @@ export class Genkit implements HasRegistry {
return run(name, funcOrInput, this.registry);
}

/**
* Returns current action (or flow) invocation context. Can be used to access things like auth
* data set by HTTP server frameworks. If invoked outside of an action (e.g. flow or tool) will
* return `undefined`.
*/
currentContext(): ActionContext | undefined {
return getContext(this);
}

/**
* Configures the Genkit instance.
*/
2 changes: 1 addition & 1 deletion js/genkit/src/index.ts
Original file line number Diff line number Diff line change
@@ -121,12 +121,12 @@ export {
defineJsonSchema,
defineSchema,
getCurrentEnv,
getFlowAuth,
getStreamingCallback,
isDevEnv,
runWithStreamingCallback,
z,
type Action,
type ActionContext,
type ActionMetadata,
type Flow,
type FlowConfig,
59 changes: 59 additions & 0 deletions js/genkit/tests/generate_test.ts
Original file line number Diff line number Diff line change
@@ -435,6 +435,65 @@ describe('generate', () => {
});
});

it('should propagate context to the tool', async () => {
const schema = z.object({
foo: z.string(),
});

ai.defineTool(
{
name: 'testTool',
description: 'description',
inputSchema: schema,
outputSchema: schema,
},
async (_, { context }) => {
return {
foo: `bar ${context.auth.email}`,
};
}
);

// first response be tools call, the subsequent just text response from agent b.
let reqCounter = 0;
pm.handleResponse = async (req, sc) => {
return {
message: {
role: 'model',
content: [
reqCounter++ === 0
? {
toolRequest: {
name: 'testTool',
input: { foo: 'fromTool' },
ref: 'ref123',
},
}
: {
text: req.messages
.splice(-1)
.map((m) =>
m.content
.map(
(c) =>
c.text || JSON.stringify(c.toolResponse?.output)
)
.join()
)
.join(),
},
],
},
};
};
const { text } = await ai.generate({
prompt: 'call the tool',
tools: ['testTool'],
context: { auth: { email: 'a@b.c' } },
});
assert.strictEqual(text, '{"foo":"bar a@b.c"}');
});

it('streams the tool responses', async () => {
ai.defineTool(
{ name: 'testTool', description: 'description' },
8 changes: 4 additions & 4 deletions js/plugins/express/src/index.ts
Original file line number Diff line number Diff line change
@@ -34,7 +34,7 @@ export interface AuthPolicyContext<
> {
action?: Action<I, O, S>;
input: z.infer<I>;
auth: any | undefined;
auth?: Record<string, any>;
request: RequestWithAuth;
}

@@ -56,7 +56,7 @@ export interface AuthPolicy<
* the flow context.
*/
export interface RequestWithAuth extends express.Request {
auth?: unknown;
auth?: Record<string, any>;
}

/**
@@ -116,7 +116,7 @@ export function expressHandler<
() =>
action.run(input, {
onChunk,
context: auth,
context: { auth },
})
);
response.write(
@@ -140,7 +140,7 @@ export function expressHandler<
}
} else {
try {
const result = await action.run(input, { context: auth });
const result = await action.run(input, { context: { auth } });
response.setHeader('x-genkit-trace-id', result.telemetry.traceId);
response.setHeader('x-genkit-span-id', result.telemetry.spanId);
// Responses for non-streaming flows are passed back with the flow result stored in a field called "result."
4 changes: 2 additions & 2 deletions js/plugins/express/tests/express_test.ts
Original file line number Diff line number Diff line change
@@ -82,7 +82,7 @@ describe('expressHandler', async () => {
inputSchema: z.object({ question: z.string() }),
},
async (input, { context }) => {
return `${input.question} - ${JSON.stringify(context)}`;
return `${input.question} - ${JSON.stringify(context.auth)}`;
}
);

@@ -382,7 +382,7 @@ describe('startFlowServer', async () => {
inputSchema: z.object({ question: z.string() }),
},
async (input, { context }) => {
return `${input.question} - ${JSON.stringify(context)}`;
return `${input.question} - ${JSON.stringify(context.auth)}`;
}
);

4 changes: 2 additions & 2 deletions js/plugins/firebase/tests/functions_test.ts
Original file line number Diff line number Diff line change
@@ -81,7 +81,7 @@ describe('function', () => {
authPolicy: authPolicy,
},
async (input, { context }) => {
return `hi ${input} - ${JSON.stringify(context)}`;
return `hi ${input} - ${JSON.stringify(context?.auth)}`;
}
);

@@ -96,7 +96,7 @@ describe('function', () => {
sendChunk({ chubk: 2 });
sendChunk({ chubk: 3 });

return `hi ${input} - ${JSON.stringify(context)}`;
return `hi ${input} - ${JSON.stringify(context?.auth)}`;
}
);
const app = express();
2 changes: 1 addition & 1 deletion js/testapps/express/src/index.ts
Original file line number Diff line number Diff line change
@@ -79,7 +79,7 @@ app.use(express.json());

const authPolicies: Record<string, AuthPolicy> = {
jokeFlow: ({ auth }) => {
if (auth.username != 'Ali Baba') {
if (auth?.username != 'Ali Baba') {
throw new Error('unauthorized: ' + JSON.stringify(auth));
}
},

0 comments on commit 21ffd9b

Please sign in to comment.