Skip to content

Commit

Permalink
core[patch]: Treat OpenAI message format as a message like (#6654)
Browse files Browse the repository at this point in the history
  • Loading branch information
jacoblee93 authored Aug 28, 2024
1 parent 0d85e8b commit dc4497f
Show file tree
Hide file tree
Showing 4 changed files with 178 additions and 3 deletions.
58 changes: 58 additions & 0 deletions langchain-core/src/language_models/tests/chat_models.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,64 @@ test("Test ChatModel accepts object shorthand for messages", async () => {
expect(response.content).toEqual("Hello there!");
});

test("Test ChatModel accepts object with role for messages", async () => {
const model = new FakeChatModel({});
const response = await model.invoke([
{
role: "human",
content: "Hello there!!",
example: true,
},
]);
expect(response.content).toEqual("Hello there!!");
});

test("Test ChatModel accepts several messages as objects with role", async () => {
const model = new FakeChatModel({});
const response = await model.invoke([
{
role: "system",
content: "You are an assistant.",
},
{
role: "human",
content: [{ type: "text", text: "What is the weather in SF?" }],
example: true,
},
{
role: "assistant",
content: "",
tool_calls: [
{
id: "call_123",
function: {
name: "get_weather",
arguments: JSON.stringify({ location: "sf" }),
},
type: "function",
},
],
},
{
role: "tool",
content: "Pretty nice right now!",
tool_call_id: "call_123",
},
]);
expect(response.content).toEqual(
[
"You are an assistant.",
JSON.stringify(
[{ type: "text", text: "What is the weather in SF?" }],
null,
2
),
"",
"Pretty nice right now!",
].join("\n")
);
});

test("Test ChatModel uses callbacks", async () => {
const model = new FakeChatModel({});
let acc = "";
Expand Down
13 changes: 13 additions & 0 deletions langchain-core/src/messages/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -450,12 +450,25 @@ export abstract class BaseMessageChunk extends BaseMessage {
abstract concat(chunk: BaseMessageChunk): BaseMessageChunk;
}

export type MessageFieldWithRole = {
role: StringWithAutocomplete<"user" | "assistant" | MessageType>;
content: MessageContent;
name?: string;
} & Record<string, unknown>;

export function _isMessageFieldWithRole(
x: BaseMessageLike
): x is MessageFieldWithRole {
return typeof (x as MessageFieldWithRole).role === "string";
}

export type BaseMessageLike =
| BaseMessage
| ({
type: MessageType | "user" | "assistant" | "placeholder";
} & BaseMessageFields &
Record<string, unknown>)
| MessageFieldWithRole
| [
StringWithAutocomplete<
MessageType | "user" | "assistant" | "placeholder"
Expand Down
62 changes: 62 additions & 0 deletions langchain-core/src/messages/tests/base_message.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import {
ToolMessage,
ToolMessageChunk,
AIMessageChunk,
coerceMessageLikeToMessage,
SystemMessage,
} from "../index.js";
import { load } from "../../load/index.js";

Expand Down Expand Up @@ -334,3 +336,63 @@ describe("Complex AIMessageChunk concat", () => {
);
});
});

describe("Message like coercion", () => {
it("Should convert OpenAI format messages", async () => {
const messages = [
{
id: "foobar",
role: "system",
content: "6",
},
{
role: "user",
content: [{ type: "image_url", image_url: { url: "7.1" } }],
},
{
role: "assistant",
content: [{ type: "text", text: "8.1" }],
tool_calls: [
{
id: "8.5",
function: {
name: "8.4",
arguments: JSON.stringify({ "8.2": "8.3" }),
},
type: "function",
},
],
},
{
role: "tool",
content: "10.2",
tool_call_id: "10.2",
},
].map(coerceMessageLikeToMessage);
expect(messages).toEqual([
new SystemMessage({
id: "foobar",
content: "6",
}),
new HumanMessage({
content: [{ type: "image_url", image_url: { url: "7.1" } }],
}),
new AIMessage({
content: [{ type: "text", text: "8.1" }],
tool_calls: [
{
id: "8.5",
name: "8.4",
args: { "8.2": "8.3" },
type: "tool_call",
},
],
}),
new ToolMessage({
name: undefined,
content: "10.2",
tool_call_id: "10.2",
}),
]);
});
});
48 changes: 45 additions & 3 deletions langchain-core/src/messages/utils.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import { _isToolCall } from "../tools/utils.js";
import { AIMessage, AIMessageChunk, AIMessageChunkFields } from "./ai.js";
import {
BaseMessageLike,
Expand All @@ -6,6 +7,7 @@ import {
StoredMessage,
StoredMessageV1,
BaseMessageFields,
_isMessageFieldWithRole,
} from "./base.js";
import {
ChatMessage,
Expand All @@ -19,16 +21,53 @@ import {
} from "./function.js";
import { HumanMessage, HumanMessageChunk } from "./human.js";
import { SystemMessage, SystemMessageChunk } from "./system.js";
import { ToolMessage, ToolMessageFieldsWithToolCallId } from "./tool.js";
import {
ToolCall,
ToolMessage,
ToolMessageFieldsWithToolCallId,
} from "./tool.js";

function _coerceToolCall(
toolCall: ToolCall | Record<string, unknown>
): ToolCall {
if (_isToolCall(toolCall)) {
return toolCall;
} else if (
typeof toolCall.id === "string" &&
toolCall.type === "function" &&
typeof toolCall.function === "object" &&
toolCall.function !== null &&
"arguments" in toolCall.function &&
typeof toolCall.function.arguments === "string" &&
"name" in toolCall.function &&
typeof toolCall.function.name === "string"
) {
// Handle OpenAI tool call format
return {
id: toolCall.id,
args: JSON.parse(toolCall.function.arguments),
name: toolCall.function.name,
type: "tool_call",
};
} else {
// TODO: Throw an error?
return toolCall as ToolCall;
}
}

function _constructMessageFromParams(
params: BaseMessageFields & { type: string }
params: BaseMessageFields & { type: string } & Record<string, unknown>
) {
const { type, ...rest } = params;
if (type === "human" || type === "user") {
return new HumanMessage(rest);
} else if (type === "ai" || type === "assistant") {
return new AIMessage(rest);
const { tool_calls: rawToolCalls, ...other } = rest;
if (!Array.isArray(rawToolCalls)) {
return new AIMessage(rest);
}
const tool_calls = rawToolCalls.map(_coerceToolCall);
return new AIMessage({ ...other, tool_calls });
} else if (type === "system") {
return new SystemMessage(rest);
} else if (type === "tool" && "tool_call_id" in rest) {
Expand Down Expand Up @@ -56,6 +95,9 @@ export function coerceMessageLikeToMessage(
if (Array.isArray(messageLike)) {
const [type, content] = messageLike;
return _constructMessageFromParams({ type, content });
} else if (_isMessageFieldWithRole(messageLike)) {
const { role: type, ...rest } = messageLike;
return _constructMessageFromParams({ ...rest, type });
} else {
return _constructMessageFromParams(messageLike);
}
Expand Down

0 comments on commit dc4497f

Please sign in to comment.