From fc5bde003292c16ab16b1e21ab16bdb8b5b1a9ad Mon Sep 17 00:00:00 2001 From: Michael Doyle Date: Fri, 16 Aug 2024 14:49:10 -0400 Subject: [PATCH 1/2] Support multiple imagen models in VertexAI --- docs/models.md | 10 ++--- js/plugins/vertexai/src/imagen.ts | 63 +++++++++++++++++++++++++++++-- js/plugins/vertexai/src/index.ts | 14 ++++++- 3 files changed, 76 insertions(+), 11 deletions(-) diff --git a/docs/models.md b/docs/models.md index d6962b17d..75fcf6372 100644 --- a/docs/models.md +++ b/docs/models.md @@ -43,11 +43,11 @@ import { gemini15Flash } from '@genkit-ai/vertexai'; Genkit provides model support through its plugin system. The following plugins are officially supported: -| Plugin | Models | -| ------------------------- | ------------------------------------------------------------------------ | -| [Google Generative AI][1] | Gemini Pro, Gemini Pro Vision | -| [Google Vertex AI][2] | Gemini Pro, Gemini Pro Vision, Gemini 1.5 Flash, Gemini 1.5 Pro, Imagen2 | -| [Ollama][3] | Many local models, including Gemma, Llama 2, Mistral, and more | +| Plugin | Models | +| ------------------------- | --------------------------------------------------------------------------------- | +| [Google Generative AI][1] | Gemini Pro, Gemini Pro Vision | +| [Google Vertex AI][2] | Gemini Pro, Gemini Pro Vision, Gemini 1.5 Flash, Gemini 1.5 Pro, Imagen2, Imagen3 | +| [Ollama][3] | Many local models, including Gemma, Llama 2, Mistral, and more | [1]: plugins/google-genai.md [2]: plugins/vertex-ai.md diff --git a/js/plugins/vertexai/src/imagen.ts b/js/plugins/vertexai/src/imagen.ts index 3739ff4d5..134017464 100644 --- a/js/plugins/vertexai/src/imagen.ts +++ b/js/plugins/vertexai/src/imagen.ts @@ -21,6 +21,7 @@ import { GenerationCommonConfigSchema, getBasicUsageStats, modelRef, + ModelReference, } from '@genkit-ai/ai/model'; import { GoogleAuth } from 'google-auth-library'; import z from 'zod'; @@ -45,6 +46,11 @@ export const imagen2 = modelRef({ name: 'vertexai/imagen2', info: { label: 'Vertex AI - Imagen2', + versions: [ + 'imagegeneration@006', + 'imagegeneration@005', + 'imagegeneration@002', + ], supports: { media: false, multiturn: false, @@ -53,9 +59,50 @@ export const imagen2 = modelRef({ output: ['media'], }, }, + version: 'imagegeneration@006', configSchema: ImagenConfigSchema, }); +export const imagen3 = modelRef({ + name: 'vertexai/imagen3', + info: { + label: 'Vertex AI - Imagen3', + versions: ['imagen-3.0-generate-001'], + supports: { + media: false, + multiturn: false, + tools: false, + systemRole: false, + output: ['media'], + }, + }, + version: 'imagen-3.0-generate-001', + configSchema: ImagenConfigSchema, +}); + +export const imagen3Fast = modelRef({ + name: 'vertexai/imagen3-fast', + info: { + label: 'Vertex AI - Imagen3 Fast', + versions: ['imagen-3.0-fast-generate-001'], + supports: { + media: false, + multiturn: false, + tools: false, + systemRole: false, + output: ['media'], + }, + }, + version: 'imagen-3.0-fast-generate-001', + configSchema: ImagenConfigSchema, +}); + +export const SUPPORTED_IMAGEN_MODELS = { + imagen2: imagen2, + imagen3: imagen3, + 'imagen3-fast': imagen3Fast, +}; + function extractText(request: GenerateRequest) { return request.messages .at(-1)! @@ -106,7 +153,15 @@ interface ImagenInstance { image?: { bytesBase64Encoded: string }; } -export function imagen2Model(client: GoogleAuth, options: PluginOptions) { +export function imagenModel( + name: string, + client: GoogleAuth, + options: PluginOptions +) { + const modelName = `vertexai/${name}`; + const model: ModelReference = SUPPORTED_IMAGEN_MODELS[name]; + if (!model) throw new Error(`Unsupported model: ${name}`); + const predictClients: Record< string, PredictClient @@ -126,7 +181,7 @@ export function imagen2Model(client: GoogleAuth, options: PluginOptions) { ...options, location: requestLocation, }, - 'imagegeneration@005' + request.config?.version || model.version || name ); } return predictClients[requestLocation]; @@ -134,8 +189,8 @@ export function imagen2Model(client: GoogleAuth, options: PluginOptions) { return defineModel( { - name: imagen2.name, - ...imagen2.info, + name: modelName, + ...model.info, configSchema: ImagenConfigSchema, }, async (request) => { diff --git a/js/plugins/vertexai/src/index.ts b/js/plugins/vertexai/src/index.ts index 896bec00c..2231e394a 100644 --- a/js/plugins/vertexai/src/index.ts +++ b/js/plugins/vertexai/src/index.ts @@ -55,7 +55,13 @@ import { geminiPro, geminiProVision, } from './gemini.js'; -import { imagen2, imagen2Model } from './imagen.js'; +import { + SUPPORTED_IMAGEN_MODELS, + imagen2, + imagen3, + imagen3Fast, + imagenModel, +} from './imagen.js'; import { SUPPORTED_OPENAI_FORMAT_MODELS, llama3, @@ -94,6 +100,8 @@ export { geminiPro, geminiProVision, imagen2, + imagen3, + imagen3Fast, llama3, llama31, textEmbedding004, @@ -175,7 +183,9 @@ export const vertexAI: Plugin<[PluginOptions] | []> = genkitPlugin( : []; const models = [ - imagen2Model(authClient, { projectId, location }), + ...Object.keys(SUPPORTED_IMAGEN_MODELS).map((name) => + imagenModel(name, authClient, { projectId, location }) + ), ...Object.keys(SUPPORTED_GEMINI_MODELS).map((name) => geminiModel(name, vertexClientFactory, { projectId, location }) ), From 88e77b9af1212a01847ccfda43a54cb2ea672d8a Mon Sep 17 00:00:00 2001 From: Michael Doyle Date: Fri, 16 Aug 2024 16:43:25 -0400 Subject: [PATCH 2/2] Additional schema params --- js/plugins/vertexai/src/imagen.ts | 41 +++++++++++++++++++++++++------ 1 file changed, 33 insertions(+), 8 deletions(-) diff --git a/js/plugins/vertexai/src/imagen.ts b/js/plugins/vertexai/src/imagen.ts index 134017464..8c9e41c56 100644 --- a/js/plugins/vertexai/src/imagen.ts +++ b/js/plugins/vertexai/src/imagen.ts @@ -34,23 +34,40 @@ const ImagenConfigSchema = GenerationCommonConfigSchema.extend({ .enum(['auto', 'en', 'es', 'hi', 'ja', 'ko', 'pt', 'zh-TW', 'zh', 'zh-CN']) .optional(), /** Desired aspect ratio of output image. */ - aspectRatio: z.enum(['1:1', '9:16', '16:9']).optional(), - /** A negative prompt to help generate the images. For example: "animals" (removes animals), "blurry" (makes the image clearer), "text" (removes text), or "cropped" (removes cropped images). */ + aspectRatio: z.enum(['1:1', '9:16', '16:9', '3:4', '4:3']).optional(), + /** + * A negative prompt to help generate the images. For example: "animals" + * (removes animals), "blurry" (makes the image clearer), "text" (removes + * text), or "cropped" (removes cropped images). + **/ negativePrompt: z.string().optional(), - /** Any non-negative integer you provide to make output images deterministic. Providing the same seed number always results in the same output images. Accepted integer values: 1 - 2147483647. */ + /** + * Any non-negative integer you provide to make output images deterministic. + * Providing the same seed number always results in the same output images. + * Accepted integer values: 1 - 2147483647. + **/ seed: z.number().optional(), + /** Your GCP project's region. e.g.) us-central1, europe-west2, etc. **/ location: z.string().optional(), + /** Allow generation of people by the model. */ + personGeneration: z + .enum(['dont_allow', 'allow_adult', 'allow_all']) + .optional(), + /** Adds a filter level to safety filtering. */ + safetySetting: z + .enum(['block_most', 'block_some', 'block_few', 'block_fewest']) + .optional(), + /** Add an invisible watermark to the generated images. */ + addWatermark: z.boolean().optional(), + /** Cloud Storage URI to store the generated images. **/ + storageUri: z.string().optional(), }); export const imagen2 = modelRef({ name: 'vertexai/imagen2', info: { label: 'Vertex AI - Imagen2', - versions: [ - 'imagegeneration@006', - 'imagegeneration@005', - 'imagegeneration@002', - ], + versions: ['imagegeneration@006', 'imagegeneration@005'], supports: { media: false, multiturn: false, @@ -116,6 +133,10 @@ interface ImagenParameters { negativePrompt?: string; seed?: number; language?: string; + personGeneration?: string; + safetySetting?: string; + addWatermark?: boolean; + storageUri?: string; } function toParameters( @@ -127,6 +148,10 @@ function toParameters( negativePrompt: request.config?.negativePrompt, seed: request.config?.seed, language: request.config?.language, + personGeneration: request.config?.personGeneration, + safetySetting: request.config?.safetySetting, + addWatermark: request.config?.addWatermark, + storageUri: request.config?.storageUri, }; for (const k in out) {