Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Ensure websocket parity with API Gateway #1301

Merged
merged 5 commits into from
Feb 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 56 additions & 31 deletions src/events/websocket/WebSocketClients.js
Original file line number Diff line number Diff line change
Expand Up @@ -101,14 +101,43 @@ export default class WebSocketClients {
clearTimeout(timeoutId)
}

async _processEvent(websocketClient, connectionId, route, event) {
let functionKey = this.#webSocketRoutes.get(route)
async verifyClient(connectionId, request) {
const route = this.#webSocketRoutes.get('$connect')
if (!route) {
return { verified: false, statusCode: 502 }
}

const connectEvent = new WebSocketConnectEvent(
connectionId,
request,
this.#options,
).create()

const lambdaFunction = this.#lambda.get(route.functionKey)
lambdaFunction.setEvent(connectEvent)

try {
const { statusCode } = await lambdaFunction.runHandler()
const verified = statusCode >= 200 && statusCode < 300
return { verified, statusCode }
} catch (err) {
if (this.log) {
this.log.debug(`Error in route handler '${route.functionKey}'`, err)
} else {
debugLog(`Error in route handler '${route.functionKey}'`, err)
}
return { verified: false, statusCode: 502 }
}
}

async _processEvent(websocketClient, connectionId, routeKey, event) {
let route = this.#webSocketRoutes.get(routeKey)

if (!functionKey && route !== '$connect' && route !== '$disconnect') {
functionKey = this.#webSocketRoutes.get('$default')
if (!route && routeKey !== '$disconnect') {
route = this.#webSocketRoutes.get('$default')
}

if (!functionKey) {
if (!route) {
return
}

Expand All @@ -123,28 +152,29 @@ export default class WebSocketClients {
)
}

// mimic AWS behaviour (close connection) when the $connect route handler throws
if (route === '$connect') {
websocketClient.close()
}

if (this.log) {
this.log.debug(`Error in route handler '${functionKey}'`, err)
this.log.debug(`Error in route handler '${route.functionKey}'`, err)
} else {
debugLog(`Error in route handler '${functionKey}'`, err)
debugLog(`Error in route handler '${route.functionKey}'`, err)
}
}

const lambdaFunction = this.#lambda.get(functionKey)
const lambdaFunction = this.#lambda.get(route.functionKey)

lambdaFunction.setEvent(event)

// let result

try {
/* result = */ await lambdaFunction.runHandler()

// TODO what to do with "result"?
const { body } = await lambdaFunction.runHandler()
if (
body &&
routeKey !== '$disconnect' &&
route.definition.routeResponseSelectionExpression === '$default'
) {
// https://docs.aws.amazon.com/apigateway/latest/developerguide/apigateway-websocket-api-selection-expressions.html#apigateway-websocket-api-route-response-selection-expressions
// TODO: Once API gateway supports RouteResponses, this will need to change to support that functionality
// For now, send body back to the client
this.send(connectionId, body)
}
} catch (err) {
if (this.log) {
this.log.error(err)
Expand Down Expand Up @@ -176,17 +206,9 @@ export default class WebSocketClients {
return route || DEFAULT_WEBSOCKETS_ROUTE
}

addClient(webSocketClient, request, connectionId) {
addClient(webSocketClient, connectionId) {
this._addWebSocketClient(webSocketClient, connectionId)

const connectEvent = new WebSocketConnectEvent(
connectionId,
request,
this.#options,
).create()

this._processEvent(webSocketClient, connectionId, '$connect', connectEvent)

webSocketClient.on('close', () => {
if (this.log) {
this.log.debug(`disconnect:${connectionId}`)
Expand Down Expand Up @@ -233,14 +255,17 @@ export default class WebSocketClients {
})
}

addRoute(functionKey, route) {
addRoute(functionKey, definition) {
// set the route name
this.#webSocketRoutes.set(route, functionKey)
this.#webSocketRoutes.set(definition.route, {
functionKey,
definition,
})

if (this.log) {
this.log.notice(`route '${route}'`)
this.log.notice(`route '${definition}'`)
} else {
serverlessLog(`route '${route}'`)
serverlessLog(`route '${definition}'`)
}
}

Expand Down
39 changes: 36 additions & 3 deletions src/events/websocket/WebSocketServer.js
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import { createUniqueId } from '../../utils/index.js'
export default class WebSocketServer {
#options = null
#webSocketClients = null
#connectionIds = new Map()

constructor(options, webSocketClients, sharedServer, v3Utils) {
this.#options = options
Expand All @@ -20,6 +21,35 @@ export default class WebSocketServer {

const server = new Server({
server: sharedServer,
verifyClient: ({ req }, cb) => {
const connectionId = createUniqueId()
const { headers } = req
const key = headers['sec-websocket-key']

if (this.log) {
this.log.debug(`verifyClient:${key} ${connectionId}`)
} else {
debugLog(`verifyClient:${key} ${connectionId}`)
}

// Use the websocket key to coorelate connection IDs
this.#connectionIds[key] = connectionId

this.#webSocketClients
.verifyClient(connectionId, req)
.then(({ verified, statusCode }) => {
try {
if (!verified) {
cb(false, statusCode)
return
}
cb(true)
} catch (e) {
debugLog(`Error verifying`, e)
cb(false)
}
})
},
})

server.on('connection', (webSocketClient, request) => {
Expand All @@ -29,15 +59,18 @@ export default class WebSocketServer {
console.log('received connection')
}

const connectionId = createUniqueId()
const { headers } = request
const key = headers['sec-websocket-key']

const connectionId = this.#connectionIds[key]

if (this.log) {
this.log.debug(`connect:${connectionId}`)
} else {
debugLog(`connect:${connectionId}`)
}

this.#webSocketClients.addClient(webSocketClient, request, connectionId)
this.#webSocketClients.addClient(webSocketClient, connectionId)
})
}

Expand All @@ -63,7 +96,7 @@ export default class WebSocketServer {
stop() {}

addRoute(functionKey, webSocketEvent) {
this.#webSocketClients.addRoute(functionKey, webSocketEvent.route)
this.#webSocketClients.addRoute(functionKey, webSocketEvent)
// serverlessLog(`route '${route}'`)
}
}
24 changes: 24 additions & 0 deletions tests/integration/_testHelpers/websocketPromise.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
const websocketSend = (ws, data) =>
new Promise((res) => {
ws.on('open', () => {
ws.send(data, (e) => {
if (e) {
res({ err: e })
}
})
})
ws.on('close', (c) => {
res({ code: c })
})
ws.on('message', (d) => {
res({ data: d })
})
ws.on('error', (e) => {
res({ err: e })
})
setTimeout(() => {
res({})
}, 5000)
})

export default websocketSend
19 changes: 19 additions & 0 deletions tests/integration/websocket-oneway/handler.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
'use strict'

exports.handler = async (event) => {
const { body, requestContext } = event

if (
body &&
JSON.parse(body).throwError &&
requestContext &&
requestContext.routeKey === '$default'
) {
throw new Error('Throwing error from incoming message')
}

return {
statusCode: 200,
body: body || undefined,
}
}
26 changes: 26 additions & 0 deletions tests/integration/websocket-oneway/serverless.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
service: oneway-websocket-tests

plugins:
- ../../../

provider:
memorySize: 128
name: aws
region: us-east-1 # default
runtime: nodejs12.x
stage: dev
versionFunctions: false

functions:
handler:
handler: handler.handler
events:
- http:
path: echo
method: get
- websocket:
route: $connect
- websocket:
route: $disconnect
- websocket:
route: $default
55 changes: 55 additions & 0 deletions tests/integration/websocket-oneway/websocket-oneway.test.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import { resolve } from 'path'
import WebSocket from 'ws'
import { joinUrl, setup, teardown } from '../_testHelpers/index.js'
import websocketSend from '../_testHelpers/websocketPromise.js'

jest.setTimeout(30000)

describe('one way websocket tests', () => {
// init
beforeAll(() =>
setup({
servicePath: resolve(__dirname),
}),
)

// cleanup
afterAll(() => teardown())

test('websocket echos nothing', async () => {
const url = new URL(joinUrl(TEST_BASE_URL, '/dev'))
url.port = url.port ? '3001' : url.port
url.protocol = 'ws'

const payload = JSON.stringify({
hello: 'world',
now: new Date().toISOString(),
})

const ws = new WebSocket(url.toString())
const { data, code, err } = await websocketSend(ws, payload)

expect(code).toBeUndefined()
expect(err).toBeUndefined()
expect(data).toBeUndefined()
})

test('execution error emits Internal Server Error', async () => {
const url = new URL(joinUrl(TEST_BASE_URL, '/dev'))
url.port = url.port ? '3001' : url.port
url.protocol = 'ws'

const payload = JSON.stringify({
hello: 'world',
now: new Date().toISOString(),
throwError: true,
})

const ws = new WebSocket(url.toString())
const { data, code, err } = await websocketSend(ws, payload)

expect(code).toBeUndefined()
expect(err).toBeUndefined()
expect(JSON.parse(data).message).toEqual('Internal server error')
})
})
32 changes: 32 additions & 0 deletions tests/integration/websocket-twoway/handler.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
'use strict'

exports.handler = async (event) => {
const { body, queryStringParameters, requestContext } = event
const statusCode =
queryStringParameters && queryStringParameters.statusCode
? Number(queryStringParameters.statusCode)
: 200

if (
queryStringParameters &&
queryStringParameters.throwError &&
requestContext &&
requestContext.routeKey === '$connect'
) {
throw new Error('Throwing error during connect phase')
}

if (
body &&
JSON.parse(body).throwError &&
requestContext &&
requestContext.routeKey === '$default'
) {
throw new Error('Throwing error from incoming message')
}

return {
statusCode,
body: body || undefined,
}
}
28 changes: 28 additions & 0 deletions tests/integration/websocket-twoway/serverless.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
service: twoway-websocket-tests

plugins:
- ../../../

provider:
memorySize: 128
name: aws
region: us-east-1 # default
runtime: nodejs12.x
stage: dev
versionFunctions: false

functions:
handler:
handler: handler.handler
events:
- http:
path: echo
method: get
- websocket:
route: $connect
- websocket:
route: $disconnect
- websocket:
route: $default
# Enable 2-way comms
routeResponseSelectionExpression: $default
Loading