Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(js/flows): consolidated defineFlow and defineStreamingFlow #1401

Merged
merged 6 commits into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 41 additions & 26 deletions js/core/src/flow.ts
Original file line number Diff line number Diff line change
Expand Up @@ -91,32 +91,37 @@ export interface StreamingFlowConfig<
streamSchema?: S;
}

export interface FlowCallOptions {
/** @deprecated use {@link context} instead. */
withLocalAuthContext?: unknown;
context?: unknown;
}

/**
* Non-streaming flow that can be called directly like a function.
*/
export interface CallableFlow<
I extends z.ZodTypeAny = z.ZodTypeAny,
O extends z.ZodTypeAny = z.ZodTypeAny,
S extends z.ZodTypeAny = z.ZodTypeAny,
> {
(
input?: z.infer<I>,
opts?: { withLocalAuthContext?: unknown }
): Promise<z.infer<O>>;
(input?: z.infer<I>, opts?: FlowCallOptions): Promise<z.infer<O>>;

stream(input?: z.infer<I>, opts?: FlowCallOptions): StreamingResponse<O, S>;

flow: Flow<I, O, z.ZodVoid>;
}

/**
* Streaming flow that can be called directly like a function.
* @deprecated use {@link CallableFlow}
*/
export interface StreamableFlow<
I extends z.ZodTypeAny = z.ZodTypeAny,
O extends z.ZodTypeAny = z.ZodTypeAny,
S extends z.ZodTypeAny = z.ZodTypeAny,
> {
(
input?: z.infer<I>,
opts?: { withLocalAuthContext?: unknown }
): StreamingResponse<O, S>;
(input?: z.infer<I>, opts?: FlowCallOptions): StreamingResponse<O, S>;
flow: Flow<I, O, S>;
}

Expand All @@ -128,7 +133,7 @@ interface StreamingResponse<
S extends z.ZodTypeAny = z.ZodTypeAny,
> {
/** Iterator over the streaming chunks. */
stream: AsyncGenerator<unknown, z.infer<O>, z.infer<S> | undefined>;
stream: AsyncGenerator<z.infer<S>>;
/** Final output of the flow. */
output: Promise<z.infer<O>>;
}
Expand All @@ -144,7 +149,7 @@ export type FlowFn<
/** Input to the flow. */
input: z.infer<I>,
/** Callback for streaming functions only. */
streamingCallback?: StreamingCallback<z.infer<S>>
streamingCallback: StreamingCallback<z.infer<S>>
) => Promise<z.infer<O>> | z.infer<O>;

/**
Expand Down Expand Up @@ -223,7 +228,10 @@ export class Flow<
});
try {
metadata.input = input;
const output = await this.flowFn(input, opts.streamingCallback);
const output = await this.flowFn(
input,
opts.streamingCallback ?? (() => {})
);
metadata.output = JSON.stringify(output);
setCustomMetadataAttribute(flowMetadataPrefix('state'), 'done');
return {
Expand Down Expand Up @@ -252,10 +260,7 @@ export class Flow<
/**
* Runs the flow. This is used when calling a flow from another flow.
*/
async run(
payload?: z.infer<I>,
opts?: { withLocalAuthContext?: unknown }
): Promise<z.infer<O>> {
async run(payload?: z.infer<I>, opts?: FlowCallOptions): Promise<z.infer<O>> {
const input = this.inputSchema ? this.inputSchema.parse(payload) : payload;
await this.authPolicy?.(opts?.withLocalAuthContext, payload);

Expand All @@ -266,7 +271,7 @@ export class Flow<
}

const result = await this.invoke(input, {
auth: opts?.withLocalAuthContext,
auth: opts?.context || opts?.withLocalAuthContext,
});
return result.result;
}
Expand All @@ -276,7 +281,7 @@ export class Flow<
*/
stream(
payload?: z.infer<I>,
opts?: { withLocalAuthContext?: unknown }
opts?: FlowCallOptions
): StreamingResponse<O, S> {
let chunkStreamController: ReadableStreamController<z.infer<S>>;
const chunkStream = new ReadableStream<z.infer<S>>({
Expand All @@ -288,7 +293,7 @@ export class Flow<
});

const authPromise =
this.authPolicy?.(opts?.withLocalAuthContext, payload) ??
this.authPolicy?.(opts?.context || opts?.withLocalAuthContext, payload) ??
Promise.resolve();

const invocationPromise = authPromise
Expand All @@ -301,7 +306,7 @@ export class Flow<
}) as S extends z.ZodVoid
? undefined
: StreamingCallback<z.infer<S>>,
auth: opts?.withLocalAuthContext,
auth: opts?.context || opts?.withLocalAuthContext,
}
).then((s) => s.result)
)
Expand Down Expand Up @@ -530,21 +535,31 @@ export class FlowServer {
export function defineFlow<
I extends z.ZodTypeAny = z.ZodTypeAny,
O extends z.ZodTypeAny = z.ZodTypeAny,
S extends z.ZodTypeAny = z.ZodTypeAny,
>(
registry: Registry,
config: FlowConfig<I, O> | string,
fn: FlowFn<I, O, z.ZodVoid>
): CallableFlow<I, O> {
config: StreamingFlowConfig<I, O> | string,
fn: FlowFn<I, O, S>
): CallableFlow<I, O, S> {
const resolvedConfig: FlowConfig<I, O> =
typeof config === 'string' ? { name: config } : config;

const flow = new Flow<I, O, z.ZodVoid>(registry, resolvedConfig, fn);
const flow = new Flow<I, O, S>(registry, resolvedConfig, fn);
registerFlowAction(registry, flow);
const callableFlow: CallableFlow<I, O> = async (input, opts) => {
const callableFlow = async (
input: z.infer<I>,
opts: FlowCallOptions
): Promise<z.infer<O>> => {
return flow.run(input, opts);
};
callableFlow.flow = flow;
return callableFlow;
(callableFlow as CallableFlow<I, O, S>).flow = flow;
(callableFlow as CallableFlow<I, O, S>).stream = (
input: z.infer<I>,
opts: FlowCallOptions
): StreamingResponse<O, S> => {
return flow.stream(input, opts);
};
return callableFlow as CallableFlow<I, O, S>;
}

/**
Expand Down
107 changes: 105 additions & 2 deletions js/core/tests/flow_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,19 @@
* limitations under the License.
*/

import { SimpleSpanProcessor } from '@opentelemetry/sdk-trace-base';
import assert from 'node:assert';
import { beforeEach, describe, it } from 'node:test';
import { defineFlow, defineStreamingFlow } from '../src/flow.js';
import { getFlowAuth, z } from '../src/index.js';
import { defineFlow, defineStreamingFlow, run } from '../src/flow.js';
import { defineAction, getFlowAuth, z } from '../src/index.js';
import { Registry } from '../src/registry.js';
import { enableTelemetry } from '../src/tracing.js';
import { TestSpanExporter } from './utils.js';

const spanExporter = new TestSpanExporter();
enableTelemetry({
spanProcessors: [new SimpleSpanProcessor(spanExporter)],
});

function createTestFlow(registry: Registry) {
return defineFlow(
Expand Down Expand Up @@ -224,4 +232,99 @@ describe('flow', () => {
assert.deepEqual(gotChunks, [{ count: 0 }, { count: 1 }, { count: 2 }]);
});
});

describe('telemetry', async () => {
beforeEach(() => {
spanExporter.exportedSpans = [];
});

it('should create a trace', async () => {
const testFlow = createTestFlow(registry);

const result = await testFlow('foo');

assert.equal(result, 'bar foo');
assert.strictEqual(spanExporter.exportedSpans.length, 1);
assert.strictEqual(spanExporter.exportedSpans[0].displayName, 'testFlow');
assert.deepStrictEqual(spanExporter.exportedSpans[0].attributes, {
'genkit:input': '"foo"',
'genkit:isRoot': true,
'genkit:metadata:flow:name': 'testFlow',
'genkit:metadata:flow:state': 'done',
'genkit:name': 'testFlow',
'genkit:output': '"bar foo"',
'genkit:path': '/{testFlow,t:flow}',
'genkit:state': 'success',
'genkit:type': 'flow',
});
});

it('records traces of nested actions', async () => {
const testAction = defineAction(
registry,
{
name: 'testAction',
actionType: 'tool',
metadata: { type: 'tool' },
},
async (i) => {
return 'bar';
}
);

const testFlow = defineFlow(
registry,
{
name: 'testFlow',
inputSchema: z.string(),
outputSchema: z.string(),
},
async (input) => {
return run('custom', async () => {
return 'foo ' + (await testAction(undefined));
});
}
);
const result = await testFlow('foo');

assert.equal(result, 'foo bar');
assert.strictEqual(spanExporter.exportedSpans.length, 3);

assert.strictEqual(
spanExporter.exportedSpans[0].displayName,
'testAction'
);
assert.deepStrictEqual(spanExporter.exportedSpans[0].attributes, {
'genkit:metadata:subtype': 'tool',
'genkit:name': 'testAction',
'genkit:output': '"bar"',
'genkit:path':
'/{testFlow,t:flow}/{custom,t:flowStep}/{testAction,t:action,s:tool}',
'genkit:state': 'success',
'genkit:type': 'action',
});

assert.strictEqual(spanExporter.exportedSpans[1].displayName, 'custom');
assert.deepStrictEqual(spanExporter.exportedSpans[1].attributes, {
'genkit:name': 'custom',
'genkit:output': '"foo bar"',
'genkit:path': '/{testFlow,t:flow}/{custom,t:flowStep}',
'genkit:state': 'success',
'genkit:type': 'flowStep',
});

assert.strictEqual(spanExporter.exportedSpans[2].displayName, 'testFlow');
assert.deepStrictEqual(spanExporter.exportedSpans[2].attributes, {
'genkit:input': '"foo"',
'genkit:isRoot': true,
'genkit:metadata:flow:name': 'testFlow',
'genkit:metadata:flow:state': 'done',
'genkit:name': 'testFlow',
'genkit:output': '"foo bar"',
'genkit:path': '/{testFlow,t:flow}',
'genkit:state': 'success',
'genkit:type': 'flow',
});
});
});
});
60 changes: 60 additions & 0 deletions js/core/tests/utils.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/**
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

import { SpanKind } from '@opentelemetry/api';
import { ExportResult } from '@opentelemetry/core';
import { ReadableSpan, SpanExporter } from '@opentelemetry/sdk-trace-base';

export class TestSpanExporter implements SpanExporter {
exportedSpans: any[] = [];

export(
spans: ReadableSpan[],
resultCallback: (result: ExportResult) => void
): void {
this.exportedSpans.push(...spans.map((s) => this._exportInfo(s)));
resultCallback({ code: 0 });
}

shutdown(): Promise<void> {
return this.forceFlush();
}

private _exportInfo(span: ReadableSpan) {
return {
spanId: span.spanContext().spanId,
traceId: span.spanContext().traceId,
attributes: { ...span.attributes },
displayName: span.name,
links: span.links,
spanKind: SpanKind[span.kind],
parentSpanId: span.parentSpanId,
sameProcessAsParentSpan: { value: !span.spanContext().isRemote },
status: span.status,
timeEvents: {
timeEvent: span.events.map((e) => ({
annotation: {
attributes: e.attributes ?? {},
description: e.name,
},
})),
},
};
}
forceFlush(): Promise<void> {
return Promise.resolve();
}
}
9 changes: 6 additions & 3 deletions js/genkit/src/genkit.ts
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,6 @@ import {
defineSchema,
defineStreamingFlow,
Flow,
FlowConfig,
FlowFn,
FlowServer,
FlowServerOptions,
Expand Down Expand Up @@ -203,7 +202,11 @@ export class Genkit {
defineFlow<
I extends z.ZodTypeAny = z.ZodTypeAny,
O extends z.ZodTypeAny = z.ZodTypeAny,
>(config: FlowConfig<I, O> | string, fn: FlowFn<I, O>): CallableFlow<I, O> {
S extends z.ZodTypeAny = z.ZodTypeAny,
>(
config: StreamingFlowConfig<I, O, S> | string,
fn: FlowFn<I, O, S>
): CallableFlow<I, O, S> {
const flow = defineFlow(this.registry, config, fn);
this.registeredFlows.push(flow.flow);
return flow;
Expand All @@ -212,7 +215,7 @@ export class Genkit {
/**
* Defines and registers a streaming flow.
*
* @todo TODO: Improve this documentation (show snippetss, etc).
* @deprecated use {@link defineFlow}
*/
defineStreamingFlow<
I extends z.ZodTypeAny = z.ZodTypeAny,
Expand Down
Loading