Skip to content

Commit

Permalink
feat(core): always provide message ID inferred from run ID if not pre…
Browse files Browse the repository at this point in the history
…sent while streaming
  • Loading branch information
dqbd committed Aug 27, 2024
1 parent 160c83c commit ab08f1d
Show file tree
Hide file tree
Showing 7 changed files with 124 additions and 47 deletions.
12 changes: 12 additions & 0 deletions langchain-core/src/language_models/chat_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,10 @@ export abstract class BaseChatModel<
callOptions,
runManagers?.[0]
)) {
if (chunk.message.id == null) {
const runId = runManagers?.at(0)?.runId;
if (runId != null) chunk.message._updateId(`run-${runId}`);
}
chunk.message.response_metadata = {
...chunk.generationInfo,
...chunk.message.response_metadata,
Expand Down Expand Up @@ -362,6 +366,10 @@ export abstract class BaseChatModel<
);
let aggregated;
for await (const chunk of stream) {
if (chunk.message.id == null) {
const runId = runManagers?.at(0)?.runId;
if (runId != null) chunk.message._updateId(`run-${runId}`);
}
if (aggregated === undefined) {
aggregated = chunk;
} else {
Expand Down Expand Up @@ -397,6 +405,10 @@ export abstract class BaseChatModel<
if (pResult.status === "fulfilled") {
const result = pResult.value;
for (const generation of result.generations) {
if (generation.message.id == null) {
const runId = runManagers?.at(0)?.runId;
if (runId != null) generation.message._updateId(`run-${runId}`);
}
generation.message.response_metadata = {
...generation.generationInfo,
...generation.message.response_metadata,
Expand Down
10 changes: 10 additions & 0 deletions langchain-core/src/messages/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,16 @@ export abstract class BaseMessage
};
}

// this private method is used to update the ID for the runtime
// value as well as in lc_kwargs for serialisation
_updateId(value: string | undefined) {
this.id = value;

// lc_attributes wouldn't work here, because jest compares the
// whole object
this.lc_kwargs.id = value;
}

get [Symbol.toStringTag]() {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
return (this.constructor as any).lc_name();
Expand Down
23 changes: 19 additions & 4 deletions langchain-core/src/runnables/tests/runnable_history.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ import {
import { ChatPromptTemplate, MessagesPlaceholder } from "../../prompts/chat.js";
import { StringOutputParser } from "../../output_parsers/string.js";

const anyString = expect.any(String) as unknown as string;

// For `BaseChatMessageHistory`
async function getGetSessionHistory(): Promise<
(sessionId: string) => Promise<BaseChatMessageHistory>
Expand Down Expand Up @@ -107,9 +109,15 @@ test("Runnable with message history with a chat model", async () => {
const sessionHistory = await getMessageHistory("2");
expect(await sessionHistory.getMessages()).toEqual([
new HumanMessage("hello"),
new AIMessage("Hello world!"),
new AIMessage({
id: anyString,
content: "Hello world!",
}),
new HumanMessage("good bye"),
new AIMessageChunk("Hello world!"),
new AIMessageChunk({
id: anyString,
content: "Hello world!",
}),
]);
});

Expand All @@ -129,6 +137,7 @@ test("Runnable with message history with a messages in, messages out chain", asy
config: {},
getMessageHistory,
});

const config: RunnableConfig = { configurable: { sessionId: "2" } };
const output = await withHistory.invoke([new HumanMessage("hello")], config);
expect(output.content).toBe("So long and thanks for the fish!!");
Expand All @@ -147,9 +156,15 @@ test("Runnable with message history with a messages in, messages out chain", asy
const sessionHistory = await getMessageHistory("2");
expect(await sessionHistory.getMessages()).toEqual([
new HumanMessage("hello"),
new AIMessage("So long and thanks for the fish!!"),
new AIMessage({
id: anyString,
content: "So long and thanks for the fish!!",
}),
new HumanMessage("good bye"),
new AIMessageChunk("So long and thanks for the fish!!"),
new AIMessageChunk({
id: anyString,
content: "So long and thanks for the fish!!",
}),
]);
});

Expand Down
22 changes: 12 additions & 10 deletions langchain-core/src/runnables/tests/runnable_stream_events.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ function reverse(s: string) {
return s.split("").reverse().join("");
}

const anyString = expect.any(String) as unknown as string;

const originalCallbackValue = process.env.LANGCHAIN_CALLBACKS_BACKGROUND;

afterEach(() => {
Expand Down Expand Up @@ -733,7 +735,7 @@ test("Runnable streamEvents method with chat model chain", async () => {
ls_stop: undefined,
},
name: "my_model",
data: { chunk: new AIMessageChunk("R") },
data: { chunk: new AIMessageChunk({ id: anyString, content: "R" }) },
},
{
event: "on_chain_stream",
Expand All @@ -743,7 +745,7 @@ test("Runnable streamEvents method with chat model chain", async () => {
foo: "bar",
},
name: "my_chain",
data: { chunk: new AIMessageChunk("R") },
data: { chunk: new AIMessageChunk({ id: anyString, content: "R" }) },
},
{
event: "on_llm_stream",
Expand All @@ -756,7 +758,7 @@ test("Runnable streamEvents method with chat model chain", async () => {
ls_stop: undefined,
},
name: "my_model",
data: { chunk: new AIMessageChunk("O") },
data: { chunk: new AIMessageChunk({ id: anyString, content: "O" }) },
},
{
event: "on_chain_stream",
Expand All @@ -766,7 +768,7 @@ test("Runnable streamEvents method with chat model chain", async () => {
foo: "bar",
},
name: "my_chain",
data: { chunk: new AIMessageChunk("O") },
data: { chunk: new AIMessageChunk({ id: anyString, content: "O" }) },
},
{
event: "on_llm_stream",
Expand All @@ -779,7 +781,7 @@ test("Runnable streamEvents method with chat model chain", async () => {
ls_stop: undefined,
},
name: "my_model",
data: { chunk: new AIMessageChunk("A") },
data: { chunk: new AIMessageChunk({ id: anyString, content: "A" }) },
},
{
event: "on_chain_stream",
Expand All @@ -789,7 +791,7 @@ test("Runnable streamEvents method with chat model chain", async () => {
foo: "bar",
},
name: "my_chain",
data: { chunk: new AIMessageChunk("A") },
data: { chunk: new AIMessageChunk({ id: anyString, content: "A" }) },
},
{
event: "on_llm_stream",
Expand All @@ -802,7 +804,7 @@ test("Runnable streamEvents method with chat model chain", async () => {
ls_stop: undefined,
},
name: "my_model",
data: { chunk: new AIMessageChunk("R") },
data: { chunk: new AIMessageChunk({ id: anyString, content: "R" }) },
},
{
event: "on_chain_stream",
Expand All @@ -812,7 +814,7 @@ test("Runnable streamEvents method with chat model chain", async () => {
foo: "bar",
},
name: "my_chain",
data: { chunk: new AIMessageChunk("R") },
data: { chunk: new AIMessageChunk({ id: anyString, content: "R" }) },
},
{
event: "on_llm_end",
Expand All @@ -836,7 +838,7 @@ test("Runnable streamEvents method with chat model chain", async () => {
[
new ChatGenerationChunk({
generationInfo: {},
message: new AIMessageChunk("ROAR"),
message: new AIMessageChunk({ id: anyString, content: "ROAR" }),
text: "ROAR",
}),
],
Expand All @@ -853,7 +855,7 @@ test("Runnable streamEvents method with chat model chain", async () => {
foo: "bar",
},
data: {
output: new AIMessageChunk("ROAR"),
output: new AIMessageChunk({ id: anyString, content: "ROAR" }),
},
},
]);
Expand Down
Loading

0 comments on commit ab08f1d

Please sign in to comment.