Skip to content

Commit

Permalink
feat(js): add support for streaming json output (#484)
Browse files Browse the repository at this point in the history
* feat(js): add support for streaming json output

* refactor: switch to partial-json library

* refactor: merge the two extract methods

* chore: update lockfile
  • Loading branch information
cabljac authored Jul 24, 2024
1 parent f6524e1 commit fe8a957
Show file tree
Hide file tree
Showing 6 changed files with 146 additions and 11 deletions.
2 changes: 1 addition & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ genkit eval:flow pdfQA '"What's a brief description of MapReduce?"'

FYI: `js` and `genkit-tools` are in two separate workspaces.

As you make changes you may want to build an test things by running test apps.
As you make changes you may want to build and test things by running test apps.
You can reduce the scope of what you're building by running a specific build command:

```
Expand Down
4 changes: 3 additions & 1 deletion js/ai/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
"build:clean": "rm -rf ./lib",
"build": "npm-run-all build:clean check compile",
"build:watch": "tsup-node --watch",
"test": "node --import tsx --test ./tests/**/*_test.ts"
"test": "node --import tsx --test ./tests/**/*_test.ts",
"test:single": "node --import tsx --test"
},
"repository": {
"type": "git",
Expand All @@ -30,6 +31,7 @@
"@types/node": "^20.11.19",
"json5": "^2.2.3",
"node-fetch": "^3.3.2",
"partial-json": "^0.1.7",
"zod": "^3.22.4"
},
"devDependencies": {
Expand Down
36 changes: 31 additions & 5 deletions js/ai/src/extract.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,27 @@
*/

import JSON5 from 'json5';
import { Allow, parse } from 'partial-json';

export function parsePartialJson<T = unknown>(jsonString: string): T {
return JSON5.parse<T>(JSON.stringify(parse(jsonString, Allow.ALL)));
}

/**
* Extracts JSON from string with lenient parsing rules to improve likelihood of successful extraction.
*/
export function extractJson<T = unknown>(text: string): T | null {
export function extractJson<T = unknown>(
text: string,
throwOnBadJson?: true
): T;
export function extractJson<T = unknown>(
text: string,
throwOnBadJson?: false
): T | null;
export function extractJson<T = unknown>(
text: string,
throwOnBadJson?: boolean
): T | null {
let openingChar: '{' | '[' | undefined;
let closingChar: '}' | ']' | undefined;
let startPos: number | undefined;
Expand Down Expand Up @@ -48,11 +64,21 @@ export function extractJson<T = unknown>(text: string): T | null {
}

if (startPos !== undefined && nestingCount > 0) {
// If an incomplete JSON structure is detected
try {
return JSON5.parse(text.substring(startPos) + (closingChar || '')) as T;
} catch (e) {
throw new Error(`Invalid JSON extracted from model output: ${text}`);
// Parse the incomplete JSON structure using partial-json for lenient parsing
// Note: partial-json automatically handles adding the closing character
return parsePartialJson<T>(text.substring(startPos));
} catch {
// If parsing fails, throw an error
if (throwOnBadJson) {
throw new Error(`Invalid JSON extracted from model output: ${text}`);
}
return null; // Return null if no JSON structure is found }
}
}
throw new Error(`No JSON object or array found in model output: ${text}`);
if (throwOnBadJson) {
throw new Error(`Invalid JSON extracted from model output: ${text}`);
}
return null; // Return null if no JSON structure is found
}
37 changes: 33 additions & 4 deletions js/ai/src/generate.ts
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ export class Message<T = unknown> implements MessageData {
*
* @returns The structured output contained in the message.
*/
output(): T | null {
output(): T {
return this.data() || extractJson<T>(this.text());
}

Expand Down Expand Up @@ -360,11 +360,17 @@ export class GenerateResponseChunk<T = unknown>
content: Part[];
/** Custom model-specific data for this chunk. */
custom?: unknown;
/** Accumulated chunks for partial output extraction. */
accumulatedChunks?: GenerateResponseChunkData[];

constructor(data: GenerateResponseChunkData) {
constructor(
data: GenerateResponseChunkData,
accumulatedChunks?: GenerateResponseChunkData[]
) {
this.index = data.index;
this.content = data.content || [];
this.custom = data.custom;
this.accumulatedChunks = accumulatedChunks;
}

/**
Expand Down Expand Up @@ -402,6 +408,18 @@ export class GenerateResponseChunk<T = unknown>
) as ToolRequestPart[];
}

/**
* Attempts to extract the longest valid JSON substring from the accumulated chunks.
* @returns The longest valid JSON substring found in the accumulated chunks.
*/
output(): T | null {
if (!this.accumulatedChunks) return null;
const accumulatedText = this.accumulatedChunks
.map((chunk) => chunk.content.map((part) => part.text || '').join(''))
.join('');
return extractJson<T>(accumulatedText, false);
}

toJSON(): GenerateResponseChunkData {
return { index: this.index, content: this.content, custom: this.custom };
}
Expand Down Expand Up @@ -586,6 +604,7 @@ export class NoValidCandidatesError extends GenkitError {
* @param options The options for this generation request.
* @returns The generated response based on the provided parameters.
*/

export async function generate<
O extends z.ZodTypeAny = z.ZodTypeAny,
CustomOptions extends z.ZodTypeAny = typeof GenerationCommonConfigSchema,
Expand All @@ -612,10 +631,20 @@ export async function generate<
}

const request = await toGenerateRequest(resolvedOptions);

const accumulatedChunks: GenerateResponseChunkData[] = [];

const response = await runWithStreamingCallback(
resolvedOptions.streamingCallback
? (chunk: GenerateResponseChunkData) =>
resolvedOptions.streamingCallback!(new GenerateResponseChunk(chunk))
? (chunk: GenerateResponseChunkData) => {
// Store accumulated chunk data
accumulatedChunks.push(chunk);
if (resolvedOptions.streamingCallback) {
resolvedOptions.streamingCallback!(
new GenerateResponseChunk(chunk, accumulatedChunks)
);
}
}
: undefined,
async () => new GenerateResponse<z.infer<O>>(await model(request), request)
);
Expand Down
75 changes: 75 additions & 0 deletions js/ai/tests/generate/generate_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,15 @@
import assert from 'node:assert';
import { describe, it } from 'node:test';
import { z } from 'zod';
import { GenerateResponseChunk } from '../../src/generate';
import {
Candidate,
GenerateOptions,
GenerateResponse,
Message,
toGenerateRequest,
} from '../../src/generate.js';
import { GenerateResponseChunkData } from '../../src/model';
import {
CandidateData,
GenerateRequest,
Expand Down Expand Up @@ -506,3 +508,76 @@ describe('toGenerateRequest', () => {
});
}
});

describe('GenerateResponseChunk', () => {
describe('#output()', () => {
const testCases = [
{
should: 'parse ``` correctly',
accumulatedChunksTexts: ['```'],
correctJson: null,
},
{
should: 'parse valid json correctly',
accumulatedChunksTexts: [`{"foo":"bar"}`],
correctJson: { foo: 'bar' },
},
{
should: 'if json invalid, return null',
accumulatedChunksTexts: [`invalid json`],
correctJson: null,
},
{
should: 'handle missing closing brace',
accumulatedChunksTexts: [`{"foo":"bar"`],
correctJson: { foo: 'bar' },
},
{
should: 'handle missing closing bracket in nested object',
accumulatedChunksTexts: [`{"foo": {"bar": "baz"`],
correctJson: { foo: { bar: 'baz' } },
},
{
should: 'handle multiple chunks',
accumulatedChunksTexts: [`{"foo": {"bar"`, `: "baz`],
correctJson: { foo: { bar: 'baz' } },
},
{
should: 'handle multiple chunks with nested objects',
accumulatedChunksTexts: [`\`\`\`json{"foo": {"bar"`, `: {"baz": "qux`],
correctJson: { foo: { bar: { baz: 'qux' } } },
},
{
should: 'handle array nested in object',
accumulatedChunksTexts: [`{"foo": ["bar`],
correctJson: { foo: ['bar'] },
},
{
should: 'handle array nested in object with multiple chunks',
accumulatedChunksTexts: [`\`\`\`json{"foo": {"bar"`, `: ["baz`],
correctJson: { foo: { bar: ['baz'] } },
},
];

for (const test of testCases) {
if (test.should) {
it(test.should, () => {
const accumulatedChunks: GenerateResponseChunkData[] =
test.accumulatedChunksTexts.map((text, index) => ({
index,
content: [{ text }],
}));

const chunkData = accumulatedChunks[accumulatedChunks.length - 1];

const responseChunk: GenerateResponseChunk =
new GenerateResponseChunk(chunkData, accumulatedChunks);

const output = responseChunk.output();

assert.deepStrictEqual(output, test.correctJson);
});
}
}
});
});
3 changes: 3 additions & 0 deletions js/pnpm-lock.yaml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit fe8a957

Please sign in to comment.