Skip to content

Commit

Permalink
openai[minor]: Add support for json schema response format (langchain…
Browse files Browse the repository at this point in the history
…-ai#6438)

* openai[minor]: Add support for json schema response format

* cr

* lowkey I did it?

* add tests and more implementation details

* fix build yo

* cr

* fix build
  • Loading branch information
bracesproul authored and CarterMorris committed Nov 10, 2024
1 parent 7edbf71 commit f3d7e9e
Show file tree
Hide file tree
Showing 6 changed files with 403 additions and 19 deletions.
2 changes: 1 addition & 1 deletion langchain-core/src/language_models/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ export type StructuredOutputType = z.infer<z.ZodObject<any, any, any, any>>;
export type StructuredOutputMethodOptions<IncludeRaw extends boolean = false> =
{
name?: string;
method?: "functionCalling" | "jsonMode";
method?: "functionCalling" | "jsonMode" | "jsonSchema" | string;
includeRaw?: IncludeRaw;
};

Expand Down
2 changes: 1 addition & 1 deletion libs/langchain-openai/langchain.config.js
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ function abs(relativePath) {


export const config = {
internals: [/node\:/, /@langchain\/core\//],
internals: [/node\:/, /@langchain\/core\//, "openai/helpers/zod"],
entrypoints: {
index: "index",
},
Expand Down
154 changes: 139 additions & 15 deletions libs/langchain-openai/src/chat_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ import {
type StructuredOutputMethodParams,
} from "@langchain/core/language_models/base";
import { NewTokenIndices } from "@langchain/core/callbacks/base";
import { convertToOpenAITool } from "@langchain/core/utils/function_calling";
import { z } from "zod";
import {
Runnable,
Expand All @@ -56,12 +55,20 @@ import {
} from "@langchain/core/output_parsers/openai_tools";
import { zodToJsonSchema } from "zod-to-json-schema";
import { ToolCallChunk } from "@langchain/core/messages/tool";
import { zodResponseFormat } from "openai/helpers/zod";
import type {
ResponseFormatText,
ResponseFormatJSONObject,
ResponseFormatJSONSchema,
} from "openai/resources/shared";
import { ParsedChatCompletion } from "openai/resources/beta/chat/completions.mjs";
import type {
AzureOpenAIInput,
OpenAICallOptions,
OpenAIChatInput,
OpenAICoreRequestOptions,
LegacyOpenAIInput,
ChatOpenAIResponseFormat,
} from "./types.js";
import { type OpenAIEndpointConfig, getEndpoint } from "./utils/azure.js";
import {
Expand All @@ -73,6 +80,7 @@ import {
FunctionDef,
formatFunctionDefinitions,
} from "./utils/openai-format-fndef.js";
import { _convertToOpenAITool } from "./utils/tools.js";

export type { AzureOpenAIInput, OpenAICallOptions, OpenAIChatInput };

Expand Down Expand Up @@ -295,7 +303,7 @@ function _convertChatOpenAIToolTypeToOpenAITool(

return tool;
}
return convertToOpenAITool(tool, fields);
return _convertToOpenAITool(tool, fields);
}

export interface ChatOpenAIStructuredOutputMethodOptions<
Expand Down Expand Up @@ -324,7 +332,7 @@ export interface ChatOpenAICallOptions
tools?: ChatOpenAIToolType[];
tool_choice?: OpenAIToolChoice;
promptIndex?: number;
response_format?: { type: "json_object" };
response_format?: ChatOpenAIResponseFormat;
seed?: number;
/**
* Additional options to pass to streamed completions.
Expand Down Expand Up @@ -1027,6 +1035,34 @@ export class ChatOpenAI<
} as Partial<CallOptions>);
}

private createResponseFormat(
resFormat?: CallOptions["response_format"]
):
| ResponseFormatText
| ResponseFormatJSONObject
| ResponseFormatJSONSchema
| undefined {
if (
resFormat &&
resFormat.type === "json_schema" &&
resFormat.json_schema.schema &&
isZodSchema(resFormat.json_schema.schema)
) {
return zodResponseFormat(
resFormat.json_schema.schema,
resFormat.json_schema.name,
{
description: resFormat.json_schema.description,
}
);
}
return resFormat as
| ResponseFormatText
| ResponseFormatJSONObject
| ResponseFormatJSONSchema
| undefined;
}

/**
* Get the parameters used to invoke the model
*/
Expand All @@ -1049,6 +1085,7 @@ export class ChatOpenAI<
} else if (this.streamUsage && (this.streaming || extra?.streaming)) {
streamOptionsConfig = { stream_options: { include_usage: true } };
}

const params: Omit<
OpenAIClient.Chat.ChatCompletionCreateParams,
"messages"
Expand All @@ -1075,7 +1112,7 @@ export class ChatOpenAI<
)
: undefined,
tool_choice: formatToOpenAIToolChoice(options?.tool_choice),
response_format: options?.response_format,
response_format: this.createResponseFormat(options?.response_format),
seed: options?.seed,
...streamOptionsConfig,
parallel_tool_calls: options?.parallel_tool_calls,
Expand Down Expand Up @@ -1113,6 +1150,32 @@ export class ChatOpenAI<
stream: true as const,
};
let defaultRole: OpenAIRoleEnum | undefined;
if (
params.response_format &&
params.response_format.type === "json_schema"
) {
console.warn(
`OpenAI does not yet support streaming with "response_format" set to "json_schema". Falling back to non-streaming mode.`
);
const res = await this._generate(messages, options, runManager);
const chunk = new ChatGenerationChunk({
message: new AIMessageChunk({
...res.generations[0].message,
}),
text: res.generations[0].text,
generationInfo: res.generations[0].generationInfo,
});
yield chunk;
return runManager?.handleLLMNewToken(
res.generations[0].text ?? "",
undefined,
undefined,
undefined,
undefined,
{ chunk }
);
}

const streamIterable = await this.completionWithRetry(params, options);
let usage: OpenAIClient.Completions.CompletionUsage | undefined;
for await (const data of streamIterable) {
Expand Down Expand Up @@ -1248,17 +1311,36 @@ export class ChatOpenAI<
tokenUsage.totalTokens = promptTokenUsage + completionTokenUsage;
return { generations, llmOutput: { estimatedTokenUsage: tokenUsage } };
} else {
const data = await this.completionWithRetry(
{
...params,
stream: false,
messages: messagesMapped,
},
{
signal: options?.signal,
...options?.options,
}
);
let data;
if (
options.response_format &&
options.response_format.type === "json_schema"
) {
data = await this.betaParsedCompletionWithRetry(
{
...params,
stream: false,
messages: messagesMapped,
},
{
signal: options?.signal,
...options?.options,
}
);
} else {
data = await this.completionWithRetry(
{
...params,
stream: false,
messages: messagesMapped,
},
{
signal: options?.signal,
...options?.options,
}
);
}

const {
completion_tokens: completionTokens,
prompt_tokens: promptTokens,
Expand Down Expand Up @@ -1478,6 +1560,31 @@ export class ChatOpenAI<
});
}

/**
* Call the beta chat completions parse endpoint. This should only be called if
* response_format is set to "json_object".
* @param {OpenAIClient.Chat.ChatCompletionCreateParamsNonStreaming} request
* @param {OpenAICoreRequestOptions | undefined} options
*/
async betaParsedCompletionWithRetry(
request: OpenAIClient.Chat.ChatCompletionCreateParamsNonStreaming,
options?: OpenAICoreRequestOptions
): Promise<ParsedChatCompletion<null>> {
const requestOptions = this._getClientOptions(options);
return this.caller.call(async () => {
try {
const res = await this.client.beta.chat.completions.parse(
request,
requestOptions
);
return res;
} catch (e) {
const error = wrapOpenAIClientError(e);
throw error;
}
});
}

protected _getClientOptions(options: OpenAICoreRequestOptions | undefined) {
if (!this.client) {
const openAIEndpointConfig: OpenAIEndpointConfig = {
Expand Down Expand Up @@ -1620,6 +1727,23 @@ export class ChatOpenAI<
} else {
outputParser = new JsonOutputParser<RunOutput>();
}
} else if (method === "jsonSchema") {
llm = this.bind({
response_format: {
type: "json_schema",
json_schema: {
name: name ?? "extract",
description: schema.description,
schema,
strict: config?.strict,
},
},
} as Partial<CallOptions>);
if (isZodSchema(schema)) {
outputParser = StructuredOutputParser.fromZodSchema(schema);
} else {
outputParser = new JsonOutputParser<RunOutput>();
}
} else {
let functionName = name ?? "extract";
// Is function calling
Expand Down
Loading

0 comments on commit f3d7e9e

Please sign in to comment.