From 4e5ab147170b6a23d107e90af29f486c118b0c34 Mon Sep 17 00:00:00 2001 From: Joe McIlvain Date: Wed, 27 Nov 2024 08:59:01 -0800 Subject: [PATCH] fix: system prompt formatting for Vertex AI Prior to this commit, system prompt messages were included by the `KurtVertexAI` as part of the `content` stream of messages. However, VertexAI requires system messages to be specified separately. This commit makes that change. --- .../spec/generateNaturalLanguage.spec.ts | 10 +++++ ...formats_a_system_prompt_for_Vertex_AI.yaml | 43 +++++++++++++++++++ packages/kurt-vertex-ai/src/KurtVertexAI.ts | 15 ++++++- 3 files changed, 67 insertions(+), 1 deletion(-) create mode 100644 packages/kurt-vertex-ai/spec/snapshots/KurtVertexAI_generateNaturalLanguage_properly_formats_a_system_prompt_for_Vertex_AI.yaml diff --git a/packages/kurt-vertex-ai/spec/generateNaturalLanguage.spec.ts b/packages/kurt-vertex-ai/spec/generateNaturalLanguage.spec.ts index eaa6fd8..1864734 100644 --- a/packages/kurt-vertex-ai/spec/generateNaturalLanguage.spec.ts +++ b/packages/kurt-vertex-ai/spec/generateNaturalLanguage.spec.ts @@ -12,6 +12,16 @@ describe("KurtVertexAI generateNaturalLanguage", () => { expect(result.text).toEqual("Hello! 👋 😊\n") }) + test("properly formats a system prompt for Vertex AI", async () => { + const result = await snapshotAndMock((kurt) => + kurt.generateNaturalLanguage({ + systemPrompt: "Don't be evil.", // sometimes Google needs to remind themselves + prompt: "Say hello!", + }) + ) + expect(result.text).toEqual("Hello! 👋 😊\n") + }) + test("writes a haiku with high temperature", async () => { const result = await snapshotAndMock((kurt) => kurt.generateNaturalLanguage({ diff --git a/packages/kurt-vertex-ai/spec/snapshots/KurtVertexAI_generateNaturalLanguage_properly_formats_a_system_prompt_for_Vertex_AI.yaml b/packages/kurt-vertex-ai/spec/snapshots/KurtVertexAI_generateNaturalLanguage_properly_formats_a_system_prompt_for_Vertex_AI.yaml new file mode 100644 index 0000000..4d0a92e --- /dev/null +++ b/packages/kurt-vertex-ai/spec/snapshots/KurtVertexAI_generateNaturalLanguage_properly_formats_a_system_prompt_for_Vertex_AI.yaml @@ -0,0 +1,43 @@ +step1Request: + generationConfig: + maxOutputTokens: 4096 + temperature: 0.5 + topP: 0.95 + contents: + - role: user + parts: + - text: Say hello! + systemInstruction: + role: system + parts: + - text: Don't be evil. +step2RawChunks: + - candidates: + - content: + role: model + parts: + - text: Hello + index: 0 + usageMetadata: {} + - candidates: + - content: + role: model + parts: + - text: | + ! 👋 😊 + finishReason: STOP + index: 0 + usageMetadata: + promptTokenCount: 9 + candidatesTokenCount: 6 + totalTokenCount: 15 +step3KurtEvents: + - chunk: Hello + - chunk: | + ! 👋 😊 + - finished: true + text: | + Hello! 👋 😊 + metadata: + totalInputTokens: 9 + totalOutputTokens: 6 diff --git a/packages/kurt-vertex-ai/src/KurtVertexAI.ts b/packages/kurt-vertex-ai/src/KurtVertexAI.ts index 54ec727..523fce2 100644 --- a/packages/kurt-vertex-ai/src/KurtVertexAI.ts +++ b/packages/kurt-vertex-ai/src/KurtVertexAI.ts @@ -91,13 +91,26 @@ export class KurtVertexAI model: this.options.model, }) as VertexAIGenerativeModel + // VertexAI requires that system messages be sent as a single message, + // so we filter them out from the main messages array to send separately. + const normalMessages = options.messages.filter((m) => m.role !== "system") + const systemMessages = options.messages.filter((m) => m.role === "system") + const singleSystemMessage: VertexAIMessage | undefined = + systemMessages.length === 0 + ? undefined + : { + role: "system", + parts: systemMessages.flatMap((m) => m.parts), + } + const req: VertexAIRequest = { generationConfig: { maxOutputTokens: options.sampling.maxOutputTokens, temperature: options.sampling.temperature, topP: options.sampling.topP, }, - contents: options.messages, + contents: normalMessages, + systemInstruction: singleSystemMessage, } const tools = Object.values(options.tools)