Skip to content

Commit

Permalink
feat(google-genai): Support Gemini system instructions (langchain-ai#…
Browse files Browse the repository at this point in the history
…7235)

Co-authored-by: Gary Chen <thegary.chen@mail.utoronto.ca>
Co-authored-by: martinl498 <martinloo498@gmail.com>
Co-authored-by: Shannon Budiman <shannon.budiman@mail.utoronto.ca>
Co-authored-by: Jacob Lee <jacoblee93@gmail.com>
  • Loading branch information
5 people authored and syntaxsec committed Dec 13, 2024
1 parent 929f0ee commit 7bbfa1c
Show file tree
Hide file tree
Showing 3 changed files with 239 additions and 7 deletions.
56 changes: 52 additions & 4 deletions libs/langchain-google-genai/src/chat_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,15 @@ export interface GoogleGenerativeAIChatInput
* @default false
*/
json?: boolean;

/**
* Whether or not model supports system instructions.
* The following models support system instructions:
* - All Gemini 1.5 Pro model versions
* - All Gemini 1.5 Flash model versions
* - Gemini 1.0 Pro version gemini-1.0-pro-002
*/
convertSystemMessageToHumanContent?: boolean | undefined;
}

/**
Expand Down Expand Up @@ -563,6 +572,8 @@ export class ChatGoogleGenerativeAI

streamUsage = true;

convertSystemMessageToHumanContent: boolean | undefined;

private client: GenerativeModel;

get _isMultimodalModel() {
Expand Down Expand Up @@ -651,6 +662,29 @@ export class ChatGoogleGenerativeAI
this.streamUsage = fields?.streamUsage ?? this.streamUsage;
}

get useSystemInstruction(): boolean {
return typeof this.convertSystemMessageToHumanContent === "boolean"
? !this.convertSystemMessageToHumanContent
: this.computeUseSystemInstruction;
}

get computeUseSystemInstruction(): boolean {
// This works on models from April 2024 and later
// Vertex AI: gemini-1.5-pro and gemini-1.0-002 and later
// AI Studio: gemini-1.5-pro-latest
if (this.modelName === "gemini-1.0-pro-001") {
return false;
} else if (this.modelName.startsWith("gemini-pro-vision")) {
return false;
} else if (this.modelName.startsWith("gemini-1.0-pro-vision")) {
return false;
} else if (this.modelName === "gemini-pro") {
// on AI Studio gemini-pro is still pointing at gemini-1.0-pro-001
return false;
}
return true;
}

getLsParams(options: this["ParsedCallOptions"]): LangSmithParams {
return {
ls_provider: "google_genai",
Expand Down Expand Up @@ -706,8 +740,15 @@ export class ChatGoogleGenerativeAI
): Promise<ChatResult> {
const prompt = convertBaseMessagesToContent(
messages,
this._isMultimodalModel
this._isMultimodalModel,
this.useSystemInstruction
);
let actualPrompt = prompt;
if (prompt[0].role === "system") {
const [systemInstruction] = prompt;
this.client.systemInstruction = systemInstruction;
actualPrompt = prompt.slice(1);
}
const parameters = this.invocationParams(options);

// Handle streaming
Expand All @@ -734,7 +775,7 @@ export class ChatGoogleGenerativeAI

const res = await this.completionWithRetry({
...parameters,
contents: prompt,
contents: actualPrompt,
});

let usageMetadata: UsageMetadata | undefined;
Expand Down Expand Up @@ -770,12 +811,19 @@ export class ChatGoogleGenerativeAI
): AsyncGenerator<ChatGenerationChunk> {
const prompt = convertBaseMessagesToContent(
messages,
this._isMultimodalModel
this._isMultimodalModel,
this.useSystemInstruction
);
let actualPrompt = prompt;
if (prompt[0].role === "system") {
const [systemInstruction] = prompt;
this.client.systemInstruction = systemInstruction;
actualPrompt = prompt.slice(1);
}
const parameters = this.invocationParams(options);
const request = {
...parameters,
contents: prompt,
contents: actualPrompt,
};
const stream = await this.caller.callWithOptions(
{ signal: options?.signal },
Expand Down
178 changes: 178 additions & 0 deletions libs/langchain-google-genai/src/tests/chat_models.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -253,3 +253,181 @@ test("convertBaseMessagesToContent correctly creates properly formatted content"
},
]);
});

test("Input has single system message followed by one user message, convert system message is false", async () => {
const messages = [
new SystemMessage("You are a helpful assistant"),
new HumanMessage("What's the weather like in new york?"),
];
const messagesAsGoogleContent = convertBaseMessagesToContent(
messages,
false,
false
);

expect(messagesAsGoogleContent).toEqual([
{
role: "user",
parts: [
{ text: "You are a helpful assistant" },
{ text: "What's the weather like in new york?" },
],
},
]);
});

test("Input has a system message that is not the first message, convert system message is false", async () => {
const messages = [
new HumanMessage("What's the weather like in new york?"),
new SystemMessage("You are a helpful assistant"),
];
expect(() => {
convertBaseMessagesToContent(messages, false, false);
}).toThrow("System message should be the first one");
});

test("Input has multiple system messages, convert system message is false", async () => {
const messages = [
new SystemMessage("You are a helpful assistant"),
new SystemMessage("You are not a helpful assistant"),
];
expect(() => {
convertBaseMessagesToContent(messages, false, false);
}).toThrow("System message should be the first one");
});

test("Input has no system message and one user message, convert system message is false", async () => {
const messages = [new HumanMessage("What's the weather like in new york?")];
const messagesAsGoogleContent = convertBaseMessagesToContent(
messages,
false,
false
);

expect(messagesAsGoogleContent).toEqual([
{
role: "user",
parts: [{ text: "What's the weather like in new york?" }],
},
]);
});

test("Input has no system message and multiple user message, convert system message is false", async () => {
const messages = [
new HumanMessage("What's the weather like in new york?"),
new HumanMessage("What's the weather like in toronto?"),
new HumanMessage("What's the weather like in los angeles?"),
];
const messagesAsGoogleContent = convertBaseMessagesToContent(
messages,
false,
false
);

expect(messagesAsGoogleContent).toEqual([
{
role: "user",
parts: [{ text: "What's the weather like in new york?" }],
},
{
role: "user",
parts: [{ text: "What's the weather like in toronto?" }],
},
{
role: "user",
parts: [{ text: "What's the weather like in los angeles?" }],
},
]);
});

test("Input has single system message followed by one user message, convert system message is true", async () => {
const messages = [
new SystemMessage("You are a helpful assistant"),
new HumanMessage("What's the weather like in new york?"),
];

const messagesAsGoogleContent = convertBaseMessagesToContent(
messages,
false,
true
);

expect(messagesAsGoogleContent).toEqual([
{
role: "system",
parts: [{ text: "You are a helpful assistant" }],
},
{
role: "user",
parts: [{ text: "What's the weather like in new york?" }],
},
]);
});

test("Input has single system message that is not the first message, convert system message is true", async () => {
const messages = [
new HumanMessage("What's the weather like in new york?"),
new SystemMessage("You are a helpful assistant"),
];

expect(() => convertBaseMessagesToContent(messages, false, true)).toThrow(
"System message should be the first one"
);
});

test("Input has multiple system message, convert system message is true", async () => {
const messages = [
new SystemMessage("What's the weather like in new york?"),
new SystemMessage("You are a helpful assistant"),
];

expect(() => convertBaseMessagesToContent(messages, false, true)).toThrow(
"System message should be the first one"
);
});

test("Input has no system message and one user message, convert system message is true", async () => {
const messages = [new HumanMessage("What's the weather like in new york?")];

const messagesAsGoogleContent = convertBaseMessagesToContent(
messages,
false,
true
);

expect(messagesAsGoogleContent).toEqual([
{
role: "user",
parts: [{ text: "What's the weather like in new york?" }],
},
]);
});

test("Input has no system message and multiple user messages, convert system message is true", async () => {
const messages = [
new HumanMessage("What's the weather like in new york?"),
new HumanMessage("Will it rain today?"),
new HumanMessage("How about next week?"),
];

const messagesAsGoogleContent = convertBaseMessagesToContent(
messages,
false,
true
);

expect(messagesAsGoogleContent).toEqual([
{
role: "user",
parts: [{ text: "What's the weather like in new york?" }],
},
{
role: "user",
parts: [{ text: "Will it rain today?" }],
},
{
role: "user",
parts: [{ text: "How about next week?" }],
},
]);
});
12 changes: 9 additions & 3 deletions libs/langchain-google-genai/src/utils/common.ts
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ export function convertAuthorToRole(
case "model": // getMessageAuthor returns message.name. code ex.: return message.name ?? type;
return "model";
case "system":
return "system";
case "human":
return "user";
case "tool":
Expand Down Expand Up @@ -179,7 +180,8 @@ export function convertMessageContentToParts(

export function convertBaseMessagesToContent(
messages: BaseMessage[],
isMultimodalModel: boolean
isMultimodalModel: boolean,
convertSystemMessageToHumanContent: boolean = false
) {
return messages.reduce<{
content: Content[];
Expand Down Expand Up @@ -223,7 +225,10 @@ export function convertBaseMessagesToContent(
};
}
let actualRole = role;
if (actualRole === "function") {
if (
actualRole === "function" ||
(actualRole === "system" && !convertSystemMessageToHumanContent)
) {
// GenerativeAI API will throw an error if the role is not "user" or "model."
actualRole = "user";
}
Expand All @@ -232,7 +237,8 @@ export function convertBaseMessagesToContent(
parts,
};
return {
mergeWithPreviousContent: author === "system",
mergeWithPreviousContent:
author === "system" && !convertSystemMessageToHumanContent,
content: [...acc.content, content],
};
},
Expand Down

0 comments on commit 7bbfa1c

Please sign in to comment.