From 11b6e7558cd481f7ee1c2eb23b4948784734267d Mon Sep 17 00:00:00 2001 From: Tmk Date: Sat, 3 Feb 2024 23:49:42 +0800 Subject: [PATCH] refactor(messaging): post data through one-time requests (#28) --- apps/ai-assistant/src/content-scripts/app.tsx | 4 +- packages/messaging/demo/background/index.ts | 8 +- .../messaging/demo/content-scripts/index.ts | 4 +- .../messaging/demo/pages/options/index.ts | 4 +- packages/messaging/demo/pages/popup/index.ts | 4 +- packages/messaging/e2e/basic.spec.ts | 33 ++-- packages/messaging/playwright.config.ts | 1 + packages/messaging/src/background.ts | 136 ++++--------- packages/messaging/src/client-base.ts | 98 +++++----- packages/messaging/src/content-script.ts | 8 +- .../src/core/__tests__/index.spec.ts | 124 +----------- .../src/core/__tests__/test-utils.ts | 20 ++ .../messaging/src/core/__tests__/trpc.spec.ts | 15 +- packages/messaging/src/core/index.ts | 184 ++++++------------ packages/messaging/src/core/trpc.ts | 26 ++- packages/messaging/src/shared.ts | 29 ++- 16 files changed, 241 insertions(+), 457 deletions(-) create mode 100644 packages/messaging/src/core/__tests__/test-utils.ts diff --git a/apps/ai-assistant/src/content-scripts/app.tsx b/apps/ai-assistant/src/content-scripts/app.tsx index 9a5eab7e..c81789a9 100644 --- a/apps/ai-assistant/src/content-scripts/app.tsx +++ b/apps/ai-assistant/src/content-scripts/app.tsx @@ -7,7 +7,7 @@ import { isSelectionValid, rangeToReference, } from '@webx-kit/runtime/content-scripts'; -import { createTrpcHandler } from '@webx-kit/messaging/content-script'; +import { createTrpcClient } from '@webx-kit/messaging/content-script'; import clsx from 'clsx'; import type { AppRouter } from '@/background/router'; import { DialogTrigger, TooltipTrigger } from 'react-aria-components'; @@ -15,7 +15,7 @@ import { Button, Popover, Tooltip } from '@/components'; import { Provider } from './features/provider'; import './global.less'; -const { client } = createTrpcHandler({}); +const { client } = createTrpcClient({}); export const App = () => { const [visible, setVisible] = useState(false); diff --git a/packages/messaging/demo/background/index.ts b/packages/messaging/demo/background/index.ts index b0b0e991..34c9c722 100644 --- a/packages/messaging/demo/background/index.ts +++ b/packages/messaging/demo/background/index.ts @@ -2,11 +2,12 @@ import { createCustomHandler } from '@/background'; const sleep = (ms: number) => new Promise((resolve) => setTimeout(resolve, ms)); -const { connections } = createCustomHandler({ +// @ts-expect-error +globalThis.__messaging = createCustomHandler({ async requestHandler(message) { return { reply: 'background', - data: message.data, + data: message, }; }, async streamHandler(_message, subscriber) { @@ -19,6 +20,3 @@ const { connections } = createCustomHandler({ subscriber.complete(); }, }); - -// @ts-expect-error -globalThis.__webxConnections = connections; diff --git a/packages/messaging/demo/content-scripts/index.ts b/packages/messaging/demo/content-scripts/index.ts index 168ff78c..b1f8b65f 100644 --- a/packages/messaging/demo/content-scripts/index.ts +++ b/packages/messaging/demo/content-scripts/index.ts @@ -6,7 +6,7 @@ const { messaging } = createCustomHandler({ requestHandler(message) { return { reply: 'content-script', - data: message.data, + data: message, }; }, async streamHandler(_message, subscriber) { @@ -21,4 +21,4 @@ const { messaging } = createCustomHandler({ }); // @ts-expect-error -globalThis.__client = messaging; +globalThis.__clientMessaging = messaging; diff --git a/packages/messaging/demo/pages/options/index.ts b/packages/messaging/demo/pages/options/index.ts index 772f1199..c99be514 100644 --- a/packages/messaging/demo/pages/options/index.ts +++ b/packages/messaging/demo/pages/options/index.ts @@ -7,7 +7,7 @@ const { messaging } = createCustomHandler({ requestHandler(message) { return { reply: 'options', - data: message.data, + data: message, }; }, async streamHandler(_message, subscriber) { @@ -22,4 +22,4 @@ const { messaging } = createCustomHandler({ }); // @ts-expect-error -globalThis.__client = messaging; +globalThis.__clientMessaging = messaging; diff --git a/packages/messaging/demo/pages/popup/index.ts b/packages/messaging/demo/pages/popup/index.ts index 3316b0e2..e08dd83d 100644 --- a/packages/messaging/demo/pages/popup/index.ts +++ b/packages/messaging/demo/pages/popup/index.ts @@ -7,7 +7,7 @@ const { messaging } = createCustomHandler({ requestHandler(message) { return { reply: 'popup', - data: message.data, + data: message, }; }, async streamHandler(_message, subscriber) { @@ -22,4 +22,4 @@ const { messaging } = createCustomHandler({ }); // @ts-expect-error -globalThis.__client = messaging; +globalThis.__clientMessaging = messaging; diff --git a/packages/messaging/e2e/basic.spec.ts b/packages/messaging/e2e/basic.spec.ts index 5ce7cffc..c926c3d0 100644 --- a/packages/messaging/e2e/basic.spec.ts +++ b/packages/messaging/e2e/basic.spec.ts @@ -1,43 +1,54 @@ import { setupStaticServer } from '@webx-kit/test-utils/playwright'; import { expect, test } from './context'; import type { Messaging } from '@/core'; -import type { WrappedMessaging } from '@/client-base'; +import type { WrappedMessaging } from '@/shared'; const getWebpageURL = setupStaticServer(test); declare module globalThis { /** only in background */ - const __webxConnections: Set | undefined; - const __client: WrappedMessaging; + const __messaging: Messaging; + const __clientMessaging: WrappedMessaging; } test('Background', async ({ background }) => { - await expect(background.evaluate(() => typeof globalThis.__webxConnections)).resolves.toBe('object'); + await expect(background.evaluate(() => typeof globalThis.__messaging)).resolves.toBe('object'); }); test('Messaging', async ({ context, getURL }) => { const optionsPage = await context.newPage(); const popupPage = await context.newPage(); - const contentScript = await context.newPage(); + const contentPage = await context.newPage(); await Promise.all([ optionsPage.goto(await getURL('options.html')), popupPage.goto(await getURL('popup.html')), - contentScript.goto(getWebpageURL()), + contentPage.goto(getWebpageURL()), ]); - await expect(optionsPage.evaluate(() => globalThis.__client.request('options', 'popup'))).resolves.toEqual({ - reply: 'popup', + await expect(optionsPage.evaluate(() => globalThis.__clientMessaging.request('options'))).resolves.toEqual({ + reply: 'background', data: 'options', }); - await expect(popupPage.evaluate(() => globalThis.__client.request('popup', 'options'))).resolves.toEqual({ + await expect(popupPage.evaluate(() => globalThis.__clientMessaging.requestTo('options', 'popup'))).resolves.toEqual({ reply: 'options', data: 'popup', }); + const [webpage] = await popupPage.evaluate(() => + chrome.tabs.query({}).then((tabs) => + tabs + .filter((tab) => tab.url?.startsWith('http')) + .map((tab) => ({ + id: tab.id, + url: tab.url, + })) + ) + ); + await expect( - popupPage.evaluate(() => globalThis.__client.request('popup to content-script', 'content-script')) + popupPage.evaluate((tabId) => globalThis.__clientMessaging.requestTo(tabId!, 'popup to content-script'), webpage.id) ).resolves.toEqual({ reply: 'content-script', data: 'popup to content-script', @@ -52,7 +63,7 @@ test('Stream', async ({ context, getURL }) => { const result = await optionsPage.evaluate(() => { return new Promise((resolve, reject) => { const result: unknown[] = []; - globalThis.__client.stream('options', { + globalThis.__clientMessaging.stream('options', { next: (value) => result.push(value), error: reject, complete: () => resolve(result), diff --git a/packages/messaging/playwright.config.ts b/packages/messaging/playwright.config.ts index deb4d8ab..43131db7 100644 --- a/packages/messaging/playwright.config.ts +++ b/packages/messaging/playwright.config.ts @@ -6,4 +6,5 @@ import { defineConfig } from '@playwright/test'; export default defineConfig({ testDir: './e2e', retries: 2, + timeout: 5000, }); diff --git a/packages/messaging/src/background.ts b/packages/messaging/src/background.ts index efe39417..a64d731c 100644 --- a/packages/messaging/src/background.ts +++ b/packages/messaging/src/background.ts @@ -1,93 +1,46 @@ import { AnyTRPCRouter } from '@trpc/server'; -import { Messaging, createMessaging, fromChromePort } from './core'; +import { Port, RequestHandler, StreamHandler, createMessaging } from './core'; import { applyMessagingHandler } from './core/trpc'; -import { NAMESPACE, RequestHandler, StreamHandler, WebxMessage, isWebxMessage } from './shared'; - -export type WebxMessageMiddleware = ( - message: WebxMessage, - port: chrome.runtime.Port -) => WebxMessage | false | void | Promise; - -// #region middlewares -const senderInfoMiddleware: WebxMessageMiddleware = (message, port) => { - return { - from: port.name.slice(NAMESPACE.length), - tabId: port.sender?.tab?.id, - ...message, - }; -}; -// #endregion - -const middlewares = new Set([senderInfoMiddleware]); - -async function applyMiddlewares(message: WebxMessage, port: chrome.runtime.Port) { - for (const middleware of middlewares) { - const modifiedMessage = await middleware(message, port); - if (modifiedMessage === false) return message; - if (modifiedMessage) message = modifiedMessage; - } - return message; -} +import { WebxMessage, isWebxMessage } from './shared'; export interface CustomHandlerOptions { requestHandler?: RequestHandler; streamHandler?: StreamHandler; } -export function createCustomHandler({ - requestHandler = () => Promise.reject(), - streamHandler = (_, subscriber) => { - subscriber.error('unimplemented'); - }, -}: CustomHandlerOptions) { - const connections = new Set(); - - const listener = (port: chrome.runtime.Port): void => { - if (!port.name.startsWith(NAMESPACE)) return; - const messaging = createMessaging(fromChromePort(port), { - async onRequest(message) { - if (!isWebxMessage(message)) return Promise.reject('unknown message'); - const webxMessage = await applyMiddlewares(message, port); +const NAME = 'background'; - if (!webxMessage.to) return requestHandler(webxMessage); - - for (const connection of connections) { - if (connection.name === port.name) continue; - if (connection.name.slice(NAMESPACE.length).startsWith(webxMessage.to)) { - return connection.request(webxMessage); - } - } - - return Promise.reject('no target'); - }, - async onStream(message, subscriber) { - if (!isWebxMessage(message)) return subscriber.error('unknown message'); - const webxMessage = await applyMiddlewares(message, port); - - if (!webxMessage.to) return streamHandler(webxMessage, subscriber); +const backgroundPort: Port = { + name: NAME, + onMessage(listener) { + chrome.runtime.onMessage.addListener(listener); + return () => { + chrome.runtime.onMessage.removeListener(listener); + }; + }, + send(message, originMessage?: Parameters[0]>) { + if (!originMessage) return chrome.runtime.sendMessage(message); + const [, sender] = originMessage; + if (!sender.tab?.id) return chrome.runtime.sendMessage(message); + chrome.tabs.sendMessage(sender.tab.id, message, { documentId: sender.documentId, frameId: sender.frameId }); + }, +}; - for (const connection of connections) { - if (connection.name === port.name) continue; - if (connection.name.slice(NAMESPACE.length).startsWith(webxMessage.to)) { - return connection.stream(webxMessage, subscriber); - } - } +function shouldSkip(data: unknown) { + if (!isWebxMessage(data)) return true; + return !(!data.to || data.to === NAME); +} - subscriber.error('no target'); - }, - onDispose() { - connections.delete(messaging); - }, - }); - connections.add(messaging); - }; - chrome.runtime.onConnect.addListener(listener); - return { - connections, - dispose() { - chrome.runtime.onConnect.removeListener(listener); - }, - }; +export function createCustomHandler({ requestHandler, streamHandler }: CustomHandlerOptions) { + return createMessaging(backgroundPort, { + intercept: (data: WebxMessage, abort) => (shouldSkip(data) ? abort : data.data), + onRequest: requestHandler || (() => Promise.reject()), + onStream: + streamHandler || + ((_, subscriber) => { + subscriber.error('unimplemented'); + }), + }); } export interface TrpcHandlerOptions { @@ -95,24 +48,9 @@ export interface TrpcHandlerOptions { } export function createTrpcHandler({ router }: TrpcHandlerOptions) { - const connections = new Set(); - - const listener = (port: chrome.runtime.Port): void => { - if (!port.name.startsWith(NAMESPACE)) return; - const messaging = applyMessagingHandler({ - port: fromChromePort(port), - router, - onDispose() { - connections.delete(messaging); - }, - }); - connections.add(messaging); - }; - chrome.runtime.onConnect.addListener(listener); - return { - connections, - dispose() { - chrome.runtime.onConnect.removeListener(listener); - }, - }; + return applyMessagingHandler({ + port: backgroundPort, + router, + intercept: (data: WebxMessage, abort) => (shouldSkip(data) ? abort : data.data), + }); } diff --git a/packages/messaging/src/client-base.ts b/packages/messaging/src/client-base.ts index 4db6c4e9..d1f7e698 100644 --- a/packages/messaging/src/client-base.ts +++ b/packages/messaging/src/client-base.ts @@ -1,39 +1,54 @@ -import type { Observer } from 'type-fest'; -import { Messaging, createMessaging, fromChromePort } from './core'; +import { Port, RequestHandler, StreamHandler, createMessaging } from './core'; import { randomID } from './core/utils'; -import { ClientType, NAMESPACE, RequestHandler, StreamHandler, WebxMessage } from './shared'; import { createTRPCClient } from '@trpc/client'; +import { ClientType, MessageTarget, WebxMessage, isWebxMessage, wrapMessaging } from './shared'; import type { AnyTRPCRouter } from '@trpc/server'; import { messagingLink } from './core/trpc'; +const clientPort: Port = { + onMessage(listener) { + chrome.runtime.onMessage.addListener(listener); + return () => { + chrome.runtime.onMessage.removeListener(listener); + }; + }, + send(message) { + if (typeof message?.d?.to === 'number') { + chrome.tabs.sendMessage(message.d.to, message); + return; + } + chrome.runtime.sendMessage(message); + }, +}; + +function shouldSkip(name: string, data: unknown) { + if (!isWebxMessage(data)) return true; + const { to } = data; + // - string: all extension pages (including background) will receive + // - number: only the specific page will receive the message, so it's unnecessary to check + return !((typeof to === 'string' && name.startsWith(to)) || typeof to === 'number'); +} + export interface CustomHandlerOptions { type: ClientType; requestHandler?: RequestHandler; streamHandler?: StreamHandler; } -export function createCustomHandler({ - type, - requestHandler = () => Promise.reject(), - streamHandler = (_, subscriber) => { - subscriber.error('unimplemented'); - }, -}: CustomHandlerOptions) { +export function createCustomHandler({ type, requestHandler, streamHandler }: CustomHandlerOptions) { const id = randomID(); + const name = `${type}@${id}`; - let port: chrome.runtime.Port | null = null; const messaging = createMessaging( - fromChromePort(() => (port ||= chrome.runtime.connect({ name: `${NAMESPACE}${type}@${id}` }))), + { ...clientPort, name }, { - onRequest(message) { - return requestHandler(message); - }, - onStream(message, subscriber) { - return streamHandler(message, subscriber); - }, - onDispose() { - port = null; - }, + intercept: (data: WebxMessage, abort) => (shouldSkip(name, data) ? abort : data.data), + onRequest: requestHandler || (() => Promise.reject()), + onStream: + streamHandler || + ((_, subscriber) => { + subscriber.error('unimplemented'); + }), } ); @@ -47,27 +62,30 @@ export function createCustomHandler({ }; } -export interface TrpcHandlerOptions { +export interface TrpcClientOptions { type: ClientType; + to?: MessageTarget; } -export function createTrpcHandler({ type }: TrpcHandlerOptions) { +export function createTrpcClient({ type, to = 'background' }: TrpcClientOptions) { const id = randomID(); - let port: chrome.runtime.Port | null = null; + const name = `${type}@${id}`; - const link = messagingLink({ - port: fromChromePort(() => (port ||= chrome.runtime.connect({ name: `${NAMESPACE}${type}@${id}` }))), - messagingOptions: { - onDispose() { - port = null; - }, - }, - }); + const messaging = wrapMessaging( + createMessaging( + { ...clientPort, name }, + { + intercept: (data: WebxMessage, abort) => (shouldSkip(name, data) ? abort : data.data), + } + ) + ); + messaging.request = (message) => messaging.requestTo(to, message); + messaging.stream = (message, subscriber) => messaging.streamTo(to, message, subscriber); + const link = messagingLink({ messaging }); const client = createTRPCClient({ links: [link], }); - const messaging = wrapMessaging(link.messaging); return { messaging, client, @@ -78,17 +96,3 @@ export function createTrpcHandler({ type }: TrpcH }, }; } - -function wrapMessaging(messaging: Messaging) { - return { - request(data: unknown, to?: ClientType) { - return messaging.request({ data, to } satisfies WebxMessage); - }, - stream(data: unknown, observer: Partial>, to?: ClientType) { - return messaging.stream({ data, to } satisfies WebxMessage, observer); - }, - dispose: messaging.dispose, - }; -} - -export type WrappedMessaging = ReturnType; diff --git a/packages/messaging/src/content-script.ts b/packages/messaging/src/content-script.ts index f477ed4d..9417fed1 100644 --- a/packages/messaging/src/content-script.ts +++ b/packages/messaging/src/content-script.ts @@ -2,13 +2,13 @@ import { AnyTRPCRouter } from '@trpc/server'; import { SetOptional } from 'type-fest'; import { CustomHandlerOptions, - TrpcHandlerOptions, + TrpcClientOptions, createCustomHandler as internalCreateCustomHandler, - createTrpcHandler as internalCreateTrpcHandler, + createTrpcClient as internalCreateTrpcClient, } from './client-base'; export const createCustomHandler = (options: SetOptional) => internalCreateCustomHandler({ type: 'content-script', ...options }); -export const createTrpcHandler = (options: SetOptional) => - internalCreateTrpcHandler({ type: 'content-script', ...options }); +export const createTrpcClient = (options: SetOptional) => + internalCreateTrpcClient({ type: 'content-script', ...options }); diff --git a/packages/messaging/src/core/__tests__/index.spec.ts b/packages/messaging/src/core/__tests__/index.spec.ts index f2019a13..78469ac3 100644 --- a/packages/messaging/src/core/__tests__/index.spec.ts +++ b/packages/messaging/src/core/__tests__/index.spec.ts @@ -1,37 +1,8 @@ import { setTimeout as sleep } from 'node:timers/promises'; import { expect, it, vi } from 'vitest'; -import { Messaging, createMessaging, fromMessagePort } from '../index'; +import { createMessaging } from '../index'; import { withResolvers } from '../utils'; - -function expectMessagingIsNotLeaked(messaging: Messaging) { - // @ts-expect-error - expect(messaging.ongoingRequestResolvers).toHaveLength(0); - // @ts-expect-error - expect(messaging.ongoingStreamObservers).toHaveLength(0); -} - -it('should on/off listener', async () => { - const { port1, port2 } = new MessageChannel(); - const listenerFn = vi.fn(); - - const resolver = withResolvers(); - const receiver = createMessaging(fromMessagePort(port1), { - on: (...args) => { - listenerFn(...args); - resolver.resolve(); - }, - }); - - await (port2.postMessage('hello'), resolver.promise); - expect(listenerFn).toBeCalledTimes(1); - expect(listenerFn).toBeCalledWith('hello'); - - receiver.dispose(); - await (port2.postMessage('hello'), sleep(10)); - expect(listenerFn).toBeCalledTimes(1); - - expectMessagingIsNotLeaked(receiver); -}); +import { expectMessagingIsNotLeaked, fromMessagePort } from './test-utils'; it('should support request', async () => { const { port1, port2 } = new MessageChannel(); @@ -159,97 +130,6 @@ it('should support abort stream', async () => { expectMessagingIsNotLeaked(sender); }); -it('should support relay request', async () => { - const { port1, port2 } = new MessageChannel(); - const { port1: port3, port2: port4 } = new MessageChannel(); - - const destination = createMessaging(fromMessagePort(port1), { - async onRequest(message) { - switch (message.name) { - case 'hello': - return await sleep(0, `Hello, ${message.user}`); - default: - throw new Error('Unknown method'); - } - }, - }); - const relay1 = createMessaging(fromMessagePort(port2)); - const relay2 = createMessaging(fromMessagePort(port3), { - onRequest() { - return this.relay(relay1); - }, - }); - const sender = createMessaging(fromMessagePort(port4)); - - await expect(sender.request({ name: 'hello', user: 'Tmk' })).resolves.toEqual('Hello, Tmk'); - await expect(sender.request({ name: 'greet', user: 'Tmk' })).rejects.toThrow('Unknown method'); - - expectMessagingIsNotLeaked(destination); - expectMessagingIsNotLeaked(relay1); - expectMessagingIsNotLeaked(relay2); - expectMessagingIsNotLeaked(sender); -}); - -it('should support relay stream', async () => { - const { port1, port2 } = new MessageChannel(); - const { port1: port3, port2: port4 } = new MessageChannel(); - - const destination = createMessaging(fromMessagePort(port1), { - async onStream(message, subscriber) { - switch (message.name) { - case 'hello': { - subscriber.next(1); - subscriber.next(2); - subscriber.next(3); - subscriber.complete(); - } - default: - throw new Error('Unknown method'); - } - }, - }); - const relay1 = createMessaging(fromMessagePort(port2)); - const relay2 = createMessaging(fromMessagePort(port3), { - onStream() { - return this.relay(relay1); - }, - }); - const sender = createMessaging(fromMessagePort(port4)); - - await expect( - new Promise((resolve, reject) => { - const result: unknown[] = []; - sender.stream( - { name: 'hello' }, - { - next: (value) => result.push(value), - error: (reason) => reject(reason), - complete: () => resolve(result), - } - ); - }) - ).resolves.toEqual([1, 2, 3]); - - await expect( - new Promise((resolve, reject) => { - const result: unknown[] = []; - sender.stream( - { name: 'greet' }, - { - next: (value) => result.push(value), - error: (reason) => reject(reason), - complete: () => resolve(result), - } - ); - }) - ).rejects.toThrow('Unknown method'); - - expectMessagingIsNotLeaked(destination); - expectMessagingIsNotLeaked(relay1); - expectMessagingIsNotLeaked(relay2); - expectMessagingIsNotLeaked(sender); -}); - it('should serialize error message', async () => { const { port1, port2 } = new MessageChannel(); diff --git a/packages/messaging/src/core/__tests__/test-utils.ts b/packages/messaging/src/core/__tests__/test-utils.ts new file mode 100644 index 00000000..945dc5cd --- /dev/null +++ b/packages/messaging/src/core/__tests__/test-utils.ts @@ -0,0 +1,20 @@ +import { expect } from 'vitest'; +import { Messaging, Port } from '../index'; + +export function fromMessagePort(port: MessagePort): Port { + return { + send: (message) => port.postMessage(message), + onMessage(listener) { + const ac = new AbortController(); + port.addEventListener('message', (ev) => listener(ev.data), { signal: ac.signal }); + return ac.abort.bind(ac); + }, + }; +} + +export function expectMessagingIsNotLeaked(messaging: Messaging) { + // @ts-expect-error + expect(messaging.ongoingRequestResolvers).toHaveLength(0); + // @ts-expect-error + expect(messaging.ongoingStreamObservers).toHaveLength(0); +} diff --git a/packages/messaging/src/core/__tests__/trpc.spec.ts b/packages/messaging/src/core/__tests__/trpc.spec.ts index 5383b74e..5eb3c16e 100644 --- a/packages/messaging/src/core/__tests__/trpc.spec.ts +++ b/packages/messaging/src/core/__tests__/trpc.spec.ts @@ -4,16 +4,10 @@ import { initTRPC } from '@trpc/server'; import { afterEach, describe, expect, it, vi } from 'vitest'; import { z } from 'zod'; import { applyMessagingHandler, messagingLink } from '../trpc'; -import { Messaging, fromMessagePort } from '../index'; import { observable } from '@trpc/server/observable'; import { withResolvers } from '../utils'; - -function expectMessagingIsNotLeaked(messaging: Messaging) { - // @ts-expect-error - expect(messaging.ongoingRequestResolvers).toHaveLength(0); - // @ts-expect-error - expect(messaging.ongoingStreamObservers).toHaveLength(0); -} +import { expectMessagingIsNotLeaked, fromMessagePort } from './test-utils'; +import { createMessaging } from '..'; describe('Basic', () => { const streamCleanupFn = vi.fn(); @@ -70,14 +64,15 @@ describe('Basic', () => { const server = applyMessagingHandler({ port: fromMessagePort(port1), router: appRouter }); // Client - const link = messagingLink({ port: fromMessagePort(port2) }); + const messaging = createMessaging(fromMessagePort(port2)); + const link = messagingLink({ messaging }); const client = createTRPCClient({ links: [link], }); afterEach(() => { expectMessagingIsNotLeaked(server); - expectMessagingIsNotLeaked(link.messaging); + expectMessagingIsNotLeaked(messaging); }); it('should support query', async () => { diff --git a/packages/messaging/src/core/index.ts b/packages/messaging/src/core/index.ts index 18130ae0..d0417968 100644 --- a/packages/messaging/src/core/index.ts +++ b/packages/messaging/src/core/index.ts @@ -5,10 +5,9 @@ export type SendMessageFunction = (message: any) => void; export type CleanupFunction = VoidFunction; export interface Port { - name: string; - postMessage: SendMessageFunction; - onMessage: (listener: SendMessageFunction) => CleanupFunction; - onDispose: (listener: VoidCallback) => CleanupFunction; + name?: string; + onMessage(listener: (message: any, ...rest: unknown[]) => Promisable): VoidFunction; + send(message: any, originMessage?: [message: any, ...unknown[]]): void; } /** @@ -35,64 +34,66 @@ interface Packet { d: any; } -export interface RequestContext { - relay: (to: Messaging) => Promisable; -} +export type RequestHandler = (message: any) => Promisable; +export type StreamHandler = (message: any, subscriber: Observer) => Promisable; -export interface StreamContext { - relay: (to: Messaging) => Promisable; -} +const INTERCEPT_ABORT = Symbol(); export interface CreateMessagingOptions { - onRequest?: (this: RequestContext, message: any) => Promisable; - onStream?: (this: StreamContext, message: any, subscriber: Observer) => Promisable; - onEvent?: (message: any) => any; - onDispose?: VoidCallback; - on?: (message: any) => any; + intercept?(data: unknown, abort: symbol): unknown | symbol; + onRequest?: RequestHandler; + onStream?: StreamHandler; } function isPacket(message: any): message is Packet { return typeof message === 'object' && message !== null && 't' in message && 'i' in message && 'd' in message; } +function identity(v: T) { + return v; +} + function noop() {} type PromiseResolvers = ReturnType>; export interface Messaging { - name: string; - request(data: unknown): Promise; + name?: string; + request(data: unknown): Promise; stream(data: unknown, observer: Partial>): VoidCallback; dispose(): void; } export function createMessaging(port: Port, options?: CreateMessagingOptions): Messaging { - const { on, onRequest, onStream, onDispose } = options || {}; + const { intercept = identity, onRequest, onStream } = options || {}; // sender side const ongoingRequestResolvers = new Map>(); const ongoingStreamObservers = new Map>>(); // receiver side const processingStreamCleanups = new Map(); - async function handleMessage(message: any) { - on?.(message); + async function handleMessage(...originMessage: [message: any, ...rest: unknown[]]) { + const [message] = originMessage; + if (!isPacket(message)) return; + + function reply(message: Packet) { + port.send(message, originMessage); + } + switch (message.t) { case 'r': { if (!onRequest) return; + const data = intercept(message.d, INTERCEPT_ABORT); + if (data === INTERCEPT_ABORT) return; let d; try { - const context: RequestContext = { - relay(to) { - return to.request(message.d); - }, - }; - const response = await onRequest.call(context, message.d); + const response = await onRequest(data); d = { data: response }; } catch (err) { d = { error: err }; } - port.postMessage({ t: 'R', i: message.i, d } satisfies Packet); + reply({ t: 'R', i: message.i, d }); break; } case 'R': { @@ -108,8 +109,11 @@ export function createMessaging(port: Port, options?: CreateMessagingOptions): M } case 's': { if (!onStream) return; + const data = intercept(message.d, INTERCEPT_ABORT); + if (data === INTERCEPT_ABORT) return; + function terminate(d: { error: unknown } | { complete: boolean }) { - port.postMessage({ t: 'S', i: message.i, d } satisfies Packet); + reply({ t: 'S', i: message.i, d }); processingStreamCleanups.get(message.i)?.(); processingStreamCleanups.delete(message.i); } @@ -122,7 +126,7 @@ export function createMessaging(port: Port, options?: CreateMessagingOptions): M const observer: Observer = { next(value) { - port.postMessage({ t: 'S', i: message.i, d: { next: value } } satisfies Packet); + reply({ t: 'S', i: message.i, d: { next: value } }); }, error(error) { terminate({ error }); @@ -133,12 +137,7 @@ export function createMessaging(port: Port, options?: CreateMessagingOptions): M }; try { - const context: StreamContext = { - relay(to) { - return to.stream(message.d, observer); - }, - }; - const cleanup = await onStream.call(context, message.d, observer); + const cleanup = await onStream(data, observer); processingStreamCleanups.set(message.i, cleanup); } catch (error) { // Error is not serializable in `chrome.runtime.Port` @@ -164,44 +163,7 @@ export function createMessaging(port: Port, options?: CreateMessagingOptions): M } } - let active = false; - const offListeners: VoidFunction[] = []; - - function ensureActive() { - if (active) return; - active = true; - offListeners.push(port.onMessage(handleMessage)); - offListeners.push(port.onDispose(dispose)); - } - function dispose() { - active = false; - onDispose?.(); - offListeners.forEach((offFn) => offFn()); - offListeners.length = 0; - } - ensureActive(); - - function reconnectIfDisconnected any>(fn: T, errorHandler: (error: unknown) => ReturnType) { - try { - // chrome.runtime.Port will [auto disconnect](https://developer.chrome.com/docs/extensions/develop/concepts/messaging?hl=en#port-lifetime) - // so that we need to ensure that there is an active connection - ensureActive(); - return fn(); - } catch (err) { - try { - // When the background is inactive while the page is paused (e.g. during debugging) - // onDisconnect will not be triggered - if (err instanceof Error && err.message.includes('disconnected')) { - dispose(); - ensureActive(); - return fn(); - } - throw err; - } catch (finalErr) { - return errorHandler(finalErr); - } - } - } + const offMessage = port.onMessage(handleMessage); const messaging: Messaging = { name: port.name, @@ -209,31 +171,32 @@ export function createMessaging(port: Port, options?: CreateMessagingOptions): M const resolvers = withResolvers(); const id = randomID(); ongoingRequestResolvers.set(id, resolvers); - reconnectIfDisconnected(() => { - port.postMessage({ t: 'r', i: id, d: data } satisfies Packet); - }, resolvers.reject); + try { + port.send({ t: 'r', i: id, d: data } satisfies Packet); + } catch (err) { + resolvers.reject(err); + } return resolvers.promise; }, stream(data, observer) { const id = randomID(); ongoingStreamObservers.set(id, observer); - return reconnectIfDisconnected( - () => { - port.postMessage({ t: 's', i: id, d: data } satisfies Packet); - return () => { - if (!ongoingStreamObservers.has(id)) return; - port.postMessage({ t: 's', i: id, d: null } satisfies Packet); - }; - }, - (err) => { - queueMicrotask(() => { - observer.error?.(err); - }); - return noop; - } - ); + try { + port.send({ t: 's', i: id, d: data } satisfies Packet); + return () => { + if (!ongoingStreamObservers.has(id)) return; + port.send({ t: 's', i: id, d: null } satisfies Packet); + }; + } catch (err) { + queueMicrotask(() => { + observer.error?.(err); + }); + return noop; + } + }, + dispose() { + offMessage(); }, - dispose, }; if (process.env.NODE_ENV === 'test') { @@ -245,40 +208,3 @@ export function createMessaging(port: Port, options?: CreateMessagingOptions): M return messaging; } - -export function fromMessagePort(port: MessagePort): Port { - return { - name: randomID(), - postMessage: port.postMessage.bind(port), - onMessage(listener) { - const ac = new AbortController(); - port.addEventListener('message', (ev) => listener(ev.data), { signal: ac.signal }); - return ac.abort.bind(ac); - }, - onDispose() { - return noop; - }, - }; -} - -export function fromChromePort(portResolver: chrome.runtime.Port | (() => chrome.runtime.Port)): Port { - const getPort = () => (typeof portResolver === 'function' ? portResolver() : portResolver); - return { - name: getPort().name, - postMessage: (message) => getPort().postMessage(message), - onMessage(listener) { - const port = getPort(); - port.onMessage.addListener(listener); - return () => { - port.onMessage.removeListener(listener); - }; - }, - onDispose(listener) { - const port = getPort(); - port.onDisconnect.addListener(listener); - return () => { - port.onDisconnect.removeListener(listener); - }; - }, - }; -} diff --git a/packages/messaging/src/core/trpc.ts b/packages/messaging/src/core/trpc.ts index 42866b88..1ad3b844 100644 --- a/packages/messaging/src/core/trpc.ts +++ b/packages/messaging/src/core/trpc.ts @@ -8,20 +8,21 @@ import { } from '@trpc/server'; import { TRPCResponseMessage, transformResult } from '@trpc/server/unstable-core-do-not-import'; import { isObservable, observable } from '@trpc/server/observable'; -import { CreateMessagingOptions, Messaging, Port, createMessaging } from './index'; import { Operation, TRPCClientError, TRPCLink } from '@trpc/client'; +import { CreateMessagingOptions, Messaging, Port, createMessaging } from './index'; export interface MessagingHandlerOptions { port: Port; router: TRouter; - onDispose?: VoidCallback; + intercept?(data: unknown, abort: symbol): unknown | symbol; } export function applyMessagingHandler(options: MessagingHandlerOptions) { - const { port, router, onDispose } = options; + const { port, router, intercept } = options; const { procedures, _config: rootConfig } = router._def; const server = createMessaging(port, { + intercept, async onRequest(message) { const { type, path, input, context: ctx } = message as Operation; try { @@ -91,30 +92,28 @@ export function applyMessagingHandler(options: Me ); } }, - onDispose, }); return server; } export interface MessagingLinkOptions { - port: Port; - messagingOptions?: CreateMessagingOptions; + messaging: Messaging; } export function messagingLink( options: MessagingLinkOptions -): TRPCLink & { messaging: Messaging } { - const { port, messagingOptions } = options; - const client = createMessaging(port, messagingOptions); - const link: TRPCLink & { messaging: Messaging } = (runtime) => { +): TRPCLink { + const { messaging } = options; + + return (runtime) => { return ({ op }) => { const { type, path, id, context } = op; const input = runtime.transformer.serialize(op.input); if (op.type !== 'subscription') { return observable((observer) => { - client + messaging .request({ type, path, input, id, context }) .then((response) => { const transformed = transformResult(response as any, runtime.transformer); @@ -133,7 +132,7 @@ export function messagingLink( } return observable((observer) => { - const unsub = client.stream( + const unsub = messaging.stream( { type, path, input, id, context }, { error(err) { @@ -160,7 +159,4 @@ export function messagingLink( }); }; }; - link.messaging = client; - - return link; } diff --git a/packages/messaging/src/shared.ts b/packages/messaging/src/shared.ts index adfabfb5..ac037e76 100644 --- a/packages/messaging/src/shared.ts +++ b/packages/messaging/src/shared.ts @@ -1,19 +1,14 @@ import type { LiteralUnion, Observer } from 'type-fest'; - -export const NAMESPACE = 'webx:'; +import { Messaging } from './core'; export type ClientType = 'devtools' | 'popup' | 'options' | 'content-script'; -export type MessageTarget = LiteralUnion; +export type MessageTarget = LiteralUnion | number; export type RequestHandler = (message: WebxMessage) => any; export type StreamHandler = (message: WebxMessage, subscriber: Observer) => any; export interface WebxMessage { - /** Source ID */ - from?: string; - /** Source tab ID */ - tabId?: number; /** Target */ to?: MessageTarget; /** Structural cloneable data */ @@ -23,3 +18,23 @@ export interface WebxMessage { export function isWebxMessage(message: unknown): message is WebxMessage { return typeof message === 'object' && message !== null && 'data' in message; } + +export function wrapMessaging(messaging: Messaging) { + return { + request(data: unknown) { + return messaging.request({ data } satisfies WebxMessage); + }, + stream(data: unknown, observer: Partial>) { + return messaging.stream({ data } satisfies WebxMessage, observer); + }, + requestTo(to: MessageTarget, data: unknown) { + return messaging.request({ data, to } satisfies WebxMessage); + }, + streamTo(to: MessageTarget, data: unknown, observer: Partial>) { + return messaging.stream({ data, to } satisfies WebxMessage, observer); + }, + dispose: messaging.dispose, + }; +} + +export type WrappedMessaging = ReturnType;