diff --git a/src/throttler.guard.ts b/src/throttler.guard.ts index ba1eb2d8..f5929e28 100644 --- a/src/throttler.guard.ts +++ b/src/throttler.guard.ts @@ -1,4 +1,4 @@ -import { CanActivate, ExecutionContext, Inject, Injectable } from '@nestjs/common'; +import { CanActivate, ExecutionContext, Inject, Injectable, ContextType } from '@nestjs/common'; import { Reflector } from '@nestjs/core'; import * as md5 from 'md5'; import { ThrottlerStorage } from './throttler-storage.interface'; @@ -61,11 +61,13 @@ export class ThrottlerGuard implements CanActivate { const limit = routeOrClassLimit || this.options.limit; const ttl = routeOrClassTtl || this.options.ttl; - switch (context.getType()) { + switch (context.getType()) { case 'http': return this.httpHandler(context, limit, ttl); case 'ws': return this.websocketHandler(context, limit, ttl); + case 'graphql': + return this.graphqlHandler(context, limit, ttl); default: return true; } @@ -133,6 +135,25 @@ export class ThrottlerGuard implements CanActivate { return true; } + private async graphqlHandler( + context: ExecutionContext, + limit: number, + ttl: number, + ): Promise { + const { req, res } = context.getArgByIndex(2); + const httpContext: ExecutionContext = { + ...context, + switchToHttp: () => ({ + getRequest: () => req, + getResponse: () => res, + getNext: context.switchToHttp().getNext, + }), + getClass: context.getClass, + getHandler: context.getHandler, + }; + return this.httpHandler(httpContext, limit, ttl); + } + /** * Generate a hashed key that will be used as a storage key. * The key will always be a combination of the current context and IP. diff --git a/test/app/resolvers/app.resolver.ts b/test/app/resolvers/app.resolver.ts index c5af58fa..8bcb3c60 100644 --- a/test/app/resolvers/app.resolver.ts +++ b/test/app/resolvers/app.resolver.ts @@ -1,7 +1,9 @@ -import { Resolver, Query, Mutation } from '@nestjs/graphql'; -import { ResolveType } from './resolve.model'; +import { Mutation, Query, Resolver } from '@nestjs/graphql'; import { AppService } from '../app.service'; +import { ResolveType } from './resolve.model'; +import { Throttle } from '../../../src'; +@Throttle(2, 10) @Resolver(ResolveType) export class AppResolver { constructor(private readonly appService: AppService) {} diff --git a/test/app/resolvers/default.resolver.ts b/test/app/resolvers/default.resolver.ts index 488a34e9..9cfc33fb 100644 --- a/test/app/resolvers/default.resolver.ts +++ b/test/app/resolvers/default.resolver.ts @@ -1,5 +1,18 @@ -import { Resolver } from '@nestjs/graphql'; +import { Mutation, Query, Resolver } from '@nestjs/graphql'; +import { AppService } from '../app.service'; import { ResolveType } from './resolve.model'; @Resolver(ResolveType) -export class DefaultResolver {} +export class DefaultResolver { + constructor(private readonly appService: AppService) {} + + @Query(() => ResolveType) + defaultQuery() { + return this.appService.success(); + } + + @Mutation(() => ResolveType) + defaultMutation() { + return this.appService.success(); + } +} diff --git a/test/app/resolvers/limit.resolver.ts b/test/app/resolvers/limit.resolver.ts index 652f9027..8df1b744 100644 --- a/test/app/resolvers/limit.resolver.ts +++ b/test/app/resolvers/limit.resolver.ts @@ -1,5 +1,20 @@ -import { Resolver } from '@nestjs/graphql'; +import { Mutation, Query, Resolver } from '@nestjs/graphql'; +import { Throttle } from '../../../src'; +import { AppService } from '../app.service'; import { ResolveType } from './resolve.model'; @Resolver(ResolveType) -export class LimitResolver {} +export class LimitResolver { + constructor(private readonly appService: AppService) {} + + @Query(() => ResolveType) + limitQuery() { + return this.appService.success(); + } + + @Throttle(2, 10) + @Mutation(() => ResolveType) + limitMutation() { + return this.appService.success(); + } +} diff --git a/test/resolver.e2e-spec.ts b/test/resolver.e2e-spec.ts index c08b6ab1..d4a200b9 100644 --- a/test/resolver.e2e-spec.ts +++ b/test/resolver.e2e-spec.ts @@ -2,28 +2,39 @@ import { INestApplication } from '@nestjs/common'; import { AbstractHttpAdapter, APP_GUARD } from '@nestjs/core'; import { GraphQLModule } from '@nestjs/graphql'; import { ExpressAdapter } from '@nestjs/platform-express'; -import { FastifyAdapter } from '@nestjs/platform-fastify'; +// import { FastifyAdapter } from '@nestjs/platform-fastify'; import { Test, TestingModule } from '@nestjs/testing'; import { ThrottlerGuard } from '../src'; import { ResolverModule } from './app/resolvers/resolver.module'; import { httPromise } from './utility/httpromise'; -function queryFactory(prefix: string): Record { - return { - query: `query ${prefix}Query{ ${prefix}Query{ success }}`, - }; -} - -function mutationFactory(prefix: string): Record { - return { - query: `mutation ${prefix}Mutation{ ${prefix}Mutation{ success }}`, - }; -} +const factories = { + query: (prefix: string): Record => { + return { + query: `query ${prefix}Query{ ${prefix}Query{ success }}`, + }; + }, + mutation: (prefix: string): Record => { + return { + query: `mutation ${prefix}Mutation{ ${prefix}Mutation{ success }}`, + }; + }, + data: (prefix: string, type: string): Record => { + type = type[0].toUpperCase() + type.substring(1, type.length); + return { + data: { + [prefix + type]: { + success: true, + }, + }, + }; + }, +}; +// ${new FastifyAdapter()} | ${'Fastify'} | ${() => ({})} describe.each` adapter | adapterName | context ${new ExpressAdapter()} | ${'Express'} | ${({ req, res }) => ({ req, res })} - ${new FastifyAdapter()} | ${'Fastify'} | ${({}) => ({})} `( '$adapterName Throttler', ({ adapter, context }: { adapter: AbstractHttpAdapter; context: () => any }) => { @@ -64,32 +75,64 @@ describe.each` * Tests for setting `@Throttle()` at the method level and for ignore routes */ describe('AppResolver', () => { - it.todo('Implement AppResolver tests'); it.each` type ${'query'} ${'mutation'} `('$type', async ({ type }: { type: string }) => { - const res = await httPromise( - appUrl, - 'POST', - {}, - type === 'query' ? queryFactory('app') : mutationFactory('app'), - ); - expect(res).toEqual({ success: true }); + const res = await httPromise(appUrl, 'POST', {}, factories[type]('app')); + expect(res.data).toEqual(factories.data('app', type)); + expect(res.headers).toMatchObject({ + 'x-ratelimit-limit': '2', + 'x-ratelimit-remaining': '1', + 'x-ratelimit-reset': /\d+/, + }); }); }); /** * Tests for setting `@Throttle()` at the class level and overriding at the method level */ describe('LimitResolver', () => { - it.todo('Implement LimitResolver test'); + it.each` + type | limit + ${'query'} | ${5} + ${'mutation'} | ${2} + `('$type', async ({ type, limit }: { type: string; limit: number }) => { + for (let i = 0; i < limit; i++) { + const res = await httPromise(appUrl, 'POST', {}, factories[type]('limit')); + expect(res.data).toEqual(factories.data('limit', type)); + expect(res.headers).toMatchObject({ + 'x-ratelimit-limit': limit.toString(), + 'x-ratelimit-remaining': (limit - (i + 1)).toString(), + 'x-ratelimit-reset': /\d+/, + }); + } + const errRes = await httPromise(appUrl, 'POST', {}, factories[type]('limit')); + expect(errRes.data).not.toEqual(factories.data('limit', type)); + expect(errRes.data.errors[0].message).toBe('ThrottlerException: Too Many Requests'); + expect(errRes.headers).toMatchObject({ + 'retry-after': /\d+/, + }); + expect(errRes.status).toBe(200); + }); }); /** * Tests for setting throttle values at the `forRoot` level */ describe('DefaultResolver', () => { - it.todo('implement DefaultResolver Test'); + it.each` + type + ${'query'} + ${'mutation'} + `('$type', async ({ type }: { type: string }) => { + const res = await httPromise(appUrl, 'POST', {}, factories[type]('default')); + expect(res.data).toEqual(factories.data('default', type)); + expect(res.headers).toMatchObject({ + 'x-ratelimit-limit': '5', + 'x-ratelimit-remaining': '4', + 'x-ratelimit-reset': /\d+/, + }); + }); }); }); },