Skip to content

Commit

Permalink
langchain[patch]: Make AgentExecutor pass config object through to to…
Browse files Browse the repository at this point in the history
…ols (#4436)

* Make AgentExecutor pass config object through to tools

* Formatting
  • Loading branch information
jacoblee93 committed Feb 16, 2024
1 parent c12bc7f commit 8c918dc
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 6 deletions.
14 changes: 11 additions & 3 deletions langchain/src/agents/executor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@ import {
ToolInputParsingException,
Tool,
} from "@langchain/core/tools";
import { Runnable, type RunnableConfig } from "@langchain/core/runnables";
import {
Runnable,
type RunnableConfig,
patchConfig,
} from "@langchain/core/runnables";
import { AgentAction, AgentFinish, AgentStep } from "@langchain/core/agents";
import { ChainValues } from "@langchain/core/utils/types";
import {
Expand Down Expand Up @@ -428,7 +432,8 @@ export class AgentExecutor extends BaseChain<ChainValues, AgentExecutorOutput> {
/** @ignore */
async _call(
inputs: ChainValues,
runManager?: CallbackManagerForChainRun
runManager?: CallbackManagerForChainRun,
config?: RunnableConfig
): Promise<AgentExecutorOutput> {
const toolsByName = Object.fromEntries(
this.tools.map((t) => [t.name.toLowerCase(), t])
Expand Down Expand Up @@ -511,7 +516,10 @@ export class AgentExecutor extends BaseChain<ChainValues, AgentExecutorOutput> {
let observation;
try {
observation = tool
? await tool.call(action.toolInput, runManager?.getChild())
? await tool.invoke(
action.toolInput,
patchConfig(config, { callbacks: runManager?.getChild() })
)
: `${action.tool} is not a valid tool, try another one.`;
} catch (e) {
// eslint-disable-next-line no-instanceof/no-instanceof
Expand Down
7 changes: 4 additions & 3 deletions langchain/src/chains/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -107,14 +107,14 @@ export abstract class BaseChain<
try {
outputValues = await (fullValues.signal
? (Promise.race([
this._call(fullValues as RunInput, runManager),
this._call(fullValues as RunInput, runManager, config),
new Promise((_, reject) => {
fullValues.signal?.addEventListener("abort", () => {
reject(new Error("AbortError"));
});
}),
]) as Promise<RunOutput>)
: this._call(fullValues as RunInput, runManager));
: this._call(fullValues as RunInput, runManager, config));
} catch (e) {
await runManager?.handleChainError(e);
throw e;
Expand Down Expand Up @@ -165,7 +165,8 @@ export abstract class BaseChain<
*/
abstract _call(
values: RunInput,
runManager?: CallbackManagerForChainRun
runManager?: CallbackManagerForChainRun,
config?: RunnableConfig
): Promise<RunOutput>;

/**
Expand Down

0 comments on commit 8c918dc

Please sign in to comment.