Skip to content

Commit

Permalink
partners[patch]: Add type field to all tool call integrations (#6067)
Browse files Browse the repository at this point in the history
* partnets[patch]: Add type field to all tool call integrations

* chore: lint files
  • Loading branch information
bracesproul authored Jul 15, 2024
1 parent 7e47915 commit 21066cc
Show file tree
Hide file tree
Showing 9 changed files with 37 additions and 14 deletions.
3 changes: 3 additions & 0 deletions libs/langchain-aws/src/common.ts
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,7 @@ export function convertConverseMessageToLangChainMessage(
id: c.toolUse.toolUseId,
name: c.toolUse.name,
args: c.toolUse.input,
type: "tool_call",
});
} else if ("text" in c && typeof c.text === "string") {
content.push({ type: "text", text: c.text });
Expand Down Expand Up @@ -391,6 +392,7 @@ export function handleConverseStreamContentBlockDelta(
{
args: contentBlockDelta.delta.toolUse.input,
index,
type: "tool_call_chunk",
},
],
}),
Expand Down Expand Up @@ -419,6 +421,7 @@ export function handleConverseStreamContentBlockStart(
name: contentBlockStart.start.toolUse.name,
id: contentBlockStart.start.toolUse.toolUseId,
index,
type: "tool_call_chunk",
},
],
}),
Expand Down
2 changes: 2 additions & 0 deletions libs/langchain-cohere/src/chat_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -599,6 +599,7 @@ export class ChatCohere<
name: toolCall.function.name,
args: toolCall.function.arguments,
id: toolCall.id,
type: "tool_call",
}));
}

Expand Down Expand Up @@ -775,6 +776,7 @@ export class ChatCohere<
args: toolCall.function.arguments,
id: toolCall.id,
index: toolCall.index,
type: "tool_call_chunk",
}));
}

Expand Down
10 changes: 5 additions & 5 deletions libs/langchain-community/src/chat_models/bedrock/web.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ import {
} from "@langchain/core/outputs";
import { StructuredToolInterface } from "@langchain/core/tools";
import { isStructuredTool } from "@langchain/core/utils/function_calling";
import { ToolCall } from "@langchain/core/messages/tool";
import { ToolCall, ToolCallChunk } from "@langchain/core/messages/tool";
import { zodToJsonSchema } from "zod-to-json-schema";

import type { SerializedFields } from "../../load/map_keys.js";
Expand Down Expand Up @@ -562,14 +562,14 @@ export class BedrockChat
options
);
const result = generations[0].message as AIMessage;
const toolCallChunks = result.tool_calls?.map(
(toolCall: ToolCall, index: number) => ({
const toolCallChunks: ToolCallChunk[] | undefined =
result.tool_calls?.map((toolCall: ToolCall, index: number) => ({
name: toolCall.name,
args: JSON.stringify(toolCall.args),
id: toolCall.id,
index,
})
);
type: "tool_call_chunk",
}));
yield new ChatGenerationChunk({
message: new AIMessageChunk({
content: result.content,
Expand Down
7 changes: 6 additions & 1 deletion libs/langchain-community/src/utils/bedrock/anthropic.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,12 @@ export function extractToolCalls(content: Record<string, any>[]) {
const toolCalls: ToolCall[] = [];
for (const block of content) {
if (block.type === "tool_use") {
toolCalls.push({ name: block.name, args: block.input, id: block.id });
toolCalls.push({
name: block.name,
args: block.input,
id: block.id,
type: "tool_call",
});
}
}
return toolCalls;
Expand Down
6 changes: 5 additions & 1 deletion libs/langchain-google-common/src/utils/gemini.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import {
ChatResult,
Generation,
} from "@langchain/core/outputs";
import { ToolCallChunk } from "@langchain/core/messages/tool";
import type {
GoogleLLMResponse,
GoogleAIModelParams,
Expand Down Expand Up @@ -597,13 +598,14 @@ export function responseToChatGenerations(
if (ret.every((item) => typeof item.message.content === "string")) {
const combinedContent = ret.map((item) => item.message.content).join("");
const combinedText = ret.map((item) => item.text).join("");
const toolCallChunks = ret[
const toolCallChunks: ToolCallChunk[] | undefined = ret[
ret.length - 1
]?.message.additional_kwargs?.tool_calls?.map((toolCall, i) => ({
name: toolCall.function.name,
args: toolCall.function.arguments,
id: toolCall.id,
index: i,
type: "tool_call_chunk",
}));
let usageMetadata: UsageMetadata | undefined;
if ("usageMetadata" in response.data) {
Expand Down Expand Up @@ -653,6 +655,7 @@ export function partsToBaseMessageFields(parts: GeminiPart[]): AIMessageFields {
name: tool.function.name,
args: JSON.parse(tool.function.arguments),
id: tool.id,
type: "tool_call",
});
// eslint-disable-next-line @typescript-eslint/no-explicit-any
} catch (e: any) {
Expand All @@ -661,6 +664,7 @@ export function partsToBaseMessageFields(parts: GeminiPart[]): AIMessageFields {
args: JSON.parse(tool.function.arguments),
id: tool.id,
error: e.message,
type: "invalid_tool_call",
});
}
}
Expand Down
6 changes: 5 additions & 1 deletion libs/langchain-google-genai/src/utils/common.ts
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,10 @@ export function mapGenerateContentResultToChatResult(
text,
message: new AIMessage({
content: text,
tool_calls: functionCalls,
tool_calls: functionCalls?.map((fc) => ({
...fc,
type: "tool_call",
})),
additional_kwargs: {
...generationInfo,
},
Expand Down Expand Up @@ -300,6 +303,7 @@ export function convertResponseContentToChatGenerationChunk(
...fc,
args: JSON.stringify(fc.args),
index: extra.index,
type: "tool_call_chunk" as const,
}))
);
}
Expand Down
9 changes: 5 additions & 4 deletions libs/langchain-groq/src/chat_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ import {
} from "@langchain/core/output_parsers/openai_tools";
import { StructuredToolInterface } from "@langchain/core/tools";
import { convertToOpenAITool } from "@langchain/core/utils/function_calling";
import { ToolCallChunk } from "@langchain/core/messages/tool";

export interface ChatGroqCallOptions extends BaseChatModelCallOptions {
headers?: Record<string, string>;
Expand Down Expand Up @@ -403,14 +404,14 @@ export class ChatGroq extends BaseChatModel<
) {
throw new Error("Could not parse Groq output.");
}
const toolCallChunks = generationMessage.tool_calls?.map(
(toolCall, i) => ({
const toolCallChunks: ToolCallChunk[] | undefined =
generationMessage.tool_calls?.map((toolCall, i) => ({
name: toolCall.name,
args: JSON.stringify(toolCall.args),
id: toolCall.id,
index: i,
})
);
type: "tool_call_chunk",
}));
yield new ChatGenerationChunk({
message: new AIMessageChunk({
content: generationMessage.content,
Expand Down
4 changes: 3 additions & 1 deletion libs/langchain-mistralai/src/chat_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ import {
RunnableToolLike,
} from "@langchain/core/runnables";
import { zodToJsonSchema } from "zod-to-json-schema";
import { ToolCallChunk } from "@langchain/core/messages/tool";

interface TokenUsage {
completionTokens?: number;
Expand Down Expand Up @@ -321,7 +322,7 @@ function _convertDeltaToMessageChunk(
}
const content = delta.content ?? "";
let additional_kwargs;
const toolCallChunks = [];
const toolCallChunks: ToolCallChunk[] = [];
if (rawToolCallChunksWithIndex !== undefined) {
additional_kwargs = {
tool_calls: rawToolCallChunksWithIndex,
Expand All @@ -332,6 +333,7 @@ function _convertDeltaToMessageChunk(
args: rawToolCallChunk.function?.arguments,
id: rawToolCallChunk.id,
index: rawToolCallChunk.index,
type: "tool_call_chunk",
});
}
} else {
Expand Down
4 changes: 3 additions & 1 deletion libs/langchain-openai/src/chat_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ import {
parseToolCall,
} from "@langchain/core/output_parsers/openai_tools";
import { zodToJsonSchema } from "zod-to-json-schema";
import { ToolCallChunk } from "@langchain/core/messages/tool";
import type {
AzureOpenAIInput,
OpenAICallOptions,
Expand Down Expand Up @@ -184,14 +185,15 @@ function _convertDeltaToMessageChunk(
if (role === "user") {
return new HumanMessageChunk({ content });
} else if (role === "assistant") {
const toolCallChunks = [];
const toolCallChunks: ToolCallChunk[] = [];
if (Array.isArray(delta.tool_calls)) {
for (const rawToolCall of delta.tool_calls) {
toolCallChunks.push({
name: rawToolCall.function?.name,
args: rawToolCall.function?.arguments,
id: rawToolCall.id,
index: rawToolCall.index,
type: "tool_call_chunk",
});
}
}
Expand Down

0 comments on commit 21066cc

Please sign in to comment.