Skip to content

Commit

Permalink
Adds support for Bedrock Anthropic tool streaming (#6070)
Browse files Browse the repository at this point in the history
  • Loading branch information
jacoblee93 authored Jul 16, 2024
1 parent 2f575fe commit 15fdeea
Show file tree
Hide file tree
Showing 5 changed files with 426 additions and 175 deletions.
247 changes: 113 additions & 134 deletions libs/langchain-community/src/chat_models/bedrock/web.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ import {
} from "@langchain/core/outputs";
import { StructuredToolInterface } from "@langchain/core/tools";
import { isStructuredTool } from "@langchain/core/utils/function_calling";
import { ToolCall, ToolCallChunk } from "@langchain/core/messages/tool";
import { zodToJsonSchema } from "zod-to-json-schema";

import type { SerializedFields } from "../../load/map_keys.js";
Expand All @@ -42,7 +41,10 @@ import {
BedrockLLMInputOutputAdapter,
type CredentialType,
} from "../../utils/bedrock/index.js";
import { isAnthropicTool } from "../../utils/bedrock/anthropic.js";
import {
_toolsInParams,
isAnthropicTool,
} from "../../utils/bedrock/anthropic.js";

type AnthropicTool = Record<string, unknown>;

Expand Down Expand Up @@ -556,145 +558,122 @@ export class BedrockChat
options: this["ParsedCallOptions"],
runManager?: CallbackManagerForLLMRun
): AsyncGenerator<ChatGenerationChunk> {
if (this._anthropicTools) {
const { generations } = await this._generateNonStreaming(
messages,
options
);
const result = generations[0].message as AIMessage;
const toolCallChunks: ToolCallChunk[] | undefined =
result.tool_calls?.map((toolCall: ToolCall, index: number) => ({
name: toolCall.name,
args: JSON.stringify(toolCall.args),
id: toolCall.id,
index,
type: "tool_call_chunk",
}));
yield new ChatGenerationChunk({
message: new AIMessageChunk({
content: result.content,
additional_kwargs: result.additional_kwargs,
tool_call_chunks: toolCallChunks,
}),
text: generations[0].text,
});
// eslint-disable-next-line no-void
void runManager?.handleLLMNewToken(generations[0].text);
} else {
const provider = this.model.split(".")[0];
const service = "bedrock-runtime";

const endpointHost =
this.endpointHost ?? `${service}.${this.region}.amazonaws.com`;

const bedrockMethod =
provider === "anthropic" ||
provider === "cohere" ||
provider === "meta" ||
provider === "mistral"
? "invoke-with-response-stream"
: "invoke";

const response = await this._signedFetch(messages, options, {
bedrockMethod,
endpointHost,
provider,
});
const provider = this.model.split(".")[0];
const service = "bedrock-runtime";

if (response.status < 200 || response.status >= 300) {
throw Error(
`Failed to access underlying url '${endpointHost}': got ${
response.status
} ${response.statusText}: ${await response.text()}`
);
}
const endpointHost =
this.endpointHost ?? `${service}.${this.region}.amazonaws.com`;

if (
provider === "anthropic" ||
provider === "cohere" ||
provider === "meta" ||
provider === "mistral"
) {
const reader = response.body?.getReader();
const decoder = new TextDecoder();
for await (const chunk of this._readChunks(reader)) {
const event = this.codec.decode(chunk);
if (
(event.headers[":event-type"] !== undefined &&
event.headers[":event-type"].value !== "chunk") ||
event.headers[":content-type"].value !== "application/json"
) {
throw Error(`Failed to get event chunk: got ${chunk}`);
}
const body = JSON.parse(decoder.decode(event.body));
if (body.message) {
throw new Error(body.message);
}
if (body.bytes !== undefined) {
const chunkResult = JSON.parse(
decoder.decode(
Uint8Array.from(atob(body.bytes), (m) => m.codePointAt(0) ?? 0)
)
);
if (this.usesMessagesApi) {
const chunk = BedrockLLMInputOutputAdapter.prepareMessagesOutput(
provider,
chunkResult
);
if (chunk === undefined) {
continue;
}
if (
provider === "anthropic" &&
chunk.generationInfo?.usage !== undefined
) {
// Avoid bad aggregation in chunks, rely on final Bedrock data
delete chunk.generationInfo.usage;
}
const finalMetrics =
chunk.generationInfo?.["amazon-bedrock-invocationMetrics"];
if (
finalMetrics != null &&
typeof finalMetrics === "object" &&
isAIMessage(chunk.message)
) {
chunk.message.usage_metadata = {
input_tokens: finalMetrics.inputTokenCount,
output_tokens: finalMetrics.outputTokenCount,
total_tokens:
finalMetrics.inputTokenCount +
finalMetrics.outputTokenCount,
};
}
if (isChatGenerationChunk(chunk)) {
yield chunk;
const bedrockMethod =
provider === "anthropic" ||
provider === "cohere" ||
provider === "meta" ||
provider === "mistral"
? "invoke-with-response-stream"
: "invoke";

const response = await this._signedFetch(messages, options, {
bedrockMethod,
endpointHost,
provider,
});

if (response.status < 200 || response.status >= 300) {
throw Error(
`Failed to access underlying url '${endpointHost}': got ${
response.status
} ${response.statusText}: ${await response.text()}`
);
}

if (
provider === "anthropic" ||
provider === "cohere" ||
provider === "meta" ||
provider === "mistral"
) {
const toolsInParams = !_toolsInParams(options);
const reader = response.body?.getReader();
const decoder = new TextDecoder();
for await (const chunk of this._readChunks(reader)) {
const event = this.codec.decode(chunk);
if (
(event.headers[":event-type"] !== undefined &&
event.headers[":event-type"].value !== "chunk") ||
event.headers[":content-type"].value !== "application/json"
) {
throw Error(`Failed to get event chunk: got ${chunk}`);
}
const body = JSON.parse(decoder.decode(event.body));
if (body.message) {
throw new Error(body.message);
}
if (body.bytes !== undefined) {
const chunkResult = JSON.parse(
decoder.decode(
Uint8Array.from(atob(body.bytes), (m) => m.codePointAt(0) ?? 0)
)
);
if (this.usesMessagesApi) {
const chunk = BedrockLLMInputOutputAdapter.prepareMessagesOutput(
provider,
chunkResult,
{
coerceContentToString: toolsInParams,
}
// eslint-disable-next-line no-void
void runManager?.handleLLMNewToken(chunk.text);
} else {
const text = BedrockLLMInputOutputAdapter.prepareOutput(
provider,
chunkResult
);
yield new ChatGenerationChunk({
text,
message: new AIMessageChunk({ content: text }),
});
// eslint-disable-next-line no-void
void runManager?.handleLLMNewToken(text);
);
if (chunk === undefined) {
continue;
}
if (
provider === "anthropic" &&
chunk.generationInfo?.usage !== undefined
) {
// Avoid bad aggregation in chunks, rely on final Bedrock data
delete chunk.generationInfo.usage;
}
const finalMetrics =
chunk.generationInfo?.["amazon-bedrock-invocationMetrics"];
if (
finalMetrics != null &&
typeof finalMetrics === "object" &&
isAIMessage(chunk.message)
) {
chunk.message.usage_metadata = {
input_tokens: finalMetrics.inputTokenCount,
output_tokens: finalMetrics.outputTokenCount,
total_tokens:
finalMetrics.inputTokenCount + finalMetrics.outputTokenCount,
};
}
if (isChatGenerationChunk(chunk)) {
yield chunk;
}
// eslint-disable-next-line no-void
void runManager?.handleLLMNewToken(chunk.text);
} else {
const text = BedrockLLMInputOutputAdapter.prepareOutput(
provider,
chunkResult
);
yield new ChatGenerationChunk({
text,
message: new AIMessageChunk({ content: text }),
});
// eslint-disable-next-line no-void
void runManager?.handleLLMNewToken(text);
}
}
} else {
const json = await response.json();
const text = BedrockLLMInputOutputAdapter.prepareOutput(provider, json);
yield new ChatGenerationChunk({
text,
message: new AIMessageChunk({ content: text }),
});
// eslint-disable-next-line no-void
void runManager?.handleLLMNewToken(text);
}
} else {
const json = await response.json();
const text = BedrockLLMInputOutputAdapter.prepareOutput(provider, json);
yield new ChatGenerationChunk({
text,
message: new AIMessageChunk({ content: text }),
});
// eslint-disable-next-line no-void
void runManager?.handleLLMNewToken(text);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import { test, expect } from "@jest/globals";
import { HumanMessage } from "@langchain/core/messages";
import { AgentExecutor, createToolCallingAgent } from "langchain/agents";
import { ChatPromptTemplate } from "@langchain/core/prompts";
import { concat } from "@langchain/core/utils/stream";
import { z } from "zod";
import { zodToJsonSchema } from "zod-to-json-schema";
import { BedrockChat as BedrockChatWeb } from "../bedrock/web.js";
Expand Down Expand Up @@ -485,3 +486,46 @@ test.skip(".bindTools with openai tool format", async () => {
const { tool_calls } = response;
expect(tool_calls[0].name.toLowerCase()).toBe("weather_tool");
});

test("Streaming tool calls with Anthropic", async () => {
const weatherTool = z
.object({
city: z.string().describe("The city to get the weather for"),
state: z.string().describe("The state to get the weather for").optional(),
})
.describe("Get the weather for a city");
const model = new BedrockChatWeb({
region: process.env.BEDROCK_AWS_REGION,
model: "anthropic.claude-3-sonnet-20240229-v1:0",
maxRetries: 0,
credentials: {
secretAccessKey: process.env.BEDROCK_AWS_SECRET_ACCESS_KEY!,
accessKeyId: process.env.BEDROCK_AWS_ACCESS_KEY_ID!,
},
});
const modelWithTools = model.bind({
tools: [
{
name: "weather_tool",
description: weatherTool.description,
input_schema: zodToJsonSchema(weatherTool),
},
],
});
const stream = await modelWithTools.stream(
"Whats the weather like in san francisco?"
);
let finalChunk;
for await (const chunk of stream) {
if (finalChunk !== undefined) {
finalChunk = concat(finalChunk, chunk);
} else {
finalChunk = chunk;
}
}
if (finalChunk?.tool_calls?.[0] === undefined) {
throw new Error("No tool calls found in response");
}
expect(finalChunk?.tool_calls?.[0].name).toBe("weather_tool");
expect(finalChunk?.tool_calls?.[0].args?.city).toBeDefined();
});
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
/* eslint-disable no-process-env */
/* eslint-disable @typescript-eslint/no-non-null-assertion */

import { describe, test, expect } from "@jest/globals";
import { Checkpoint, CheckpointTuple } from "@langchain/langgraph";
Expand Down
Loading

0 comments on commit 15fdeea

Please sign in to comment.