Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[JS] Support multiple imagen models in VertexAI #796

Merged
merged 2 commits into from
Aug 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions docs/models.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,11 @@ import { gemini15Flash } from '@genkit-ai/vertexai';
Genkit provides model support through its plugin system. The following plugins
are officially supported:

| Plugin | Models |
| ------------------------- | ------------------------------------------------------------------------ |
| [Google Generative AI][1] | Gemini Pro, Gemini Pro Vision |
| [Google Vertex AI][2] | Gemini Pro, Gemini Pro Vision, Gemini 1.5 Flash, Gemini 1.5 Pro, Imagen2 |
| [Ollama][3] | Many local models, including Gemma, Llama 2, Mistral, and more |
| Plugin | Models |
| ------------------------- | --------------------------------------------------------------------------------- |
| [Google Generative AI][1] | Gemini Pro, Gemini Pro Vision |
| [Google Vertex AI][2] | Gemini Pro, Gemini Pro Vision, Gemini 1.5 Flash, Gemini 1.5 Pro, Imagen2, Imagen3 |
| [Ollama][3] | Many local models, including Gemma, Llama 2, Mistral, and more |

[1]: plugins/google-genai.md
[2]: plugins/vertex-ai.md
Expand Down
94 changes: 87 additions & 7 deletions js/plugins/vertexai/src/imagen.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import {
GenerationCommonConfigSchema,
getBasicUsageStats,
modelRef,
ModelReference,
} from '@genkit-ai/ai/model';
import { GoogleAuth } from 'google-auth-library';
import z from 'zod';
Expand All @@ -33,18 +34,57 @@ const ImagenConfigSchema = GenerationCommonConfigSchema.extend({
.enum(['auto', 'en', 'es', 'hi', 'ja', 'ko', 'pt', 'zh-TW', 'zh', 'zh-CN'])
.optional(),
/** Desired aspect ratio of output image. */
aspectRatio: z.enum(['1:1', '9:16', '16:9']).optional(),
/** A negative prompt to help generate the images. For example: "animals" (removes animals), "blurry" (makes the image clearer), "text" (removes text), or "cropped" (removes cropped images). */
aspectRatio: z.enum(['1:1', '9:16', '16:9', '3:4', '4:3']).optional(),
/**
* A negative prompt to help generate the images. For example: "animals"
* (removes animals), "blurry" (makes the image clearer), "text" (removes
* text), or "cropped" (removes cropped images).
**/
negativePrompt: z.string().optional(),
/** Any non-negative integer you provide to make output images deterministic. Providing the same seed number always results in the same output images. Accepted integer values: 1 - 2147483647. */
/**
* Any non-negative integer you provide to make output images deterministic.
* Providing the same seed number always results in the same output images.
* Accepted integer values: 1 - 2147483647.
**/
seed: z.number().optional(),
/** Your GCP project's region. e.g.) us-central1, europe-west2, etc. **/
location: z.string().optional(),
/** Allow generation of people by the model. */
personGeneration: z
.enum(['dont_allow', 'allow_adult', 'allow_all'])
.optional(),
/** Adds a filter level to safety filtering. */
safetySetting: z
.enum(['block_most', 'block_some', 'block_few', 'block_fewest'])
.optional(),
/** Add an invisible watermark to the generated images. */
addWatermark: z.boolean().optional(),
/** Cloud Storage URI to store the generated images. **/
storageUri: z.string().optional(),
});

export const imagen2 = modelRef({
name: 'vertexai/imagen2',
info: {
label: 'Vertex AI - Imagen2',
versions: ['imagegeneration@006', 'imagegeneration@005'],
supports: {
media: false,
multiturn: false,
tools: false,
systemRole: false,
output: ['media'],
},
},
version: 'imagegeneration@006',
configSchema: ImagenConfigSchema,
});

export const imagen3 = modelRef({
name: 'vertexai/imagen3',
info: {
label: 'Vertex AI - Imagen3',
versions: ['imagen-3.0-generate-001'],
supports: {
media: false,
multiturn: false,
Expand All @@ -53,9 +93,33 @@ export const imagen2 = modelRef({
output: ['media'],
},
},
version: 'imagen-3.0-generate-001',
configSchema: ImagenConfigSchema,
});

export const imagen3Fast = modelRef({
name: 'vertexai/imagen3-fast',
info: {
label: 'Vertex AI - Imagen3 Fast',
versions: ['imagen-3.0-fast-generate-001'],
supports: {
media: false,
multiturn: false,
tools: false,
systemRole: false,
output: ['media'],
},
},
version: 'imagen-3.0-fast-generate-001',
configSchema: ImagenConfigSchema,
});

export const SUPPORTED_IMAGEN_MODELS = {
imagen2: imagen2,
imagen3: imagen3,
'imagen3-fast': imagen3Fast,
};

function extractText(request: GenerateRequest) {
return request.messages
.at(-1)!
Expand All @@ -69,6 +133,10 @@ interface ImagenParameters {
negativePrompt?: string;
seed?: number;
language?: string;
personGeneration?: string;
safetySetting?: string;
addWatermark?: boolean;
storageUri?: string;
}

function toParameters(
Expand All @@ -80,6 +148,10 @@ function toParameters(
negativePrompt: request.config?.negativePrompt,
seed: request.config?.seed,
language: request.config?.language,
personGeneration: request.config?.personGeneration,
safetySetting: request.config?.safetySetting,
addWatermark: request.config?.addWatermark,
storageUri: request.config?.storageUri,
};

for (const k in out) {
Expand All @@ -106,7 +178,15 @@ interface ImagenInstance {
image?: { bytesBase64Encoded: string };
}

export function imagen2Model(client: GoogleAuth, options: PluginOptions) {
export function imagenModel(
name: string,
client: GoogleAuth,
options: PluginOptions
) {
const modelName = `vertexai/${name}`;
const model: ModelReference<z.ZodTypeAny> = SUPPORTED_IMAGEN_MODELS[name];
if (!model) throw new Error(`Unsupported model: ${name}`);

const predictClients: Record<
string,
PredictClient<ImagenInstance, ImagenPrediction, ImagenParameters>
Expand All @@ -126,16 +206,16 @@ export function imagen2Model(client: GoogleAuth, options: PluginOptions) {
...options,
location: requestLocation,
},
'imagegeneration@005'
request.config?.version || model.version || name
);
}
return predictClients[requestLocation];
};

return defineModel(
{
name: imagen2.name,
...imagen2.info,
name: modelName,
...model.info,
configSchema: ImagenConfigSchema,
},
async (request) => {
Expand Down
14 changes: 12 additions & 2 deletions js/plugins/vertexai/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,13 @@ import {
geminiPro,
geminiProVision,
} from './gemini.js';
import { imagen2, imagen2Model } from './imagen.js';
import {
SUPPORTED_IMAGEN_MODELS,
imagen2,
imagen3,
imagen3Fast,
imagenModel,
} from './imagen.js';
import {
SUPPORTED_OPENAI_FORMAT_MODELS,
llama3,
Expand Down Expand Up @@ -94,6 +100,8 @@ export {
geminiPro,
geminiProVision,
imagen2,
imagen3,
imagen3Fast,
llama3,
llama31,
textEmbedding004,
Expand Down Expand Up @@ -175,7 +183,9 @@ export const vertexAI: Plugin<[PluginOptions] | []> = genkitPlugin(
: [];

const models = [
imagen2Model(authClient, { projectId, location }),
...Object.keys(SUPPORTED_IMAGEN_MODELS).map((name) =>
imagenModel(name, authClient, { projectId, location })
),
...Object.keys(SUPPORTED_GEMINI_MODELS).map((name) =>
geminiModel(name, vertexClientFactory, { projectId, location })
),
Expand Down
Loading