Skip to content

Commit

Permalink
aws[minor]: Implement WSO with tool_choice (#6443)
Browse files Browse the repository at this point in the history
* aws[minor]: Implement WSO with tool_choice

* cr

* chore: lint files

* cr

* add more tests

* chore: lint files

* allow users to pass tool choice supported values
  • Loading branch information
bracesproul authored Aug 7, 2024
1 parent 937d1fe commit 7321ca0
Show file tree
Hide file tree
Showing 4 changed files with 431 additions and 12 deletions.
1 change: 1 addition & 0 deletions libs/langchain-aws/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
"@aws-sdk/client-kendra": "^3.352.0",
"@aws-sdk/credential-provider-node": "^3.600.0",
"@langchain/core": ">=0.2.21 <0.3.0",
"zod": "^3.23.8",
"zod-to-json-schema": "^3.22.5"
},
"devDependencies": {
Expand Down
194 changes: 187 additions & 7 deletions libs/langchain-aws/src/chat_models.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import type { BaseMessage } from "@langchain/core/messages";
import { AIMessageChunk } from "@langchain/core/messages";
import type { BaseLanguageModelInput } from "@langchain/core/language_models/base";
import type {
BaseLanguageModelInput,
StructuredOutputMethodOptions,
ToolDefinition,
} from "@langchain/core/language_models/base";
import { CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager";
import {
type BaseChatModelParams,
Expand All @@ -24,12 +28,15 @@ import {
DefaultProviderInit,
} from "@aws-sdk/credential-provider-node";
import type { DocumentType as __DocumentType } from "@smithy/types";
import { Runnable } from "@langchain/core/runnables";
import {
ChatBedrockConverseToolType,
ConverseCommandParams,
CredentialType,
} from "./types.js";
Runnable,
RunnableLambda,
RunnablePassthrough,
RunnableSequence,
} from "@langchain/core/runnables";
import { zodToJsonSchema } from "zod-to-json-schema";
import { isZodSchema } from "@langchain/core/utils/types";
import { z } from "zod";
import {
convertToConverseTools,
convertToBedrockToolChoice,
Expand All @@ -40,6 +47,11 @@ import {
handleConverseStreamContentBlockStart,
BedrockConverseToolChoice,
} from "./common.js";
import {
ChatBedrockConverseToolType,
ConverseCommandParams,
CredentialType,
} from "./types.js";

/**
* Inputs for ChatBedrockConverse.
Expand Down Expand Up @@ -120,6 +132,14 @@ export interface ChatBedrockConverseInput
* Configuration information for a guardrail that you want to use in the request.
*/
guardrailConfig?: GuardrailConfiguration;

/**
* Which types of `tool_choice` values the model supports.
*
* Inferred if not specified. Inferred as ['auto', 'any', 'tool'] if a 'claude-3'
* model is used, ['auto', 'any'] if a 'mistral-large' model is used, empty otherwise.
*/
supportsToolChoiceValues?: Array<"auto" | "any" | "tool">;
}

export interface ChatBedrockConverseCallOptions
Expand Down Expand Up @@ -214,6 +234,14 @@ export class ChatBedrockConverse

client: BedrockRuntimeClient;

/**
* Which types of `tool_choice` values the model supports.
*
* Inferred if not specified. Inferred as ['auto', 'any', 'tool'] if a 'claude-3'
* model is used, ['auto', 'any'] if a 'mistral-large' model is used, empty otherwise.
*/
supportsToolChoiceValues?: Array<"auto" | "any" | "tool">;

constructor(fields?: ChatBedrockConverseInput) {
super(fields ?? {});
const {
Expand Down Expand Up @@ -264,6 +292,18 @@ export class ChatBedrockConverse
this.additionalModelRequestFields = rest?.additionalModelRequestFields;
this.streamUsage = rest?.streamUsage ?? this.streamUsage;
this.guardrailConfig = rest?.guardrailConfig;

if (rest?.supportsToolChoiceValues === undefined) {
if (this.model.includes("claude-3")) {
this.supportsToolChoiceValues = ["auto", "any", "tool"];
} else if (this.model.includes("mistral-large")) {
this.supportsToolChoiceValues = ["auto", "any"];
} else {
this.supportsToolChoiceValues = undefined;
}
} else {
this.supportsToolChoiceValues = rest.supportsToolChoiceValues;
}
}

getLsParams(options: this["ParsedCallOptions"]): LangSmithParams {
Expand Down Expand Up @@ -303,7 +343,10 @@ export class ChatBedrockConverse
toolConfig = {
tools,
toolChoice: options.tool_choice
? convertToBedrockToolChoice(options.tool_choice, tools)
? convertToBedrockToolChoice(options.tool_choice, tools, {
model: this.model,
supportsToolChoiceValues: this.supportsToolChoiceValues,
})
: undefined,
};
}
Expand Down Expand Up @@ -430,4 +473,141 @@ export class ChatBedrockConverse
}
}
}

withStructuredOutput<
// eslint-disable-next-line @typescript-eslint/no-explicit-any
RunOutput extends Record<string, any> = Record<string, any>
// eslint-disable-next-line @typescript-eslint/no-explicit-any
>(
outputSchema:
| z.ZodType<RunOutput>
// eslint-disable-next-line @typescript-eslint/no-explicit-any
| Record<string, any>,
config?: StructuredOutputMethodOptions<false>
): Runnable<BaseLanguageModelInput, RunOutput>;

withStructuredOutput<
// eslint-disable-next-line @typescript-eslint/no-explicit-any
RunOutput extends Record<string, any> = Record<string, any>
>(
outputSchema:
| z.ZodType<RunOutput>
// eslint-disable-next-line @typescript-eslint/no-explicit-any
| Record<string, any>,
config?: StructuredOutputMethodOptions<true>
): Runnable<BaseLanguageModelInput, { raw: BaseMessage; parsed: RunOutput }>;

withStructuredOutput<
// eslint-disable-next-line @typescript-eslint/no-explicit-any
RunOutput extends Record<string, any> = Record<string, any>
>(
outputSchema:
| z.ZodType<RunOutput>
// eslint-disable-next-line @typescript-eslint/no-explicit-any
| Record<string, any>,
config?: StructuredOutputMethodOptions<boolean>
):
| Runnable<BaseLanguageModelInput, RunOutput>
| Runnable<
BaseLanguageModelInput,
{
raw: BaseMessage;
parsed: RunOutput;
}
> {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const schema: z.ZodType<RunOutput> | Record<string, any> = outputSchema;
const name = config?.name;
const description = schema.description ?? "A function available to call.";
const method = config?.method;
const includeRaw = config?.includeRaw;
if (method === "jsonMode") {
throw new Error(`ChatBedrockConverse does not support 'jsonMode'.`);
}

let functionName = name ?? "extract";
let tools: ToolDefinition[];
if (isZodSchema(schema)) {
tools = [
{
type: "function",
function: {
name: functionName,
description,
parameters: zodToJsonSchema(schema),
},
},
];
} else {
if ("name" in schema) {
functionName = schema.name;
}
tools = [
{
type: "function",
function: {
name: functionName,
description,
parameters: schema,
},
},
];
}

const supportsToolChoiceValues = this.supportsToolChoiceValues ?? [];
let toolChoiceObj: { tool_choice: string } | undefined;
if (supportsToolChoiceValues.includes("tool")) {
toolChoiceObj = {
tool_choice: tools[0].function.name,
};
} else if (supportsToolChoiceValues.includes("any")) {
toolChoiceObj = {
tool_choice: "any",
};
}

const llm = this.bindTools(tools, toolChoiceObj);
const outputParser = RunnableLambda.from<AIMessageChunk, RunOutput>(
(input: AIMessageChunk): RunOutput => {
if (!input.tool_calls || input.tool_calls.length === 0) {
throw new Error("No tool calls found in the response.");
}
const toolCall = input.tool_calls.find(
(tc) => tc.name === functionName
);
if (!toolCall) {
throw new Error(`No tool call found with name ${functionName}.`);
}
return toolCall.args as RunOutput;
}
);

if (!includeRaw) {
return llm.pipe(outputParser).withConfig({
runName: "StructuredOutput",
}) as Runnable<BaseLanguageModelInput, RunOutput>;
}

const parserAssign = RunnablePassthrough.assign({
// eslint-disable-next-line @typescript-eslint/no-explicit-any
parsed: (input: any, config) => outputParser.invoke(input.raw, config),
});
const parserNone = RunnablePassthrough.assign({
parsed: () => null,
});
const parsedWithFallback = parserAssign.withFallbacks({
fallbacks: [parserNone],
});
return RunnableSequence.from<
BaseLanguageModelInput,
{ raw: BaseMessage; parsed: RunOutput }
>([
{
raw: llm,
},
parsedWithFallback,
]).withConfig({
runName: "StructuredOutputRunnable",
});
}
}
44 changes: 39 additions & 5 deletions libs/langchain-aws/src/common.ts
Original file line number Diff line number Diff line change
Expand Up @@ -271,18 +271,27 @@ export type BedrockConverseToolChoice =

export function convertToBedrockToolChoice(
toolChoice: BedrockConverseToolChoice,
tools: BedrockTool[]
tools: BedrockTool[],
fields: {
model: string;
supportsToolChoiceValues?: Array<"auto" | "any" | "tool">;
}
): BedrockToolChoice {
const supportsToolChoiceValues = fields.supportsToolChoiceValues ?? [];

let bedrockToolChoice: BedrockToolChoice;
if (typeof toolChoice === "string") {
switch (toolChoice) {
case "any":
return {
bedrockToolChoice = {
any: {},
};
break;
case "auto":
return {
bedrockToolChoice = {
auto: {},
};
break;
default: {
const foundTool = tools.find(
(tool) => tool.toolSpec?.name === toolChoice
Expand All @@ -292,15 +301,40 @@ export function convertToBedrockToolChoice(
`Tool with name ${toolChoice} not found in tools list.`
);
}
return {
bedrockToolChoice = {
tool: {
name: toolChoice,
},
};
}
}
} else {
bedrockToolChoice = toolChoice;
}

const toolChoiceType = Object.keys(bedrockToolChoice)[0] as
| "auto"
| "any"
| "tool";
if (!supportsToolChoiceValues.includes(toolChoiceType)) {
let supportedTxt = "";
if (supportsToolChoiceValues.length) {
supportedTxt =
`Model ${fields.model} does not currently support 'tool_choice' ` +
`of type ${toolChoiceType}. The following 'tool_choice' types ` +
`are supported: ${supportsToolChoiceValues.join(", ")}.`;
} else {
supportedTxt = `Model ${fields.model} does not currently support 'tool_choice'.`;
}

throw new Error(
`${supportedTxt} Please see` +
"https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolChoice.html" +
"for the latest documentation on models that support tool choice."
);
}
return toolChoice;

return bedrockToolChoice;
}

export function convertConverseMessageToLangChainMessage(
Expand Down
Loading

0 comments on commit 7321ca0

Please sign in to comment.