Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

core[patch]: Fix double streaming issue when streamEvents is called directly on chat models/LLMs #6155

Merged
merged 3 commits into from
Jul 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/core_docs/docs/how_to/streaming.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -663,6 +663,8 @@
"| on_prompt_start | [template_name] | | {\"question\": \"hello\"} | |\n",
"| on_prompt_end | [template_name] | | {\"question\": \"hello\"} | ChatPromptValue(messages: [SystemMessage, ...]) |\n",
"\n",
"`streamEvents` will also emit dispatched custom events in `v2`. Please see [this guide](/docs/how_to/callbacks_custom_events/) for more.\n",
"\n",
"### Chat Model\n",
"\n",
"Let's start off by looking at the events produced by a chat model."
Expand Down
52 changes: 49 additions & 3 deletions langchain-core/src/runnables/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -762,11 +762,11 @@ export abstract class Runnable<
* +----------------------+------------------+---------------------------------+-----------------------------------------------+-------------------------------------------------+
* | on_llm_end | [model name] | | 'Hello human!' | |
* +----------------------+------------------+---------------------------------+-----------------------------------------------+-------------------------------------------------+
* | on_chain_start | format_docs | | | |
* | on_chain_start | some_runnable | | | |
* +----------------------+------------------+---------------------------------+-----------------------------------------------+-------------------------------------------------+
* | on_chain_stream | format_docs | "hello world!, goodbye world!" | | |
* | on_chain_stream | some_runnable | "hello world!, goodbye world!" | | |
* +----------------------+------------------+---------------------------------+-----------------------------------------------+-------------------------------------------------+
* | on_chain_end | format_docs | | [Document(...)] | "hello world!, goodbye world!" |
* | on_chain_end | some_runnable | | [Document(...)] | "hello world!, goodbye world!" |
* +----------------------+------------------+---------------------------------+-----------------------------------------------+-------------------------------------------------+
* | on_tool_start | some_tool | | {"x": 1, "y": "2"} | |
* +----------------------+------------------+---------------------------------+-----------------------------------------------+-------------------------------------------------+
Expand All @@ -780,6 +780,52 @@ export abstract class Runnable<
* +----------------------+------------------+---------------------------------+-----------------------------------------------+-------------------------------------------------+
* | on_prompt_end | [template_name] | | {"question": "hello"} | ChatPromptValue(messages: [SystemMessage, ...]) |
* +----------------------+------------------+---------------------------------+-----------------------------------------------+-------------------------------------------------+
*
* The "on_chain_*" events are the default for Runnables that don't fit one of the above categories.
*
* In addition to the standard events above, users can also dispatch custom events.
*
* Custom events will be only be surfaced with in the `v2` version of the API!
*
* A custom event has following format:
*
* +-----------+------+-----------------------------------------------------------------------------------------------------------+
* | Attribute | Type | Description |
* +===========+======+===========================================================================================================+
* | name | str | A user defined name for the event. |
* +-----------+------+-----------------------------------------------------------------------------------------------------------+
* | data | Any | The data associated with the event. This can be anything, though we suggest making it JSON serializable. |
* +-----------+------+-----------------------------------------------------------------------------------------------------------+
*
* Here's an example:
* @example
* ```ts
* import { RunnableLambda } from "@langchain/core/runnables";
* import { dispatchCustomEvent } from "@langchain/core/callbacks/dispatch";
* // Use this import for web environments that don't support "async_hooks"
* // and manually pass config to child runs.
* // import { dispatchCustomEvent } from "@langchain/core/callbacks/dispatch/web";
*
* const slowThing = RunnableLambda.from(async (someInput: string) => {
* // Placeholder for some slow operation
* await new Promise((resolve) => setTimeout(resolve, 100));
* await dispatchCustomEvent("progress_event", {
* message: "Finished step 1 of 2",
* });
* await new Promise((resolve) => setTimeout(resolve, 100));
* return "Done";
* });
*
* const eventStream = await slowThing.streamEvents("hello world", {
* version: "v2",
* });
*
* for await (const event of eventStream) {
* if (event.event === "on_custom_event") {
* console.log(event);
* }
* }
* ```
*/
streamEvents(
input: RunInput,
Expand Down
102 changes: 54 additions & 48 deletions langchain-core/src/runnables/tests/runnable_stream_events_v2.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,60 @@ test("Runnable streamEvents method", async () => {
]);
});

test("Runnable streamEvents method on a chat model", async () => {
const model = new FakeListChatModel({
responses: ["abc"],
});

const events = [];
const eventStream = await model.streamEvents("hello", { version: "v2" });
for await (const event of eventStream) {
events.push(event);
}
expect(events).toMatchObject([
{
data: { input: "hello" },
event: "on_chat_model_start",
name: "FakeListChatModel",
metadata: expect.any(Object),
run_id: expect.any(String),
tags: [],
},
{
data: { chunk: new AIMessageChunk({ content: "a" }) },
event: "on_chat_model_stream",
name: "FakeListChatModel",
metadata: expect.any(Object),
run_id: expect.any(String),
tags: [],
},
{
data: { chunk: new AIMessageChunk({ content: "b" }) },
event: "on_chat_model_stream",
name: "FakeListChatModel",
metadata: expect.any(Object),
run_id: expect.any(String),
tags: [],
},
{
data: { chunk: new AIMessageChunk({ content: "c" }) },
event: "on_chat_model_stream",
name: "FakeListChatModel",
metadata: expect.any(Object),
run_id: expect.any(String),
tags: [],
},
{
data: { output: new AIMessageChunk({ content: "abc" }) },
event: "on_chat_model_end",
name: "FakeListChatModel",
metadata: expect.any(Object),
run_id: expect.any(String),
tags: [],
},
]);
});

test("Runnable streamEvents method with three runnables", async () => {
const r = RunnableLambda.from(reverse);

Expand Down Expand Up @@ -599,18 +653,6 @@ test("Runnable streamEvents method with llm", async () => {
a: "b",
},
},
{
event: "on_llm_stream",
run_id: expect.any(String),
name: "my_model",
tags: ["my_model"],
metadata: {
a: "b",
},
data: {
chunk: "h",
},
},
{
event: "on_llm_stream",
data: {
Expand All @@ -625,18 +667,6 @@ test("Runnable streamEvents method with llm", async () => {
a: "b",
},
},
{
event: "on_llm_stream",
run_id: expect.any(String),
name: "my_model",
tags: ["my_model"],
metadata: {
a: "b",
},
data: {
chunk: "e",
},
},
{
event: "on_llm_stream",
data: {
Expand All @@ -651,18 +681,6 @@ test("Runnable streamEvents method with llm", async () => {
a: "b",
},
},
{
event: "on_llm_stream",
run_id: expect.any(String),
name: "my_model",
tags: ["my_model"],
metadata: {
a: "b",
},
data: {
chunk: "y",
},
},
{
event: "on_llm_stream",
data: {
Expand All @@ -677,18 +695,6 @@ test("Runnable streamEvents method with llm", async () => {
a: "b",
},
},
{
event: "on_llm_stream",
run_id: expect.any(String),
name: "my_model",
tags: ["my_model"],
metadata: {
a: "b",
},
data: {
chunk: "!",
},
},
{
event: "on_llm_end",
data: {
Expand Down
17 changes: 15 additions & 2 deletions langchain-core/src/tracers/event_stream.ts
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,13 @@ export class EventStreamCallbackHandler extends BaseTracer {
yield firstChunk.value;
return;
}
// Match format from handlers below
function _formatOutputChunk(eventType: string, data: unknown) {
if (eventType === "llm" && typeof data === "string") {
return new GenerationChunk({ text: data });
}
return data;
}
let tappedPromise = this.tappedPromises.get(runId);
// if we are the first to tap, issue stream events
if (tappedPromise === undefined) {
Expand All @@ -264,7 +271,9 @@ export class EventStreamCallbackHandler extends BaseTracer {
await this.send(
{
...event,
data: { chunk: firstChunk.value },
data: {
chunk: _formatOutputChunk(runInfo.runType, firstChunk.value),
},
},
runInfo
);
Expand All @@ -276,7 +285,7 @@ export class EventStreamCallbackHandler extends BaseTracer {
{
...event,
data: {
chunk,
chunk: _formatOutputChunk(runInfo.runType, chunk),
},
},
runInfo
Expand Down Expand Up @@ -354,6 +363,10 @@ export class EventStreamCallbackHandler extends BaseTracer {
if (runInfo === undefined) {
throw new Error(`onLLMNewToken: Run ID ${run.id} not found in run map.`);
}
// Top-level streaming events are covered by tapOutputIterable
if (run.parent_run_id === undefined) {
return;
}
if (runInfo.runType === "chat_model") {
eventName = "on_chat_model_stream";
if (kwargs?.chunk === undefined) {
Expand Down
Loading