Skip to content

Commit

Permalink
Ensure connectivity to private HF spaces with SSE protocol (#8181)
Browse files Browse the repository at this point in the history
* add msw setup and initialisation tests

* add changeset

* add eventsource polyfill for node and browser envs

* add changeset

* add changeset

* config tweak

* types

* update eventsource usage

* add changeset

* add walk_and_store_blobs improvements and add tests

* add changeset

* api_info tests

* add direct space URL link tests

* fix tests

* add view_api tests

* add post_message test

* tweak

* add spaces tests

* jwt and protocol tests

* add post_data tests

* test tweaks

* dynamically import eventsource

* revet eventsource imports

* add jwt param to sse requests

* add stream test

* add changeset

* add changeset

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
  • Loading branch information
2 people authored and dawoodkhan82 committed May 6, 2024
1 parent 1d7fc6f commit c5277a3
Show file tree
Hide file tree
Showing 8 changed files with 121 additions and 10 deletions.
7 changes: 7 additions & 0 deletions .changeset/yummy-paws-eat.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
---
"@gradio/app": patch
"@gradio/client": patch
"gradio": patch
---

fix:Ensure connectivity to private HF spaces with SSE protocol
12 changes: 11 additions & 1 deletion client/js/src/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,16 @@ export class NodeBlob extends Blob {
}
}

if (typeof window === "undefined") {
import("eventsource")
.then((EventSourceModule) => {
global.EventSource = EventSourceModule.default as any;
})
.catch((error) =>
console.error("Failed to load EventSource module:", error)
);
}

export class Client {
app_reference: string;
options: ClientOptions;
Expand All @@ -51,7 +61,7 @@ export class Client {
stream_status = { open: false };
pending_stream_messages: Record<string, any[][]> = {};
pending_diff_streams: Record<string, any[][]> = {};
event_callbacks: Record<string, () => Promise<void>> = {};
event_callbacks: Record<string, (data?: unknown) => Promise<void>> = {};
unclosed_events: Set<string> = new Set();
heartbeat_event: EventSource | null = null;

Expand Down
11 changes: 11 additions & 0 deletions client/js/src/test/mock_eventsource.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import { vi } from "vitest";

Object.defineProperty(window, "EventSource", {
writable: true,
value: vi.fn().mockImplementation(() => ({
close: vi.fn(() => {}),
addEventListener: vi.fn(),
onmessage: vi.fn((_event: MessageEvent) => {}),
onerror: vi.fn((_event: Event) => {})
}))
});
67 changes: 67 additions & 0 deletions client/js/src/test/stream.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import { vi } from "vitest";
import { Client } from "../client";
import { initialise_server } from "./server";

import { describe, it, expect, afterEach } from "vitest";
import "./mock_eventsource.ts";

const server = initialise_server();

beforeAll(() => server.listen());
afterEach(() => server.resetHandlers());
afterAll(() => server.close());

describe("open_stream", () => {
let mock_eventsource: any;
let app: any;

beforeEach(async () => {
app = await Client.connect("hmb/hello_world");
app.eventSource_factory = vi.fn().mockImplementation(() => {
mock_eventsource = new EventSource("");
return mock_eventsource;
});
});

afterEach(() => {
vi.clearAllMocks();
});

it("should throw an error if config is not defined", () => {
app.config = undefined;

expect(() => {
app.open_stream();
}).toThrow("Could not resolve app config");
});

it("should connect to the SSE endpoint and handle messages", async () => {
app.open_stream();

const eventsource_mock_call = app.eventSource_factory.mock.calls[0][0];

expect(eventsource_mock_call.href).toMatch(
/https:\/\/hmb-hello-world\.hf\.space\/queue\/data\?session_hash/
);

expect(app.eventSource_factory).toHaveBeenCalledWith(eventsource_mock_call);

const onMessageCallback = mock_eventsource.onmessage;
const onErrorCallback = mock_eventsource.onerror;

const message = { msg: "hello jerry" };

onMessageCallback({ data: JSON.stringify(message) });
expect(app.stream_status.open).toBe(true);

expect(app.event_callbacks).toEqual({});
expect(app.pending_stream_messages).toEqual({});

const close_stream_message = { msg: "close_stream" };
onMessageCallback({ data: JSON.stringify(close_stream_message) });
expect(app.stream_status.open).toBe(false);

onErrorCallback({ data: JSON.stringify("404") });
expect(app.stream_status.open).toBe(false);
});
});
18 changes: 11 additions & 7 deletions client/js/src/utils/stream.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ export function open_stream(this: Client): void {
unclosed_events,
pending_stream_messages,
stream_status,
config
config,
jwt
} = this;

if (!config) {
Expand All @@ -22,10 +23,16 @@ export function open_stream(this: Client): void {
}).toString();

let url = new URL(`${config.root}/queue/data?${params}`);

if (jwt) {
url.searchParams.set("__sign", jwt);
}

event_source = this.eventSource_factory(url);

if (!event_source) {
throw new Error("Cannot connect to sse endpoint: " + url.toString());
console.warn("Cannot connect to SSE endpoint: " + url.toString());
return;
}

event_source.onmessage = async function (event: MessageEvent) {
Expand All @@ -37,10 +44,8 @@ export function open_stream(this: Client): void {
const event_id = _data.event_id;
if (!event_id) {
await Promise.all(
Object.keys(event_callbacks).map(
(event_id) =>
// @ts-ignore
event_callbacks[event_id](_data) // todo: check event_callbacks
Object.keys(event_callbacks).map((event_id) =>
event_callbacks[event_id](_data)
)
);
} else if (event_callbacks[event_id] && config) {
Expand Down Expand Up @@ -70,7 +75,6 @@ export function open_stream(this: Client): void {
event_source.onerror = async function () {
await Promise.all(
Object.keys(event_callbacks).map((event_id) =>
// @ts-ignore
event_callbacks[event_id]({
msg: "unexpected_error",
message: BROKEN_CONNECTION_MSG
Expand Down
8 changes: 6 additions & 2 deletions client/js/src/utils/submit.ts
Original file line number Diff line number Diff line change
Expand Up @@ -368,11 +368,15 @@ export function submit(
}${params}`
);

if (this.jwt) {
url.searchParams.set("__sign", this.jwt);
}

event_source = this.eventSource_factory(url);

if (!event_source) {
throw new Error(
"Cannot connect to sse endpoint: " + url.toString()
return Promise.reject(
new Error("Cannot connect to SSE endpoint: " + url.toString())
);
}

Expand Down
2 changes: 2 additions & 0 deletions js/app/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,10 @@
"@gradio/utils": "workspace:^",
"@gradio/video": "workspace:^",
"@gradio/wasm": "workspace:^",
"@types/eventsource": "^1.1.15",
"cross-env": "^7.0.3",
"d3-dsv": "^3.0.1",
"eventsource": "^2.0.2",
"mime-types": "^2.1.34",
"postcss": "^8.4.21",
"postcss-prefix-selector": "^1.16.0"
Expand Down
6 changes: 6 additions & 0 deletions pnpm-lock.yaml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit c5277a3

Please sign in to comment.