diff --git a/libs/langchain-google-common/src/utils/gemini.ts b/libs/langchain-google-common/src/utils/gemini.ts index 432620dcc52b..d230203a85de 100644 --- a/libs/langchain-google-common/src/utils/gemini.ts +++ b/libs/langchain-google-common/src/utils/gemini.ts @@ -2,7 +2,7 @@ import { v4 as uuidv4 } from "uuid"; import { AIMessage, AIMessageChunk, - AIMessageFields, + AIMessageChunkFields, BaseMessage, BaseMessageChunk, BaseMessageFields, @@ -566,7 +566,7 @@ export function chunkToString(chunk: BaseMessageChunk): string { } export function partToMessageChunk(part: GeminiPart): BaseMessageChunk { - const fields = partsToBaseMessageFields([part]); + const fields = partsToBaseMessageChunkFields([part]); if (typeof fields.content === "string") { return new AIMessageChunk(fields); } else if (fields.content.every((item) => item.type === "text")) { @@ -636,12 +636,15 @@ export function responseToBaseMessageFields( response: GoogleLLMResponse ): BaseMessageFields { const parts = responseToParts(response); - return partsToBaseMessageFields(parts); + return partsToBaseMessageChunkFields(parts); } -export function partsToBaseMessageFields(parts: GeminiPart[]): AIMessageFields { - const fields: AIMessageFields = { +export function partsToBaseMessageChunkFields( + parts: GeminiPart[] +): AIMessageChunkFields { + const fields: AIMessageChunkFields = { content: partsToMessageContent(parts), + tool_call_chunks: [], tool_calls: [], invalid_tool_calls: [], }; @@ -650,6 +653,13 @@ export function partsToBaseMessageFields(parts: GeminiPart[]): AIMessageFields { if (rawTools.length > 0) { const tools = toolsRawToTools(rawTools); for (const tool of tools) { + fields.tool_call_chunks?.push({ + name: tool.function.name, + args: tool.function.arguments, + id: tool.id, + type: "tool_call_chunk", + }); + try { fields.tool_calls?.push({ name: tool.function.name, @@ -661,7 +671,7 @@ export function partsToBaseMessageFields(parts: GeminiPart[]): AIMessageFields { } catch (e: any) { fields.invalid_tool_calls?.push({ name: tool.function.name, - args: JSON.parse(tool.function.arguments), + args: tool.function.arguments, id: tool.id, error: e.message, type: "invalid_tool_call", diff --git a/libs/langchain-google-vertexai/package.json b/libs/langchain-google-vertexai/package.json index 34a52b83c702..867234a91548 100644 --- a/libs/langchain-google-vertexai/package.json +++ b/libs/langchain-google-vertexai/package.json @@ -70,7 +70,8 @@ "release-it": "^15.10.1", "rollup": "^4.5.2", "ts-jest": "^29.1.0", - "typescript": "<5.2.0" + "typescript": "<5.2.0", + "zod": "^3.22.4" }, "publishConfig": { "access": "public" diff --git a/libs/langchain-google-vertexai/src/tests/chat_models.int.test.ts b/libs/langchain-google-vertexai/src/tests/chat_models.int.test.ts index 054462d7d1c0..2fa428a20924 100644 --- a/libs/langchain-google-vertexai/src/tests/chat_models.int.test.ts +++ b/libs/langchain-google-vertexai/src/tests/chat_models.int.test.ts @@ -11,111 +11,70 @@ import { SystemMessage, ToolMessage, } from "@langchain/core/messages"; -import { ChatVertexAI } from "../chat_models.js"; +import { tool } from "@langchain/core/tools"; +import { concat } from "@langchain/core/utils/stream"; +import { z } from "zod"; import { GeminiTool } from "../types.js"; +import { ChatVertexAI } from "../chat_models.js"; describe("GAuth Chat", () => { test("invoke", async () => { const model = new ChatVertexAI(); - try { - const res = await model.invoke("What is 1 + 1?"); - expect(res).toBeDefined(); - expect(res._getType()).toEqual("ai"); - - const aiMessage = res as AIMessageChunk; - expect(aiMessage.content).toBeDefined(); - - expect(typeof aiMessage.content).toBe("string"); - const text = aiMessage.content as string; - expect(text).toMatch(/(1 + 1 (equals|is|=) )?2.? ?/); + const res = await model.invoke("What is 1 + 1?"); + expect(res).toBeDefined(); + expect(res._getType()).toEqual("ai"); - /* - expect(aiMessage.content.length).toBeGreaterThan(0); - expect(aiMessage.content[0]).toBeDefined(); - const content = aiMessage.content[0] as MessageContentComplex; - expect(content).toHaveProperty("type"); - expect(content.type).toEqual("text"); + const aiMessage = res as AIMessageChunk; + expect(aiMessage.content).toBeDefined(); - const textContent = content as MessageContentText; - expect(textContent.text).toBeDefined(); - expect(textContent.text).toEqual("2"); - */ - } catch (e) { - console.error(e); - throw e; - } + expect(typeof aiMessage.content).toBe("string"); + const text = aiMessage.content as string; + expect(text).toMatch(/(1 + 1 (equals|is|=) )?2.? ?/); }); test("generate", async () => { const model = new ChatVertexAI(); - try { - const messages: BaseMessage[] = [ - new SystemMessage( - "You will reply to all requests to flip a coin with either H, indicating heads, or T, indicating tails." - ), - new HumanMessage("Flip it"), - new AIMessage("T"), - new HumanMessage("Flip the coin again"), - ]; - const res = await model.predictMessages(messages); - expect(res).toBeDefined(); - expect(res._getType()).toEqual("ai"); - - const aiMessage = res as AIMessageChunk; - expect(aiMessage.content).toBeDefined(); - - expect(typeof aiMessage.content).toBe("string"); - const text = aiMessage.content as string; - expect(["H", "T"]).toContainEqual(text); - - /* - expect(aiMessage.content.length).toBeGreaterThan(0); - expect(aiMessage.content[0]).toBeDefined(); + const messages: BaseMessage[] = [ + new SystemMessage( + "You will reply to all requests to flip a coin with either H, indicating heads, or T, indicating tails." + ), + new HumanMessage("Flip it"), + new AIMessage("T"), + new HumanMessage("Flip the coin again"), + ]; + const res = await model.predictMessages(messages); + expect(res).toBeDefined(); + expect(res._getType()).toEqual("ai"); - const content = aiMessage.content[0] as MessageContentComplex; - expect(content).toHaveProperty("type"); - expect(content.type).toEqual("text"); + const aiMessage = res as AIMessageChunk; + expect(aiMessage.content).toBeDefined(); - const textContent = content as MessageContentText; - expect(textContent.text).toBeDefined(); - expect(["H", "T"]).toContainEqual(textContent.text); - */ - } catch (e) { - console.error(e); - throw e; - } + expect(typeof aiMessage.content).toBe("string"); + const text = aiMessage.content as string; + expect(["H", "T"]).toContainEqual(text); }); test("stream", async () => { const model = new ChatVertexAI(); - try { - const input: BaseLanguageModelInput = new ChatPromptValue([ - new SystemMessage( - "You will reply to all requests to flip a coin with either H, indicating heads, or T, indicating tails." - ), - new HumanMessage("Flip it"), - new AIMessage("T"), - new HumanMessage("Flip the coin again"), - ]); - const res = await model.stream(input); - const resArray: BaseMessageChunk[] = []; - for await (const chunk of res) { - resArray.push(chunk); - } - expect(resArray).toBeDefined(); - expect(resArray.length).toBeGreaterThanOrEqual(1); - - const lastChunk = resArray[resArray.length - 1]; - expect(lastChunk).toBeDefined(); - expect(lastChunk._getType()).toEqual("ai"); - const aiChunk = lastChunk as AIMessageChunk; - console.log(aiChunk); - - console.log(JSON.stringify(resArray, null, 2)); - } catch (e) { - console.error(e); - throw e; + const input: BaseLanguageModelInput = new ChatPromptValue([ + new SystemMessage( + "You will reply to all requests to flip a coin with either H, indicating heads, or T, indicating tails." + ), + new HumanMessage("Flip it"), + new AIMessage("T"), + new HumanMessage("Flip the coin again"), + ]); + const res = await model.stream(input); + const resArray: BaseMessageChunk[] = []; + for await (const chunk of res) { + resArray.push(chunk); } + expect(resArray).toBeDefined(); + expect(resArray.length).toBeGreaterThanOrEqual(1); + + const lastChunk = resArray[resArray.length - 1]; + expect(lastChunk).toBeDefined(); + expect(lastChunk._getType()).toEqual("ai"); }); test("function", async () => { @@ -209,7 +168,7 @@ describe("GAuth Chat", () => { for await (const chunk of res) { resArray.push(chunk); } - console.log(JSON.stringify(resArray, null, 2)); + // console.log(JSON.stringify(resArray, null, 2)); }); test("withStructuredOutput", async () => { @@ -249,7 +208,7 @@ test("Stream token count usage_metadata", async () => { res = res.concat(chunk); } } - console.log(res); + // console.log(res); expect(res?.usage_metadata).toBeDefined(); if (!res?.usage_metadata) { return; @@ -276,7 +235,7 @@ test("streamUsage excludes token usage", async () => { res = res.concat(chunk); } } - console.log(res); + // console.log(res); expect(res?.usage_metadata).not.toBeDefined(); }); @@ -286,7 +245,7 @@ test("Invoke token count usage_metadata", async () => { maxOutputTokens: 10, }); const res = await model.invoke("Why is the sky blue? Be concise."); - console.log(res); + // console.log(res); expect(res?.usage_metadata).toBeDefined(); if (!res?.usage_metadata) { return; @@ -322,3 +281,39 @@ test("Streaming true constructor param will stream", async () => { expect(totalTokenCount).toBeGreaterThan(1); }); + +test("ChatGoogleGenerativeAI can stream tools", async () => { + const model = new ChatVertexAI({}); + + const weatherTool = tool( + (_) => "The weather in San Francisco today is 18 degrees and sunny.", + { + name: "current_weather_tool", + description: "Get the current weather for a given location.", + schema: z.object({ + location: z.string().describe("The location to get the weather for."), + }), + } + ); + + const modelWithTools = model.bindTools([weatherTool]); + const stream = await modelWithTools.stream( + "Whats the weather like today in San Francisco?" + ); + let finalChunk: AIMessageChunk | undefined; + for await (const chunk of stream) { + finalChunk = !finalChunk ? chunk : concat(finalChunk, chunk); + } + + expect(finalChunk).toBeDefined(); + if (!finalChunk) return; + + const toolCalls = finalChunk.tool_calls; + expect(toolCalls).toBeDefined(); + if (!toolCalls) { + throw new Error("tool_calls not in response"); + } + expect(toolCalls.length).toBe(1); + expect(toolCalls[0].name).toBe("current_weather_tool"); + expect(toolCalls[0].args).toHaveProperty("location"); +}); diff --git a/libs/langchain-google-vertexai/src/tests/chat_models.standard.int.test.ts b/libs/langchain-google-vertexai/src/tests/chat_models.standard.int.test.ts index c44f36916ddc..60c5b6c421b0 100644 --- a/libs/langchain-google-vertexai/src/tests/chat_models.standard.int.test.ts +++ b/libs/langchain-google-vertexai/src/tests/chat_models.standard.int.test.ts @@ -19,6 +19,7 @@ class ChatVertexAIStandardIntegrationTests extends ChatModelIntegrationTests< Cls: ChatVertexAI, chatModelHasToolCalling: true, chatModelHasStructuredOutput: true, + invokeResponseType: AIMessageChunk, constructorArgs: { model: "gemini-1.5-pro", }, @@ -32,6 +33,14 @@ class ChatVertexAIStandardIntegrationTests extends ChatModelIntegrationTests< "Not implemented." ); } + + async testInvokeMoreComplexTools() { + this.skipTestMessage( + "testInvokeMoreComplexTools", + "ChatVertexAI", + "Google VertexAI does not support tool schemas where the object properties are not defined." + ); + } } const testClass = new ChatVertexAIStandardIntegrationTests(); diff --git a/yarn.lock b/yarn.lock index c5dcede32a3e..94223547d02b 100644 --- a/yarn.lock +++ b/yarn.lock @@ -11695,6 +11695,7 @@ __metadata: rollup: ^4.5.2 ts-jest: ^29.1.0 typescript: <5.2.0 + zod: ^3.22.4 languageName: unknown linkType: soft