From 0bc5c3f9308ca9d256d5aca204c9457851d00003 Mon Sep 17 00:00:00 2001 From: Sam Sussman Date: Sat, 16 Sep 2023 15:59:13 -0500 Subject: [PATCH] fix: fix tests, local, and socket middleware interface (#447) --- apps/tests/aws-runtime-cdk/src/app.mts | 4 + apps/tests/aws-runtime/test/env.ts | 3 + apps/tests/aws-runtime/test/test-service.ts | 123 +++++------------- apps/tests/aws-runtime/test/tester.test.ts | 78 ++++++++++- packages/@eventual/cli/src/commands/local.ts | 2 +- packages/@eventual/compiler/src/ast-util.ts | 12 ++ .../@eventual/compiler/src/eventual-infer.ts | 2 + .../__snapshots__/infer-plugin.test.ts.snap | 23 +++- packages/@eventual/core/src/socket/socket.ts | 53 ++++++-- 9 files changed, 188 insertions(+), 112 deletions(-) diff --git a/apps/tests/aws-runtime-cdk/src/app.mts b/apps/tests/aws-runtime-cdk/src/app.mts index 9354eee58..cb0f336b1 100644 --- a/apps/tests/aws-runtime-cdk/src/app.mts +++ b/apps/tests/aws-runtime-cdk/src/app.mts @@ -154,6 +154,10 @@ new CfnOutput(stack, "serviceUrl", { value: testService.gateway.apiEndpoint, }); +new CfnOutput(stack, "testSocketUrl", { + value: testService.sockets.socket1.gatewayStage.url, +}); + new CfnOutput(stack, "chaosParamName", { value: chaosExtension.ssm.parameterName, }); diff --git a/apps/tests/aws-runtime/test/env.ts b/apps/tests/aws-runtime/test/env.ts index 86ea917ec..a4ef651e7 100644 --- a/apps/tests/aws-runtime/test/env.ts +++ b/apps/tests/aws-runtime/test/env.ts @@ -10,5 +10,8 @@ const outputs = fs.existsSync(path.resolve(outputsFile)) export const awsRegion = () => process.env.AWS_REGION ?? "us-east-1"; export const serviceUrl = () => outputs?.["eventual-tests"]?.serviceUrl ?? "http://localhost:3111"; +export const testSocketUrl = () => + outputs?.["eventual-tests"]?.testSocketUrl ?? + "http://localhost:3111/__ws/socket1"; export const chaosSSMParamName = () => outputs?.["eventual-tests"]?.chaosParamName; diff --git a/apps/tests/aws-runtime/test/test-service.ts b/apps/tests/aws-runtime/test/test-service.ts index a01fd6571..0de13fa96 100644 --- a/apps/tests/aws-runtime/test/test-service.ts +++ b/apps/tests/aws-runtime/test/test-service.ts @@ -37,7 +37,6 @@ import { } from "@eventual/core"; import type openapi from "openapi3-ts"; import { Readable } from "stream"; -import { WebSocket } from "ws"; import z from "zod"; import { AsyncWriterTestEvent } from "./async-writer-handler.js"; @@ -1392,7 +1391,7 @@ export const searchBlog = command( * 7. the socket will resolve the value in the connection, completing the test */ -interface SocketMessage { +export interface SocketMessage { id: string; v: number; } @@ -1413,31 +1412,33 @@ const jsonSocket = socket.use({ }, }); -export const socket1 = jsonSocket.use(({ request, context, next }) => { - const { id, n } = (request.query ?? {}) as { n?: string; id?: string }; - if (!id || !n) { - throw new Error("Missing ID"); - } - return next({ ...context, id, n }); -})("socket1", { - $connect: async ({ connectionId }, { id, n }) => { - console.log("sending signal to", id); - await socketConnectSignal.sendSignal(id, { - connectionId, - n: Number(n), - }); - console.log("signal sent to", id); - }, - $disconnect: async () => undefined, - $default: async ({ connectionId }, { data }) => { - console.log("sending signal to", data.id); - await socketMessageSignal.sendSignal(data.id, { - ...data, - connectionId, - }); - console.log("signal sent to", data.id); - }, -}); +export const socket1 = jsonSocket + .use(({ request, context, next }) => { + const { id, n } = (request.query ?? {}) as { n?: string; id?: string }; + if (!id || !n) { + throw new Error("Missing ID"); + } + return next({ ...context, id, n }); + }) + .socket("socket1", { + $connect: async ({ connectionId }, { id, n }) => { + console.log("sending signal to", id); + await socketConnectSignal.sendSignal(id, { + connectionId, + n: Number(n), + }); + console.log("signal sent to", id); + }, + $disconnect: async () => undefined, + $default: async ({ connectionId }, { data }) => { + console.log("sending signal to", data.id); + await socketMessageSignal.sendSignal(data.id, { + ...data, + connectionId, + }); + console.log("signal sent to", data.id); + }, + }); export const socketConnectSignal = signal<{ connectionId: string; n: number }>( "socketConnectSignal" @@ -1447,13 +1448,13 @@ export const socketMessageSignal = signal<{ v: number; }>("socketMessageSignal"); -interface StartSocketEvent { +export interface StartSocketEvent { type: "start"; n: number; v: number; } -interface DataSocketEvent { +export interface DataSocketEvent { type: "data"; n: number; v: number; @@ -1520,69 +1521,7 @@ export const socketTest = command( const { executionId } = await socketWorkflow.startExecution({ input: undefined, }); - const encodedId = encodeURIComponent(executionId); - - console.log("pre-socket"); - - const ws1 = new WebSocket(`${socket1.wssEndpoint}?id=${encodedId}&n=0`); - const ws2 = new WebSocket(`${socket1.wssEndpoint}?id=${encodedId}&n=1`); - - console.log("setup-socket"); - const running1 = setupWS(executionId, ws1); - const running2 = setupWS(executionId, ws2); - - console.log("waiting..."); - - return await Promise.all([running1, running2]); + return executionId; } ); - -function setupWS(executionId: string, ws: WebSocket) { - let n: number | undefined; - let v: number | undefined; - return new Promise((resolve, reject) => { - ws.on("error", (err) => { - console.log("error", err); - reject(err); - }); - ws.on("message", (data) => { - try { - console.log(n, "message"); - const d = (data as Buffer).toString("utf8"); - console.log(d); - const event = JSON.parse(d) as StartSocketEvent | DataSocketEvent; - if (event.type === "start") { - n = event.n; - // after connecting, we will send our "n" and incremented "value" back. - ws.send( - JSON.stringify({ - id: executionId, - v: event.v + 1, - } satisfies SocketMessage) - ); - } else if (event.type === "data") { - v = event.v; - } else { - console.log("unexpected event", event); - reject(event); - } - } catch (err) { - console.error(err); - reject(err); - } - }); - ws.on("close", (code, reason) => { - try { - console.log(code, reason.toString("utf-8")); - console.log(n, "close", v); - if (n === undefined) { - throw new Error("n was not set"); - } - resolve(v ?? -1); - } catch (err) { - reject(err); - } - }); - }); -} diff --git a/apps/tests/aws-runtime/test/tester.test.ts b/apps/tests/aws-runtime/test/tester.test.ts index 426f4e5e1..bb9d0dc48 100644 --- a/apps/tests/aws-runtime/test/tester.test.ts +++ b/apps/tests/aws-runtime/test/tester.test.ts @@ -6,8 +6,9 @@ import { ServiceContext, } from "@eventual/core"; import { jest } from "@jest/globals"; +import { WebSocket } from "ws"; import { ChaosEffects, ChaosTargets } from "./chaos-extension/chaos-engine.js"; -import { serviceUrl } from "./env.js"; +import { serviceUrl, testSocketUrl } from "./env.js"; import { eventualRuntimeTestHarness } from "./runtime-test-harness.js"; import type * as TestService from "./test-service.js"; import { @@ -15,12 +16,15 @@ import { asyncWorkflow, bucketWorkflow, createAndDestroyWorkflow, + DataSocketEvent, entityWorkflow, eventDrivenWorkflow, failedWorkflow, heartbeatWorkflow, parentWorkflow, queueWorkflow, + SocketMessage, + StartSocketEvent, timedOutWorkflow, timedWorkflow, transactionWorkflow, @@ -302,6 +306,7 @@ eventualRuntimeTestHarness( ); const url = serviceUrl(); +const socketUrl = testSocketUrl(); test("hello API should route and return OK response", async () => { const restResponse = await (await fetch(`${url}/hello`)).json(); @@ -445,15 +450,80 @@ test("test service context", async () => { }); test("socket test", async () => { - const rpcResponse = await ( + const executionId = (await ( await fetch(`${url}/${commandRpcPath({ name: "socketTest" })}`, { method: "POST", }) - ).json(); + ).json()) as string; + + const encodedId = encodeURIComponent(executionId); + + console.log("pre-socket"); + + const ws1 = new WebSocket(`${socketUrl}?id=${encodedId}&n=0`); + const ws2 = new WebSocket(`${socketUrl}?id=${encodedId}&n=1`); + + console.log("setup-socket"); + + const running1 = setupWS(executionId, ws1); + const running2 = setupWS(executionId, ws2); + + console.log("waiting..."); + + const result = await Promise.all([running1, running2]); - expect(rpcResponse).toEqual([3, 4]); + expect(result).toEqual([3, 4]); }); +function setupWS(executionId: string, ws: WebSocket) { + let n: number | undefined; + let v: number | undefined; + return new Promise((resolve, reject) => { + ws.on("error", (err) => { + console.log("error", err); + reject(err); + }); + ws.on("message", (data) => { + try { + console.log(n, "message"); + const d = (data as Buffer).toString("utf8"); + console.log(d); + const event = JSON.parse(d) as StartSocketEvent | DataSocketEvent; + if (event.type === "start") { + n = event.n; + // after connecting, we will send our "n" and incremented "value" back. + ws.send( + JSON.stringify({ + id: executionId, + v: event.v + 1, + } satisfies SocketMessage) + ); + } else if (event.type === "data") { + v = event.v; + } else { + console.log("unexpected event", event); + reject(event); + } + } catch (err) { + console.error(err); + reject(err); + } + }); + ws.on("close", (code, reason) => { + try { + console.log(code, reason.toString("utf-8")); + console.log(n, "close", v); + if (n === undefined) { + throw new Error("n was not set"); + } + resolve(v ?? -1); + } catch (err) { + reject(err); + } + }); + }); +} + if (!process.env.TEST_LOCAL) { test("index.search", async () => { const serviceClient = new ServiceClient({ diff --git a/packages/@eventual/cli/src/commands/local.ts b/packages/@eventual/cli/src/commands/local.ts index 81a900fc3..43ca125f6 100644 --- a/packages/@eventual/cli/src/commands/local.ts +++ b/packages/@eventual/cli/src/commands/local.ts @@ -190,7 +190,7 @@ export const local = (yargs: Argv) => const hasSockets = serviceSpec.sockets.length > 0; if (hasSockets) { - const wss = new WebSocketServer({ server }); + const wss = new WebSocketServer({ noServer: true }); server.on("upgrade", (request, socket, head) => { if (request.url?.startsWith("/__ws/")) { diff --git a/packages/@eventual/compiler/src/ast-util.ts b/packages/@eventual/compiler/src/ast-util.ts index 200f41e69..1d8256905 100644 --- a/packages/@eventual/compiler/src/ast-util.ts +++ b/packages/@eventual/compiler/src/ast-util.ts @@ -118,6 +118,18 @@ export function isSocketResourceCall(call: CallExpression): boolean { return false; } +export function isSocketMemberCall(call: CallExpression): boolean { + const c = call.callee; + if (c.type === "MemberExpression") { + if (isId(c.property, "socket")) { + // socket.use().socket("handlerName", async () => { }) + // socket.use().socket("handlerName", options, async () => { }) + return call.arguments.length === 2 || call.arguments.length === 3; + } + } + return false; +} + /** * A heuristic for identifying a {@link CallExpression} that is a call to an `subscription` handler. * diff --git a/packages/@eventual/compiler/src/eventual-infer.ts b/packages/@eventual/compiler/src/eventual-infer.ts index ec30ce0e5..6d0ca627d 100644 --- a/packages/@eventual/compiler/src/eventual-infer.ts +++ b/packages/@eventual/compiler/src/eventual-infer.ts @@ -33,6 +33,7 @@ import { isEntityStreamMemberCall, isOnEventCall, isQueueResourceCall, + isSocketMemberCall, isSocketResourceCall, isSubscriptionCall, isTaskCall, @@ -279,6 +280,7 @@ export class InferVisitor extends Visitor { isBucketHandlerMemberCall, isQueueResourceCall, isSocketResourceCall, + isSocketMemberCall, ].some((op) => op(call)) ) { this.didMutate = true; diff --git a/packages/@eventual/compiler/test/__snapshots__/infer-plugin.test.ts.snap b/packages/@eventual/compiler/test/__snapshots__/infer-plugin.test.ts.snap index 05545377e..40a2f116e 100644 --- a/packages/@eventual/compiler/test/__snapshots__/infer-plugin.test.ts.snap +++ b/packages/@eventual/compiler/test/__snapshots__/infer-plugin.test.ts.snap @@ -312,8 +312,8 @@ globalThis.tryGetEventualHook ??= () => { return void 0; }; -function createSocketBuilder(middlewares) { - const socketFunction = (...args) => { +function createSocketFunction(middlewares) { + return (...args) => { const { sourceLocation, name, options, handlers } = parseSocketArgs(args); const socket2 = { middlewares, @@ -350,14 +350,27 @@ function createSocketBuilder(middlewares) { }; return registerEventualResource("Socket", socket2); }; - const useFunction = (socketMiddleware) => { +} +function createUseFunction(middlewares) { + return (socketMiddleware) => { const middleware = typeof socketMiddleware === "function" ? { connect: socketMiddleware } : socketMiddleware; - return createSocketBuilder([...middlewares, middleware]); + return createSocketRouter([...middlewares, middleware]); + }; +} +function createSocketRouter(middlewares) { + return { + middlewares, + use: createUseFunction(middlewares), + socket: createSocketFunction(middlewares) }; +} +function createSocketBuilder() { + const socketFunction = createSocketFunction([]); + const useFunction = createUseFunction([]); socketFunction.use = useFunction; return socketFunction; } -var socket = createSocketBuilder([]); +var socket = createSocketBuilder(); function parseSocketArgs(args) { return parseArgs(args, { sourceLocation: isSourceLocation, diff --git a/packages/@eventual/core/src/socket/socket.ts b/packages/@eventual/core/src/socket/socket.ts index 8ea5dff09..631cac390 100644 --- a/packages/@eventual/core/src/socket/socket.ts +++ b/packages/@eventual/core/src/socket/socket.ts @@ -96,7 +96,7 @@ export type Socket< export type SocketOptions = FunctionRuntimeProps; -export interface socket { +export interface SocketRouter { middlewares: SocketMiddleware[]; use< NextConnectContext extends SocketHandlerContext = Context["connect"], @@ -117,9 +117,22 @@ export interface socket { Context["connect"], NextConnectContext > - ): socket< + ): SocketRouter< SocketContext >; + socket( + name: Name, + options: SocketOptions, + handlers: SocketHandlers + ): Socket; + socket( + name: Name, + handlers: SocketHandlers + ): Socket; +} + +export interface socket + extends SocketRouter { ( name: Name, options: SocketOptions, @@ -131,10 +144,10 @@ export interface socket { >; } -function createSocketBuilder( +function createSocketFunction( middlewares: SocketMiddleware[] -): socket { - const socketFunction = ( +) { + return ( ...args: | [name: Name, options: SocketOptions, handlers: SocketHandlers] | [name: Name, handlers: SocketHandlers] @@ -185,9 +198,12 @@ function createSocketBuilder( return registerEventualResource("Socket", socket as any) as Socket; }; - const useFunction: socket["use"] = < - NextContext extends SocketContext = Context - >( +} + +function createUseFunction( + middlewares: SocketMiddleware[] +): SocketRouter["use"] { + return ( socketMiddleware: | SocketMiddleware | SocketMiddlewareFunction< @@ -200,14 +216,31 @@ function createSocketBuilder( typeof socketMiddleware === "function" ? { connect: socketMiddleware } : socketMiddleware; - return createSocketBuilder([...middlewares, middleware]); + return createSocketRouter([...middlewares, middleware]); + }; +} + +function createSocketRouter( + middlewares: SocketMiddleware[] +): SocketRouter { + return { + middlewares, + use: createUseFunction(middlewares), + socket: createSocketFunction(middlewares), }; +} + +function createSocketBuilder< + Context extends SocketContext = SocketContext +>(): socket { + const socketFunction = createSocketFunction([]); + const useFunction = createUseFunction([]); (socketFunction as unknown as socket).use = useFunction; return socketFunction as unknown as socket; } -export const socket = createSocketBuilder([]); +export const socket = createSocketBuilder(); export function parseSocketArgs(args: any[]) { return parseArgs(args, {