Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(js): add support for streaming json output #484

Merged
merged 4 commits into from
Jul 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ genkit eval:flow pdfQA '"What's a brief description of MapReduce?"'

FYI: `js` and `genkit-tools` are in two separate workspaces.

As you make changes you may want to build an test things by running test apps.
As you make changes you may want to build and test things by running test apps.
You can reduce the scope of what you're building by running a specific build command:

```
Expand Down
4 changes: 3 additions & 1 deletion js/ai/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
"build:clean": "rm -rf ./lib",
"build": "npm-run-all build:clean check compile",
"build:watch": "tsup-node --watch",
"test": "node --import tsx --test ./tests/**/*_test.ts"
"test": "node --import tsx --test ./tests/**/*_test.ts",
"test:single": "node --import tsx --test"
},
"repository": {
"type": "git",
Expand All @@ -30,6 +31,7 @@
"@types/node": "^20.11.19",
"json5": "^2.2.3",
"node-fetch": "^3.3.2",
"partial-json": "^0.1.7",
"zod": "^3.22.4"
},
"devDependencies": {
Expand Down
36 changes: 31 additions & 5 deletions js/ai/src/extract.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,27 @@
*/

import JSON5 from 'json5';
import { Allow, parse } from 'partial-json';

export function parsePartialJson<T = unknown>(jsonString: string): T {
return JSON5.parse<T>(JSON.stringify(parse(jsonString, Allow.ALL)));
}

/**
* Extracts JSON from string with lenient parsing rules to improve likelihood of successful extraction.
*/
export function extractJson<T = unknown>(text: string): T | null {
export function extractJson<T = unknown>(
text: string,
throwOnBadJson?: true
): T;
export function extractJson<T = unknown>(
text: string,
throwOnBadJson?: false
): T | null;
export function extractJson<T = unknown>(
text: string,
throwOnBadJson?: boolean
): T | null {
let openingChar: '{' | '[' | undefined;
let closingChar: '}' | ']' | undefined;
let startPos: number | undefined;
Expand Down Expand Up @@ -48,11 +64,21 @@ export function extractJson<T = unknown>(text: string): T | null {
}

if (startPos !== undefined && nestingCount > 0) {
// If an incomplete JSON structure is detected
try {
return JSON5.parse(text.substring(startPos) + (closingChar || '')) as T;
} catch (e) {
throw new Error(`Invalid JSON extracted from model output: ${text}`);
// Parse the incomplete JSON structure using partial-json for lenient parsing
// Note: partial-json automatically handles adding the closing character
return parsePartialJson<T>(text.substring(startPos));
} catch {
// If parsing fails, throw an error
if (throwOnBadJson) {
throw new Error(`Invalid JSON extracted from model output: ${text}`);
}
return null; // Return null if no JSON structure is found }
}
}
throw new Error(`No JSON object or array found in model output: ${text}`);
if (throwOnBadJson) {
throw new Error(`Invalid JSON extracted from model output: ${text}`);
}
return null; // Return null if no JSON structure is found
}
37 changes: 33 additions & 4 deletions js/ai/src/generate.ts
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ export class Message<T = unknown> implements MessageData {
*
* @returns The structured output contained in the message.
*/
output(): T | null {
output(): T {
return this.data() || extractJson<T>(this.text());
}

Expand Down Expand Up @@ -360,11 +360,17 @@ export class GenerateResponseChunk<T = unknown>
content: Part[];
/** Custom model-specific data for this chunk. */
custom?: unknown;
/** Accumulated chunks for partial output extraction. */
accumulatedChunks?: GenerateResponseChunkData[];

constructor(data: GenerateResponseChunkData) {
constructor(
data: GenerateResponseChunkData,
accumulatedChunks?: GenerateResponseChunkData[]
) {
this.index = data.index;
this.content = data.content || [];
this.custom = data.custom;
this.accumulatedChunks = accumulatedChunks;
}

/**
Expand Down Expand Up @@ -402,6 +408,18 @@ export class GenerateResponseChunk<T = unknown>
) as ToolRequestPart[];
}

/**
* Attempts to extract the longest valid JSON substring from the accumulated chunks.
* @returns The longest valid JSON substring found in the accumulated chunks.
*/
output(): T | null {
if (!this.accumulatedChunks) return null;
const accumulatedText = this.accumulatedChunks
.map((chunk) => chunk.content.map((part) => part.text || '').join(''))
.join('');
return extractJson<T>(accumulatedText, false);
}

toJSON(): GenerateResponseChunkData {
return { index: this.index, content: this.content, custom: this.custom };
}
Expand Down Expand Up @@ -586,6 +604,7 @@ export class NoValidCandidatesError extends GenkitError {
* @param options The options for this generation request.
* @returns The generated response based on the provided parameters.
*/

export async function generate<
O extends z.ZodTypeAny = z.ZodTypeAny,
CustomOptions extends z.ZodTypeAny = typeof GenerationCommonConfigSchema,
Expand All @@ -612,10 +631,20 @@ export async function generate<
}

const request = await toGenerateRequest(resolvedOptions);

const accumulatedChunks: GenerateResponseChunkData[] = [];

const response = await runWithStreamingCallback(
resolvedOptions.streamingCallback
? (chunk: GenerateResponseChunkData) =>
resolvedOptions.streamingCallback!(new GenerateResponseChunk(chunk))
? (chunk: GenerateResponseChunkData) => {
// Store accumulated chunk data
accumulatedChunks.push(chunk);
if (resolvedOptions.streamingCallback) {
resolvedOptions.streamingCallback!(
new GenerateResponseChunk(chunk, accumulatedChunks)
);
}
}
: undefined,
async () => new GenerateResponse<z.infer<O>>(await model(request), request)
);
Expand Down
75 changes: 75 additions & 0 deletions js/ai/tests/generate/generate_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,15 @@
import assert from 'node:assert';
import { describe, it } from 'node:test';
import { z } from 'zod';
import { GenerateResponseChunk } from '../../src/generate';
import {
Candidate,
GenerateOptions,
GenerateResponse,
Message,
toGenerateRequest,
} from '../../src/generate.js';
import { GenerateResponseChunkData } from '../../src/model';
import {
CandidateData,
GenerateRequest,
Expand Down Expand Up @@ -506,3 +508,76 @@ describe('toGenerateRequest', () => {
});
}
});

describe('GenerateResponseChunk', () => {
describe('#output()', () => {
const testCases = [
{
should: 'parse ``` correctly',
accumulatedChunksTexts: ['```'],
correctJson: null,
},
{
should: 'parse valid json correctly',
accumulatedChunksTexts: [`{"foo":"bar"}`],
correctJson: { foo: 'bar' },
},
{
should: 'if json invalid, return null',
accumulatedChunksTexts: [`invalid json`],
correctJson: null,
},
{
should: 'handle missing closing brace',
accumulatedChunksTexts: [`{"foo":"bar"`],
correctJson: { foo: 'bar' },
},
{
should: 'handle missing closing bracket in nested object',
accumulatedChunksTexts: [`{"foo": {"bar": "baz"`],
correctJson: { foo: { bar: 'baz' } },
},
{
should: 'handle multiple chunks',
accumulatedChunksTexts: [`{"foo": {"bar"`, `: "baz`],
correctJson: { foo: { bar: 'baz' } },
},
{
should: 'handle multiple chunks with nested objects',
accumulatedChunksTexts: [`\`\`\`json{"foo": {"bar"`, `: {"baz": "qux`],
correctJson: { foo: { bar: { baz: 'qux' } } },
},
{
should: 'handle array nested in object',
accumulatedChunksTexts: [`{"foo": ["bar`],
correctJson: { foo: ['bar'] },
},
{
should: 'handle array nested in object with multiple chunks',
accumulatedChunksTexts: [`\`\`\`json{"foo": {"bar"`, `: ["baz`],
correctJson: { foo: { bar: ['baz'] } },
},
];

for (const test of testCases) {
if (test.should) {
it(test.should, () => {
const accumulatedChunks: GenerateResponseChunkData[] =
test.accumulatedChunksTexts.map((text, index) => ({
index,
content: [{ text }],
}));

const chunkData = accumulatedChunks[accumulatedChunks.length - 1];

const responseChunk: GenerateResponseChunk =
new GenerateResponseChunk(chunkData, accumulatedChunks);

const output = responseChunk.output();

assert.deepStrictEqual(output, test.correctJson);
});
}
}
});
});
3 changes: 3 additions & 0 deletions js/pnpm-lock.yaml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading