diff --git a/langchain-core/src/language_models/tests/chat_models.test.ts b/langchain-core/src/language_models/tests/chat_models.test.ts index 940cc50802b2..70ff187243e8 100644 --- a/langchain-core/src/language_models/tests/chat_models.test.ts +++ b/langchain-core/src/language_models/tests/chat_models.test.ts @@ -27,6 +27,64 @@ test("Test ChatModel accepts object shorthand for messages", async () => { expect(response.content).toEqual("Hello there!"); }); +test("Test ChatModel accepts object with role for messages", async () => { + const model = new FakeChatModel({}); + const response = await model.invoke([ + { + role: "human", + content: "Hello there!!", + example: true, + }, + ]); + expect(response.content).toEqual("Hello there!!"); +}); + +test("Test ChatModel accepts several messages as objects with role", async () => { + const model = new FakeChatModel({}); + const response = await model.invoke([ + { + role: "system", + content: "You are an assistant.", + }, + { + role: "human", + content: [{ type: "text", text: "What is the weather in SF?" }], + example: true, + }, + { + role: "assistant", + content: "", + tool_calls: [ + { + id: "call_123", + function: { + name: "get_weather", + arguments: JSON.stringify({ location: "sf" }), + }, + type: "function", + }, + ], + }, + { + role: "tool", + content: "Pretty nice right now!", + tool_call_id: "call_123", + }, + ]); + expect(response.content).toEqual( + [ + "You are an assistant.", + JSON.stringify( + [{ type: "text", text: "What is the weather in SF?" }], + null, + 2 + ), + "", + "Pretty nice right now!", + ].join("\n") + ); +}); + test("Test ChatModel uses callbacks", async () => { const model = new FakeChatModel({}); let acc = ""; diff --git a/langchain-core/src/messages/base.ts b/langchain-core/src/messages/base.ts index 639fcf226025..7f46ce761221 100644 --- a/langchain-core/src/messages/base.ts +++ b/langchain-core/src/messages/base.ts @@ -450,12 +450,25 @@ export abstract class BaseMessageChunk extends BaseMessage { abstract concat(chunk: BaseMessageChunk): BaseMessageChunk; } +export type MessageFieldWithRole = { + role: StringWithAutocomplete<"user" | "assistant" | MessageType>; + content: MessageContent; + name?: string; +} & Record; + +export function _isMessageFieldWithRole( + x: BaseMessageLike +): x is MessageFieldWithRole { + return typeof (x as MessageFieldWithRole).role === "string"; +} + export type BaseMessageLike = | BaseMessage | ({ type: MessageType | "user" | "assistant" | "placeholder"; } & BaseMessageFields & Record) + | MessageFieldWithRole | [ StringWithAutocomplete< MessageType | "user" | "assistant" | "placeholder" diff --git a/langchain-core/src/messages/tests/base_message.test.ts b/langchain-core/src/messages/tests/base_message.test.ts index cc6926a1e6b6..0e6883c89dc0 100644 --- a/langchain-core/src/messages/tests/base_message.test.ts +++ b/langchain-core/src/messages/tests/base_message.test.ts @@ -6,6 +6,8 @@ import { ToolMessage, ToolMessageChunk, AIMessageChunk, + coerceMessageLikeToMessage, + SystemMessage, } from "../index.js"; import { load } from "../../load/index.js"; @@ -334,3 +336,63 @@ describe("Complex AIMessageChunk concat", () => { ); }); }); + +describe("Message like coercion", () => { + it("Should convert OpenAI format messages", async () => { + const messages = [ + { + id: "foobar", + role: "system", + content: "6", + }, + { + role: "user", + content: [{ type: "image_url", image_url: { url: "7.1" } }], + }, + { + role: "assistant", + content: [{ type: "text", text: "8.1" }], + tool_calls: [ + { + id: "8.5", + function: { + name: "8.4", + arguments: JSON.stringify({ "8.2": "8.3" }), + }, + type: "function", + }, + ], + }, + { + role: "tool", + content: "10.2", + tool_call_id: "10.2", + }, + ].map(coerceMessageLikeToMessage); + expect(messages).toEqual([ + new SystemMessage({ + id: "foobar", + content: "6", + }), + new HumanMessage({ + content: [{ type: "image_url", image_url: { url: "7.1" } }], + }), + new AIMessage({ + content: [{ type: "text", text: "8.1" }], + tool_calls: [ + { + id: "8.5", + name: "8.4", + args: { "8.2": "8.3" }, + type: "tool_call", + }, + ], + }), + new ToolMessage({ + name: undefined, + content: "10.2", + tool_call_id: "10.2", + }), + ]); + }); +}); diff --git a/langchain-core/src/messages/utils.ts b/langchain-core/src/messages/utils.ts index 4a34e03984ee..b9df034ee923 100644 --- a/langchain-core/src/messages/utils.ts +++ b/langchain-core/src/messages/utils.ts @@ -1,3 +1,4 @@ +import { _isToolCall } from "../tools/utils.js"; import { AIMessage, AIMessageChunk, AIMessageChunkFields } from "./ai.js"; import { BaseMessageLike, @@ -6,6 +7,7 @@ import { StoredMessage, StoredMessageV1, BaseMessageFields, + _isMessageFieldWithRole, } from "./base.js"; import { ChatMessage, @@ -19,16 +21,53 @@ import { } from "./function.js"; import { HumanMessage, HumanMessageChunk } from "./human.js"; import { SystemMessage, SystemMessageChunk } from "./system.js"; -import { ToolMessage, ToolMessageFieldsWithToolCallId } from "./tool.js"; +import { + ToolCall, + ToolMessage, + ToolMessageFieldsWithToolCallId, +} from "./tool.js"; + +function _coerceToolCall( + toolCall: ToolCall | Record +): ToolCall { + if (_isToolCall(toolCall)) { + return toolCall; + } else if ( + typeof toolCall.id === "string" && + toolCall.type === "function" && + typeof toolCall.function === "object" && + toolCall.function !== null && + "arguments" in toolCall.function && + typeof toolCall.function.arguments === "string" && + "name" in toolCall.function && + typeof toolCall.function.name === "string" + ) { + // Handle OpenAI tool call format + return { + id: toolCall.id, + args: JSON.parse(toolCall.function.arguments), + name: toolCall.function.name, + type: "tool_call", + }; + } else { + // TODO: Throw an error? + return toolCall as ToolCall; + } +} function _constructMessageFromParams( - params: BaseMessageFields & { type: string } + params: BaseMessageFields & { type: string } & Record ) { const { type, ...rest } = params; if (type === "human" || type === "user") { return new HumanMessage(rest); } else if (type === "ai" || type === "assistant") { - return new AIMessage(rest); + const { tool_calls: rawToolCalls, ...other } = rest; + if (!Array.isArray(rawToolCalls)) { + return new AIMessage(rest); + } + const tool_calls = rawToolCalls.map(_coerceToolCall); + return new AIMessage({ ...other, tool_calls }); } else if (type === "system") { return new SystemMessage(rest); } else if (type === "tool" && "tool_call_id" in rest) { @@ -56,6 +95,9 @@ export function coerceMessageLikeToMessage( if (Array.isArray(messageLike)) { const [type, content] = messageLike; return _constructMessageFromParams({ type, content }); + } else if (_isMessageFieldWithRole(messageLike)) { + const { role: type, ...rest } = messageLike; + return _constructMessageFromParams({ ...rest, type }); } else { return _constructMessageFromParams(messageLike); }