Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

breaking(go): update flow streaming protocol to SSE #1316

Merged
merged 19 commits into from
Jan 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions go/genkit/servers.go
Original file line number Diff line number Diff line change
Expand Up @@ -399,12 +399,13 @@ func nonDurableFlowHandler(f flow) func(http.ResponseWriter, *http.Request) erro
return err
}
var callback streamingCallback[json.RawMessage]
if stream {
if r.Header.Get("Accept") == "text/event-stream" || stream {
w.Header().Set("Content-Type", "text/plain")
w.Header().Set("Transfer-Encoding", "chunked")
// Stream results are newline-separated JSON.
// Event Stream results are in JSON format separated by two newline escape sequences
// including the `data` and `message` labels
callback = func(ctx context.Context, msg json.RawMessage) error {
_, err := fmt.Fprintf(w, "%s\n", msg)
_, err := fmt.Fprintf(w, "data: {\"message\": %s}\n\n", msg)
if err != nil {
return err
}
Expand All @@ -417,8 +418,19 @@ func nonDurableFlowHandler(f flow) func(http.ResponseWriter, *http.Request) erro
// TODO: telemetry
out, err := f.runJSON(r.Context(), r.Header.Get("Authorization"), body.Data, callback)
if err != nil {
if r.Header.Get("Accept") == "text/event-stream" || stream {
_, err = fmt.Fprintf(w, "data: {\"error\": {\"status\": \"INTERNAL\", \"message\": \"stream flow error\", \"details\": \"%v\"}}\n\n", err)
return err
}
return err
}
// Responses for streaming, non-durable flows should be prefixed
// with "data"
if r.Header.Get("Accept") == "text/event-stream" || stream {
_, err = fmt.Fprintf(w, "data: {\"result\": %s}\n\n", out)
return err
}

// Responses for non-streaming, non-durable flows are passed back
// with the flow result stored in a field called "result."
_, err = fmt.Fprintf(w, `{"result": %s}\n`, out)
Expand Down
16 changes: 15 additions & 1 deletion go/samples/flow-sample1/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ import (
"encoding/json"
"errors"
"fmt"

"log"
"strconv"

Expand Down Expand Up @@ -106,6 +105,21 @@ func main() {
return fmt.Sprintf("done: %d, streamed: %d times", count, i), nil
})

genkit.DefineStreamingFlow(g, "streamyThrowy", func(ctx context.Context, count int, cb func(context.Context, chunk) error) (string, error) {
i := 0
if cb != nil {
for ; i < count; i++ {
if i == 3 {
return "", errors.New("boom!")
}
if err := cb(ctx, chunk{i}); err != nil {
return "", err
}
}
}
return fmt.Sprintf("done: %d, streamed: %d times", count, i), nil
})

if err := g.Start(context.Background(), nil); err != nil {
log.Fatal(err)
}
Expand Down
3 changes: 1 addition & 2 deletions go/tests/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,7 @@ type test struct {
const hostPort = "http://localhost:3100"

func TestReflectionAPI(t *testing.T) {
filenames, err := filepath.Glob(filepath.FromSlash("../../tests/*.yaml"))

filenames, err := filepath.Glob(filepath.FromSlash("../../tests/reflection_api_tests.yaml"))
hugoaguirre marked this conversation as resolved.
Show resolved Hide resolved
if err != nil {
t.Fatal(err)
}
Expand Down
25 changes: 24 additions & 1 deletion go/tests/test_app/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,23 @@ package main
import (
"context"
"encoding/json"
"fmt"
"log"

"github.com/firebase/genkit/go/ai"
"github.com/firebase/genkit/go/genkit"
)

func main() {
opts := genkit.StartOptions{
FlowAddr: "127.0.0.1:3400",
}

// used for streamed flows
type chunk struct {
Count int `json:"count"`
}

g, err := genkit.New(nil)
if err != nil {
log.Fatal(err)
Expand All @@ -39,7 +49,20 @@ func main() {
_ = res
return "TBD", nil
})
if err := g.Start(context.Background(), nil); err != nil {

genkit.DefineStreamingFlow(g, "streamy", func(ctx context.Context, count int, cb func(context.Context, chunk) error) (string, error) {
i := 0
if cb != nil {
for ; i < count; i++ {
if err := cb(ctx, chunk{i}); err != nil {
return "", err
}
}
}
return fmt.Sprintf("done %d, streamed: %d times", count, i), nil
})

if err := g.Start(context.Background(), &opts); err != nil {
log.Fatal(err)
}
}
Expand Down
13 changes: 13 additions & 0 deletions tests/flow_server_tests.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# This file describes the responses to HTTP requests made
# to the flow server

# TODO: add more test cases

app: flow_server
hugoaguirre marked this conversation as resolved.
Show resolved Hide resolved
tests:
- path: streamy
post:
data: 5
response:
message: '{"count":{count}}'
result: 'done {count}, streamed: {count} times'
5 changes: 3 additions & 2 deletions tests/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
"description": "",
"main": "lib/e2e.js",
"scripts": {
"test": "npm-run-all test:reflection_api",
"test": "npm-run-all test:flow_server test:reflection_api",
"test:dev_ui_test": "node --import tsx src/dev_ui_test.ts",
"test:reflection_api": "node --import tsx src/reflection_api_test.ts"
"test:reflection_api": "node --import tsx src/reflection_api_test.ts",
"test:flow_server": "node --import tsx src/flow_server_test.ts"
},
"keywords": [],
"author": "",
Expand Down
4 changes: 4 additions & 0 deletions tests/reflection_api_tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -876,3 +876,7 @@ tests:
/flow/testFlow:
key: /flow/testFlow
name: testFlow

/flow/streamy:
key: /flow/streamy
name: streamy
95 changes: 95 additions & 0 deletions tests/src/flow_server_test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
/**
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

import { readFileSync } from 'fs';
import { streamFlow } from 'genkit/client';
import * as yaml from 'yaml';
import { retriable, runTestsForApp } from './utils.js';

(async () => {
// TODO: Add NodeJS tests
// Run the tests for go test app
await runTestsForApp('../go/tests/test_app', 'go run main.go', async () => {
await testFlowServer();
console.log('Flow server tests done! \\o/');
});
})();

type TestResults = {
message: string;
result: string;
};

async function testFlowServer() {
const url = 'http://localhost:3400';
await retriable(
async () => {
const res = await fetch(`${url}/streamy`, {
method: 'POST',
body: JSON.stringify({
data: 1,
}),
});
if (res.status != 200) {
throw new Error(`timed out waiting for flow server to become healthy`);
}
},
{
maxRetries: 30,
delayMs: 1000,
}
);

const t = yaml.parse(readFileSync('flow_server_tests.yaml', 'utf8'));
for (const test of t.tests) {
let chunkCount = 0;
let expected: string = '';
let want: TestResults = {
message: test.response.message,
result: test.response.result,
};
console.log(`checking stream for: ${test.path}`);
(async () => {
const response = await streamFlow({
url: `${url}/${test.path}`,
input: test.post.data,
});

for await (const chunk of response.stream()) {
expected = want.message.replace('{count}', chunkCount.toString());
let chunkJSON = JSON.stringify(await chunk);
if (chunkJSON != expected) {
throw new Error(
`unexpected chunk data received, got: ${chunkJSON}, want: ${want.message}`
);
}
chunkCount++;
}
if (chunkCount != test.post.data) {
throw new Error(
`unexpected number of stream chunks received: got ${chunkCount}, want: ${test.post.data}`
);
}
let out = await response.output();
want.result = want.result.replace(/\{count\}/g, chunkCount.toString());
if (out != want.result) {
throw new Error(
`unexpected output received, got: ${out}, want: ${want.result}`
);
}
})();
}
}
20 changes: 20 additions & 0 deletions tests/test_js_app/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,23 @@ export const testFlow = ai.defineFlow(
return 'Test flow passed';
}
);

// genkit flow:run streamy 5 -s
export const streamy = ai.defineStreamingFlow(
{
name: 'streamy',
inputSchema: z.number(),
outputSchema: z.string(),
streamSchema: z.object({ count: z.number() }),
},
async (count, streamingCallback) => {
let i = 0;
if (streamingCallback) {
for (; i < count; i++) {
await new Promise((r) => setTimeout(r, 1000));
streamingCallback({ count: i });
}
}
return `done: ${count}, streamed: ${i} times`;
}
);
Loading