Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(js/plugins/vertexai): instruduced gemini model ref helper and ability to register versions #1668

Merged
merged 7 commits into from
Jan 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions js/plugins/vertexai/src/common/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,19 @@ function parseFirebaseProjectId(): string | undefined {
}
}

/** @hidden */
export function __setFakeDerivedParams(params: any) {
__fake_getDerivedParams = params;
}
let __fake_getDerivedParams: any;

export async function getDerivedParams(
options?: PluginOptions
): Promise<DerivedParams> {
if (__fake_getDerivedParams) {
return __fake_getDerivedParams;
}

let authOptions = options?.googleAuth;
let authClient: GoogleAuth;
const providedProjectId =
Expand Down
9 changes: 8 additions & 1 deletion js/plugins/vertexai/src/common/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
* limitations under the License.
*/

import { ModelReference } from 'genkit';
import { GoogleAuthOptions } from 'google-auth-library';
import { GeminiConfigSchema } from '../gemini';

/** Common options for Vertex AI plugin configuration */
export interface CommonPluginOptions {
Expand All @@ -27,4 +29,9 @@ export interface CommonPluginOptions {
}

/** Combined plugin options, extending common options with subplugin-specific options */
export interface PluginOptions extends CommonPluginOptions {}
export interface PluginOptions extends CommonPluginOptions {
models?: (
| ModelReference</** @ignore */ typeof GeminiConfigSchema>
| string
)[];
}
130 changes: 115 additions & 15 deletions js/plugins/vertexai/src/gemini.ts
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ import {
MediaPart,
MessageData,
ModelAction,
ModelInfo,
ModelMiddleware,
ModelReference,
Part,
Expand Down Expand Up @@ -166,6 +167,69 @@ export const GeminiConfigSchema = GenerationCommonConfigSchema.extend({
.optional(),
});

/**
* Known model names, to allow code completion for convenience. Allows other model names.
*/
export type GeminiVersionString =
| keyof typeof SUPPORTED_GEMINI_MODELS
| (string & {});

/**
* Returns a reference to a model that can be used in generate calls.
*
* ```js
* await ai.generate({
* prompt: 'hi',
* model: gemini('gemini-1.5-flash')
* });
* ```
*/
export function gemini(
version: GeminiVersionString,
options: GeminiConfig = {}
): ModelReference<typeof GeminiConfigSchema> {
const nearestModel = nearestGeminiModelRef(version);
return modelRef({
name: `vertexai/${version}`,
config: options,
configSchema: GeminiConfigSchema,
info: {
...nearestModel.info,
// If exact suffix match for a known model, use its label, otherwise create a new label
label: nearestModel.name.endsWith(version)
? nearestModel.info?.label
: `Vertex AI - ${version}`,
},
});
}

function nearestGeminiModelRef(
version: GeminiVersionString,
options: GeminiConfig = {}
): ModelReference<typeof GeminiConfigSchema> {
const matchingKey = longestMatchingPrefix(
version,
Object.keys(SUPPORTED_GEMINI_MODELS)
);
if (matchingKey) {
return SUPPORTED_GEMINI_MODELS[matchingKey].withConfig({
...options,
version,
});
}
return GENERIC_GEMINI_MODEL.withConfig({ ...options, version });
}

function longestMatchingPrefix(version: string, potentialMatches: string[]) {
return potentialMatches
.filter((p) => version.startsWith(p))
.reduce(
(longest, current) =>
current.length > longest.length ? current : longest,
''
);
}

/**
* Gemini model configuration options.
*
Expand Down Expand Up @@ -268,6 +332,21 @@ export const gemini20FlashExp = modelRef({
configSchema: GeminiConfigSchema,
});

export const GENERIC_GEMINI_MODEL = modelRef({
name: 'vertexai/gemini',
configSchema: GeminiConfigSchema,
info: {
label: 'Google Gemini',
supports: {
multiturn: true,
media: true,
tools: true,
toolChoice: true,
systemRole: true,
},
},
});

export const SUPPORTED_V1_MODELS = {
'gemini-1.0-pro': gemini10Pro,
};
Expand All @@ -281,19 +360,19 @@ export const SUPPORTED_V15_MODELS = {
export const SUPPORTED_GEMINI_MODELS = {
...SUPPORTED_V1_MODELS,
...SUPPORTED_V15_MODELS,
};
} as const;

function toGeminiRole(
role: MessageData['role'],
model?: ModelReference<z.ZodTypeAny>
modelInfo?: ModelInfo
): string {
switch (role) {
case 'user':
return 'user';
case 'model':
return 'model';
case 'system':
if (model && SUPPORTED_V15_MODELS[model.name]) {
if (modelInfo && modelInfo.supports?.systemRole) {
// We should have already pulled out the supported system messages,
// anything remaining is unsupported; throw an error.
throw new Error(
Expand Down Expand Up @@ -387,10 +466,10 @@ export function toGeminiSystemInstruction(message: MessageData): Content {

export function toGeminiMessage(
message: MessageData,
model?: ModelReference<z.ZodTypeAny>
modelInfo?: ModelInfo
): Content {
return {
role: toGeminiRole(message.role, model),
role: toGeminiRole(message.role, modelInfo),
parts: message.content.map(toGeminiPart),
};
}
Expand Down Expand Up @@ -581,7 +660,7 @@ export function cleanSchema(schema: JSONSchema): JSONSchema {
/**
* Define a Vertex AI Gemini model.
*/
export function defineGeminiModel(
export function defineGeminiKnownModel(
ai: Genkit,
name: string,
vertexClientFactory: (
Expand All @@ -594,19 +673,42 @@ export function defineGeminiModel(
const model: ModelReference<z.ZodTypeAny> = SUPPORTED_GEMINI_MODELS[name];
if (!model) throw new Error(`Unsupported model: ${name}`);

return defineGeminiModel(
ai,
modelName,
name,
model?.info,
vertexClientFactory,
options
);
}

/**
* Define a Vertex AI Gemini model.
*/
export function defineGeminiModel(
ai: Genkit,
modelName: string,
version: string,
modelInfo: ModelInfo | undefined,
vertexClientFactory: (
request: GenerateRequest<typeof GeminiConfigSchema>
) => VertexAI,
options: PluginOptions
): ModelAction {
const middlewares: ModelMiddleware[] = [];
if (SUPPORTED_V1_MODELS[name]) {
if (SUPPORTED_V1_MODELS[version]) {
middlewares.push(simulateSystemPrompt());
}
if (model?.info?.supports?.media) {
if (modelInfo?.supports?.media) {
// the gemini api doesn't support downloading media from http(s)
middlewares.push(downloadRequestMedia({ maxBytes: 1024 * 1024 * 20 }));
}

return ai.defineModel(
{
name: modelName,
...model.info,
...modelInfo,
configSchema: GeminiConfigSchema,
use: middlewares,
},
Expand All @@ -619,7 +721,7 @@ export function defineGeminiModel(

// Handle system instructions separately
let systemInstruction: Content | undefined = undefined;
if (SUPPORTED_V15_MODELS[name]) {
if (!SUPPORTED_V1_MODELS[version]) {
const systemMessage = messages.find((m) => m.role === 'system');
if (systemMessage) {
messages.splice(messages.indexOf(systemMessage), 1);
Expand Down Expand Up @@ -659,7 +761,7 @@ export function defineGeminiModel(
toolConfig,
history: messages
.slice(0, -1)
.map((message) => toGeminiMessage(message, model)),
.map((message) => toGeminiMessage(message, modelInfo)),
generationConfig: {
candidateCount: request.candidates || undefined,
temperature: request.config?.temperature,
Expand All @@ -673,9 +775,7 @@ export function defineGeminiModel(
};

// Handle cache
const modelVersion = (request.config?.version ||
model.version ||
name) as string;
const modelVersion = (request.config?.version || version) as string;
const cacheConfigDetails = extractCacheConfig(request);

const apiClient = new ApiClient(
Expand Down Expand Up @@ -727,7 +827,7 @@ export function defineGeminiModel(
});
}

const msg = toGeminiMessage(messages[messages.length - 1], model);
const msg = toGeminiMessage(messages[messages.length - 1], modelInfo);

if (cache) {
genModel = vertex.preview.getGenerativeModelFromCachedContent(
Expand Down
30 changes: 29 additions & 1 deletion js/plugins/vertexai/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@ import {
} from './embedder.js';
import {
SUPPORTED_GEMINI_MODELS,
defineGeminiKnownModel,
defineGeminiModel,
gemini,
gemini10Pro,
gemini15Flash,
gemini15Pro,
Expand All @@ -51,6 +53,7 @@ import {
} from './imagen.js';
export { type PluginOptions } from './common/types.js';
export {
gemini,
gemini10Pro,
gemini15Flash,
gemini15Pro,
Expand Down Expand Up @@ -78,8 +81,33 @@ export function vertexAI(options?: PluginOptions): GenkitPlugin {
imagenModel(ai, name, authClient, { projectId, location })
);
Object.keys(SUPPORTED_GEMINI_MODELS).map((name) =>
defineGeminiModel(ai, name, vertexClientFactory, { projectId, location })
defineGeminiKnownModel(ai, name, vertexClientFactory, {
projectId,
location,
})
);
if (options?.models) {
for (const modelOrRef of options?.models) {
const modelName =
typeof modelOrRef === 'string'
? modelOrRef
: // strip out the `vertexai/` prefix
modelOrRef.name.split('/')[1];
const modelRef =
typeof modelOrRef === 'string' ? gemini(modelOrRef) : modelOrRef;
defineGeminiModel(
ai,
modelRef.name,
modelName,
modelRef.info,
vertexClientFactory,
{
projectId,
location,
}
);
}
}

Object.keys(SUPPORTED_EMBEDDER_MODELS).map((name) =>
defineVertexAIEmbedder(ai, name, authClient, { projectId, location })
Expand Down
2 changes: 2 additions & 0 deletions js/plugins/vertexai/src/modelgarden/mistral.ts
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@ function toMistralRole(role: Role): MistralRole {
return 'tool';
case 'system':
return 'system';
default:
throw new Error(`Unknwon role ${role}`);
}
}
function toMistralToolRequest(toolRequest: Record<string, any>): FunctionCall {
Expand Down
Loading
Loading