Skip to content

Commit

Permalink
Merge pull request #120 from upstash/DX-1116
Browse files Browse the repository at this point in the history
Hardcode script hash values and deprecate cacheScripts
  • Loading branch information
CahidArda committed Sep 10, 2024
2 parents bd64e78 + 6477b03 commit 377f72c
Show file tree
Hide file tree
Showing 6 changed files with 174 additions and 95 deletions.
65 changes: 20 additions & 45 deletions src/hash.ts
Original file line number Diff line number Diff line change
@@ -1,63 +1,38 @@
import { ScriptInfo } from "./lua-scripts/hash";
import { Context, RegionContext } from "./types"

type ScriptKind = "limitHash" | "getRemainingHash" | "resetHash"

/**
* Loads the scripts to redises with SCRIPT LOAD if the first region context
* doesn't have the kind of script hash in it.
*
* @param ctx Regional or multi region context
* @param script script to load
* @param kind script kind
*/
const setHash = async (
ctx: Context,
script: string,
kind: ScriptKind
) => {
const regionContexts = "redis" in ctx ? [ctx] : ctx.regionContexts
const hashSample = regionContexts[0].scriptHashes[kind]
if (!hashSample) {
await Promise.all(regionContexts.map(async (context) => {
context.scriptHashes[kind] = await context.redis.scriptLoad(script)
}));
};
}

/**
* Runds the specified script with EVALSHA if ctx.cacheScripts or EVAL
* otherwise.
* Runs the specified script with EVALSHA using the scriptHash parameter.
*
* If the script is not found when EVALSHA is used, it submits the script
* with LOAD SCRIPT, then calls EVALSHA again.
* If the EVALSHA fails, loads the script to redis and runs again with the
* hash returned from Redis.
*
* @param ctx Regional or multi region context
* @param script script to run
* @param kind script kind
* @param keys
* @param args
* @param script ScriptInfo of script to run. Contains the script and its hash
* @param keys eval keys
* @param args eval args
*/
export const safeEval = async (
ctx: RegionContext,
script: string,
kind: ScriptKind,
script: ScriptInfo,
keys: any[],
args: any[],
) => {
if (!ctx.cacheScripts) {
return await ctx.redis.eval(script, keys, args);
};

await setHash(ctx, script, kind);
try {
return await ctx.redis.evalsha(ctx.scriptHashes[kind]!, keys, args)
return await ctx.redis.evalsha(script.hash, keys, args)
} catch (error) {
if (`${error}`.includes("NOSCRIPT")) {
console.log("Script with the expected hash was not found in redis db. It is probably flushed. Will load another scipt before continuing.");
ctx.scriptHashes[kind] = undefined;
await setHash(ctx, script, kind)
console.log(" New script successfully loaded.")
return await ctx.redis.evalsha(ctx.scriptHashes[kind]!, keys, args)
const hash = await ctx.redis.scriptLoad(script.script)

if (hash !== script.hash) {
console.warn(
"Upstash Ratelimit: Expected hash and the hash received from Redis"
+ " are different. Ratelimit will work as usual but performance will"
+ " be reduced."
);
}

return await ctx.redis.evalsha(hash, keys, args)
}
throw error;
}
Expand Down
32 changes: 32 additions & 0 deletions src/lua-scripts/hash.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import { Redis } from "@upstash/redis";
import { describe, expect, test } from "bun:test";
import { RESET_SCRIPT, SCRIPTS } from "./hash";

describe("should use correct hash for lua scripts", () => {
const redis = Redis.fromEnv();

const validateHash = async (script: string, expectedHash: string) => {
const hash = await redis.scriptLoad(script)
expect(hash).toBe(expectedHash)
}

const algorithms = [
...Object.entries(SCRIPTS.singleRegion), ...Object.entries(SCRIPTS.multiRegion)
]

// for each algorithm (fixedWindow, slidingWindow etc)
for (const [algorithm, scripts] of algorithms) {
describe(`${algorithm}`, () => {
// for each method (limit & getRemaining)
for (const [method, scriptInfo] of Object.entries(scripts)) {
test(method, async () => {
await validateHash(scriptInfo.script, scriptInfo.hash)
})
}
})
}

test("reset script", async () => {
await validateHash(RESET_SCRIPT.script, RESET_SCRIPT.hash)
})
})
95 changes: 95 additions & 0 deletions src/lua-scripts/hash.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import * as Single from "./single"
import * as Multi from "./multi"
import { resetScript } from "./reset"

export type ScriptInfo = {
script: string,
hash: string
}

type Algorithm = {
limit: ScriptInfo,
getRemaining: ScriptInfo,
}

type AlgorithmKind =
| "fixedWindow"
| "slidingWindow"
| "tokenBucket"
| "cachedFixedWindow"

export const SCRIPTS: {
singleRegion: Record<AlgorithmKind, Algorithm>,
multiRegion: Record<Exclude<AlgorithmKind, "tokenBucket" | "cachedFixedWindow">, Algorithm>,
} = {
singleRegion: {
fixedWindow: {
limit: {
script: Single.fixedWindowLimitScript,
hash: "b13943e359636db027ad280f1def143f02158c13"
},
getRemaining: {
script: Single.fixedWindowRemainingTokensScript,
hash: "8c4c341934502aee132643ffbe58ead3450e5208"
},
},
slidingWindow: {
limit: {
script: Single.slidingWindowLimitScript,
hash: "e1391e429b699c780eb0480350cd5b7280fd9213"
},
getRemaining: {
script: Single.slidingWindowRemainingTokensScript,
hash: "65a73ac5a05bf9712903bc304b77268980c1c417"
},
},
tokenBucket: {
limit: {
script: Single.tokenBucketLimitScript,
hash: "5bece90aeef8189a8cfd28995b479529e270b3c6"
},
getRemaining: {
script: Single.tokenBucketRemainingTokensScript,
hash: "a15be2bb1db2a15f7c82db06146f9d08983900d0"
},
},
cachedFixedWindow: {
limit: {
script: Single.cachedFixedWindowLimitScript,
hash: "c26b12703dd137939b9a69a3a9b18e906a2d940f"
},
getRemaining: {
script: Single.cachedFixedWindowRemainingTokenScript,
hash: "8e8f222ccae68b595ee6e3f3bf2199629a62b91a"
},
}
},
multiRegion: {
fixedWindow: {
limit: {
script: Multi.fixedWindowLimitScript,
hash: "a8c14f3835aa87bd70e5e2116081b81664abcf5c"
},
getRemaining: {
script: Multi.fixedWindowRemainingTokensScript,
hash: "8ab8322d0ed5fe5ac8eb08f0c2e4557f1b4816fd"
},
},
slidingWindow: {
limit: {
script: Multi.slidingWindowLimitScript,
hash: "cb4fdc2575056df7c6d422764df0de3a08d6753b"
},
getRemaining: {
script: Multi.slidingWindowRemainingTokensScript,
hash: "558c9306b7ec54abb50747fe0b17e5d44bd24868"
},
},
}
}

/** COMMON */
export const RESET_SCRIPT: ScriptInfo = {
script: resetScript,
hash: "54bd274ddc59fb3be0f42deee2f64322a10e2b50"
}
21 changes: 7 additions & 14 deletions src/multi.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import { Cache } from "./cache";
import type { Duration } from "./duration";
import { ms } from "./duration";
import { safeEval } from "./hash";
import { RESET_SCRIPT, SCRIPTS } from "./lua-scripts/hash";
import {
fixedWindowLimitScript,
fixedWindowRemainingTokensScript,
Expand Down Expand Up @@ -115,8 +116,6 @@ export class MultiRegionRatelimit extends Ratelimit<MultiRegionContext> {
ctx: {
regionContexts: config.redis.map(redis => ({
redis: redis,
scriptHashes: {},
cacheScripts: config.cacheScripts ?? true,
})),
cache: config.ephemeralCache ? new Cache(config.ephemeralCache) : undefined,
},
Expand Down Expand Up @@ -178,8 +177,7 @@ export class MultiRegionRatelimit extends Ratelimit<MultiRegionContext> {
redis: regionContext.redis,
request: safeEval(
regionContext,
fixedWindowLimitScript,
"limitHash",
SCRIPTS.multiRegion.fixedWindow.limit,
[key],
[requestId, windowDuration, incrementBy],
) as Promise<string[]>,
Expand Down Expand Up @@ -284,8 +282,7 @@ export class MultiRegionRatelimit extends Ratelimit<MultiRegionContext> {
redis: regionContext.redis,
request: safeEval(
regionContext,
fixedWindowRemainingTokensScript,
"getRemainingHash",
SCRIPTS.multiRegion.fixedWindow.getRemaining,
[key],
[null]
) as Promise<string[]>,
Expand Down Expand Up @@ -316,8 +313,7 @@ export class MultiRegionRatelimit extends Ratelimit<MultiRegionContext> {
await Promise.all(ctx.regionContexts.map((regionContext) => {
safeEval(
regionContext,
resetScript,
"resetHash",
RESET_SCRIPT,
[pattern],
[null]
);
Expand Down Expand Up @@ -385,8 +381,7 @@ export class MultiRegionRatelimit extends Ratelimit<MultiRegionContext> {
redis: regionContext.redis,
request: safeEval(
regionContext,
slidingWindowLimitScript,
"limitHash",
SCRIPTS.multiRegion.slidingWindow.limit,
[currentKey, previousKey],
[tokens, now, windowDuration, requestId, incrementBy],
// lua seems to return `1` for true and `null` for false
Expand Down Expand Up @@ -508,8 +503,7 @@ export class MultiRegionRatelimit extends Ratelimit<MultiRegionContext> {
redis: regionContext.redis,
request: safeEval(
regionContext,
slidingWindowRemainingTokensScript,
"getRemainingHash",
SCRIPTS.multiRegion.slidingWindow.getRemaining,
[currentKey, previousKey],
[now, windowSize],
// lua seems to return `1` for true and `null` for false
Expand All @@ -532,8 +526,7 @@ export class MultiRegionRatelimit extends Ratelimit<MultiRegionContext> {
await Promise.all(ctx.regionContexts.map((regionContext) => {
safeEval(
regionContext,
resetScript,
"resetHash",
RESET_SCRIPT,
[pattern],
[null]
);
Expand Down
Loading

0 comments on commit 377f72c

Please sign in to comment.