Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

aws[patch]: Fix fails when calling multiple tools simultaneously #6175

Merged
merged 3 commits into from
Jul 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion libs/langchain-aws/src/common.ts
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,29 @@ export function convertToConverseMessages(messages: BaseMessage[]): {
}
});

return { converseMessages, converseSystem };
// Combine consecutive user tool result messages into a single message
const combinedConverseMessages = converseMessages.reduce<BedrockMessage[]>(
(acc, curr) => {
const lastMessage = acc[acc.length - 1];

if (
lastMessage &&
lastMessage.role === "user" &&
lastMessage.content?.some((c) => "toolResult" in c) &&
curr.role === "user" &&
curr.content?.some((c) => "toolResult" in c)
) {
lastMessage.content = lastMessage.content.concat(curr.content);
} else {
acc.push(curr);
}

return acc;
},
[]
);

return { converseMessages: combinedConverseMessages, converseSystem };
}

export function isBedrockTool(tool: unknown): tool is BedrockTool {
Expand Down
346 changes: 287 additions & 59 deletions libs/langchain-aws/src/tests/chat_models.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,77 +4,305 @@ import {
AIMessage,
ToolMessage,
AIMessageChunk,
BaseMessage,
} from "@langchain/core/messages";
import { concat } from "@langchain/core/utils/stream";
import type {
Message as BedrockMessage,
SystemContentBlock as BedrockSystemContentBlock,
} from "@aws-sdk/client-bedrock-runtime";
import {
convertToConverseMessages,
handleConverseStreamContentBlockDelta,
} from "../common.js";

test("convertToConverseMessages works", () => {
const messages = [
new SystemMessage("You're an advanced AI assistant."),
new HumanMessage(
"What's the weather like today in Berkeley, CA? Use weather.com to check."
),
new AIMessage({
content: "",
tool_calls: [
{
name: "retrieverTool",
args: {
url: "https://weather.com",
describe("convertToConverseMessages", () => {
const testCases: {
name: string;
input: BaseMessage[];
output: {
converseMessages: BedrockMessage[];
converseSystem: BedrockSystemContentBlock[];
};
}[] = [
{
name: "empty input",
input: [],
output: {
converseMessages: [],
converseSystem: [],
},
},
{
name: "simple messages",
input: [
new SystemMessage("You're an advanced AI assistant."),
new HumanMessage(
"What's the weather like today in Berkeley, CA? Use weather.com to check."
),
new AIMessage({
content: "",
tool_calls: [
{
name: "retrieverTool",
args: {
url: "https://weather.com",
},
id: "123_retriever_tool",
},
],
}),
new ToolMessage({
tool_call_id: "123_retriever_tool",
content: "The weather in Berkeley, CA is 70 degrees and sunny.",
}),
],
output: {
converseMessages: [
{
role: "user",
content: [
{
text: "What's the weather like today in Berkeley, CA? Use weather.com to check.",
},
],
},
{
role: "assistant",
content: [
{
toolUse: {
name: "retrieverTool",
toolUseId: "123_retriever_tool",
input: {
url: "https://weather.com",
},
},
},
],
},
{
role: "user",
content: [
{
toolResult: {
toolUseId: "123_retriever_tool",
content: [
{
text: "The weather in Berkeley, CA is 70 degrees and sunny.",
},
],
},
},
],
},
id: "123_retriever_tool",
},
],
converseSystem: [
{
text: "You're an advanced AI assistant.",
},
],
},
},
{
name: "consecutive user tool messages",
input: [
new SystemMessage("You're an advanced AI assistant."),
new HumanMessage(
"What's the weather like today in Berkeley, CA and in Paris, France? Use weather.com to check."
),
new AIMessage({
content: "",
tool_calls: [
{
name: "retrieverTool",
args: {
url: "https://weather.com",
},
id: "123_retriever_tool",
},
{
name: "retrieverTool",
args: {
url: "https://weather.com",
},
id: "456_retriever_tool",
},
],
}),
new ToolMessage({
tool_call_id: "123_retriever_tool",
content: "The weather in Berkeley, CA is 70 degrees and sunny.",
}),
new ToolMessage({
tool_call_id: "456_retriever_tool",
content: "The weather in Paris, France is perfect.",
}),
new HumanMessage(
"What's the weather like today in Berkeley, CA and in Paris, France? Use meteofrance.com to check."
),
new AIMessage({
content: "",
tool_calls: [
{
name: "retrieverTool",
args: {
url: "https://meteofrance.com",
},
id: "321_retriever_tool",
},
{
name: "retrieverTool",
args: {
url: "https://meteofrance.com",
},
id: "654_retriever_tool",
},
],
}),
new ToolMessage({
tool_call_id: "321_retriever_tool",
content: "Why don't you check yourself?",
}),
new ToolMessage({
tool_call_id: "654_retriever_tool",
content: "The weather in Paris, France is horrible.",
}),
],
}),
new ToolMessage({
tool_call_id: "123_retriever_tool",
content: "The weather in Berkeley, CA is 70 degrees and sunny.",
}),
output: {
converseSystem: [
{
text: "You're an advanced AI assistant.",
},
],
converseMessages: [
{
role: "user",
content: [
{
text: "What's the weather like today in Berkeley, CA and in Paris, France? Use weather.com to check.",
},
],
},
{
role: "assistant",
content: [
{
toolUse: {
name: "retrieverTool",
toolUseId: "123_retriever_tool",
input: {
url: "https://weather.com",
},
},
},
{
toolUse: {
name: "retrieverTool",
toolUseId: "456_retriever_tool",
input: {
url: "https://weather.com",
},
},
},
],
},
{
role: "user",
content: [
{
toolResult: {
toolUseId: "123_retriever_tool",
content: [
{
text: "The weather in Berkeley, CA is 70 degrees and sunny.",
},
],
},
},
{
toolResult: {
toolUseId: "456_retriever_tool",
content: [
{
text: "The weather in Paris, France is perfect.",
},
],
},
},
],
},
{
role: "user",
content: [
{
text: "What's the weather like today in Berkeley, CA and in Paris, France? Use meteofrance.com to check.",
},
],
},
{
role: "assistant",
content: [
{
toolUse: {
name: "retrieverTool",
toolUseId: "321_retriever_tool",
input: {
url: "https://meteofrance.com",
},
},
},
{
toolUse: {
name: "retrieverTool",
toolUseId: "654_retriever_tool",
input: {
url: "https://meteofrance.com",
},
},
},
],
},
{
role: "user",
content: [
{
toolResult: {
toolUseId: "321_retriever_tool",
content: [
{
text: "Why don't you check yourself?",
},
],
},
},
{
toolResult: {
toolUseId: "654_retriever_tool",
content: [
{
text: "The weather in Paris, France is horrible.",
},
],
},
},
],
},
],
},
},
];

const { converseMessages, converseSystem } =
convertToConverseMessages(messages);

expect(converseSystem).toHaveLength(1);
expect(converseSystem[0].text).toBe("You're an advanced AI assistant.");

expect(converseMessages).toHaveLength(3);

const userMsgs = converseMessages.filter((msg) => msg.role === "user");
// Length of two because of the first user question, and tool use
// messages will have the user role.
expect(userMsgs).toHaveLength(2);
const textUserMsg = userMsgs.find((msg) => msg.content?.[0].text);
expect(textUserMsg?.content?.[0].text).toBe(
"What's the weather like today in Berkeley, CA? Use weather.com to check."
it.each(testCases.map((tc) => [tc.name, tc]))(
"convertToConverseMessages: case %s",
(_, tc) => {
const { converseMessages, converseSystem } = convertToConverseMessages(
tc.input
);
expect(converseMessages).toEqual(tc.output.converseMessages);
expect(converseSystem).toEqual(tc.output.converseSystem);
}
);

const toolUseUserMsg = userMsgs.find((msg) => msg.content?.[0].toolResult);
expect(toolUseUserMsg).toBeDefined();
expect(toolUseUserMsg?.content).toHaveLength(1);
if (!toolUseUserMsg?.content?.length) return;

const toolResultContent = toolUseUserMsg.content[0];
expect(toolResultContent).toBeDefined();
expect(toolResultContent.toolResult?.toolUseId).toBe("123_retriever_tool");
expect(toolResultContent.toolResult?.content?.[0].text).toBe(
"The weather in Berkeley, CA is 70 degrees and sunny."
);

const assistantMsg = converseMessages.find((msg) => msg.role === "assistant");
expect(assistantMsg).toBeDefined();
if (!assistantMsg) return;

const toolUseContent = assistantMsg.content?.find((c) => "toolUse" in c);
expect(toolUseContent).toBeDefined();
expect(toolUseContent?.toolUse?.name).toBe("retrieverTool");
expect(toolUseContent?.toolUse?.toolUseId).toBe("123_retriever_tool");
expect(toolUseContent?.toolUse?.input).toEqual({
url: "https://weather.com",
});
});

test("Streaming supports empty string chunks", async () => {
Expand Down
Loading