From f97b337a4ad1b5c87b7643cf667e7a84dba3c86d Mon Sep 17 00:00:00 2001 From: Pavel Jbanov Date: Thu, 13 Feb 2025 07:59:59 -0500 Subject: [PATCH 1/4] feat(js/ai): allow disabling native constrained generation and enabling custom instruction --- js/ai/src/formats/json.ts | 15 +++++- js/ai/src/formats/types.ts | 4 +- js/ai/src/generate.ts | 19 ++++---- js/ai/src/generate/action.ts | 27 ++++++++++- js/ai/src/model.ts | 18 ++++---- js/ai/tests/model/middleware_test.ts | 68 +++++++++++++++++++++++++++- js/core/src/action.ts | 4 +- js/core/src/error.ts | 2 +- js/genkit/src/registry.ts | 6 +-- js/genkit/tests/prompts_test.ts | 4 +- 10 files changed, 138 insertions(+), 29 deletions(-) diff --git a/js/ai/src/formats/json.ts b/js/ai/src/formats/json.ts index cc4beecd0..24c618995 100644 --- a/js/ai/src/formats/json.ts +++ b/js/ai/src/formats/json.ts @@ -23,8 +23,19 @@ export const jsonFormatter: Formatter = { format: 'json', contentType: 'application/json', constrained: true, + defaultInstruction: false, }, - handler: () => { + handler: (schema) => { + let instructions: string | undefined; + + if (schema) { + instructions = `Output should be in JSON format and conform to the following schema: + +\`\`\` +${JSON.stringify(schema)} +\`\`\` +`; + } return { parseChunk: (chunk) => { return extractJson(chunk.accumulatedText); @@ -33,6 +44,8 @@ export const jsonFormatter: Formatter = { parseMessage: (message) => { return extractJson(message.text); }, + + instructions, }; }, }; diff --git a/js/ai/src/formats/types.ts b/js/ai/src/formats/types.ts index 7f0c9fbd5..2c447a679 100644 --- a/js/ai/src/formats/types.ts +++ b/js/ai/src/formats/types.ts @@ -23,7 +23,9 @@ export type OutputContentTypes = 'application/json' | 'text/plain'; export interface Formatter { name: string; - config: ModelRequest['output']; + config: ModelRequest['output'] & { + defaultInstruction?: false; + }; handler: (schema?: JSONSchema) => { parseMessage(message: Message): O; parseChunk?: (chunk: GenerateResponseChunk) => CO; diff --git a/js/ai/src/generate.ts b/js/ai/src/generate.ts index d0b3ebef1..a428949f1 100755 --- a/js/ai/src/generate.ts +++ b/js/ai/src/generate.ts @@ -33,7 +33,10 @@ import { resolveFormat, resolveInstructions, } from './formats/index.js'; -import { generateHelper } from './generate/action.js'; +import { + generateHelper, + shouldInjectFormatInstruction, +} from './generate/action.js'; import { GenerateResponseChunk } from './generate/chunk.js'; import { GenerateResponse } from './generate/response.js'; import { Message } from './message.js'; @@ -211,7 +214,12 @@ export async function toGenerateRequest( ); const out = { - messages: injectInstructions(messages, instructions), + messages: shouldInjectFormatInstruction( + resolvedFormat?.config, + options.output + ) + ? injectInstructions(messages, instructions) + : messages, config: options.config, docs: options.docs, tools: tools?.map(toToolDefinition) || [], @@ -343,16 +351,11 @@ export async function generate< resolvedOptions.output.format = 'json'; } const resolvedFormat = await resolveFormat(registry, resolvedOptions.output); - const instructions = resolveInstructions( - resolvedFormat, - resolvedSchema, - resolvedOptions?.output?.instructions - ); const params: GenerateActionOptions = { model: resolvedModel.modelAction.__action.name, docs: resolvedOptions.docs, - messages: injectInstructions(messages, instructions), + messages: messages, tools, toolChoice: resolvedOptions.toolChoice, config: { diff --git a/js/ai/src/generate/action.ts b/js/ai/src/generate/action.ts index 74f69ad07..337d8820e 100644 --- a/js/ai/src/generate/action.ts +++ b/js/ai/src/generate/action.ts @@ -41,6 +41,7 @@ import { import { GenerateActionOptions, GenerateActionOptionsSchema, + GenerateActionOutputConfig, GenerateRequest, GenerateRequestSchema, GenerateResponseChunkData, @@ -171,8 +172,22 @@ function applyFormat( outRequest?.output?.instructions ); + console.log( + '- - - - -- - ', + JSON.stringify(resolvedFormat?.config, undefined, 2), + '\n', + JSON.stringify(rawRequest.output, undefined, 2) + ); + if (resolvedFormat) { - outRequest.messages = injectInstructions(outRequest.messages, instructions); + if ( + shouldInjectFormatInstruction(resolvedFormat.config, rawRequest?.output) + ) { + outRequest.messages = injectInstructions( + outRequest.messages, + instructions + ); + } outRequest.output = { // use output config from the format ...resolvedFormat.config, @@ -184,6 +199,16 @@ function applyFormat( return outRequest; } +export function shouldInjectFormatInstruction( + formatConfig?: Formatter['config'], + rawRequestConfig?: z.infer +) { + return ( + formatConfig?.defaultInstruction !== false || + rawRequestConfig?.instructions === true + ); +} + function applyTransferPreamble( rawRequest: GenerateActionOptions, transferPreamble?: GenerateActionOptions diff --git a/js/ai/src/model.ts b/js/ai/src/model.ts index 3f7e2ca57..368d48a24 100644 --- a/js/ai/src/model.ts +++ b/js/ai/src/model.ts @@ -699,6 +699,14 @@ export async function resolveModel( return out; } +export const GenerateActionOutputConfig = z.object({ + format: z.string().optional(), + contentType: z.string().optional(), + instructions: z.union([z.boolean(), z.string()]).optional(), + jsonSchema: z.any().optional(), + constrained: z.boolean().optional(), +}); + export const GenerateActionOptionsSchema = z.object({ /** A model name (e.g. `vertexai/gemini-1.0-pro`). */ model: z.string(), @@ -713,15 +721,7 @@ export const GenerateActionOptionsSchema = z.object({ /** Configuration for the generation request. */ config: z.any().optional(), /** Configuration for the desired output of the request. Defaults to the model's default output if unspecified. */ - output: z - .object({ - format: z.string().optional(), - contentType: z.string().optional(), - instructions: z.union([z.boolean(), z.string()]).optional(), - jsonSchema: z.any().optional(), - constrained: z.boolean().optional(), - }) - .optional(), + output: GenerateActionOutputConfig.optional(), /** Options for resuming an interrupted generation. */ resume: z .object({ diff --git a/js/ai/tests/model/middleware_test.ts b/js/ai/tests/model/middleware_test.ts index 49cd624fd..cdd70d1ca 100644 --- a/js/ai/tests/model/middleware_test.ts +++ b/js/ai/tests/model/middleware_test.ts @@ -396,7 +396,7 @@ describe('augmentWithContext', () => { }); }); -describe('simulateConstrainedGeneration', () => { +describe.only('simulateConstrainedGeneration', () => { let registry: Registry; beforeEach(() => { @@ -555,4 +555,70 @@ describe('simulateConstrainedGeneration', () => { tools: [], }); }); + + it('uses format instructions when instructions is explicitly set to true', async () => { + let pm = defineProgrammableModel(registry, { + supports: { constrained: 'all' }, + }); + pm.handleResponse = async (req, sc) => { + return { + message: { + role: 'model', + content: [{ text: '```\n{"foo": "bar"}\n```' }], + }, + }; + }; + + const { output } = await generate(registry, { + model: 'programmableModel', + prompt: 'generate json', + output: { + instructions: true, + constrained: false, + schema: z.object({ + foo: z.string(), + }), + }, + }); + assert.deepEqual(output, { foo: 'bar' }); + assert.deepStrictEqual(pm.lastRequest, { + config: {}, + messages: [ + { + role: 'user', + content: [ + { text: 'generate json' }, + { + metadata: { + purpose: 'output', + }, + text: + 'Output should be in JSON format and conform to the following schema:\n' + + '\n' + + '```\n' + + '{"type":"object","properties":{"foo":{"type":"string"}},"required":["foo"],"additionalProperties":true,"$schema":"http://json-schema.org/draft-07/schema#"}\n' + + '```\n', + }, + ], + }, + ], + output: { + constrained: false, + contentType: 'application/json', + format: 'json', + schema: { + $schema: 'http://json-schema.org/draft-07/schema#', + additionalProperties: true, + properties: { + foo: { + type: 'string', + }, + }, + required: ['foo'], + type: 'object', + }, + }, + tools: [], + }); + }); }); diff --git a/js/core/src/action.ts b/js/core/src/action.ts index 4fee1628c..4b9441f19 100644 --- a/js/core/src/action.ts +++ b/js/core/src/action.ts @@ -14,7 +14,7 @@ * limitations under the License. */ -import { JSONSchema7 } from 'json-schema'; +import { type JSONSchema7 } from 'json-schema'; import * as z from 'zod'; import { lazy } from './async.js'; import { ActionContext, getContext, runWithContext } from './context.js'; @@ -26,7 +26,7 @@ import { setCustomMetadataAttributes, } from './tracing.js'; -export { Status, StatusCodes, StatusSchema } from './statusTypes.js'; +export { StatusCodes, StatusSchema, type Status } from './statusTypes.js'; export { JSONSchema7 }; /** diff --git a/js/core/src/error.ts b/js/core/src/error.ts index bda27d595..305220a48 100644 --- a/js/core/src/error.ts +++ b/js/core/src/error.ts @@ -15,7 +15,7 @@ */ import { Registry } from './registry.js'; -import { httpStatusCode, StatusName } from './statusTypes.js'; +import { httpStatusCode, type StatusName } from './statusTypes.js'; export { StatusName }; diff --git a/js/genkit/src/registry.ts b/js/genkit/src/registry.ts index 8c45c10d4..367d697b5 100644 --- a/js/genkit/src/registry.ts +++ b/js/genkit/src/registry.ts @@ -15,8 +15,8 @@ */ export { - ActionType, - AsyncProvider, Registry, - Schema, + type ActionType, + type AsyncProvider, + type Schema, } from '@genkit-ai/core/registry'; diff --git a/js/genkit/tests/prompts_test.ts b/js/genkit/tests/prompts_test.ts index 14d547c81..3f4ea276e 100644 --- a/js/genkit/tests/prompts_test.ts +++ b/js/genkit/tests/prompts_test.ts @@ -173,7 +173,7 @@ describe('definePrompt', () => { }); }); -describe.only('definePrompt', () => { +describe('definePrompt', () => { describe('default model', () => { let ai: GenkitBeta; @@ -310,7 +310,7 @@ describe.only('definePrompt', () => { }); }); - describe.only('default model ref', () => { + describe('default model ref', () => { let ai: GenkitBeta; beforeEach(() => { From a04fffb8c0164100096c01057091ad064b2cf1bb Mon Sep 17 00:00:00 2001 From: Pavel Jbanov Date: Thu, 13 Feb 2025 08:01:41 -0500 Subject: [PATCH 2/4] cleanup --- js/ai/src/generate/action.ts | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/js/ai/src/generate/action.ts b/js/ai/src/generate/action.ts index 337d8820e..4354a0c65 100644 --- a/js/ai/src/generate/action.ts +++ b/js/ai/src/generate/action.ts @@ -172,13 +172,6 @@ function applyFormat( outRequest?.output?.instructions ); - console.log( - '- - - - -- - ', - JSON.stringify(resolvedFormat?.config, undefined, 2), - '\n', - JSON.stringify(rawRequest.output, undefined, 2) - ); - if (resolvedFormat) { if ( shouldInjectFormatInstruction(resolvedFormat.config, rawRequest?.output) @@ -205,7 +198,7 @@ export function shouldInjectFormatInstruction( ) { return ( formatConfig?.defaultInstruction !== false || - rawRequestConfig?.instructions === true + rawRequestConfig?.instructions ); } From 6f6d0abf1298c655327af99df9ab8b53aa206b49 Mon Sep 17 00:00:00 2001 From: Pavel Jbanov Date: Thu, 13 Feb 2025 08:06:18 -0500 Subject: [PATCH 3/4] format --- js/ai/src/generate/action.ts | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/js/ai/src/generate/action.ts b/js/ai/src/generate/action.ts index 4354a0c65..3f66f0707 100644 --- a/js/ai/src/generate/action.ts +++ b/js/ai/src/generate/action.ts @@ -197,8 +197,7 @@ export function shouldInjectFormatInstruction( rawRequestConfig?: z.infer ) { return ( - formatConfig?.defaultInstruction !== false || - rawRequestConfig?.instructions + formatConfig?.defaultInstruction !== false || rawRequestConfig?.instructions ); } From 708792fdca646a3422b03e25016b3fc2ad8c81cb Mon Sep 17 00:00:00 2001 From: Pavel Jbanov Date: Mon, 17 Feb 2025 21:16:54 -0500 Subject: [PATCH 4/4] feedback, testing --- js/ai/src/formats/json.ts | 2 +- js/ai/src/formats/types.ts | 2 +- js/ai/src/generate.ts | 6 +- js/ai/src/generate/action.ts | 7 ++- js/genkit/tests/formats_test.ts | 2 +- js/testapps/flow-simple-ai/src/index.ts | 82 +++++++++++++++++++++++++ js/testapps/format-tester/src/index.ts | 1 + 7 files changed, 93 insertions(+), 9 deletions(-) diff --git a/js/ai/src/formats/json.ts b/js/ai/src/formats/json.ts index 24c618995..5d7465546 100644 --- a/js/ai/src/formats/json.ts +++ b/js/ai/src/formats/json.ts @@ -23,7 +23,7 @@ export const jsonFormatter: Formatter = { format: 'json', contentType: 'application/json', constrained: true, - defaultInstruction: false, + defaultInstructions: false, }, handler: (schema) => { let instructions: string | undefined; diff --git a/js/ai/src/formats/types.ts b/js/ai/src/formats/types.ts index 2c447a679..3cd17e3d3 100644 --- a/js/ai/src/formats/types.ts +++ b/js/ai/src/formats/types.ts @@ -24,7 +24,7 @@ export type OutputContentTypes = 'application/json' | 'text/plain'; export interface Formatter { name: string; config: ModelRequest['output'] & { - defaultInstruction?: false; + defaultInstructions?: false; }; handler: (schema?: JSONSchema) => { parseMessage(message: Message): O; diff --git a/js/ai/src/generate.ts b/js/ai/src/generate.ts index a428949f1..8d9a1766f 100755 --- a/js/ai/src/generate.ts +++ b/js/ai/src/generate.ts @@ -35,7 +35,7 @@ import { } from './formats/index.js'; import { generateHelper, - shouldInjectFormatInstruction, + shouldInjectFormatInstructions, } from './generate/action.js'; import { GenerateResponseChunk } from './generate/chunk.js'; import { GenerateResponse } from './generate/response.js'; @@ -214,7 +214,7 @@ export async function toGenerateRequest( ); const out = { - messages: shouldInjectFormatInstruction( + messages: shouldInjectFormatInstructions( resolvedFormat?.config, options.output ) @@ -225,8 +225,8 @@ export async function toGenerateRequest( tools: tools?.map(toToolDefinition) || [], output: { ...(resolvedFormat?.config || {}), - schema: resolvedSchema, ...options.output, + schema: resolvedSchema, }, } as GenerateRequest; if (!out?.output?.schema) delete out?.output?.schema; diff --git a/js/ai/src/generate/action.ts b/js/ai/src/generate/action.ts index 3f66f0707..2dc71bf8a 100644 --- a/js/ai/src/generate/action.ts +++ b/js/ai/src/generate/action.ts @@ -174,7 +174,7 @@ function applyFormat( if (resolvedFormat) { if ( - shouldInjectFormatInstruction(resolvedFormat.config, rawRequest?.output) + shouldInjectFormatInstructions(resolvedFormat.config, rawRequest?.output) ) { outRequest.messages = injectInstructions( outRequest.messages, @@ -192,12 +192,13 @@ function applyFormat( return outRequest; } -export function shouldInjectFormatInstruction( +export function shouldInjectFormatInstructions( formatConfig?: Formatter['config'], rawRequestConfig?: z.infer ) { return ( - formatConfig?.defaultInstruction !== false || rawRequestConfig?.instructions + formatConfig?.defaultInstructions !== false || + rawRequestConfig?.instructions ); } diff --git a/js/genkit/tests/formats_test.ts b/js/genkit/tests/formats_test.ts index 774bbd47e..b23ef4967 100644 --- a/js/genkit/tests/formats_test.ts +++ b/js/genkit/tests/formats_test.ts @@ -93,7 +93,7 @@ describe('formats', () => { }); it('lets you define and use a custom output format with simulated constrained generation', async () => { - defineEchoModel(ai, { supports: { constrained: false } }); + defineEchoModel(ai, { supports: { constrained: 'none' } }); const { output } = await ai.generate({ model: 'echoModel', diff --git a/js/testapps/flow-simple-ai/src/index.ts b/js/testapps/flow-simple-ai/src/index.ts index 79da26e11..fea7b141f 100644 --- a/js/testapps/flow-simple-ai/src/index.ts +++ b/js/testapps/flow-simple-ai/src/index.ts @@ -654,3 +654,85 @@ ai.defineFlow('blockingMiddleware', async () => { }); return text; }); + +ai.defineFlow('formatJson', async (input, { sendChunk }) => { + const { output, text } = await ai.generate({ + prompt: `generate an RPG game character of type ${input || 'archer'}`, + output: { + constrained: false, + instructions: true, + schema: z + .object({ + name: z.string(), + weapon: z.string(), + }) + .strict(), + }, + onChunk: (c) => sendChunk(c.output), + }); + return { output, text }; +}); + +ai.defineFlow('formatJsonManualSchema', async (input, { sendChunk }) => { + const { output, text } = await ai.generate({ + prompt: `generate one RPG game character of type ${input || 'archer'} and generated JSON must match this interface + + \`\`\`typescript + interface Character { + name: string; + weapon: string; + } + \`\`\` + `, + output: { + constrained: true, + instructions: false, + schema: z + .object({ + name: z.string(), + weapon: z.string(), + }) + .strict(), + }, + onChunk: (c) => sendChunk(c.output), + }); + return { output, text }; +}); + +ai.defineFlow('testArray', async (input, { sendChunk }) => { + const { output } = await ai.generate({ + prompt: `10 different weapons for ${input}`, + output: { + format: 'array', + schema: z.array(z.string()), + }, + onChunk: (c) => sendChunk(c.output), + }); + return output; +}); + +ai.defineFlow('formatEnum', async (input, { sendChunk }) => { + const { output } = await ai.generate({ + prompt: `classify the denger level of sky diving`, + output: { + format: 'enum', + schema: z.enum(['safe', 'dangerous', 'medium']), + }, + onChunk: (c) => sendChunk(c.output), + }); + return output; +}); + +ai.defineFlow('formatJsonl', async (input, { sendChunk }) => { + const { output } = await ai.generate({ + prompt: `generate 5 randon persons`, + output: { + format: 'jsonl', + schema: z.array( + z.object({ name: z.string(), surname: z.string() }).strict() + ), + }, + onChunk: (c) => sendChunk(c.output), + }); + return output; +}); diff --git a/js/testapps/format-tester/src/index.ts b/js/testapps/format-tester/src/index.ts index 84dcfb590..59716c162 100644 --- a/js/testapps/format-tester/src/index.ts +++ b/js/testapps/format-tester/src/index.ts @@ -154,6 +154,7 @@ if (!models.length) { 'vertexai/gemini-1.5-flash', 'googleai/gemini-1.5-pro', 'googleai/gemini-1.5-flash', + 'googleai/gemini-2.0-flash', ]; }