Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds defineSchema to support registered schema use in Dotprompt. #503

Merged
merged 4 commits into from
Jun 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 45 additions & 1 deletion docs/dotprompt.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ conditional portions to your prompt or iterate through structured content. The
file format utilizes YAML frontmatter to provide metadata for a prompt inline
with the template.

## Defining Input/Output Schemas with Picoschema
## Defining Input/Output Schemas

Dotprompt includes a compact, YAML-optimized schema definition format called
Picoschema to make it easy to define the most important attributs of a schema
Expand Down Expand Up @@ -142,6 +142,50 @@ output:
minimum: 20
```

### Leveraging Reusable Schemas

In addition to directly defining schemas in the `.prompt` file, you can reference
a schema registered with `defineSchema` by name. To register a schema:

```ts
import { defineSchema } from '@genkit-ai/core';
import { z } from 'zod';

const MySchema = defineSchema(
'MySchema',
z.object({
field1: z.string(),
field2: z.number(),
})
);
```

Within your prompt, you can provide the name of the registered schema:

```yaml
# myPrompt.prompt
---
model: vertexai/gemini-1.5-flash
output:
schema: MySchema
---
```

The Dotprompt library will automatically resolve the name to the underlying
registered Zod schema. You can then utilize the schema to strongly type the
output of a Dotprompt:

```ts
import { prompt } from "@genkit-ai/dotprompt";

const myPrompt = await prompt("myPrompt");

const result = await myPrompt.generate<typeof MySchema>({...});

// now strongly typed as MySchema
result.output();
```

## Overriding Prompt Metadata

While `.prompt` files allow you to embed metadata such as model configuration in
Expand Down
1 change: 1 addition & 0 deletions js/core/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,5 @@ export * from './action.js';
export * from './config.js';
export { GenkitError } from './error.js';
export * from './flowTypes.js';
export { defineJsonSchema, defineSchema } from './schema.js';
export * from './telemetryTypes.js';
22 changes: 22 additions & 0 deletions js/core/src/registry.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import { FlowStateStore } from './flowTypes.js';
import { logger } from './logging.js';
import { PluginProvider } from './plugin.js';
import { startReflectionApi } from './reflectionApi.js';
import { JSONSchema } from './schema.js';
import { TraceStore } from './tracing/types.js';

export type AsyncProvider<T> = () => Promise<T>;
Expand All @@ -28,6 +29,7 @@ const ACTIONS_BY_ID = 'genkit__ACTIONS_BY_ID';
const TRACE_STORES_BY_ENV = 'genkit__TRACE_STORES_BY_ENV';
const FLOW_STATE_STORES_BY_ENV = 'genkit__FLOW_STATE_STORES_BY_ENV';
const PLUGINS_BY_NAME = 'genkit__PLUGINS_BY_NAME';
const SCHEMAS_BY_NAME = 'genkit__SCHEMAS_BY_NAME';

function actionsById(): Record<string, Action<z.ZodTypeAny, z.ZodTypeAny>> {
if (global[ACTIONS_BY_ID] === undefined) {
Expand All @@ -53,6 +55,15 @@ function pluginsByName(): Record<string, PluginProvider> {
}
return global[PLUGINS_BY_NAME];
}
function schemasByName(): Record<
string,
{ schema?: z.ZodTypeAny; jsonSchema?: JSONSchema }
> {
if (global[SCHEMAS_BY_NAME] === undefined) {
global[SCHEMAS_BY_NAME] = {};
}
return global[SCHEMAS_BY_NAME];
}

/**
* Type of a runnable action.
Expand Down Expand Up @@ -211,6 +222,17 @@ export async function initializePlugin(name: string) {
return undefined;
}

export function registerSchema(
name: string,
data: { schema?: z.ZodTypeAny; jsonSchema?: JSONSchema }
) {
schemasByName()[name] = data;
}

export function lookupSchema(name: string) {
return schemasByName()[name];
}

/**
* Development mode only. Starts a Reflection API so that the actions can be called by the Runner.
*/
Expand Down
14 changes: 14 additions & 0 deletions js/core/src/schema.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import addFormats from 'ajv-formats';
import { z } from 'zod';
import zodToJsonSchema from 'zod-to-json-schema';
import { GenkitError } from './error.js';
import { registerSchema } from './registry.js';
const ajv = new Ajv();
addFormats(ajv);

Expand Down Expand Up @@ -109,3 +110,16 @@ export function parseSchema<T = unknown>(
if (!valid) throw new ValidationError({ data, errors: errors!, schema });
return data as T;
}

export function defineSchema<T extends z.ZodTypeAny>(
name: string,
schema: T
): T {
registerSchema(name, { schema });
return schema;
}

export function defineJsonSchema(name: string, jsonSchema: JSONSchema) {
registerSchema(name, { jsonSchema });
return jsonSchema;
}
31 changes: 24 additions & 7 deletions js/plugins/dotprompt/src/metadata.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import {
ModelArgument,
} from '@genkit-ai/ai/model';
import { ToolArgument } from '@genkit-ai/ai/tool';
import { lookupSchema } from '@genkit-ai/core/registry';
import { JSONSchema, parseSchema, toJsonSchema } from '@genkit-ai/core/schema';
import z from 'zod';
import { picoschema } from './picoschema.js';
Expand Down Expand Up @@ -92,8 +93,8 @@ export const PromptFrontmatterSchema = z.object({
config: GenerationCommonConfigSchema.passthrough().optional(),
input: z
.object({
schema: z.unknown(),
default: z.any(),
schema: z.unknown(),
})
.optional(),
output: z
Expand Down Expand Up @@ -122,21 +123,37 @@ function stripUndefinedOrNull(obj: any) {
return obj;
}

function fmSchemaToSchema(fmSchema: any) {
if (!fmSchema) return {};
if (typeof fmSchema === 'string') return lookupSchema(fmSchema);
return { jsonSchema: picoschema(fmSchema) };
}

export function toMetadata(attributes: unknown): Partial<PromptMetadata> {
const fm = parseSchema<z.infer<typeof PromptFrontmatterSchema>>(attributes, {
schema: PromptFrontmatterSchema,
});

let input: PromptMetadata['input'] | undefined;
if (fm.input) {
input = { default: fm.input.default, ...fmSchemaToSchema(fm.input.schema) };
}

let output: PromptMetadata['output'] | undefined;
if (fm.output) {
output = {
format: fm.output.format,
...fmSchemaToSchema(fm.output.schema),
};
}

return stripUndefinedOrNull({
name: fm.name,
variant: fm.variant,
model: fm.model,
config: fm.config,
input: fm.input
? { default: fm.input.default, jsonSchema: picoschema(fm.input.schema) }
: undefined,
output: fm.output
? { format: fm.output.format, jsonSchema: picoschema(fm.output.schema) }
: undefined,
input,
output,
metadata: fm.metadata,
tools: fm.tools,
candidates: fm.candidates,
Expand Down
20 changes: 11 additions & 9 deletions js/plugins/dotprompt/src/prompt.ts
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ export class Dotprompt<Variables = unknown> implements PromptMetadata {
throw new GenkitError({
source: 'Dotprompt',
status: 'INVALID_ARGUMENT',
message: `Error parsing YAML frontmatter of '${name}' prompt: ${e.message}`,
message: `Error parsing YAML frontmatter of '${name}' prompt: ${e.stack}`,
});
}
}
Expand Down Expand Up @@ -166,9 +166,9 @@ export class Dotprompt<Variables = unknown> implements PromptMetadata {
);
}

private _generateOptions(
private _generateOptions<O extends z.ZodTypeAny = z.ZodTypeAny>(
options: PromptGenerateOptions<Variables>
): GenerateOptions {
): GenerateOptions<z.ZodTypeAny, O> {
const messages = this.renderMessages(options.input, {
history: options.history,
context: options.context,
Expand All @@ -188,17 +188,19 @@ export class Dotprompt<Variables = unknown> implements PromptMetadata {
tools: (options.tools || []).concat(this.tools || []),
streamingCallback: options.streamingCallback,
returnToolRequests: options.returnToolRequests,
};
} as GenerateOptions<z.ZodTypeAny, O>;
}

render(opt: PromptGenerateOptions<Variables>): GenerateOptions {
return this._generateOptions(opt);
render<O extends z.ZodTypeAny = z.ZodTypeAny>(
opt: PromptGenerateOptions<Variables>
): GenerateOptions<z.ZodTypeAny, O> {
return this._generateOptions<O>(opt);
}

async generate(
async generate<O extends z.ZodTypeAny = z.ZodTypeAny>(
opt: PromptGenerateOptions<Variables>
): Promise<GenerateResponse> {
return generate(this.render(opt));
): Promise<GenerateResponse<z.infer<O>>> {
return generate<z.ZodTypeAny, O>(this.render<O>(opt));
}

async generateStream(
Expand Down
19 changes: 19 additions & 0 deletions js/plugins/dotprompt/tests/prompt_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import { defineModel } from '@genkit-ai/ai/model';
import { toJsonSchema, ValidationError } from '@genkit-ai/core/schema';
import z from 'zod';
import { registerPluginProvider } from '../../../core/src/registry.js';
import { defineJsonSchema, defineSchema } from '../../../core/src/schema.js';
import { defineDotprompt, Dotprompt, prompt } from '../src/index.js';
import { PromptMetadata } from '../src/metadata.js';

Expand Down Expand Up @@ -200,6 +201,24 @@ output:
},
});
});

it('should use registered schemas', () => {
const MyInput = defineSchema('MyInput', z.number());
defineJsonSchema('MyOutput', { type: 'boolean' });

const p = Dotprompt.parse(
'example2',
`---
input:
schema: MyInput
output:
schema: MyOutput
---`
);

assert.deepEqual(p.input, { schema: MyInput });
assert.deepEqual(p.output, { jsonSchema: { type: 'boolean' } });
});
});

describe('defineDotprompt', () => {
Expand Down
1 change: 1 addition & 0 deletions js/testapps/prompt-file/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
"@genkit-ai/googleai": "workspace:*",
"zod": "^3.22.4"
},
"main": "lib/index.js",
"scripts": {
"build": "tsc",
"test": "echo \"Error: no test specified\" && exit 1"
Expand Down
7 changes: 1 addition & 6 deletions js/testapps/prompt-file/prompts/recipe.prompt
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,7 @@ input:
schema:
food: string
output:
schema:
title: string, recipe title
ingredients(array):
name: string
quantity: string
steps(array, the steps required to complete the recipe): string
schema: Recipe
---

You are a chef famous for making creative recipes that can be prepared in 45 minutes or less.
Expand Down
27 changes: 24 additions & 3 deletions js/testapps/prompt-file/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
* limitations under the License.
*/

import { configureGenkit } from '@genkit-ai/core';
import { configureGenkit, defineSchema } from '@genkit-ai/core';
import { dotprompt, prompt } from '@genkit-ai/dotprompt';
import { defineFlow } from '@genkit-ai/flow';
import { googleAI } from '@genkit-ai/googleai';
Expand All @@ -26,6 +26,24 @@ configureGenkit({
logLevel: 'debug',
});

/*
title: string, recipe title
ingredients(array):
name: string
quantity: string
steps(array, the steps required to complete the recipe): string
*/
const RecipeSchema = defineSchema(
'Recipe',
z.object({
title: z.string().describe('recipe title'),
ingredients: z.array(z.object({ name: z.string(), quantity: z.string() })),
steps: z
.array(z.string())
.describe('the steps required to complete the recipe'),
})
);

// This example demonstrates using prompt files in a flow
// Load the prompt file during initialization.
// If it fails, due to the prompt file being invalid, the process will crash,
Expand All @@ -38,9 +56,12 @@ prompt('recipe').then((recipePrompt) => {
inputSchema: z.object({
food: z.string(),
}),
outputSchema: z.any(),
outputSchema: RecipeSchema,
},
async (input) => (await recipePrompt.generate({ input: input })).output()
async (input) =>
(
await recipePrompt.generate<typeof RecipeSchema>({ input: input })
).output()!
);
});

Expand Down
Loading