diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index a5e8b11c..5777ecef 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -29,7 +29,7 @@ Please make sure to include tests with all pull requests. You can create a new service using the `npm run service:create` command. You need to pass in the name of the entity you want the service to interact with. -For example, if you are adding a "Global Stats" service, you would run: `npm run service:create global-stat` (note that the entity name is singular and not a plural). +For example, if you are adding a "Global Stats" service, you would run: `npm run service:create -- global-stat` (note that the entity name is singular and not a plural). This will create a policy, entity and REST API for your new entity. If you want to expose API endpoints (so that it can be used by the Unity SDK), add `--api` to the end of the command. @@ -42,3 +42,7 @@ Modify the default name of the file from `Migration[Timestamp].ts` to `[Timestam You should also rename the exported class to be `[PascalCaseDescriptionOfTheMigration]`. You will then need to import and add that migration class to the end of the list of migrations inside `index.ts` in the same folder. + +### ClickHouse migrations + +ClickHouse migrations are created in the `src/migrations/clickhouse` folder. These are manually created and should be added to the `src/migrations/clickhouse/index.ts` file. The migration script will automatically run the migration if it hasn't already been applied. diff --git a/_templates/service/new/api-service-test.ejs.t b/_templates/service/new/api-service-test.ejs.t index a9596359..fe592f0f 100644 --- a/_templates/service/new/api-service-test.ejs.t +++ b/_templates/service/new/api-service-test.ejs.t @@ -1,7 +1,6 @@ --- to: "<%= (typeof api !== 'undefined') ? `tests/services/_api/${name}-api/post.test.ts` : null %>" --- -import { EntityManager } from '@mikro-orm/mysql' import request from 'supertest' import { APIKeyScope } from '../../../../src/entities/api-key' import createAPIKeyAndToken from '../../../utils/createAPIKeyAndToken' diff --git a/_templates/service/new/service-test.ejs.t b/_templates/service/new/service-test.ejs.t index 3a00534e..dbf4f925 100644 --- a/_templates/service/new/service-test.ejs.t +++ b/_templates/service/new/service-test.ejs.t @@ -4,14 +4,16 @@ to: tests/services/<%= name %>/get.test.ts import request from 'supertest' import createUserAndToken from '../../utils/createUserAndToken' import <%= h.changeCase.pascal(name) %>Factory from '../../fixtures/<%= h.changeCase.pascal(name) %>Factory' +import { EntityManager } from '@mikro-orm/mysql' describe('<%= h.changeCase.sentenceCase(name) %> service - get', () => { - it('should return a list of <%= h.changeCase.noCase(name) %>s', async () => { + it('should return a of <%= h.changeCase.noCase(name) %>s', async () => { const [token] = await createUserAndToken() - const <%= name %> = await new <%= h.changeCase.pascal(name) %>Factory().one() + const <%= h.changeCase.camel(name) %> = await new <%= h.changeCase.pascal(name) %>Factory().one() + await (global.em).persistAndFlush(<%= h.changeCase.camel(name) %>) await request(global.app) - .get(`/<%= name %>/<%= name %>.id`) + .get(`/<%= name %>/${<%= h.changeCase.camel(name) %>.id}`) .auth(token, { type: 'bearer' }) .expect(200) }) diff --git a/package-lock.json b/package-lock.json index 9213b58b..fcf79f32 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,12 +1,12 @@ { "name": "game-services", - "version": "0.54.0", + "version": "0.55.0", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "game-services", - "version": "0.54.0", + "version": "0.55.0", "license": "MIT", "dependencies": { "@clickhouse/client": "^1.4.1", @@ -65,7 +65,6 @@ "hygen": "^6.2.11", "lint-staged": ">=10", "supertest": "^7.0.0", - "superwstest": "^2.0.4", "ts-node": "^10.7.0", "tsx": "^4.11.0", "typescript": "^5.4.5", @@ -7944,25 +7943,6 @@ "node": ">=14.18.0" } }, - "node_modules/superwstest": { - "version": "2.0.4", - "resolved": "https://registry.npmjs.org/superwstest/-/superwstest-2.0.4.tgz", - "integrity": "sha512-7u9H76yvLMdjwdrD0BFdc2JN6m2dcEQ8h7+nERrIFGADDw0HBA+clG1Yx/aQ0B/RqKzrHNkVVkGzvVBeknoCeg==", - "dev": true, - "dependencies": { - "@types/supertest": "*", - "@types/ws": "7.x || 8.x", - "ws": "7.x || 8.x" - }, - "peerDependencies": { - "supertest": "*" - }, - "peerDependenciesMeta": { - "supertest": { - "optional": true - } - } - }, "node_modules/supports-color": { "version": "7.2.0", "resolved": "https://registry.npmjs.org/supports-color/-/supports-color-7.2.0.tgz", diff --git a/package.json b/package.json index 7ee8e66f..a0b04130 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "game-services", - "version": "0.54.0", + "version": "0.55.0", "description": "", "main": "src/index.ts", "scripts": { @@ -40,7 +40,6 @@ "hygen": "^6.2.11", "lint-staged": ">=10", "supertest": "^7.0.0", - "superwstest": "^2.0.4", "ts-node": "^10.7.0", "tsx": "^4.11.0", "typescript": "^5.4.5", diff --git a/src/config/api-routes.ts b/src/config/api-routes.ts index c2923b2e..e1946a19 100644 --- a/src/config/api-routes.ts +++ b/src/config/api-routes.ts @@ -17,6 +17,7 @@ import playerAuthMiddleware from '../middlewares/player-auth-middleware' import PlayerAuthAPIService from '../services/api/player-auth-api.service' import continunityMiddleware from '../middlewares/continunity-middleware' import PlayerGroupAPIService from '../services/api/player-group-api.service' +import SocketTicketAPIService from '../services/api/socket-ticket-api.service' export default function configureAPIRoutes(app: Koa) { app.use(apiKeyMiddleware) @@ -33,6 +34,7 @@ export default function configureAPIRoutes(app: Koa) { app.use(playerAuthMiddleware) app.use(continunityMiddleware) + app.use(service('/v1/socket-tickets', new SocketTicketAPIService())) app.use(service('/v1/game-channels', new GameChannelAPIService())) app.use(service('/v1/player-groups', new PlayerGroupAPIService())) app.use(service('/v1/health-check', new HealthCheckAPIService())) diff --git a/src/docs/socket-tickets-api.docs.ts b/src/docs/socket-tickets-api.docs.ts new file mode 100644 index 00000000..15c3bc1a --- /dev/null +++ b/src/docs/socket-tickets-api.docs.ts @@ -0,0 +1,18 @@ +import SocketTicketAPIService from '../services/api/socket-ticket-api.service' +import APIDocs from './api-docs' + +const SocketTicketAPIDocs: APIDocs = { + post: { + description: 'Create a socket ticket (expires after 5 minutes)', + samples: [ + { + title: 'Sample response', + sample: { + ticket: '6c6ef345-0ac3-4edc-a221-b807fbbac4ac' + } + } + ] + } +} + +export default SocketTicketAPIDocs diff --git a/src/services/api/health-check-api.service.ts b/src/services/api/health-check-api.service.ts index 8d83bef0..4c797822 100644 --- a/src/services/api/health-check-api.service.ts +++ b/src/services/api/health-check-api.service.ts @@ -1,11 +1,6 @@ -import { Response, Routes } from 'koa-clay' +import { Response } from 'koa-clay' import APIService from './api-service' -@Routes([ - { - method: 'GET' - } -]) export default class HealthCheckAPIService extends APIService { async index(): Promise { return { diff --git a/src/services/api/socket-ticket-api.service.ts b/src/services/api/socket-ticket-api.service.ts new file mode 100644 index 00000000..e9fae24f --- /dev/null +++ b/src/services/api/socket-ticket-api.service.ts @@ -0,0 +1,31 @@ +import { Docs, Request, Response } from 'koa-clay' +import APIService from './api-service' +import { createRedisConnection } from '../../config/redis.config' +import { v4 } from 'uuid' +import Redis from 'ioredis' +import APIKey from '../../entities/api-key' +import SocketTicketAPIDocs from '../../docs/socket-tickets-api.docs' + +export async function createSocketTicket(redis: Redis, key: APIKey, devBuild: boolean): Promise { + const ticket = v4() + const payload = `${key.id}:${devBuild ? '1' : '0'}` + await redis.set(`socketTickets.${ticket}`, payload, 'EX', 300) + + return ticket +} + +export default class SocketTicketAPIService extends APIService { + @Docs(SocketTicketAPIDocs.post) + async post(req: Request): Promise { + const redis = createRedisConnection(req.ctx) + + const ticket = await createSocketTicket(redis, req.ctx.state.key, req.headers['x-talo-dev-build'] === '1') + + return { + status: 200, + body: { + ticket + } + } + } +} diff --git a/src/socket/authenticateSocket.ts b/src/socket/authenticateSocket.ts deleted file mode 100644 index 62a82c97..00000000 --- a/src/socket/authenticateSocket.ts +++ /dev/null @@ -1,25 +0,0 @@ -import getAPIKeyFromToken from '../lib/auth/getAPIKeyFromToken' -import { promisify } from 'util' -import jwt from 'jsonwebtoken' -import { RequestContext } from '@mikro-orm/core' -import APIKey from '../entities/api-key' - -export default async function authenticateSocket(authHeader: string): Promise { - const apiKey = await getAPIKeyFromToken(authHeader) - if (!apiKey || apiKey.revokedAt) { - return - } - - apiKey.lastUsedAt = new Date() - await RequestContext.getEntityManager().flush() - - try { - const token = authHeader.split('Bearer ')[1] - const secret = apiKey.game.apiSecret.getPlainSecret() - await promisify(jwt.verify)(token, secret) - } catch (err) { - return - } - - return apiKey -} diff --git a/src/socket/index.ts b/src/socket/index.ts index 9c63ff90..90836f35 100644 --- a/src/socket/index.ts +++ b/src/socket/index.ts @@ -2,7 +2,6 @@ import { IncomingMessage, Server } from 'http' import { RawData, WebSocket, WebSocketServer } from 'ws' import { captureException } from '@sentry/node' import { EntityManager, RequestContext } from '@mikro-orm/mysql' -import authenticateSocket from './authenticateSocket' import SocketConnection from './socketConnection' import SocketRouter from './router/socketRouter' import { sendMessage } from './messages/socketMessage' @@ -11,6 +10,9 @@ import { Queue } from 'bullmq' import { createSocketEventQueue, SocketEventData } from './socketEvent' import { ClickHouseClient } from '@clickhouse/client' import createClickhouseClient from '../lib/clickhouse/createClient' +import Redis from 'ioredis' +import redisConfig from '../config/redis.config' +import SocketTicket from './socketTicket' type CloseConnectionOptions = { code?: number @@ -73,10 +75,14 @@ export default class Socket { async handleConnection(ws: WebSocket, req: IncomingMessage): Promise { logConnection(req) + const redis = new Redis(redisConfig) + await RequestContext.create(this.em, async () => { - const key = await authenticateSocket(req.headers?.authorization ?? '') - if (key) { - const connection = new SocketConnection(this, ws, key, req) + const url = new URL(req.url, 'http://localhost') + const ticket = new SocketTicket(url.searchParams.get('ticket') ?? '') + + if (await ticket.validate(redis)) { + const connection = new SocketConnection(this, ws, ticket, req.socket.remoteAddress) this.connections.set(ws, connection) await this.trackEvent('open', { @@ -85,7 +91,7 @@ export default class Socket { code: null, gameId: connection.game.id, playerAliasId: null, - devBuild: req.headers['x-talo-dev-build'] === '1' + devBuild: ticket.devBuild }) await sendMessage(connection, 'v1.connected', {}) @@ -93,6 +99,8 @@ export default class Socket { await this.closeConnection(ws) } }) + + await redis.quit() } async handleMessage(ws: WebSocket, data: RawData): Promise { diff --git a/src/socket/socketConnection.ts b/src/socket/socketConnection.ts index 9a9cc254..e98bca09 100644 --- a/src/socket/socketConnection.ts +++ b/src/socket/socketConnection.ts @@ -2,9 +2,7 @@ import { WebSocket } from 'ws' import PlayerAlias from '../entities/player-alias' import Game from '../entities/game' import APIKey, { APIKeyScope } from '../entities/api-key' -import { IncomingHttpHeaders, IncomingMessage } from 'http' import { RequestContext } from '@mikro-orm/core' -import jwt from 'jsonwebtoken' import { v4 } from 'uuid' import Redis from 'ioredis' import redisConfig from '../config/redis.config' @@ -13,14 +11,13 @@ import Socket from '.' import { SocketMessageResponse } from './messages/socketMessage' import { logResponse } from './messages/socketLogger' import { SocketErrorCode } from './messages/socketError' +import SocketTicket from './socketTicket' export default class SocketConnection { alive: boolean = true - playerAliasId: number | null = null - game: Game | null = null - private scopes: APIKeyScope[] = [] - private headers: IncomingHttpHeaders = {} - private remoteAddress: string = 'unknown' + playerAliasId: number + game: Game + private apiKey: APIKey rateLimitKey: string = v4() rateLimitWarnings: number = 0 @@ -28,13 +25,11 @@ export default class SocketConnection { constructor( private readonly wss: Socket, private readonly ws: WebSocket, - apiKey: APIKey, - req: IncomingMessage + private readonly ticket: SocketTicket, + private readonly remoteAddress: string ) { - this.game = apiKey.game - this.scopes = apiKey.scopes - this.headers = req.headers - this.remoteAddress = req.socket.remoteAddress + this.game = this.ticket.apiKey.game + this.apiKey = this.ticket.apiKey } async getPlayerAlias(): Promise { @@ -44,27 +39,19 @@ export default class SocketConnection { } getAPIKeyId(): number { - const token = this.headers.authorization.split('Bearer ')[1] - const decodedToken = jwt.decode(token) - return decodedToken.sub + return this.ticket.apiKey.id } hasScope(scope: APIKeyScope): boolean { - return this.scopes.includes(APIKeyScope.FULL_ACCESS) || this.scopes.includes(scope) + return this.apiKey.scopes.includes(APIKeyScope.FULL_ACCESS) || this.apiKey.scopes.includes(scope) } hasScopes(scopes: APIKeyScope[]): boolean { - if (this.hasScope(APIKeyScope.FULL_ACCESS)) { - return true - } - return scopes.every((scope) => this.hasScope(scope)) + return this.hasScope(APIKeyScope.FULL_ACCESS) || scopes.every((scope) => this.hasScope(scope)) } getRateLimitMaxRequests(): number { - if (this.playerAliasId) { - return 100 - } - return 10 + return this.playerAliasId ? 100 : 10 } async checkRateLimitExceeded(): Promise { @@ -88,7 +75,7 @@ export default class SocketConnection { } isDevBuild(): boolean { - return this.headers['x-talo-dev-build'] === '1' + return this.ticket.devBuild } async sendMessage(res: SocketMessageResponse, data: T): Promise { diff --git a/src/socket/socketTicket.ts b/src/socket/socketTicket.ts new file mode 100644 index 00000000..c7cce61f --- /dev/null +++ b/src/socket/socketTicket.ts @@ -0,0 +1,35 @@ +import { Redis } from 'ioredis' +import APIKey from '../entities/api-key' +import { RequestContext } from '@mikro-orm/mysql' + +export default class SocketTicket { + apiKey: APIKey + devBuild: boolean + + constructor(private readonly ticket: string) { } + + async validate(redis: Redis): Promise { + const ticketValue = await redis.get(`socketTickets.${this.ticket}`) + if (ticketValue) { + await redis.del(`socketTickets.${this.ticket}`) + const [keyId, devBuild] = ticketValue.split(':') + + try { + this.devBuild = devBuild === '1' + + const em = RequestContext.getEntityManager() + this.apiKey = await em.getRepository(APIKey).findOneOrFail({ + id: Number(keyId), + revokedAt: null + }, { + populate: ['game'] + }) + + return true + } catch (error) { + return false + } + } + return false + } +} diff --git a/tests/services/_api/game-channel-api/delete.test.ts b/tests/services/_api/game-channel-api/delete.test.ts index 2aaf13c5..7f61c2fe 100644 --- a/tests/services/_api/game-channel-api/delete.test.ts +++ b/tests/services/_api/game-channel-api/delete.test.ts @@ -1,26 +1,14 @@ import request from 'supertest' -import requestWs from 'superwstest' import { EntityManager } from '@mikro-orm/mysql' import GameChannelFactory from '../../../fixtures/GameChannelFactory' import { APIKeyScope } from '../../../../src/entities/api-key' import createAPIKeyAndToken from '../../../utils/createAPIKeyAndToken' import PlayerFactory from '../../../fixtures/PlayerFactory' import GameChannel from '../../../../src/entities/game-channel' -import createSocketIdentifyMessage from '../../../utils/requestAuthedSocket' -import Socket from '../../../../src/socket' +import createSocketIdentifyMessage from '../../../utils/createSocketIdentifyMessage' +import createTestSocket from '../../../utils/createTestSocket' describe('Game channel API service - delete', () => { - let socket: Socket - - beforeAll(() => { - socket = new Socket(global.server, global.em) - global.ctx.wss = socket - }) - - afterAll(() => { - socket.getServer().close() - }) - it('should delete a channel if the scope is valid', async () => { const em: EntityManager = global.em @@ -138,7 +126,7 @@ describe('Game channel API service - delete', () => { }) it('should notify players in the channel when the channel is deleted', async () => { - const [identifyMessage, token, player] = await createSocketIdentifyMessage([ + const { identifyMessage, ticket, player, token } = await createSocketIdentifyMessage([ APIKeyScope.READ_PLAYERS, APIKeyScope.READ_GAME_CHANNELS, APIKeyScope.WRITE_GAME_CHANNELS @@ -147,28 +135,21 @@ describe('Game channel API service - delete', () => { const em: EntityManager = global.em const channel = await new GameChannelFactory(player.game).one() - channel.owner = player.aliases[0] channel.members.add(player.aliases[0]) - await em.persistAndFlush(channel) - await requestWs(global.server) - .ws('/') - .set('authorization', `Bearer ${token}`) - .expectJson() - .sendJson(identifyMessage) - .expectJson() - .exec(async () => { - await request(global.app) - .delete(`/v1/game-channels/${channel.id}`) - .auth(token, { type: 'bearer' }) - .set('x-talo-alias', String(player.aliases[0].id)) - .expect(204) - }) - .expectJson((actual) => { + await createTestSocket(`/?ticket=${ticket}`, async (client) => { + await client.identify(identifyMessage) + await request(global.app) + .delete(`/v1/game-channels/${channel.id}`) + .auth(token, { type: 'bearer' }) + .set('x-talo-alias', String(player.aliases[0].id)) + .expect(204) + await client.expectJson((actual) => { expect(actual.res).toBe('v1.channels.deleted') expect(actual.data.channel.id).toBe(channel.id) }) + }) }) }) diff --git a/tests/services/_api/game-channel-api/join.test.ts b/tests/services/_api/game-channel-api/join.test.ts index 6f8aff0e..597b00ba 100644 --- a/tests/services/_api/game-channel-api/join.test.ts +++ b/tests/services/_api/game-channel-api/join.test.ts @@ -1,25 +1,13 @@ import request from 'supertest' -import requestWs from 'superwstest' import { EntityManager } from '@mikro-orm/mysql' import GameChannelFactory from '../../../fixtures/GameChannelFactory' import { APIKeyScope } from '../../../../src/entities/api-key' import createAPIKeyAndToken from '../../../utils/createAPIKeyAndToken' import PlayerFactory from '../../../fixtures/PlayerFactory' -import createSocketIdentifyMessage from '../../../utils/requestAuthedSocket' -import Socket from '../../../../src/socket' +import createSocketIdentifyMessage from '../../../utils/createSocketIdentifyMessage' +import createTestSocket from '../../../utils/createTestSocket' describe('Game channel API service - join', () => { - let socket: Socket - - beforeAll(() => { - socket = new Socket(global.server, global.em) - global.ctx.wss = socket - }) - - afterAll(() => { - socket.getServer().close() - }) - it('should join a channel if the scope is valid', async () => { const [apiKey, token] = await createAPIKeyAndToken([APIKeyScope.WRITE_GAME_CHANNELS]) @@ -105,7 +93,7 @@ describe('Game channel API service - join', () => { }) it('should notify players in the channel when a new player joins', async () => { - const [identifyMessage, token, player] = await createSocketIdentifyMessage([ + const { identifyMessage, ticket, player, token } = await createSocketIdentifyMessage([ APIKeyScope.READ_PLAYERS, APIKeyScope.READ_GAME_CHANNELS, APIKeyScope.WRITE_GAME_CHANNELS @@ -114,23 +102,18 @@ describe('Game channel API service - join', () => { const channel = await new GameChannelFactory(player.game).one() await (global.em).persistAndFlush([channel, player]) - await requestWs(global.server) - .ws('/') - .set('authorization', `Bearer ${token}`) - .expectJson() - .sendJson(identifyMessage) - .expectJson() - .exec(async () => { - await request(global.app) - .post(`/v1/game-channels/${channel.id}/join`) - .auth(token, { type: 'bearer' }) - .set('x-talo-alias', String(player.aliases[0].id)) - .expect(200) - }) - .expectJson((actual) => { + await createTestSocket(`/?ticket=${ticket}`, async (client) => { + await client.identify(identifyMessage) + await request(global.app) + .post(`/v1/game-channels/${channel.id}/join`) + .auth(token, { type: 'bearer' }) + .set('x-talo-alias', String(player.aliases[0].id)) + .expect(200) + await client.expectJson((actual) => { expect(actual.res).toBe('v1.channels.player-joined') expect(actual.data.channel.id).toBe(channel.id) expect(actual.data.playerAlias.id).toBe(player.aliases[0].id) }) + }) }) }) diff --git a/tests/services/_api/game-channel-api/leave.test.ts b/tests/services/_api/game-channel-api/leave.test.ts index 1b683f2b..e748897f 100644 --- a/tests/services/_api/game-channel-api/leave.test.ts +++ b/tests/services/_api/game-channel-api/leave.test.ts @@ -1,26 +1,14 @@ import request from 'supertest' -import requestWs from 'superwstest' import { EntityManager } from '@mikro-orm/mysql' import GameChannelFactory from '../../../fixtures/GameChannelFactory' import { APIKeyScope } from '../../../../src/entities/api-key' import createAPIKeyAndToken from '../../../utils/createAPIKeyAndToken' import PlayerFactory from '../../../fixtures/PlayerFactory' import GameChannel from '../../../../src/entities/game-channel' -import Socket from '../../../../src/socket' -import createSocketIdentifyMessage from '../../../utils/requestAuthedSocket' +import createSocketIdentifyMessage from '../../../utils/createSocketIdentifyMessage' +import createTestSocket from '../../../utils/createTestSocket' describe('Game channel API service - leave', () => { - let socket: Socket - - beforeAll(() => { - socket = new Socket(global.server, global.em) - global.ctx.wss = socket - }) - - afterAll(() => { - socket.getServer().close() - }) - it('should leave a channel if the scope is valid', async () => { const [apiKey, token] = await createAPIKeyAndToken([APIKeyScope.WRITE_GAME_CHANNELS]) @@ -178,7 +166,7 @@ describe('Game channel API service - leave', () => { }) it('should notify players in the channel when a player leaves', async () => { - const [identifyMessage, token, player] = await createSocketIdentifyMessage([ + const { identifyMessage, ticket, player, token } = await createSocketIdentifyMessage([ APIKeyScope.READ_PLAYERS, APIKeyScope.READ_GAME_CHANNELS, APIKeyScope.WRITE_GAME_CHANNELS @@ -188,23 +176,18 @@ describe('Game channel API service - leave', () => { channel.members.add(player.aliases[0]) await (global.em).persistAndFlush(channel) - await requestWs(global.server) - .ws('/') - .set('authorization', `Bearer ${token}`) - .expectJson() - .sendJson(identifyMessage) - .expectJson() - .exec(async () => { - await request(global.app) - .post(`/v1/game-channels/${channel.id}/leave`) - .auth(token, { type: 'bearer' }) - .set('x-talo-alias', String(player.aliases[0].id)) - .expect(204) - }) - .expectJson((actual) => { + await createTestSocket(`/?ticket=${ticket}`, async (client) => { + await client.identify(identifyMessage) + await request(global.app) + .post(`/v1/game-channels/${channel.id}/leave`) + .auth(token, { type: 'bearer' }) + .set('x-talo-alias', String(player.aliases[0].id)) + .expect(204) + await client.expectJson((actual) => { expect(actual.res).toBe('v1.channels.player-left') expect(actual.data.channel.id).toBe(channel.id) expect(actual.data.playerAlias.id).toBe(player.aliases[0].id) }) + }) }) }) diff --git a/tests/services/_api/game-channel-api/put.test.ts b/tests/services/_api/game-channel-api/put.test.ts index ac1190a8..aeae70bf 100644 --- a/tests/services/_api/game-channel-api/put.test.ts +++ b/tests/services/_api/game-channel-api/put.test.ts @@ -1,25 +1,13 @@ import request from 'supertest' -import requestWs from 'superwstest' import { EntityManager } from '@mikro-orm/mysql' import GameChannelFactory from '../../../fixtures/GameChannelFactory' import { APIKeyScope } from '../../../../src/entities/api-key' import createAPIKeyAndToken from '../../../utils/createAPIKeyAndToken' import PlayerFactory from '../../../fixtures/PlayerFactory' -import createSocketIdentifyMessage from '../../../utils/requestAuthedSocket' -import Socket from '../../../../src/socket' +import createSocketIdentifyMessage from '../../../utils/createSocketIdentifyMessage' +import createTestSocket from '../../../utils/createTestSocket' describe('Game channel API service - put', () => { - let socket: Socket - - beforeAll(() => { - socket = new Socket(global.server, global.em) - global.ctx.wss = socket - }) - - afterAll(() => { - socket.getServer().close() - }) - it('should update a channel if the scope is valid', async () => { const em: EntityManager = global.em @@ -279,7 +267,7 @@ describe('Game channel API service - put', () => { }) it('should notify players in the channel when ownership is transferred', async () => { - const [identifyMessage, token, player] = await createSocketIdentifyMessage([ + const { identifyMessage, ticket, player, token } = await createSocketIdentifyMessage([ APIKeyScope.READ_PLAYERS, APIKeyScope.READ_GAME_CHANNELS, APIKeyScope.WRITE_GAME_CHANNELS @@ -295,24 +283,19 @@ describe('Game channel API service - put', () => { await em.persistAndFlush(channel) - await requestWs(global.server) - .ws('/') - .set('authorization', `Bearer ${token}`) - .expectJson() - .sendJson(identifyMessage) - .expectJson() - .exec(async () => { - await request(global.app) - .put(`/v1/game-channels/${channel.id}`) - .send({ ownerAliasId: newOwner.aliases[0].id }) - .auth(token, { type: 'bearer' }) - .set('x-talo-alias', String(player.aliases[0].id)) - .expect(200) - }) - .expectJson((actual) => { + await createTestSocket(`/?ticket=${ticket}`, async (client) => { + await client.identify(identifyMessage) + await request(global.app) + .put(`/v1/game-channels/${channel.id}`) + .send({ ownerAliasId: newOwner.aliases[0].id }) + .auth(token, { type: 'bearer' }) + .set('x-talo-alias', String(player.aliases[0].id)) + .expect(200) + await client.expectJson((actual) => { expect(actual.res).toBe('v1.channels.ownership-transferred') expect(actual.data.channel.id).toBe(channel.id) expect(actual.data.newOwner.id).toBe(newOwner.aliases[0].id) }) + }) }) }) diff --git a/tests/services/_api/socket-ticket-api/post.test.ts b/tests/services/_api/socket-ticket-api/post.test.ts new file mode 100644 index 00000000..c2928ab9 --- /dev/null +++ b/tests/services/_api/socket-ticket-api/post.test.ts @@ -0,0 +1,15 @@ +import request from 'supertest' +import createAPIKeyAndToken from '../../../utils/createAPIKeyAndToken' + +describe('Socket ticket API service - post', () => { + it('should return a valid socket ticket', async () => { + const [, token] = await createAPIKeyAndToken([]) + + const res = await request(global.app) + .post('/v1/socket-tickets') + .auth(token, { type: 'bearer' }) + .expect(200) + + expect(res.body.ticket).toHaveLength(36) + }) +}) diff --git a/tests/services/api-key/delete.test.ts b/tests/services/api-key/delete.test.ts index ed474205..734fb53e 100644 --- a/tests/services/api-key/delete.test.ts +++ b/tests/services/api-key/delete.test.ts @@ -1,6 +1,5 @@ import { EntityManager } from '@mikro-orm/mysql' import request from 'supertest' -import requestWs from 'superwstest' import { UserType } from '../../../src/entities/user' import APIKey from '../../../src/entities/api-key' import UserFactory from '../../fixtures/UserFactory' @@ -8,21 +7,12 @@ import GameActivity, { GameActivityType } from '../../../src/entities/game-activ import userPermissionProvider from '../../utils/userPermissionProvider' import createUserAndToken from '../../utils/createUserAndToken' import createOrganisationAndGame from '../../utils/createOrganisationAndGame' -import Socket from '../../../src/socket' -import { createToken } from '../../../src/services/api-key.service' +import { createSocketTicket } from '../../../src/services/api/socket-ticket-api.service' +import redisConfig from '../../../src/config/redis.config' +import { Redis } from 'ioredis' +import createTestSocket from '../../utils/createTestSocket' describe('API key service - delete', () => { - let socket: Socket - - beforeAll(() => { - socket = new Socket(global.server, global.em) - global.ctx.wss = socket - }) - - afterAll(() => { - socket.getServer().close() - }) - it.each(userPermissionProvider([ UserType.ADMIN ], 204))('should return a %i for a %s user', async (statusCode, _, type) => { @@ -92,21 +82,18 @@ describe('API key service - delete', () => { const key = new APIKey(game, user) await (global.em).persistAndFlush(key) - const apiToken = await createToken(global.em, key) - - await requestWs(global.server) - .ws('/') - .set('authorization', `Bearer ${apiToken}`) - .expectJson({ - res: 'v1.connected', - data: {} - }) - .exec(async () => { - await request(global.app) - .delete(`/games/${game.id}/api-keys/${key.id}`) - .auth(token, { type: 'bearer' }) - .expect(204) - }) - .expectClosed(3000) + + const redis = new Redis(redisConfig) + const ticket = await createSocketTicket(redis, key, false) + await redis.quit() + + await createTestSocket(`/?ticket=${ticket}`, async (client) => { + await request(global.app) + .delete(`/games/${game.id}/api-keys/${key.id}`) + .auth(token, { type: 'bearer' }) + .expect(204) + + await client.expectClosed(3000) + }) }) }) diff --git a/tests/setupTest.ts b/tests/setupTest.ts index d17494ff..ddcc9a7f 100644 --- a/tests/setupTest.ts +++ b/tests/setupTest.ts @@ -2,11 +2,11 @@ import { EntityManager, MikroORM } from '@mikro-orm/mysql' import init from '../src' import ormConfig from '../src/config/mikro-orm.config' import { ClickHouseClient } from '@clickhouse/client' -import { createServer, Server } from 'http' beforeAll(async () => { vi.mock('@sendgrid/mail') vi.mock('bullmq') + vi.stubEnv('DISABLE_SOCKET_EVENTS', '1') const orm = await MikroORM.init(ormConfig) await orm.getSchemaGenerator().clearDatabase() @@ -17,10 +17,6 @@ beforeAll(async () => { global.ctx = app.context global.em = app.context.em - vi.stubEnv('DISABLE_SOCKET_EVENTS', '1') - global.server = createServer() - global.server.listen(0) - global.clickhouse = app.context.clickhouse await (global.clickhouse as ClickHouseClient).command({ query: `TRUNCATE ALL TABLES from ${process.env.CLICKHOUSE_DB}` @@ -31,15 +27,11 @@ afterAll(async () => { const em: EntityManager = global.em await em.getConnection().close(true) - const server: Server = global.server - server.close() - const clickhouse: ClickHouseClient = global.clickhouse await clickhouse.close() delete global.app delete global.ctx delete global.em - delete global.server delete global.clickhouse }) diff --git a/tests/socket/listeners/gameChannelListeners/message.test.ts b/tests/socket/listeners/gameChannelListeners/message.test.ts index 86bf6113..974a0672 100644 --- a/tests/socket/listeners/gameChannelListeners/message.test.ts +++ b/tests/socket/listeners/gameChannelListeners/message.test.ts @@ -1,23 +1,12 @@ -import request from 'superwstest' -import Socket from '../../../../src/socket' import { APIKeyScope } from '../../../../src/entities/api-key' -import createSocketIdentifyMessage from '../../../utils/requestAuthedSocket' +import createSocketIdentifyMessage from '../../../utils/createSocketIdentifyMessage' import GameChannelFactory from '../../../fixtures/GameChannelFactory' import { EntityManager } from '@mikro-orm/mysql' +import createTestSocket from '../../../utils/createTestSocket' describe('Game channel listeners - message', () => { - let socket: Socket - - beforeAll(() => { - socket = new Socket(global.server, global.em) - }) - - afterAll(() => { - socket.getServer().close() - }) - it('should successfully send a message', async () => { - const [identifyMessage, token, player] = await createSocketIdentifyMessage([ + const { identifyMessage, ticket, player } = await createSocketIdentifyMessage([ APIKeyScope.READ_PLAYERS, APIKeyScope.READ_GAME_CHANNELS, APIKeyScope.WRITE_GAME_CHANNELS @@ -26,13 +15,9 @@ describe('Game channel listeners - message', () => { channel.members.add(player.aliases[0]) await (global.em).persistAndFlush(channel) - await request(global.server) - .ws('/') - .set('authorization', `Bearer ${token}`) - .expectJson() - .sendJson(identifyMessage) - .expectJson() - .sendJson({ + await createTestSocket(`/?ticket=${ticket}`, async (client) => { + await client.identify(identifyMessage) + client.sendJson({ req: 'v1.channels.message', data: { channel: { @@ -41,16 +26,17 @@ describe('Game channel listeners - message', () => { message: 'Hello world' } }) - .expectJson((actual) => { + await client.expectJson((actual) => { expect(actual.res).toBe('v1.channels.message') expect(actual.data.channel.id).toBe(channel.id) expect(actual.data.message).toBe('Hello world') expect(actual.data.playerAlias.id).toBe(player.aliases[0].id) }) + }) }) it('should receive an error if the player is not in the channel', async () => { - const [identifyMessage, token, player] = await createSocketIdentifyMessage([ + const { identifyMessage, ticket, player } = await createSocketIdentifyMessage([ APIKeyScope.READ_PLAYERS, APIKeyScope.READ_GAME_CHANNELS, APIKeyScope.WRITE_GAME_CHANNELS @@ -58,13 +44,9 @@ describe('Game channel listeners - message', () => { const channel = await new GameChannelFactory(player.game).one() await (global.em).persistAndFlush(channel) - await request(global.server) - .ws('/') - .set('authorization', `Bearer ${token}`) - .expectJson() - .sendJson(identifyMessage) - .expectJson() - .sendJson({ + await createTestSocket(`/?ticket=${ticket}`, async (client) => { + await client.identify(identifyMessage) + client.sendJson({ req: 'v1.channels.message', data: { channel: { @@ -73,7 +55,7 @@ describe('Game channel listeners - message', () => { message: 'Hello world' } }) - .expectJson({ + await client.expectJsonToStrictEqual({ res: 'v1.error', data: { req: 'v1.channels.message', @@ -82,23 +64,19 @@ describe('Game channel listeners - message', () => { cause: 'Player not in channel' } }) - .close() + }) }) it('should receive an error if the channel does not exist', async () => { - const [identifyMessage, token] = await createSocketIdentifyMessage([ + const { identifyMessage, ticket } = await createSocketIdentifyMessage([ APIKeyScope.READ_PLAYERS, APIKeyScope.READ_GAME_CHANNELS, APIKeyScope.WRITE_GAME_CHANNELS ]) - await request(global.server) - .ws('/') - .set('authorization', `Bearer ${token}`) - .expectJson() - .sendJson(identifyMessage) - .expectJson() - .sendJson({ + await createTestSocket(`/?ticket=${ticket}`, async (client) => { + await client.identify(identifyMessage) + client.sendJson({ req: 'v1.channels.message', data: { channel: { @@ -107,7 +85,7 @@ describe('Game channel listeners - message', () => { message: 'Hello world' } }) - .expectJson({ + await client.expectJsonToStrictEqual({ res: 'v1.error', data: { req: 'v1.channels.message', @@ -116,6 +94,6 @@ describe('Game channel listeners - message', () => { cause: 'Channel not found' } }) - .close() + }) }) }) diff --git a/tests/socket/listeners/playerListeners/identify.test.ts b/tests/socket/listeners/playerListeners/identify.test.ts index ff79ac92..6d8c6479 100644 --- a/tests/socket/listeners/playerListeners/identify.test.ts +++ b/tests/socket/listeners/playerListeners/identify.test.ts @@ -1,60 +1,38 @@ -import request from 'superwstest' -import Socket from '../../../../src/socket' import { APIKeyScope } from '../../../../src/entities/api-key' -import createSocketIdentifyMessage from '../../../utils/requestAuthedSocket' +import createSocketIdentifyMessage from '../../../utils/createSocketIdentifyMessage' import { EntityManager } from '@mikro-orm/mysql' import createAPIKeyAndToken from '../../../utils/createAPIKeyAndToken' import PlayerFactory from '../../../fixtures/PlayerFactory' import Redis from 'ioredis' import redisConfig from '../../../../src/config/redis.config' +import { createSocketTicket } from '../../../../src/services/api/socket-ticket-api.service' +import createTestSocket from '../../../utils/createTestSocket' describe('Player listeners - identify', () => { - let socket: Socket - - beforeAll(() => { - socket = new Socket(global.server, global.em) - }) - - afterAll(() => { - socket.getServer().close() - }) - it('should successfully identify a player', async () => { - const [identifyMessage, token, player] = await createSocketIdentifyMessage([APIKeyScope.READ_PLAYERS]) + const { identifyMessage, ticket, player } = await createSocketIdentifyMessage([APIKeyScope.READ_PLAYERS]) - await request(global.server) - .ws('/') - .set('authorization', `Bearer ${token}`) - .expectJson({ - res: 'v1.connected', - data: {} - }) - .sendJson(identifyMessage) - .expectJson((actual) => { + await createTestSocket(`/?ticket=${ticket}`, async (client) => { + client.sendJson(identifyMessage) + await client.expectJson((actual) => { expect(actual.res).toBe('v1.players.identify.success') expect(actual.data.id).toBe(player.aliases[0].id) }) - .close() + }) }) it('should require the socket token to be valid', async () => { - const [identifyMessage, token] = await createSocketIdentifyMessage([APIKeyScope.READ_PLAYERS]) + const { identifyMessage, ticket } = await createSocketIdentifyMessage([APIKeyScope.READ_PLAYERS]) - await request(global.server) - .ws('/') - .set('authorization', `Bearer ${token}`) - .expectJson({ - res: 'v1.connected', - data: {} - }) - .sendJson({ + await createTestSocket(`/?ticket=${ticket}`, async (client) => { + client.sendJson({ ...identifyMessage, data: { ...identifyMessage.data, socketToken: 'invalid' } }) - .expectJson({ + await client.expectJsonToStrictEqual({ res: 'v1.error', data: { req: 'v1.players.identify', @@ -62,31 +40,26 @@ describe('Player listeners - identify', () => { errorCode: 'INVALID_SOCKET_TOKEN' } }) - .close() + }) }) it('should require a valid session token to identify Talo aliases', async () => { const em: EntityManager = global.em - const [apiKey, token] = await createAPIKeyAndToken([APIKeyScope.READ_PLAYERS]) + const [apiKey] = await createAPIKeyAndToken([APIKeyScope.READ_PLAYERS]) const player = await new PlayerFactory([apiKey.game]).withTaloAlias().one() await em.persistAndFlush(player) const redis = new Redis(redisConfig) + const ticket = await createSocketTicket(redis, apiKey, false) const socketToken = await player.aliases[0].createSocketToken(redis) await redis.quit() const sessionToken = await player.auth.createSession(player.aliases[0]) await em.flush() - await request(global.server) - .ws('/') - .set('authorization', `Bearer ${token}`) - .expectJson({ - res: 'v1.connected', - data: {} - }) - .sendJson({ + await createTestSocket(`/?ticket=${ticket}`, async (client) => { + client.sendJson({ req: 'v1.players.identify', data: { playerAliasId: player.aliases[0].id, @@ -94,32 +67,27 @@ describe('Player listeners - identify', () => { sessionToken: sessionToken } }) - .expectJson((actual) => { + await client.expectJson((actual) => { expect(actual.res).toBe('v1.players.identify.success') expect(actual.data.id).toBe(player.aliases[0].id) }) - .close() + }) }) it('should reject identify for Talo aliases without a valid session token', async () => { const em: EntityManager = global.em - const [apiKey, token] = await createAPIKeyAndToken([APIKeyScope.READ_PLAYERS]) + const [apiKey] = await createAPIKeyAndToken([APIKeyScope.READ_PLAYERS]) const player = await new PlayerFactory([apiKey.game]).withTaloAlias().one() await em.persistAndFlush(player) const redis = new Redis(redisConfig) + const ticket = await createSocketTicket(redis, apiKey, false) const socketToken = await player.aliases[0].createSocketToken(redis) await redis.quit() - await request(global.server) - .ws('/') - .set('authorization', `Bearer ${token}`) - .expectJson({ - res: 'v1.connected', - data: {} - }) - .sendJson({ + await createTestSocket(`/?ticket=${ticket}`, async (client) => { + client.sendJson({ req: 'v1.players.identify', data: { playerAliasId: player.aliases[0].id, @@ -127,7 +95,7 @@ describe('Player listeners - identify', () => { sessionToken: 'blah' } }) - .expectJson({ + await client.expectJsonToStrictEqual({ res: 'v1.error', data: { req: 'v1.players.identify', @@ -135,6 +103,6 @@ describe('Player listeners - identify', () => { errorCode: 'INVALID_SESSION_TOKEN' } }) - .close() + }) }) }) diff --git a/tests/socket/messages/socketLogger.test.ts b/tests/socket/messages/socketLogger.test.ts index 24d284bb..9acc9b30 100644 --- a/tests/socket/messages/socketLogger.test.ts +++ b/tests/socket/messages/socketLogger.test.ts @@ -2,10 +2,11 @@ import { WebSocket } from 'ws' import TaloSocket from '../../../src/socket' import SocketConnection from '../../../src/socket/socketConnection' import createAPIKeyAndToken from '../../utils/createAPIKeyAndToken' -import { IncomingMessage } from 'http' +import { createServer, IncomingMessage } from 'http' import { Socket } from 'net' import { logConnection, logConnectionClosed, logRequest, logResponse } from '../../../src/socket/messages/socketLogger' import { EntityManager } from '@mikro-orm/mysql' +import SocketTicket from '../../../src/socket/socketTicket' describe('Socket logger', () => { const consoleMock = vi.spyOn(console, 'log').mockImplementation(() => undefined) @@ -22,31 +23,29 @@ describe('Socket logger', () => { vi.unstubAllEnvs() }) - async function createSocketConnection(): Promise<[TaloSocket, SocketConnection, () => void]> { + async function createSocketConnection(): Promise<[SocketConnection, () => void]> { const [apiKey] = await createAPIKeyAndToken([]) await (global.em).persistAndFlush(apiKey) - const socket = new TaloSocket(global.server, global.em) - const conn = new SocketConnection( - socket, - new WebSocket(null, [], {}), - apiKey, - new IncomingMessage(new Socket()) - ) + const ticket = new SocketTicket('') + ticket.apiKey = apiKey + ticket.devBuild = false - vi.spyOn(conn, 'getRemoteAddress').mockReturnValue('0.0.0.0') + const server = createServer() + server.listen(0) + + const wss = new TaloSocket(server, global.em) + const ws = new WebSocket(null, [], {}) + const conn = new SocketConnection(wss, ws, ticket, '0.0.0.0') return [ - socket, conn, - () => { - socket.getServer().close() - } + () => server.close() ] } it('should log requests', async () => { - const [, conn, cleanup] = await createSocketConnection() + const [conn, cleanup] = await createSocketConnection() logRequest(conn, JSON.stringify({ req: 'v1.fake', data: {} })) @@ -57,7 +56,7 @@ describe('Socket logger', () => { }) it('should log requests with aliases', async () => { - const [, conn, cleanup] = await createSocketConnection() + const [conn, cleanup] = await createSocketConnection() conn.playerAliasId = 2 logRequest(conn, JSON.stringify({ req: 'v1.fake', data: {} })) @@ -69,7 +68,7 @@ describe('Socket logger', () => { }) it('should log requests without valid json', async () => { - const [, conn, cleanup] = await createSocketConnection() + const [conn, cleanup] = await createSocketConnection() logRequest(conn, 'v1.fake') @@ -80,7 +79,7 @@ describe('Socket logger', () => { }) it('should log requests without a req', async () => { - const [, conn, cleanup] = await createSocketConnection() + const [conn, cleanup] = await createSocketConnection() logRequest(conn, JSON.stringify({ wrong: 'v1.fake' })) @@ -91,7 +90,7 @@ describe('Socket logger', () => { }) it('should log responses', async () => { - const [, conn, cleanup] = await createSocketConnection() + const [conn, cleanup] = await createSocketConnection() logResponse(conn, 'v1.players.identify.success', JSON.stringify({ res: 'v1.players.identify.success', data: {} })) @@ -109,7 +108,7 @@ describe('Socket logger', () => { }) it('should log pre-closed connections', async () => { - const [, conn, cleanup] = await createSocketConnection() + const [conn, cleanup] = await createSocketConnection() logConnectionClosed(conn, true, 3000) @@ -120,7 +119,7 @@ describe('Socket logger', () => { }) it('should log manually-closed connections', async () => { - const [, conn, cleanup] = await createSocketConnection() + const [conn, cleanup] = await createSocketConnection() logConnectionClosed(conn, false, 3000, 'Unauthorised') @@ -131,7 +130,7 @@ describe('Socket logger', () => { }) it('should log manually-closed connections without a reason', async () => { - const [, conn, cleanup] = await createSocketConnection() + const [conn, cleanup] = await createSocketConnection() logConnectionClosed(conn, false, 3000) @@ -149,7 +148,7 @@ describe('Socket logger', () => { }) it('should log pre-closed connection with aliases', async () => { - const [, conn, cleanup] = await createSocketConnection() + const [conn, cleanup] = await createSocketConnection() conn.playerAliasId = 2 logConnectionClosed(conn, true, 3000) diff --git a/tests/socket/rateLimiting.test.ts b/tests/socket/rateLimiting.test.ts index 2079bfc2..32b4802c 100644 --- a/tests/socket/rateLimiting.test.ts +++ b/tests/socket/rateLimiting.test.ts @@ -1,26 +1,14 @@ -import request from 'superwstest' -import Socket from '../../src/socket' import { APIKeyScope } from '../../src/entities/api-key' -import createSocketIdentifyMessage from '../utils/requestAuthedSocket' +import createSocketIdentifyMessage from '../utils/createSocketIdentifyMessage' import GameChannelFactory from '../fixtures/GameChannelFactory' import { EntityManager } from '@mikro-orm/mysql' import Redis from 'ioredis' import redisConfig from '../../src/config/redis.config' +import createTestSocket from '../utils/createTestSocket' describe('Socket rate limiting', () => { - let socket: Socket - - beforeAll(() => { - socket = new Socket(global.server, global.em) - global.ctx.wss = socket - }) - - afterAll(() => { - socket.getServer().close() - }) - it('should return a rate limiting error', async () => { - const [identifyMessage, token, player] = await createSocketIdentifyMessage([ + const { identifyMessage, ticket, player } = await createSocketIdentifyMessage([ APIKeyScope.READ_PLAYERS, APIKeyScope.READ_GAME_CHANNELS, APIKeyScope.WRITE_GAME_CHANNELS @@ -29,20 +17,15 @@ describe('Socket rate limiting', () => { channel.members.add(player.aliases[0]) await (global.em).persistAndFlush(channel) - await request(global.server) - .ws('/') - .set('authorization', `Bearer ${token}`) - .expectJson() - .sendJson(identifyMessage) - .expectJson() - .exec(async () => { - const conn = socket.findConnections((conn) => conn.playerAliasId === player.aliases[0].id)[0] + await createTestSocket(`/?ticket=${ticket}`, async (client, socket) => { + await client.identify(identifyMessage) - const redis = new Redis(redisConfig) - await redis.set(`requests.${conn.rateLimitKey}`, 999) - await redis.quit() - }) - .sendJson({ + const conn = socket.findConnections((conn) => conn.playerAliasId === player.aliases[0].id)[0] + const redis = new Redis(redisConfig) + await redis.set(`requests.${conn.rateLimitKey}`, 999) + await redis.quit() + + client.sendJson({ req: 'v1.channels.message', data: { channel: { @@ -51,7 +34,7 @@ describe('Socket rate limiting', () => { message: 'Hello world' } }) - .expectJson({ + await client.expectJsonToStrictEqual({ res: 'v1.error', data: { req: 'unknown', @@ -59,11 +42,11 @@ describe('Socket rate limiting', () => { errorCode: 'RATE_LIMIT_EXCEEDED' } }) - .close() + }) }) it('should disconnect connections after 3 warnings', async () => { - const [identifyMessage, token, player] = await createSocketIdentifyMessage([ + const { identifyMessage, ticket, player } = await createSocketIdentifyMessage([ APIKeyScope.READ_PLAYERS, APIKeyScope.READ_GAME_CHANNELS, APIKeyScope.WRITE_GAME_CHANNELS @@ -72,21 +55,17 @@ describe('Socket rate limiting', () => { channel.members.add(player.aliases[0]) await (global.em).persistAndFlush(channel) - await request(global.server) - .ws('/') - .set('authorization', `Bearer ${token}`) - .expectJson() - .sendJson(identifyMessage) - .expectJson() - .exec(async () => { - const conn = socket.findConnections((conn) => conn.playerAliasId === player.aliases[0].id)[0] - conn.rateLimitWarnings = 3 + await createTestSocket(`/?ticket=${ticket}`, async (client, socket) => { + await client.identify(identifyMessage) - const redis = new Redis(redisConfig) - await redis.set(`requests.${conn.rateLimitKey}`, 999) - await redis.quit() - }) - .sendJson({ + const conn = socket.findConnections((conn) => conn.playerAliasId === player.aliases[0].id)[0] + conn.rateLimitWarnings = 3 + + const redis = new Redis(redisConfig) + await redis.set(`requests.${conn.rateLimitKey}`, 999) + await redis.quit() + + client.sendJson({ req: 'v1.channels.message', data: { channel: { @@ -95,6 +74,7 @@ describe('Socket rate limiting', () => { message: 'Hello world' } }) - .expectClosed(1008, 'RATE_LIMIT_EXCEEDED') + await client.expectClosed(1008, 'RATE_LIMIT_EXCEEDED') + }) }) }) diff --git a/tests/socket/router.test.ts b/tests/socket/router.test.ts index 26f7a4e6..b1fa2cfa 100644 --- a/tests/socket/router.test.ts +++ b/tests/socket/router.test.ts @@ -1,36 +1,25 @@ -import request from 'superwstest' -import Socket from '../../src/socket' import createAPIKeyAndToken from '../utils/createAPIKeyAndToken' import { APIKeyScope } from '../../src/entities/api-key' -import createSocketIdentifyMessage from '../utils/requestAuthedSocket' +import createSocketIdentifyMessage from '../utils/createSocketIdentifyMessage' import { EntityManager } from '@mikro-orm/mysql' import GameChannelFactory from '../fixtures/GameChannelFactory' +import { createSocketTicket } from '../../src/services/api/socket-ticket-api.service' +import redisConfig from '../../src/config/redis.config' +import Redis from 'ioredis' +import createTestSocket from '../utils/createTestSocket' describe('Socket router', () => { - let socket: Socket - - beforeAll(() => { - socket = new Socket(global.server, global.em) - }) - - afterAll(() => { - socket.getServer().close() - }) - it('should reject invalid messages', async () => { - const [, token] = await createAPIKeyAndToken([]) + const [apiKey] = await createAPIKeyAndToken([]) + const redis = new Redis(redisConfig) + const ticket = await createSocketTicket(redis, apiKey, false) + await redis.quit() - await request(global.server) - .ws('/') - .set('authorization', `Bearer ${token}`) - .expectJson({ - res: 'v1.connected', - data: {} - }) - .sendJson({ + await createTestSocket(`/?ticket=${ticket}`, async (client) => { + client.sendJson({ blah: 'blah' }) - .expectJson({ + await client.expectJsonToStrictEqual({ res: 'v1.error', data: { req: 'unknown', @@ -39,24 +28,21 @@ describe('Socket router', () => { cause: '{"blah":"blah"}' } }) - .close() + }) }) it('should reject unknown requests', async () => { - const [, token] = await createAPIKeyAndToken([]) + const [apiKey] = await createAPIKeyAndToken([]) + const redis = new Redis(redisConfig) + const ticket = await createSocketTicket(redis, apiKey, false) + await redis.quit() - await request(global.server) - .ws('/') - .set('authorization', `Bearer ${token}`) - .expectJson({ - res: 'v1.connected', - data: {} - }) - .sendJson({ + await createTestSocket(`/?ticket=${ticket}`, async (client) => { + client.sendJson({ req: 'v1.magic', data: {} }) - .expectJson({ + await client.expectJsonToStrictEqual({ res: 'v1.error', data: { req: 'unknown', @@ -65,20 +51,17 @@ describe('Socket router', () => { cause: '{"req":"v1.magic","data":{}}' } }) - .close() + }) }) it('should reject requests where a player is required but one hasn\'t been identified yet', async () => { - const [, token] = await createAPIKeyAndToken([]) + const [apiKey] = await createAPIKeyAndToken([]) + const redis = new Redis(redisConfig) + const ticket = await createSocketTicket(redis, apiKey, false) + await redis.quit() - await request(global.server) - .ws('/') - .set('authorization', `Bearer ${token}`) - .expectJson({ - res: 'v1.connected', - data: {} - }) - .sendJson({ + await createTestSocket(`/?ticket=${ticket}`, async (client) => { + client.sendJson({ req: 'v1.channels.message', data: { channel: { @@ -87,7 +70,7 @@ describe('Socket router', () => { message: 'Hello world' } }) - .expectJson({ + await client.expectJsonToStrictEqual({ res: 'v1.error', data: { req: 'v1.channels.message', @@ -95,22 +78,16 @@ describe('Socket router', () => { errorCode: 'NO_PLAYER_FOUND' } }) - .close() + }) }) it('should reject requests where a scope is required but is not present', async () => { - const [identifyMessage, token] = await createSocketIdentifyMessage([APIKeyScope.READ_PLAYERS]) + const { identifyMessage, ticket } = await createSocketIdentifyMessage([APIKeyScope.READ_PLAYERS]) - await request(global.server) - .ws('/') - .set('authorization', `Bearer ${token}`) - .expectJson({ - res: 'v1.connected', - data: {} - }) - .sendJson(identifyMessage) - .expectJson() - .sendJson({ + await createTestSocket(`/?ticket=${ticket}`, async (client) => { + await client.identify(identifyMessage) + + client.sendJson({ req: 'v1.channels.message', data: { channel: { @@ -119,7 +96,7 @@ describe('Socket router', () => { message: 'Hello world' } }) - .expectJson({ + await client.expectJsonToStrictEqual({ res: 'v1.error', data: { req: 'v1.channels.message', @@ -127,25 +104,19 @@ describe('Socket router', () => { errorCode: 'MISSING_ACCESS_KEY_SCOPES' } }) - .close() + }) }) it('should be able to accept requests where a scope is required and the key has the full access scope', async () => { - const [identifyMessage, token, player] = await createSocketIdentifyMessage([APIKeyScope.FULL_ACCESS]) + const { identifyMessage, ticket, player } = await createSocketIdentifyMessage([APIKeyScope.FULL_ACCESS]) const channel = await new GameChannelFactory(player.game).one() channel.members.add(player.aliases[0]) await (global.em).persistAndFlush(channel) - await request(global.server) - .ws('/') - .set('authorization', `Bearer ${token}`) - .expectJson({ - res: 'v1.connected', - data: {} - }) - .sendJson(identifyMessage) - .expectJson() - .sendJson({ + await createTestSocket(`/?ticket=${ticket}`, async (client) => { + await client.identify(identifyMessage) + + client.sendJson({ req: 'v1.channels.message', data: { channel: { @@ -154,28 +125,25 @@ describe('Socket router', () => { message: 'Hello world' } }) - .expectJson((actual) => { + await client.expectJson((actual) => { expect(actual.res).toBe('v1.channels.message') expect(actual.data.channel.id).toBe(channel.id) expect(actual.data.message).toBe('Hello world') expect(actual.data.playerAlias.id).toBe(player.aliases[0].id) }) - .close() + }) }) it('should reject requests where the payload fails the listener\'s validation', async () => { - const [identifyMessage, token] = await createSocketIdentifyMessage([APIKeyScope.READ_PLAYERS, APIKeyScope.WRITE_GAME_CHANNELS]) + const { identifyMessage, ticket } = await createSocketIdentifyMessage([ + APIKeyScope.READ_PLAYERS, + APIKeyScope.WRITE_GAME_CHANNELS + ]) - await request(global.server) - .ws('/') - .set('authorization', `Bearer ${token}`) - .expectJson({ - res: 'v1.connected', - data: {} - }) - .sendJson(identifyMessage) - .expectJson() - .sendJson({ + await createTestSocket(`/?ticket=${ticket}`, async (client) => { + await client.identify(identifyMessage) + + client.sendJson({ req: 'v1.channels.message', data: { channel: { @@ -184,7 +152,7 @@ describe('Socket router', () => { myMessageToTheChannelIsGoingToBeThis: 'Hello world' } }) - .expectJson({ + await client.expectJsonToStrictEqual({ res: 'v1.error', data: { req: 'v1.channels.message', @@ -193,6 +161,6 @@ describe('Socket router', () => { cause: '{"channel":{"id":1},"myMessageToTheChannelIsGoingToBeThis":"Hello world"}' } }) - .close() + }) }) }) diff --git a/tests/socket/server.test.ts b/tests/socket/server.test.ts index 413f611e..efdf96f7 100644 --- a/tests/socket/server.test.ts +++ b/tests/socket/server.test.ts @@ -1,60 +1,57 @@ -import request from 'superwstest' -import Socket from '../../src/socket' import createAPIKeyAndToken from '../utils/createAPIKeyAndToken' -import { isToday, subDays } from 'date-fns' +import { Redis } from 'ioredis' +import redisConfig from '../../src/config/redis.config' +import { createSocketTicket } from '../../src/services/api/socket-ticket-api.service' +import createTestSocket from '../utils/createTestSocket' import { EntityManager } from '@mikro-orm/mysql' -import { promisify } from 'util' -import jwt from 'jsonwebtoken' describe('Socket server', () => { - let socket: Socket + it('should send a connected message when sending an auth ticket', async () => { + const [apiKey] = await createAPIKeyAndToken([]) - beforeAll(() => { - socket = new Socket(global.server, global.em) - }) + const redis = new Redis(redisConfig) + const ticket = await createSocketTicket(redis, apiKey, false) + await redis.quit() - afterAll(() => { - socket.getServer().close() - }) - - it('should send a connected message when sending an auth header', async () => { - const [apiKey, token] = await createAPIKeyAndToken([]) - apiKey.lastUsedAt = subDays(new Date(), 1) - await (global.em).flush() - - await request(global.server) - .ws('/') - .set('authorization', `Bearer ${token}`) - .expectJson({ + await createTestSocket(`/?ticket=${ticket}`, async (client) => { + await client.expectJsonToStrictEqual({ res: 'v1.connected', data: {} }) - .close() + }, { + waitForReady: false + }) + }) - await (global.em).refresh(apiKey) - expect(isToday(apiKey.lastUsedAt)).toBe(true) + it('should close connections without an auth ticket', async () => { + await createTestSocket('/', async (client) => { + await client.expectClosed(3000) + }, { + waitForReady: false + }) }) - it('should close connections without an auth header', async () => { - await request(global.server) - .ws('/') - .expectClosed(3000) + it('should close connections message when sending an invalid auth ticket', async () => { + await createTestSocket('/?ticket=abc123', async (client) => { + await client.expectClosed(3000) + }, { + waitForReady: false + }) }) - it('should close connections message when sending an invalid auth header', async () => { + it('should close connections where the socket ticket has a revoked api key', async () => { const [apiKey] = await createAPIKeyAndToken([]) + apiKey.revokedAt = new Date() + await (global.em).flush() - const payload = { - sub: apiKey.id, - api: true, - iat: Math.floor(new Date(apiKey.createdAt).getTime() / 1000) - } - - const token = await promisify(jwt.sign)(payload, 'not_a_real_signature') + const redis = new Redis(redisConfig) + const ticket = await createSocketTicket(redis, apiKey, false) + await redis.quit() - await request(global.server) - .ws('/') - .set('authorization', `Bearer ${token}`) - .expectClosed(3000) + await createTestSocket(`/?ticket=${ticket}`, async (client) => { + await client.expectClosed(3000) + }, { + waitForReady: false + }) }) }) diff --git a/tests/socket/socketEvents.test.ts b/tests/socket/socketEvents.test.ts index 2b3851a5..3cd9a0cc 100644 --- a/tests/socket/socketEvents.test.ts +++ b/tests/socket/socketEvents.test.ts @@ -1,38 +1,25 @@ -import request from 'superwstest' -import Socket from '../../src/socket' import createAPIKeyAndToken from '../utils/createAPIKeyAndToken' -import { subDays } from 'date-fns' -import { EntityManager } from '@mikro-orm/mysql' import { ClickHouseClient } from '@clickhouse/client' import { ClickhouseSocketEvent } from '../../src/socket/socketEvent' -import createSocketIdentifyMessage from '../utils/requestAuthedSocket' +import createSocketIdentifyMessage from '../utils/createSocketIdentifyMessage' import { APIKeyScope } from '../../src/entities/api-key' +import redisConfig from '../../src/config/redis.config' +import { createSocketTicket } from '../../src/services/api/socket-ticket-api.service' +import Redis from 'ioredis' +import createTestSocket from '../utils/createTestSocket' describe('Socket events', () => { - let socket: Socket - beforeAll(() => { vi.stubEnv('DISABLE_SOCKET_EVENTS', '0') - socket = new Socket(global.server, global.em) - }) - - afterAll(() => { - socket.getServer().close() }) it('should track open, connected and close events', async () => { - const [apiKey, token] = await createAPIKeyAndToken([]) - apiKey.lastUsedAt = subDays(new Date(), 1) - await (global.em).flush() - - await request(global.server) - .ws('/') - .set('authorization', `Bearer ${token}`) - .expectJson({ - res: 'v1.connected', - data: {} - }) - .close() + const [apiKey] = await createAPIKeyAndToken([]) + const redis = new Redis(redisConfig) + const ticket = await createSocketTicket(redis, apiKey, false) + await redis.quit() + + await createTestSocket(`/?ticket=${ticket}`, async () => {}) let events: ClickhouseSocketEvent[] = [] await vi.waitUntil(async () => { @@ -66,18 +53,11 @@ describe('Socket events', () => { }) it('should track requests and responses', async () => { - const [identifyMessage, token, player] = await createSocketIdentifyMessage([APIKeyScope.READ_PLAYERS]) - - await request(global.server) - .ws('/') - .set('authorization', `Bearer ${token}`) - .expectJson({ - res: 'v1.connected', - data: {} - }) - .sendJson(identifyMessage) - .expectJson() - .close() + const { identifyMessage, ticket, player } = await createSocketIdentifyMessage([APIKeyScope.READ_PLAYERS]) + + await createTestSocket(`/?ticket=${ticket}`, async (client) => { + await client.identify(identifyMessage) + }) let events: ClickhouseSocketEvent[] = [] await vi.waitUntil(async () => { @@ -125,23 +105,17 @@ describe('Socket events', () => { }) it('should track errors', async () => { - const [identifyMessage, token, player] = await createSocketIdentifyMessage([APIKeyScope.READ_PLAYERS]) - - await request(global.server) - .ws('/') - .set('authorization', `Bearer ${token}`) - .expectJson({ - res: 'v1.connected', - data: {} - }) - .sendJson({ + const { identifyMessage, ticket, player } = await createSocketIdentifyMessage([APIKeyScope.READ_PLAYERS]) + + await createTestSocket(`/?ticket=${ticket}`, async (client) => { + client.sendJson({ ...identifyMessage, data: { ...identifyMessage.data, socketToken: 'invalid' } }) - .expectJson({ + await client.expectJsonToStrictEqual({ res: 'v1.error', data: { req: 'v1.players.identify', @@ -149,7 +123,7 @@ describe('Socket events', () => { errorCode: 'INVALID_SOCKET_TOKEN' } }) - .close() + }) let events: ClickhouseSocketEvent[] = [] await vi.waitUntil(async () => { @@ -197,19 +171,12 @@ describe('Socket events', () => { }) it('should track dev build events', async () => { - const [apiKey, token] = await createAPIKeyAndToken([]) - apiKey.lastUsedAt = subDays(new Date(), 1) - await (global.em).flush() - - await request(global.server) - .ws('/') - .set('authorization', `Bearer ${token}`) - .set('x-talo-dev-build', '1') - .expectJson({ - res: 'v1.connected', - data: {} - }) - .close() + const [apiKey] = await createAPIKeyAndToken([]) + const redis = new Redis(redisConfig) + const ticket = await createSocketTicket(redis, apiKey, true) + await redis.quit() + + await createTestSocket(`/?ticket=${ticket}`, async () => {}) let events: ClickhouseSocketEvent[] = [] await vi.waitUntil(async () => { diff --git a/tests/utils/requestAuthedSocket.ts b/tests/utils/createSocketIdentifyMessage.ts similarity index 68% rename from tests/utils/requestAuthedSocket.ts rename to tests/utils/createSocketIdentifyMessage.ts index 0432d852..83f3e313 100644 --- a/tests/utils/requestAuthedSocket.ts +++ b/tests/utils/createSocketIdentifyMessage.ts @@ -5,8 +5,9 @@ import createAPIKeyAndToken from './createAPIKeyAndToken' import Redis from 'ioredis' import redisConfig from '../../src/config/redis.config' import Player from '../../src/entities/player' +import { createSocketTicket } from '../../src/services/api/socket-ticket-api.service' -type IdentifyMessage = { +export type IdentifyMessage = { req: 'v1.players.identify' data: { playerAliasId: number @@ -14,24 +15,33 @@ type IdentifyMessage = { } } -export default async function createSocketIdentifyMessage(scopes: APIKeyScope[]): Promise<[IdentifyMessage, string, Player]> { +type SocketIdentifyData = { + identifyMessage: IdentifyMessage + ticket: string + player: Player + token: string +} + +export default async function createSocketIdentifyMessage(scopes: APIKeyScope[]): Promise { const [apiKey, token] = await createAPIKeyAndToken(scopes) const player = await new PlayerFactory([apiKey.game]).one() await (global.em).persistAndFlush(player) const redis = new Redis(redisConfig) + const ticket = await createSocketTicket(redis, apiKey, false) const socketToken = await player.aliases[0].createSocketToken(redis) await redis.quit() - return [ - { + return { + identifyMessage: { req: 'v1.players.identify', data: { playerAliasId: player.aliases[0].id, - socketToken: socketToken + socketToken } }, - token, - player - ] + ticket, + player, + token + } } diff --git a/tests/utils/createTestSocket.ts b/tests/utils/createTestSocket.ts new file mode 100644 index 00000000..5f8884e9 --- /dev/null +++ b/tests/utils/createTestSocket.ts @@ -0,0 +1,111 @@ + +import { createServer } from 'http' +import Socket from '../../src/socket' +import { randNumber } from '@ngneat/falso' +import { WebSocket } from 'ws' +import { IdentifyMessage } from './createSocketIdentifyMessage' + +class TestClient extends WebSocket { + private messages: string[] = [] + private closed = false + private closeCode: number + private closeReason: string + + constructor(url: string) { + super(url) + this.on('message', (message) => { + this.messages.push(message.toString()) + }) + this.on('close', (code, reason) => { + this.closed = true + this.closeCode = code + this.closeReason = reason.toString() + }) + } + + async identify(message: IdentifyMessage) { + this.sendJson(message) + + await this.expectJson((json) => { + expect(json.res).toBe('v1.players.identify.success') + expect(json.data.id).toBe(message.data.playerAliasId) + }) + } + + expectReady() { + return this.expectJsonToStrictEqual({ + res: 'v1.connected', + data: {} + }) + } + + expectClosed(code?: number, reason?: string) { + return vi.waitUntil(() => { + return this.closed && + (code !== undefined ? this.closeCode === code : true) && + (reason !== undefined ? this.closeReason === reason : true) + }) + } + + sendJson(json: T) { + this.messages = [] + this.send(JSON.stringify(json)) + } + + // eslint-disable-next-line @typescript-eslint/no-explicit-any + async expectJson(cb: (json: any) => void) { + try { + await vi.waitUntil(async () => { + for (const message of this.messages.reverse()) { + try { + cb(JSON.parse(message)) + return true + } catch { + continue + } + } + return false + }) + } catch (err) { + throw new Error('Message not found') + } + } + + async expectJsonToStrictEqual(json: object) { + try { + await this.expectJson((actual) => { + expect(actual).toStrictEqual(json) + }) + } catch (err) { + throw new Error(`Message not found: ${JSON.stringify(json)}`) + } + } +} + +type TestSocketOptions = { + waitForReady?: boolean +} + +export default async function createTestSocket( + url: string, + cb: (client: TestClient, wss: Socket) => Promise, + opts: TestSocketOptions = { + waitForReady: true + } +) { + const server = createServer() + const port = randNumber({ min: 999, max: 65535 }) + await new Promise((resolve) => server.listen(port, resolve)) + + const wss = new Socket(server, global.em) + global.ctx.wss = wss + + const client = new TestClient(`ws://localhost:${port}${url}`) + if (opts.waitForReady) { + await client.expectReady() + } + await cb(client, wss) + client.close() + + await new Promise((resolve) => server.close(() => resolve())) +}