Skip to content

Commit

Permalink
feat(core,openai): Adds streaming support for OpenAI withStructuredOu…
Browse files Browse the repository at this point in the history
…tput (#6721)
  • Loading branch information
jacoblee93 authored Sep 10, 2024
1 parent 93fb71f commit b7dfae0
Show file tree
Hide file tree
Showing 5 changed files with 304 additions and 57 deletions.
4 changes: 4 additions & 0 deletions langchain-core/src/messages/ai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,10 @@ export function isAIMessage(x: BaseMessage): x is AIMessage {
return x._getType() === "ai";
}

export function isAIMessageChunk(x: BaseMessageChunk): x is AIMessageChunk {
return x._getType() === "ai";
}

export type AIMessageChunkFields = AIMessageFields & {
tool_call_chunks?: ToolCallChunk[];
};
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
import { z } from "zod";
import { ChatGeneration } from "../../outputs.js";
import { BaseLLMOutputParser, OutputParserException } from "../base.js";
import { ChatGeneration, ChatGenerationChunk } from "../../outputs.js";
import { OutputParserException } from "../base.js";
import { parsePartialJson } from "../json.js";
import { InvalidToolCall, ToolCall } from "../../messages/tool.js";
import {
BaseCumulativeTransformOutputParser,
BaseCumulativeTransformOutputParserInput,
} from "../transform.js";
import { isAIMessage } from "../../messages/ai.js";

export type ParsedToolCall = {
id?: string;
Expand All @@ -23,7 +28,7 @@ export type ParsedToolCall = {
export type JsonOutputToolsParserParams = {
/** Whether to return the tool call id. */
returnId?: boolean;
};
} & BaseCumulativeTransformOutputParserInput;

export function parseToolCall(
// eslint-disable-next-line @typescript-eslint/no-explicit-any
Expand All @@ -35,6 +40,11 @@ export function parseToolCall(
rawToolCall: Record<string, any>,
options?: { returnId?: boolean; partial?: false }
): ToolCall;
export function parseToolCall(
// eslint-disable-next-line @typescript-eslint/no-explicit-any
rawToolCall: Record<string, any>,
options?: { returnId?: boolean; partial?: boolean }
): ToolCall | undefined;
export function parseToolCall(
// eslint-disable-next-line @typescript-eslint/no-explicit-any
rawToolCall: Record<string, any>,
Expand Down Expand Up @@ -112,9 +122,9 @@ export function makeInvalidToolCall(
/**
* Class for parsing the output of a tool-calling LLM into a JSON object.
*/
export class JsonOutputToolsParser extends BaseLLMOutputParser<
ParsedToolCall[]
> {
export class JsonOutputToolsParser<
T
> extends BaseCumulativeTransformOutputParser<T> {
static lc_name() {
return "JsonOutputToolsParser";
}
Expand All @@ -130,31 +140,64 @@ export class JsonOutputToolsParser extends BaseLLMOutputParser<
this.returnId = fields?.returnId ?? this.returnId;
}

protected _diff() {
throw new Error("Not supported.");
}

async parse(): Promise<T> {
throw new Error("Not implemented.");
}

async parseResult(generations: ChatGeneration[]): Promise<T> {
const result = await this.parsePartialResult(generations, false);
return result;
}

/**
* Parses the output and returns a JSON object. If `argsOnly` is true,
* only the arguments of the function call are returned.
* @param generations The output of the LLM to parse.
* @returns A JSON object representation of the function call or its arguments.
*/
async parseResult(generations: ChatGeneration[]): Promise<ParsedToolCall[]> {
const toolCalls = generations[0].message.additional_kwargs.tool_calls;
if (!toolCalls) {
throw new Error(
`No tools_call in message ${JSON.stringify(generations)}`
async parsePartialResult(
generations: ChatGenerationChunk[] | ChatGeneration[],
partial = true
// eslint-disable-next-line @typescript-eslint/no-explicit-any
): Promise<any> {
const message = generations[0].message;
let toolCalls;
if (isAIMessage(message) && message.tool_calls?.length) {
toolCalls = message.tool_calls.map((toolCall) => {
const { id, ...rest } = toolCall;
if (!this.returnId) {
return rest;
}
return {
id,
...rest,
};
});
} else if (message.additional_kwargs.tool_calls !== undefined) {
const rawToolCalls = JSON.parse(
JSON.stringify(message.additional_kwargs.tool_calls)
);
toolCalls = rawToolCalls.map((rawToolCall: Record<string, unknown>) => {
return parseToolCall(rawToolCall, { returnId: this.returnId, partial });
});
}
if (!toolCalls) {
return [];
}
const clonedToolCalls = JSON.parse(JSON.stringify(toolCalls));
const parsedToolCalls = [];
for (const toolCall of clonedToolCalls) {
const parsedToolCall = parseToolCall(toolCall, { partial: true });
if (parsedToolCall !== undefined) {
for (const toolCall of toolCalls) {
if (toolCall !== undefined) {
// backward-compatibility with previous
// versions of Langchain JS, which uses `name` and `arguments`
// @ts-expect-error name and arguemnts are defined by Object.defineProperty
const backwardsCompatibleToolCall: ParsedToolCall = {
type: parsedToolCall.name,
args: parsedToolCall.args,
id: parsedToolCall.id,
type: toolCall.name,
args: toolCall.args,
id: toolCall.id,
};
Object.defineProperty(backwardsCompatibleToolCall, "name", {
get() {
Expand All @@ -180,10 +223,8 @@ export type JsonOutputKeyToolsParserParams<
> = {
keyName: string;
returnSingle?: boolean;
/** Whether to return the tool call id. */
returnId?: boolean;
zodSchema?: z.ZodType<T>;
};
} & JsonOutputToolsParserParams;

/**
* Class for parsing the output of a tool-calling LLM into a JSON object if you are
Expand All @@ -192,7 +233,7 @@ export type JsonOutputKeyToolsParserParams<
export class JsonOutputKeyToolsParser<
// eslint-disable-next-line @typescript-eslint/no-explicit-any
T extends Record<string, any> = Record<string, any>
> extends BaseLLMOutputParser<T> {
> extends JsonOutputToolsParser<T> {
static lc_name() {
return "JsonOutputKeyToolsParser";
}
Expand All @@ -209,15 +250,12 @@ export class JsonOutputKeyToolsParser<
/** Whether to return only the first tool call. */
returnSingle = false;

initialParser: JsonOutputToolsParser;

zodSchema?: z.ZodType<T>;

constructor(params: JsonOutputKeyToolsParserParams<T>) {
super(params);
this.keyName = params.keyName;
this.returnSingle = params.returnSingle ?? this.returnSingle;
this.initialParser = new JsonOutputToolsParser(params);
this.zodSchema = params.zodSchema;
}

Expand All @@ -240,17 +278,45 @@ export class JsonOutputKeyToolsParser<
}
}

// eslint-disable-next-line @typescript-eslint/no-explicit-any
async parsePartialResult(generations: ChatGeneration[]): Promise<any> {
const results = await super.parsePartialResult(generations);
const matchingResults = results.filter(
(result: ParsedToolCall) => result.type === this.keyName
);
// eslint-disable-next-line @typescript-eslint/no-explicit-any
let returnedValues: ParsedToolCall[] | Record<string, any>[] =
matchingResults;
if (!matchingResults.length) {
return undefined;
}
if (!this.returnId) {
returnedValues = matchingResults.map(
(result: ParsedToolCall) => result.args
);
}
if (this.returnSingle) {
return returnedValues[0];
}
return returnedValues;
}

// eslint-disable-next-line @typescript-eslint/no-explicit-any
async parseResult(generations: ChatGeneration[]): Promise<any> {
const results = await this.initialParser.parseResult(generations);
const results = await super.parsePartialResult(generations, false);
const matchingResults = results.filter(
(result) => result.type === this.keyName
(result: ParsedToolCall) => result.type === this.keyName
);
// eslint-disable-next-line @typescript-eslint/no-explicit-any
let returnedValues: ParsedToolCall[] | Record<string, any>[] =
matchingResults;
if (!matchingResults.length) {
return undefined;
}
if (!this.returnId) {
returnedValues = matchingResults.map((result) => result.args);
returnedValues = matchingResults.map(
(result: ParsedToolCall) => result.args
);
}
if (this.returnSingle) {
return this._validateResult(returnedValues[0]);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
/* eslint-disable @typescript-eslint/no-explicit-any */
import { test, expect } from "@jest/globals";
import { z } from "zod";
import { JsonOutputKeyToolsParser } from "../json_output_tools_parsers.js";
import { AIMessage } from "../../../messages/index.js";
import { OutputParserException } from "../../base.js";
import { AIMessage, AIMessageChunk } from "../../../messages/ai.js";
import { RunnableLambda } from "../../../runnables/base.js";

test("JSONOutputKeyToolsParser invoke", async () => {
const outputParser = new JsonOutputKeyToolsParser({
Expand Down Expand Up @@ -87,3 +89,144 @@ test("JSONOutputKeyToolsParser can validate a proper input", async () => {
);
expect(result).toEqual({ testKey: "testval" });
});

test("JSONOutputKeyToolsParser invoke with a top-level tool call", async () => {
const outputParser = new JsonOutputKeyToolsParser({
keyName: "testing",
returnSingle: true,
});
const result = await outputParser.invoke(
new AIMessage({
content: "",
tool_calls: [
{
id: "test",
name: "testing",
args: { testKey: 9 },
},
],
})
);
expect(result).toEqual({ testKey: 9 });
});

test("JSONOutputKeyToolsParser with a top-level tool call and passed schema throws", async () => {
const outputParser = new JsonOutputKeyToolsParser({
keyName: "testing",
returnSingle: true,
zodSchema: z.object({
testKey: z.string(),
}),
});
try {
await outputParser.invoke(
new AIMessage({
content: "",
tool_calls: [
{
id: "test",
name: "testing",
args: { testKey: 9 },
},
],
})
);
} catch (e) {
expect(e).toBeInstanceOf(OutputParserException);
}
});

test("JSONOutputKeyToolsParser with a top-level tool call can validate a proper input", async () => {
const outputParser = new JsonOutputKeyToolsParser({
keyName: "testing",
returnSingle: true,
zodSchema: z.object({
testKey: z.string(),
}),
});
const result = await outputParser.invoke(
new AIMessage({
content: "",
tool_calls: [
{
id: "test",
name: "testing",
args: { testKey: "testval" },
},
],
})
);
expect(result).toEqual({ testKey: "testval" });
});

test("JSONOutputKeyToolsParser can handle streaming input", async () => {
const outputParser = new JsonOutputKeyToolsParser({
keyName: "testing",
returnSingle: true,
zodSchema: z.object({
testKey: z.string(),
}),
});
const fakeModel = RunnableLambda.from(async function* () {
yield new AIMessageChunk({
content: "",
tool_call_chunks: [
{
index: 0,
id: "test",
name: "testing",
args: `{ "testKey":`,
type: "tool_call_chunk",
},
],
});
yield new AIMessageChunk({
content: "",
tool_call_chunks: [],
});
yield new AIMessageChunk({
content: "",
tool_call_chunks: [
{
index: 0,
id: "test",
args: ` "testv`,
type: "tool_call_chunk",
},
],
});
yield new AIMessageChunk({
content: "",
tool_call_chunks: [
{
index: 0,
id: "test",
args: `al" }`,
type: "tool_call_chunk",
},
],
});
});
const stream = await (fakeModel as any).pipe(outputParser).stream();
const chunks = [];
for await (const chunk of stream) {
chunks.push(chunk);
}
expect(chunks.length).toBeGreaterThan(1);
expect(chunks.at(-1)).toEqual({ testKey: "testval" });
// TODO: Fix typing issue
const result = await (fakeModel as any).pipe(outputParser).invoke(
new AIMessage({
content: "",
tool_calls: [
{
id: "test",
name: "testing",
args: { testKey: "testval" },
type: "tool_call",
},
],
})
);
expect(result).toEqual({ testKey: "testval" });
});
4 changes: 4 additions & 0 deletions langchain-core/src/output_parsers/transform.ts
Original file line number Diff line number Diff line change
Expand Up @@ -134,4 +134,8 @@ export abstract class BaseCumulativeTransformOutputParser<
}
}
}

getFormatInstructions(): string {
return "";
}
}
Loading

0 comments on commit b7dfae0

Please sign in to comment.