diff --git a/lib/RateLimiterRedis.js b/lib/RateLimiterRedis.js index a045d47..98d52b1 100644 --- a/lib/RateLimiterRedis.js +++ b/lib/RateLimiterRedis.js @@ -27,13 +27,14 @@ class RateLimiterRedis extends RateLimiterStoreAbstract { this.client = opts.storeClient; this._rejectIfRedisNotReady = !!opts.rejectIfRedisNotReady; + this._incrTtlLuaScript = opts.customIncrTtlLuaScript || incrTtlLuaScript; this.useRedisPackage = opts.useRedisPackage || this.client.constructor.name === 'Commander' || false; this.useRedis3AndLowerPackage = opts.useRedis3AndLowerPackage; if (typeof this.client.defineCommand === 'function') { this.client.defineCommand("rlflxIncr", { numberOfKeys: 1, - lua: incrTtlLuaScript, + lua: this._incrTtlLuaScript, }); } } @@ -105,7 +106,7 @@ class RateLimiterRedis extends RateLimiterStoreAbstract { if (secDuration > 0) { if(!this.useRedisPackage && !this.useRedis3AndLowerPackage){ return this.client.rlflxIncr( - [rlKey].concat([String(points), String(secDuration)])); + [rlKey].concat([String(points), String(secDuration), String(this.points)])); } if (this.useRedis3AndLowerPackage) { return new Promise((resolve, reject) => { @@ -118,15 +119,15 @@ class RateLimiterRedis extends RateLimiterStoreAbstract { }; if (typeof this.client.rlflxIncr === 'function') { - this.client.rlflxIncr(rlKey, points, secDuration, incrCallback); + this.client.rlflxIncr(rlKey, points, secDuration, this.points, incrCallback); } else { - this.client.eval(incrTtlLuaScript, 1, rlKey, points, secDuration, incrCallback); + this.client.eval(this._incrTtlLuaScript, 1, rlKey, points, secDuration, this.points, incrCallback); } }); } else { - return this.client.eval(incrTtlLuaScript, { + return this.client.eval(this._incrTtlLuaScript, { keys: [rlKey], - arguments: [String(points), String(secDuration)], + arguments: [String(points), String(secDuration), String(this.points)], }); } } else { diff --git a/lib/index.d.ts b/lib/index.d.ts index 8f0f7f7..9246f69 100644 --- a/lib/index.d.ts +++ b/lib/index.d.ts @@ -255,6 +255,7 @@ interface IRateLimiterRedisOptions extends IRateLimiterStoreOptions { rejectIfRedisNotReady?: boolean; useRedisPackage?: boolean; useRedis3AndLowerPackage?: boolean; + customIncrTtlLuaScript?: string; } interface ICallbackReady { diff --git a/test/RateLimiterRedis.ioredis.test.js b/test/RateLimiterRedis.ioredis.test.js index cac5a5b..247ae0d 100644 --- a/test/RateLimiterRedis.ioredis.test.js +++ b/test/RateLimiterRedis.ioredis.test.js @@ -59,6 +59,53 @@ describe('RateLimiterRedis with fixed window', function RateLimiterRedisTest() { }); }); + describe('when customIncrTtlLuaScript is provided', () => { + it('rejected when consume more than maximum points and multiply delay', (done) => { + const testKey = 'consume2'; + const rateLimiter = new RateLimiterRedis({ + storeClient: redisMockClient, + points: 1, + duration: 5, + customIncrTtlLuaScript: `local ok = redis.call('set', KEYS[1], 0, 'EX', ARGV[2], 'NX') \ + local consumed = redis.call('incrby', KEYS[1], ARGV[1]) \ + local ttl = redis.call('pttl', KEYS[1]) \ + if ttl == -1 then \ + redis.call('expire', KEYS[1], ARGV[2]) \ + ttl = 1000 * ARGV[2] \ + else \ + local maxPoints = tonumber(ARGV[3]) \ + if maxPoints > 0 and (consumed-1) % maxPoints == 0 and not ok then \ + local expireTime = ttl + tonumber(ARGV[2]) * 1000 \ + redis.call('pexpire', KEYS[1], expireTime) \ + return {consumed, expireTime} \ + end \ + end \ + return {consumed, ttl} \ + ` + }); + rateLimiter + .consume(testKey) + .then(() => { + rateLimiter + .consume(testKey) + .then(() => {}) + .catch((rejRes) => { + expect(rejRes.msBeforeNext >= 5000).to.equal(true); + rateLimiter + .consume(testKey) + .then(() => {}) + .catch((rejRes2) => { + expect(rejRes2.msBeforeNext >= 10000).to.equal(true); + done(); + }); + }); + }) + .catch((err) => { + done(err); + }); + }); + }); + it('execute evenly over duration', (done) => { const testKey = 'consumeEvenly'; const rateLimiter = new RateLimiterRedis({ diff --git a/test/RateLimiterRedis.redis.test.js b/test/RateLimiterRedis.redis.test.js index 7151e25..c1c5cb6 100644 --- a/test/RateLimiterRedis.redis.test.js +++ b/test/RateLimiterRedis.redis.test.js @@ -58,6 +58,54 @@ describe('RateLimiterRedis with fixed window', function RateLimiterRedisTest() { }); }); + describe('when customIncrTtlLuaScript is provided', () => { + it('rejected when consume more than maximum points and multiply delay', (done) => { + const testKey = 'consume2'; + const rateLimiter = new RateLimiterRedis({ + storeClient: redisMockClient, + points: 1, + duration: 5, + customIncrTtlLuaScript: `local ok = redis.call('set', KEYS[1], 0, 'EX', ARGV[2], 'NX') \ + local consumed = redis.call('incrby', KEYS[1], ARGV[1]) \ + local ttl = redis.call('pttl', KEYS[1]) \ + if ttl == -1 then \ + redis.call('expire', KEYS[1], ARGV[2]) \ + ttl = 1000 * ARGV[2] \ + else \ + local maxPoints = tonumber(ARGV[3]) \ + if maxPoints > 0 and (consumed-1) % maxPoints == 0 and not ok then \ + local expireTime = ttl + tonumber(ARGV[2]) * 1000 \ + redis.call('pexpire', KEYS[1], expireTime) \ + return {consumed, expireTime} \ + end \ + end \ + return {consumed, ttl} \ + `, + useRedisPackage: true, + }); + rateLimiter + .consume(testKey) + .then(() => { + rateLimiter + .consume(testKey) + .then(() => {}) + .catch((rejRes) => { + expect(rejRes.msBeforeNext >= 5000).to.equal(true); + rateLimiter + .consume(testKey) + .then(() => {}) + .catch((rejRes2) => { + expect(rejRes2.msBeforeNext >= 10000).to.equal(true); + done(); + }); + }); + }) + .catch((err) => { + done(err); + }); + }); + }); + it('execute evenly over duration', (done) => { const testKey = 'consumeEvenly'; const rateLimiter = new RateLimiterRedis({