diff --git a/langchain-core/src/messages/ai.ts b/langchain-core/src/messages/ai.ts index 9c542af4e28b..c8a5311f040c 100644 --- a/langchain-core/src/messages/ai.ts +++ b/langchain-core/src/messages/ai.ts @@ -143,6 +143,10 @@ export function isAIMessage(x: BaseMessage): x is AIMessage { return x._getType() === "ai"; } +export function isAIMessageChunk(x: BaseMessageChunk): x is AIMessageChunk { + return x._getType() === "ai"; +} + export type AIMessageChunkFields = AIMessageFields & { tool_call_chunks?: ToolCallChunk[]; }; diff --git a/langchain-core/src/output_parsers/openai_tools/json_output_tools_parsers.ts b/langchain-core/src/output_parsers/openai_tools/json_output_tools_parsers.ts index fad6f82206a5..fe7987617d60 100644 --- a/langchain-core/src/output_parsers/openai_tools/json_output_tools_parsers.ts +++ b/langchain-core/src/output_parsers/openai_tools/json_output_tools_parsers.ts @@ -1,8 +1,13 @@ import { z } from "zod"; -import { ChatGeneration } from "../../outputs.js"; -import { BaseLLMOutputParser, OutputParserException } from "../base.js"; +import { ChatGeneration, ChatGenerationChunk } from "../../outputs.js"; +import { OutputParserException } from "../base.js"; import { parsePartialJson } from "../json.js"; import { InvalidToolCall, ToolCall } from "../../messages/tool.js"; +import { + BaseCumulativeTransformOutputParser, + BaseCumulativeTransformOutputParserInput, +} from "../transform.js"; +import { isAIMessage } from "../../messages/ai.js"; export type ParsedToolCall = { id?: string; @@ -23,7 +28,7 @@ export type ParsedToolCall = { export type JsonOutputToolsParserParams = { /** Whether to return the tool call id. */ returnId?: boolean; -}; +} & BaseCumulativeTransformOutputParserInput; export function parseToolCall( // eslint-disable-next-line @typescript-eslint/no-explicit-any @@ -35,6 +40,11 @@ export function parseToolCall( rawToolCall: Record, options?: { returnId?: boolean; partial?: false } ): ToolCall; +export function parseToolCall( + // eslint-disable-next-line @typescript-eslint/no-explicit-any + rawToolCall: Record, + options?: { returnId?: boolean; partial?: boolean } +): ToolCall | undefined; export function parseToolCall( // eslint-disable-next-line @typescript-eslint/no-explicit-any rawToolCall: Record, @@ -112,9 +122,9 @@ export function makeInvalidToolCall( /** * Class for parsing the output of a tool-calling LLM into a JSON object. */ -export class JsonOutputToolsParser extends BaseLLMOutputParser< - ParsedToolCall[] -> { +export class JsonOutputToolsParser< + T +> extends BaseCumulativeTransformOutputParser { static lc_name() { return "JsonOutputToolsParser"; } @@ -130,31 +140,64 @@ export class JsonOutputToolsParser extends BaseLLMOutputParser< this.returnId = fields?.returnId ?? this.returnId; } + protected _diff() { + throw new Error("Not supported."); + } + + async parse(): Promise { + throw new Error("Not implemented."); + } + + async parseResult(generations: ChatGeneration[]): Promise { + const result = await this.parsePartialResult(generations, false); + return result; + } + /** * Parses the output and returns a JSON object. If `argsOnly` is true, * only the arguments of the function call are returned. * @param generations The output of the LLM to parse. * @returns A JSON object representation of the function call or its arguments. */ - async parseResult(generations: ChatGeneration[]): Promise { - const toolCalls = generations[0].message.additional_kwargs.tool_calls; - if (!toolCalls) { - throw new Error( - `No tools_call in message ${JSON.stringify(generations)}` + async parsePartialResult( + generations: ChatGenerationChunk[] | ChatGeneration[], + partial = true + // eslint-disable-next-line @typescript-eslint/no-explicit-any + ): Promise { + const message = generations[0].message; + let toolCalls; + if (isAIMessage(message) && message.tool_calls?.length) { + toolCalls = message.tool_calls.map((toolCall) => { + const { id, ...rest } = toolCall; + if (!this.returnId) { + return rest; + } + return { + id, + ...rest, + }; + }); + } else if (message.additional_kwargs.tool_calls !== undefined) { + const rawToolCalls = JSON.parse( + JSON.stringify(message.additional_kwargs.tool_calls) ); + toolCalls = rawToolCalls.map((rawToolCall: Record) => { + return parseToolCall(rawToolCall, { returnId: this.returnId, partial }); + }); + } + if (!toolCalls) { + return []; } - const clonedToolCalls = JSON.parse(JSON.stringify(toolCalls)); const parsedToolCalls = []; - for (const toolCall of clonedToolCalls) { - const parsedToolCall = parseToolCall(toolCall, { partial: true }); - if (parsedToolCall !== undefined) { + for (const toolCall of toolCalls) { + if (toolCall !== undefined) { // backward-compatibility with previous // versions of Langchain JS, which uses `name` and `arguments` // @ts-expect-error name and arguemnts are defined by Object.defineProperty const backwardsCompatibleToolCall: ParsedToolCall = { - type: parsedToolCall.name, - args: parsedToolCall.args, - id: parsedToolCall.id, + type: toolCall.name, + args: toolCall.args, + id: toolCall.id, }; Object.defineProperty(backwardsCompatibleToolCall, "name", { get() { @@ -180,10 +223,8 @@ export type JsonOutputKeyToolsParserParams< > = { keyName: string; returnSingle?: boolean; - /** Whether to return the tool call id. */ - returnId?: boolean; zodSchema?: z.ZodType; -}; +} & JsonOutputToolsParserParams; /** * Class for parsing the output of a tool-calling LLM into a JSON object if you are @@ -192,7 +233,7 @@ export type JsonOutputKeyToolsParserParams< export class JsonOutputKeyToolsParser< // eslint-disable-next-line @typescript-eslint/no-explicit-any T extends Record = Record -> extends BaseLLMOutputParser { +> extends JsonOutputToolsParser { static lc_name() { return "JsonOutputKeyToolsParser"; } @@ -209,15 +250,12 @@ export class JsonOutputKeyToolsParser< /** Whether to return only the first tool call. */ returnSingle = false; - initialParser: JsonOutputToolsParser; - zodSchema?: z.ZodType; constructor(params: JsonOutputKeyToolsParserParams) { super(params); this.keyName = params.keyName; this.returnSingle = params.returnSingle ?? this.returnSingle; - this.initialParser = new JsonOutputToolsParser(params); this.zodSchema = params.zodSchema; } @@ -240,17 +278,45 @@ export class JsonOutputKeyToolsParser< } } + // eslint-disable-next-line @typescript-eslint/no-explicit-any + async parsePartialResult(generations: ChatGeneration[]): Promise { + const results = await super.parsePartialResult(generations); + const matchingResults = results.filter( + (result: ParsedToolCall) => result.type === this.keyName + ); + // eslint-disable-next-line @typescript-eslint/no-explicit-any + let returnedValues: ParsedToolCall[] | Record[] = + matchingResults; + if (!matchingResults.length) { + return undefined; + } + if (!this.returnId) { + returnedValues = matchingResults.map( + (result: ParsedToolCall) => result.args + ); + } + if (this.returnSingle) { + return returnedValues[0]; + } + return returnedValues; + } + // eslint-disable-next-line @typescript-eslint/no-explicit-any async parseResult(generations: ChatGeneration[]): Promise { - const results = await this.initialParser.parseResult(generations); + const results = await super.parsePartialResult(generations, false); const matchingResults = results.filter( - (result) => result.type === this.keyName + (result: ParsedToolCall) => result.type === this.keyName ); // eslint-disable-next-line @typescript-eslint/no-explicit-any let returnedValues: ParsedToolCall[] | Record[] = matchingResults; + if (!matchingResults.length) { + return undefined; + } if (!this.returnId) { - returnedValues = matchingResults.map((result) => result.args); + returnedValues = matchingResults.map( + (result: ParsedToolCall) => result.args + ); } if (this.returnSingle) { return this._validateResult(returnedValues[0]); diff --git a/langchain-core/src/output_parsers/openai_tools/tests/json_output_tools_parser.test.ts b/langchain-core/src/output_parsers/openai_tools/tests/json_output_tools_parser.test.ts index bffe7e0249af..131a4d59dd39 100644 --- a/langchain-core/src/output_parsers/openai_tools/tests/json_output_tools_parser.test.ts +++ b/langchain-core/src/output_parsers/openai_tools/tests/json_output_tools_parser.test.ts @@ -1,8 +1,10 @@ +/* eslint-disable @typescript-eslint/no-explicit-any */ import { test, expect } from "@jest/globals"; import { z } from "zod"; import { JsonOutputKeyToolsParser } from "../json_output_tools_parsers.js"; -import { AIMessage } from "../../../messages/index.js"; import { OutputParserException } from "../../base.js"; +import { AIMessage, AIMessageChunk } from "../../../messages/ai.js"; +import { RunnableLambda } from "../../../runnables/base.js"; test("JSONOutputKeyToolsParser invoke", async () => { const outputParser = new JsonOutputKeyToolsParser({ @@ -87,3 +89,144 @@ test("JSONOutputKeyToolsParser can validate a proper input", async () => { ); expect(result).toEqual({ testKey: "testval" }); }); + +test("JSONOutputKeyToolsParser invoke with a top-level tool call", async () => { + const outputParser = new JsonOutputKeyToolsParser({ + keyName: "testing", + returnSingle: true, + }); + const result = await outputParser.invoke( + new AIMessage({ + content: "", + tool_calls: [ + { + id: "test", + name: "testing", + args: { testKey: 9 }, + }, + ], + }) + ); + expect(result).toEqual({ testKey: 9 }); +}); + +test("JSONOutputKeyToolsParser with a top-level tool call and passed schema throws", async () => { + const outputParser = new JsonOutputKeyToolsParser({ + keyName: "testing", + returnSingle: true, + zodSchema: z.object({ + testKey: z.string(), + }), + }); + try { + await outputParser.invoke( + new AIMessage({ + content: "", + tool_calls: [ + { + id: "test", + name: "testing", + args: { testKey: 9 }, + }, + ], + }) + ); + } catch (e) { + expect(e).toBeInstanceOf(OutputParserException); + } +}); + +test("JSONOutputKeyToolsParser with a top-level tool call can validate a proper input", async () => { + const outputParser = new JsonOutputKeyToolsParser({ + keyName: "testing", + returnSingle: true, + zodSchema: z.object({ + testKey: z.string(), + }), + }); + const result = await outputParser.invoke( + new AIMessage({ + content: "", + tool_calls: [ + { + id: "test", + name: "testing", + args: { testKey: "testval" }, + }, + ], + }) + ); + expect(result).toEqual({ testKey: "testval" }); +}); + +test("JSONOutputKeyToolsParser can handle streaming input", async () => { + const outputParser = new JsonOutputKeyToolsParser({ + keyName: "testing", + returnSingle: true, + zodSchema: z.object({ + testKey: z.string(), + }), + }); + const fakeModel = RunnableLambda.from(async function* () { + yield new AIMessageChunk({ + content: "", + tool_call_chunks: [ + { + index: 0, + id: "test", + name: "testing", + args: `{ "testKey":`, + type: "tool_call_chunk", + }, + ], + }); + yield new AIMessageChunk({ + content: "", + tool_call_chunks: [], + }); + yield new AIMessageChunk({ + content: "", + tool_call_chunks: [ + { + index: 0, + id: "test", + args: ` "testv`, + type: "tool_call_chunk", + }, + ], + }); + yield new AIMessageChunk({ + content: "", + tool_call_chunks: [ + { + index: 0, + id: "test", + args: `al" }`, + type: "tool_call_chunk", + }, + ], + }); + }); + const stream = await (fakeModel as any).pipe(outputParser).stream(); + const chunks = []; + for await (const chunk of stream) { + chunks.push(chunk); + } + expect(chunks.length).toBeGreaterThan(1); + expect(chunks.at(-1)).toEqual({ testKey: "testval" }); + // TODO: Fix typing issue + const result = await (fakeModel as any).pipe(outputParser).invoke( + new AIMessage({ + content: "", + tool_calls: [ + { + id: "test", + name: "testing", + args: { testKey: "testval" }, + type: "tool_call", + }, + ], + }) + ); + expect(result).toEqual({ testKey: "testval" }); +}); diff --git a/langchain-core/src/output_parsers/transform.ts b/langchain-core/src/output_parsers/transform.ts index 64647e10feb1..384c3285f74c 100644 --- a/langchain-core/src/output_parsers/transform.ts +++ b/langchain-core/src/output_parsers/transform.ts @@ -134,4 +134,8 @@ export abstract class BaseCumulativeTransformOutputParser< } } } + + getFormatInstructions(): string { + return ""; + } } diff --git a/libs/langchain-openai/src/tests/chat_models_structured_output.int.test.ts b/libs/langchain-openai/src/tests/chat_models_structured_output.int.test.ts index c6a6c97dbeda..bb3057dbe7d0 100644 --- a/libs/langchain-openai/src/tests/chat_models_structured_output.int.test.ts +++ b/libs/langchain-openai/src/tests/chat_models_structured_output.int.test.ts @@ -25,10 +25,8 @@ test("withStructuredOutput zod schema function calling", async () => { ); const prompt = ChatPromptTemplate.fromMessages([ - "system", - "You are VERY bad at math and must always use a calculator.", - "human", - "Please help me!! What is 2 + 2?", + ["system", "You are VERY bad at math and must always use a calculator."], + ["human", "Please help me!! What is 2 + 2?"], ]); const chain = prompt.pipe(modelWithStructuredOutput); const result = await chain.invoke({}); @@ -38,6 +36,41 @@ test("withStructuredOutput zod schema function calling", async () => { expect("number2" in result).toBe(true); }); +test("withStructuredOutput zod schema streaming", async () => { + const model = new ChatOpenAI({ + temperature: 0, + modelName: "gpt-4-turbo-preview", + }); + + const calculatorSchema = z.object({ + operation: z.enum(["add", "subtract", "multiply", "divide"]), + number1: z.number(), + number2: z.number(), + }); + const modelWithStructuredOutput = model.withStructuredOutput( + calculatorSchema, + { + name: "calculator", + } + ); + + const prompt = ChatPromptTemplate.fromMessages([ + ["system", "You are VERY bad at math and must always use a calculator."], + ["human", "Please help me!! What is 2 + 2?"], + ]); + const chain = prompt.pipe(modelWithStructuredOutput); + const stream = await chain.stream({}); + const chunks = []; + for await (const chunk of stream) { + chunks.push(chunk); + } + expect(chunks.length).toBeGreaterThan(1); + const result = chunks.at(-1) ?? {}; + expect("operation" in result).toBe(true); + expect("number1" in result).toBe(true); + expect("number2" in result).toBe(true); +}); + test("withStructuredOutput zod schema JSON mode", async () => { const model = new ChatOpenAI({ temperature: 0, @@ -58,15 +91,16 @@ test("withStructuredOutput zod schema JSON mode", async () => { ); const prompt = ChatPromptTemplate.fromMessages([ - "system", - `You are VERY bad at math and must always use a calculator. + [ + "system", + `You are VERY bad at math and must always use a calculator. Respond with a JSON object containing three keys: 'operation': the type of operation to execute, either 'add', 'subtract', 'multiply' or 'divide', 'number1': the first number to operate on, 'number2': the second number to operate on. `, - "human", - "Please help me!! What is 2 + 2?", + ], + ["human", "Please help me!! What is 2 + 2?"], ]); const chain = prompt.pipe(modelWithStructuredOutput); const result = await chain.invoke({}); @@ -93,10 +127,8 @@ test("withStructuredOutput JSON schema function calling", async () => { }); const prompt = ChatPromptTemplate.fromMessages([ - "system", - `You are VERY bad at math and must always use a calculator.`, - "human", - "Please help me!! What is 2 + 2?", + ["system", `You are VERY bad at math and must always use a calculator.`], + ["human", "Please help me!! What is 2 + 2?"], ]); const chain = prompt.pipe(modelWithStructuredOutput); const result = await chain.invoke({}); @@ -123,10 +155,8 @@ test("withStructuredOutput OpenAI function definition function calling", async ( }); const prompt = ChatPromptTemplate.fromMessages([ - "system", - `You are VERY bad at math and must always use a calculator.`, - "human", - "Please help me!! What is 2 + 2?", + ["system", `You are VERY bad at math and must always use a calculator.`], + ["human", "Please help me!! What is 2 + 2?"], ]); const chain = prompt.pipe(modelWithStructuredOutput); const result = await chain.invoke({}); @@ -156,15 +186,16 @@ test("withStructuredOutput JSON schema JSON mode", async () => { ); const prompt = ChatPromptTemplate.fromMessages([ - "system", - `You are VERY bad at math and must always use a calculator. + [ + "system", + `You are VERY bad at math and must always use a calculator. Respond with a JSON object containing three keys: 'operation': the type of operation to execute, either 'add', 'subtract', 'multiply' or 'divide', 'number1': the first number to operate on, 'number2': the second number to operate on. `, - "human", - "Please help me!! What is 2 + 2?", + ], + ["human", "Please help me!! What is 2 + 2?"], ]); const chain = prompt.pipe(modelWithStructuredOutput); const result = await chain.invoke({}); @@ -196,15 +227,16 @@ test("withStructuredOutput JSON schema", async () => { const modelWithStructuredOutput = model.withStructuredOutput(jsonSchema); const prompt = ChatPromptTemplate.fromMessages([ - "system", - `You are VERY bad at math and must always use a calculator. + [ + "system", + `You are VERY bad at math and must always use a calculator. Respond with a JSON object containing three keys: 'operation': the type of operation to execute, either 'add', 'subtract', 'multiply' or 'divide', 'number1': the first number to operate on, 'number2': the second number to operate on. `, - "human", - "Please help me!! What is 2 + 2?", + ], + ["human", "Please help me!! What is 2 + 2?"], ]); const chain = prompt.pipe(modelWithStructuredOutput); const result = await chain.invoke({}); @@ -234,10 +266,8 @@ test("withStructuredOutput includeRaw true", async () => { ); const prompt = ChatPromptTemplate.fromMessages([ - "system", - "You are VERY bad at math and must always use a calculator.", - "human", - "Please help me!! What is 2 + 2?", + ["system", "You are VERY bad at math and must always use a calculator."], + ["human", "Please help me!! What is 2 + 2?"], ]); const chain = prompt.pipe(modelWithStructuredOutput); const result = await chain.invoke({});