From 21066ccf2ed1032b04c3356e23b46adba04c4bf7 Mon Sep 17 00:00:00 2001 From: Brace Sproul Date: Mon, 15 Jul 2024 16:52:24 -0700 Subject: [PATCH] partners[patch]: Add type field to all tool call integrations (#6067) * partnets[patch]: Add type field to all tool call integrations * chore: lint files --- libs/langchain-aws/src/common.ts | 3 +++ libs/langchain-cohere/src/chat_models.ts | 2 ++ .../langchain-community/src/chat_models/bedrock/web.ts | 10 +++++----- .../langchain-community/src/utils/bedrock/anthropic.ts | 7 ++++++- libs/langchain-google-common/src/utils/gemini.ts | 6 +++++- libs/langchain-google-genai/src/utils/common.ts | 6 +++++- libs/langchain-groq/src/chat_models.ts | 9 +++++---- libs/langchain-mistralai/src/chat_models.ts | 4 +++- libs/langchain-openai/src/chat_models.ts | 4 +++- 9 files changed, 37 insertions(+), 14 deletions(-) diff --git a/libs/langchain-aws/src/common.ts b/libs/langchain-aws/src/common.ts index 13d5223241c0..14afc63cad6e 100644 --- a/libs/langchain-aws/src/common.ts +++ b/libs/langchain-aws/src/common.ts @@ -351,6 +351,7 @@ export function convertConverseMessageToLangChainMessage( id: c.toolUse.toolUseId, name: c.toolUse.name, args: c.toolUse.input, + type: "tool_call", }); } else if ("text" in c && typeof c.text === "string") { content.push({ type: "text", text: c.text }); @@ -391,6 +392,7 @@ export function handleConverseStreamContentBlockDelta( { args: contentBlockDelta.delta.toolUse.input, index, + type: "tool_call_chunk", }, ], }), @@ -419,6 +421,7 @@ export function handleConverseStreamContentBlockStart( name: contentBlockStart.start.toolUse.name, id: contentBlockStart.start.toolUse.toolUseId, index, + type: "tool_call_chunk", }, ], }), diff --git a/libs/langchain-cohere/src/chat_models.ts b/libs/langchain-cohere/src/chat_models.ts index ab49603f755a..eb6887e53dc6 100644 --- a/libs/langchain-cohere/src/chat_models.ts +++ b/libs/langchain-cohere/src/chat_models.ts @@ -599,6 +599,7 @@ export class ChatCohere< name: toolCall.function.name, args: toolCall.function.arguments, id: toolCall.id, + type: "tool_call", })); } @@ -775,6 +776,7 @@ export class ChatCohere< args: toolCall.function.arguments, id: toolCall.id, index: toolCall.index, + type: "tool_call_chunk", })); } diff --git a/libs/langchain-community/src/chat_models/bedrock/web.ts b/libs/langchain-community/src/chat_models/bedrock/web.ts index a15dc949412a..3a2a350f97c9 100644 --- a/libs/langchain-community/src/chat_models/bedrock/web.ts +++ b/libs/langchain-community/src/chat_models/bedrock/web.ts @@ -33,7 +33,7 @@ import { } from "@langchain/core/outputs"; import { StructuredToolInterface } from "@langchain/core/tools"; import { isStructuredTool } from "@langchain/core/utils/function_calling"; -import { ToolCall } from "@langchain/core/messages/tool"; +import { ToolCall, ToolCallChunk } from "@langchain/core/messages/tool"; import { zodToJsonSchema } from "zod-to-json-schema"; import type { SerializedFields } from "../../load/map_keys.js"; @@ -562,14 +562,14 @@ export class BedrockChat options ); const result = generations[0].message as AIMessage; - const toolCallChunks = result.tool_calls?.map( - (toolCall: ToolCall, index: number) => ({ + const toolCallChunks: ToolCallChunk[] | undefined = + result.tool_calls?.map((toolCall: ToolCall, index: number) => ({ name: toolCall.name, args: JSON.stringify(toolCall.args), id: toolCall.id, index, - }) - ); + type: "tool_call_chunk", + })); yield new ChatGenerationChunk({ message: new AIMessageChunk({ content: result.content, diff --git a/libs/langchain-community/src/utils/bedrock/anthropic.ts b/libs/langchain-community/src/utils/bedrock/anthropic.ts index 4bb888a9de77..6ed262795c08 100644 --- a/libs/langchain-community/src/utils/bedrock/anthropic.ts +++ b/libs/langchain-community/src/utils/bedrock/anthropic.ts @@ -14,7 +14,12 @@ export function extractToolCalls(content: Record[]) { const toolCalls: ToolCall[] = []; for (const block of content) { if (block.type === "tool_use") { - toolCalls.push({ name: block.name, args: block.input, id: block.id }); + toolCalls.push({ + name: block.name, + args: block.input, + id: block.id, + type: "tool_call", + }); } } return toolCalls; diff --git a/libs/langchain-google-common/src/utils/gemini.ts b/libs/langchain-google-common/src/utils/gemini.ts index 749b78ad27d9..432620dcc52b 100644 --- a/libs/langchain-google-common/src/utils/gemini.ts +++ b/libs/langchain-google-common/src/utils/gemini.ts @@ -21,6 +21,7 @@ import { ChatResult, Generation, } from "@langchain/core/outputs"; +import { ToolCallChunk } from "@langchain/core/messages/tool"; import type { GoogleLLMResponse, GoogleAIModelParams, @@ -597,13 +598,14 @@ export function responseToChatGenerations( if (ret.every((item) => typeof item.message.content === "string")) { const combinedContent = ret.map((item) => item.message.content).join(""); const combinedText = ret.map((item) => item.text).join(""); - const toolCallChunks = ret[ + const toolCallChunks: ToolCallChunk[] | undefined = ret[ ret.length - 1 ]?.message.additional_kwargs?.tool_calls?.map((toolCall, i) => ({ name: toolCall.function.name, args: toolCall.function.arguments, id: toolCall.id, index: i, + type: "tool_call_chunk", })); let usageMetadata: UsageMetadata | undefined; if ("usageMetadata" in response.data) { @@ -653,6 +655,7 @@ export function partsToBaseMessageFields(parts: GeminiPart[]): AIMessageFields { name: tool.function.name, args: JSON.parse(tool.function.arguments), id: tool.id, + type: "tool_call", }); // eslint-disable-next-line @typescript-eslint/no-explicit-any } catch (e: any) { @@ -661,6 +664,7 @@ export function partsToBaseMessageFields(parts: GeminiPart[]): AIMessageFields { args: JSON.parse(tool.function.arguments), id: tool.id, error: e.message, + type: "invalid_tool_call", }); } } diff --git a/libs/langchain-google-genai/src/utils/common.ts b/libs/langchain-google-genai/src/utils/common.ts index afb60eb69c8d..1d4bdddbaa66 100644 --- a/libs/langchain-google-genai/src/utils/common.ts +++ b/libs/langchain-google-genai/src/utils/common.ts @@ -264,7 +264,10 @@ export function mapGenerateContentResultToChatResult( text, message: new AIMessage({ content: text, - tool_calls: functionCalls, + tool_calls: functionCalls?.map((fc) => ({ + ...fc, + type: "tool_call", + })), additional_kwargs: { ...generationInfo, }, @@ -300,6 +303,7 @@ export function convertResponseContentToChatGenerationChunk( ...fc, args: JSON.stringify(fc.args), index: extra.index, + type: "tool_call_chunk" as const, })) ); } diff --git a/libs/langchain-groq/src/chat_models.ts b/libs/langchain-groq/src/chat_models.ts index 1abe7b3db765..eb8f2dffc76d 100644 --- a/libs/langchain-groq/src/chat_models.ts +++ b/libs/langchain-groq/src/chat_models.ts @@ -64,6 +64,7 @@ import { } from "@langchain/core/output_parsers/openai_tools"; import { StructuredToolInterface } from "@langchain/core/tools"; import { convertToOpenAITool } from "@langchain/core/utils/function_calling"; +import { ToolCallChunk } from "@langchain/core/messages/tool"; export interface ChatGroqCallOptions extends BaseChatModelCallOptions { headers?: Record; @@ -403,14 +404,14 @@ export class ChatGroq extends BaseChatModel< ) { throw new Error("Could not parse Groq output."); } - const toolCallChunks = generationMessage.tool_calls?.map( - (toolCall, i) => ({ + const toolCallChunks: ToolCallChunk[] | undefined = + generationMessage.tool_calls?.map((toolCall, i) => ({ name: toolCall.name, args: JSON.stringify(toolCall.args), id: toolCall.id, index: i, - }) - ); + type: "tool_call_chunk", + })); yield new ChatGenerationChunk({ message: new AIMessageChunk({ content: generationMessage.content, diff --git a/libs/langchain-mistralai/src/chat_models.ts b/libs/langchain-mistralai/src/chat_models.ts index ea0317f34d84..1a1fdd36e12b 100644 --- a/libs/langchain-mistralai/src/chat_models.ts +++ b/libs/langchain-mistralai/src/chat_models.ts @@ -65,6 +65,7 @@ import { RunnableToolLike, } from "@langchain/core/runnables"; import { zodToJsonSchema } from "zod-to-json-schema"; +import { ToolCallChunk } from "@langchain/core/messages/tool"; interface TokenUsage { completionTokens?: number; @@ -321,7 +322,7 @@ function _convertDeltaToMessageChunk( } const content = delta.content ?? ""; let additional_kwargs; - const toolCallChunks = []; + const toolCallChunks: ToolCallChunk[] = []; if (rawToolCallChunksWithIndex !== undefined) { additional_kwargs = { tool_calls: rawToolCallChunksWithIndex, @@ -332,6 +333,7 @@ function _convertDeltaToMessageChunk( args: rawToolCallChunk.function?.arguments, id: rawToolCallChunk.id, index: rawToolCallChunk.index, + type: "tool_call_chunk", }); } } else { diff --git a/libs/langchain-openai/src/chat_models.ts b/libs/langchain-openai/src/chat_models.ts index 601d70d33fff..c3f1f03ffc1b 100644 --- a/libs/langchain-openai/src/chat_models.ts +++ b/libs/langchain-openai/src/chat_models.ts @@ -55,6 +55,7 @@ import { parseToolCall, } from "@langchain/core/output_parsers/openai_tools"; import { zodToJsonSchema } from "zod-to-json-schema"; +import { ToolCallChunk } from "@langchain/core/messages/tool"; import type { AzureOpenAIInput, OpenAICallOptions, @@ -184,7 +185,7 @@ function _convertDeltaToMessageChunk( if (role === "user") { return new HumanMessageChunk({ content }); } else if (role === "assistant") { - const toolCallChunks = []; + const toolCallChunks: ToolCallChunk[] = []; if (Array.isArray(delta.tool_calls)) { for (const rawToolCall of delta.tool_calls) { toolCallChunks.push({ @@ -192,6 +193,7 @@ function _convertDeltaToMessageChunk( args: rawToolCall.function?.arguments, id: rawToolCall.id, index: rawToolCall.index, + type: "tool_call_chunk", }); } }