Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

core[minor]: Standardize tool choice #6111

Merged
merged 4 commits into from
Jul 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion langchain-core/src/language_models/chat_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ import { concat } from "../utils/stream.js";
import { RunnablePassthrough } from "../runnables/passthrough.js";
import { isZodSchema } from "../utils/types/is_zod_schema.js";

// eslint-disable-next-line @typescript-eslint/no-explicit-any
type ToolChoice = string | Record<string, any> | "auto" | "any";

/**
* Represents a serialized chat model.
*/
Expand Down Expand Up @@ -73,7 +76,23 @@ export type BaseChatModelParams = BaseLanguageModelParams;
/**
* Represents the call options for a base chat model.
*/
export type BaseChatModelCallOptions = BaseLanguageModelCallOptions;
export type BaseChatModelCallOptions = BaseLanguageModelCallOptions & {
/**
* Specifies how the chat model should use tools.
* @default undefined
*
* Possible values:
* - "auto": The model may choose to use any of the provided tools, or none.
* - "any": The model must use one of the provided tools.
* - "none": The model must not use any tools.
* - A string (not "auto", "any", or "none"): The name of a specific tool the model must use.
* - An object: A custom schema specifying tool choice parameters. Specific to the provider.
*
* Note: Not all providers support tool_choice. An error will be thrown
* if used with an unsupporting model.
*/
tool_choice?: ToolChoice;
};

/**
* Creates a transform stream for encoding chat message chunks.
Expand Down
42 changes: 11 additions & 31 deletions libs/langchain-anthropic/src/chat_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@ import {
import { getEnvironmentVariable } from "@langchain/core/utils/env";
import {
BaseChatModel,
BaseChatModelCallOptions,
LangSmithParams,
type BaseChatModelParams,
} from "@langchain/core/language_models/chat_models";
import {
type StructuredOutputMethodOptions,
type BaseLanguageModelCallOptions,
type BaseLanguageModelInput,
type ToolDefinition,
isOpenAITool,
Expand All @@ -53,30 +53,23 @@ import {
extractToolCalls,
} from "./output_parsers.js";
import { AnthropicToolResponse } from "./types.js";
import {
AnthropicToolChoice,
AnthropicToolTypes,
handleToolChoice,
} from "./utils.js";

type AnthropicMessage = Anthropic.MessageParam;
type AnthropicMessageCreateParams = Anthropic.MessageCreateParamsNonStreaming;
type AnthropicStreamingMessageCreateParams =
Anthropic.MessageCreateParamsStreaming;
type AnthropicMessageStreamEvent = Anthropic.MessageStreamEvent;
type AnthropicRequestOptions = Anthropic.RequestOptions;
type AnthropicToolChoice =
| {
type: "tool";
name: string;
}
| "any"
| "auto";

export interface ChatAnthropicCallOptions
extends BaseLanguageModelCallOptions,
extends BaseChatModelCallOptions,
Pick<AnthropicInput, "streamUsage"> {
tools?: (
| StructuredToolInterface
| AnthropicTool
| Record<string, unknown>
| ToolDefinition
| RunnableToolLike
)[];
tools?: AnthropicToolTypes[];
/**
* Whether or not to specify what tool the model should use
* @default "auto"
Expand Down Expand Up @@ -855,24 +848,11 @@ export class ChatAnthropicMessages<
"messages"
> &
Kwargs {
let tool_choice:
const tool_choice:
| MessageCreateParams.ToolChoiceAuto
| MessageCreateParams.ToolChoiceAny
| MessageCreateParams.ToolChoiceTool
| undefined;
if (options?.tool_choice) {
if (options?.tool_choice === "any") {
tool_choice = {
type: "any",
};
} else if (options?.tool_choice === "auto") {
tool_choice = {
type: "auto",
};
} else {
tool_choice = options?.tool_choice;
}
}
| undefined = handleToolChoice(options?.tool_choice);

return {
model: this.model,
Expand Down
51 changes: 51 additions & 0 deletions libs/langchain-anthropic/src/utils.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import type {
MessageCreateParams,
Tool as AnthropicTool,
} from "@anthropic-ai/sdk/resources/index.mjs";
import { ToolDefinition } from "@langchain/core/language_models/base";
import { RunnableToolLike } from "@langchain/core/runnables";
import { StructuredToolInterface } from "@langchain/core/tools";

export type AnthropicToolChoice =
| {
type: "tool";
name: string;
}
| "any"
| "auto"
| "none"
| string;

export type AnthropicToolTypes =
| StructuredToolInterface
| AnthropicTool
| Record<string, unknown>
| ToolDefinition
| RunnableToolLike;

export function handleToolChoice(
toolChoice?: AnthropicToolChoice
):
| MessageCreateParams.ToolChoiceAuto
| MessageCreateParams.ToolChoiceAny
| MessageCreateParams.ToolChoiceTool
| undefined {
if (!toolChoice) {
return undefined;
} else if (toolChoice === "any") {
return {
type: "any",
};
} else if (toolChoice === "auto") {
return {
type: "auto",
};
} else if (typeof toolChoice === "string") {
return {
type: "tool",
name: toolChoice,
};
} else {
return toolChoice;
}
}
13 changes: 5 additions & 8 deletions libs/langchain-aws/src/chat_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@ import type { BaseMessage } from "@langchain/core/messages";
import { AIMessageChunk } from "@langchain/core/messages";
import type {
ToolDefinition,
BaseLanguageModelCallOptions,
BaseLanguageModelInput,
} from "@langchain/core/language_models/base";
import { CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager";
import {
type BaseChatModelParams,
BaseChatModel,
LangSmithParams,
BaseChatModelCallOptions,
} from "@langchain/core/language_models/chat_models";
import type {
ToolConfiguration,
Expand All @@ -30,11 +30,7 @@ import {
import type { DocumentType as __DocumentType } from "@smithy/types";
import { StructuredToolInterface } from "@langchain/core/tools";
import { Runnable, RunnableToolLike } from "@langchain/core/runnables";
import {
BedrockToolChoice,
ConverseCommandParams,
CredentialType,
} from "./types.js";
import { ConverseCommandParams, CredentialType } from "./types.js";
import {
convertToConverseTools,
convertToBedrockToolChoice,
Expand All @@ -43,6 +39,7 @@ import {
handleConverseStreamContentBlockDelta,
handleConverseStreamMetadata,
handleConverseStreamContentBlockStart,
BedrockConverseToolChoice,
} from "./common.js";

/**
Expand Down Expand Up @@ -127,7 +124,7 @@ export interface ChatBedrockConverseInput
}

export interface ChatBedrockConverseCallOptions
extends BaseLanguageModelCallOptions,
extends BaseChatModelCallOptions,
Pick<
ChatBedrockConverseInput,
"additionalModelRequestFields" | "streamUsage"
Expand All @@ -149,7 +146,7 @@ export interface ChatBedrockConverseCallOptions
* or whether to generate text instead.
* If a tool name is passed, it will force the model to call that specific tool.
*/
tool_choice?: "any" | "auto" | string | BedrockToolChoice;
tool_choice?: BedrockConverseToolChoice;
}

/**
Expand Down
8 changes: 7 additions & 1 deletion libs/langchain-aws/src/common.ts
Original file line number Diff line number Diff line change
Expand Up @@ -258,8 +258,14 @@ export function convertToConverseTools(
);
}

export type BedrockConverseToolChoice =
| "any"
| "auto"
| string
| BedrockToolChoice;

export function convertToBedrockToolChoice(
toolChoice: string | BedrockToolChoice,
toolChoice: BedrockConverseToolChoice,
tools: BedrockTool[]
): BedrockToolChoice {
if (typeof toolChoice === "string") {
Expand Down
2 changes: 1 addition & 1 deletion libs/langchain-groq/src/chat_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ import { ToolCallChunk } from "@langchain/core/messages/tool";
export interface ChatGroqCallOptions extends BaseChatModelCallOptions {
headers?: Record<string, string>;
tools?: OpenAIClient.ChatCompletionTool[];
tool_choice?: OpenAIClient.ChatCompletionToolChoiceOption;
tool_choice?: OpenAIClient.ChatCompletionToolChoiceOption | "any" | string;
response_format?: { type: "json_object" };
}

Expand Down
10 changes: 7 additions & 3 deletions libs/langchain-openai/src/chat_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,11 @@ import type {
LegacyOpenAIInput,
} from "./types.js";
import { type OpenAIEndpointConfig, getEndpoint } from "./utils/azure.js";
import { wrapOpenAIClientError } from "./utils/openai.js";
import {
OpenAIToolChoice,
formatToOpenAIToolChoice,
wrapOpenAIClientError,
} from "./utils/openai.js";
import {
FunctionDef,
formatFunctionDefinitions,
Expand Down Expand Up @@ -274,7 +278,7 @@ export interface ChatOpenAICallOptions
extends OpenAICallOptions,
BaseFunctionCallOptions {
tools?: StructuredToolInterface[] | OpenAIClient.ChatCompletionTool[];
tool_choice?: OpenAIClient.ChatCompletionToolChoiceOption;
tool_choice?: OpenAIToolChoice;
promptIndex?: number;
response_format?: { type: "json_object" };
seed?: number;
Expand Down Expand Up @@ -613,7 +617,7 @@ export class ChatOpenAI<
tools: isStructuredToolArray(options?.tools)
? options?.tools.map(convertToOpenAITool)
: options?.tools,
tool_choice: options?.tool_choice,
tool_choice: formatToOpenAIToolChoice(options?.tool_choice),
response_format: options?.response_format,
seed: options?.seed,
...streamOptionsConfig,
Expand Down
34 changes: 33 additions & 1 deletion libs/langchain-openai/src/utils/openai.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
import { APIConnectionTimeoutError, APIUserAbortError } from "openai";
import {
APIConnectionTimeoutError,
APIUserAbortError,
OpenAI as OpenAIClient,
} from "openai";
import { zodToJsonSchema } from "zod-to-json-schema";
import type { StructuredToolInterface } from "@langchain/core/tools";
import {
Expand Down Expand Up @@ -36,3 +40,31 @@ export function formatToOpenAIAssistantTool(tool: StructuredToolInterface) {
},
};
}

export type OpenAIToolChoice =
| OpenAIClient.ChatCompletionToolChoiceOption
| "any"
| string;

export function formatToOpenAIToolChoice(
toolChoice?: OpenAIToolChoice
): OpenAIClient.ChatCompletionToolChoiceOption | undefined {
if (!toolChoice) {
return undefined;
} else if (toolChoice === "any" || toolChoice === "required") {
return "required";
} else if (toolChoice === "auto") {
return "auto";
} else if (toolChoice === "none") {
return "none";
} else if (typeof toolChoice === "string") {
return {
type: "function",
function: {
name: toolChoice,
},
};
} else {
return toolChoice;
}
}
Loading