diff --git a/.changeset/metal-snails-prove.md b/.changeset/metal-snails-prove.md new file mode 100644 index 00000000..d1f972c1 --- /dev/null +++ b/.changeset/metal-snails-prove.md @@ -0,0 +1,5 @@ +--- +"nostream": patch +--- + +fix: resolve TOCTOU race condition and key collisions in SlidingWindowRateLimiter diff --git a/src/@types/adapters.ts b/src/@types/adapters.ts index 130d7853..3b6077eb 100644 --- a/src/@types/adapters.ts +++ b/src/@types/adapters.ts @@ -25,8 +25,11 @@ export interface ICacheAdapter { removeRangeByScoreFromSortedSet(key: string, min: number, max: number): Promise getRangeFromSortedSet(key: string, start: number, stop: number): Promise setKeyExpiry(key: string, expiry: number): Promise + deleteKey(key: string): Promise getHKey(key: string, field: string): Promise setHKey(key: string, fields: Record): Promise + + eval(script: string, keys: string[], args: string[]): Promise } diff --git a/src/adapters/redis-adapter.ts b/src/adapters/redis-adapter.ts index 3b8e062f..0203d02b 100644 --- a/src/adapters/redis-adapter.ts +++ b/src/adapters/redis-adapter.ts @@ -96,6 +96,7 @@ export class RedisAdapter implements ICacheAdapter { return this.client.zAdd(key, members) } + public async deleteKey(key: string): Promise { await this.connection logger('delete %s key', key) @@ -123,4 +124,5 @@ export class RedisAdapter implements ICacheAdapter { return await this.client.evalSha(this.scriptShas.get(script)!, { keys, arguments: args }) } + } diff --git a/src/utils/sliding-window-rate-limiter.ts b/src/utils/sliding-window-rate-limiter.ts index 44d94432..c91efae4 100644 --- a/src/utils/sliding-window-rate-limiter.ts +++ b/src/utils/sliding-window-rate-limiter.ts @@ -4,24 +4,67 @@ import { ICacheAdapter } from '../@types/adapters' const logger = createLogger('sliding-window-rate-limiter') +const SLIDING_WINDOW_RATE_LIMITER_LUA_SCRIPT = ` + local key = KEYS[1] + local timestamp = tonumber(ARGV[1]) + local period = tonumber(ARGV[2]) + local step = tonumber(ARGV[3]) + local max_rate = tonumber(ARGV[4]) + + local windowStart = timestamp - period + + redis.call('ZREMRANGEBYSCORE', key, 0, windowStart) + + local entries = redis.call('ZRANGE', key, 0, -1) + local hits = 0 + for i=1, #entries do + local step_str = string.match(entries[i], "^[^:]+:([^:]+)") + if step_str then + local entry_step = tonumber(step_str) + if entry_step then + hits = hits + entry_step + end + end + end + + if hits + step > max_rate then + return 1 + end + + local base_member = timestamp .. ':' .. step + local member = base_member + local counter = 0 + while redis.call('ZSCORE', key, member) do + counter = counter + 1 + member = base_member .. ':' .. counter + end + + redis.call('ZADD', key, timestamp, member) + redis.call('PEXPIRE', key, period) + + return 0 +` + export class SlidingWindowRateLimiter implements IRateLimiter { - public constructor(private readonly cache: ICacheAdapter) {} + public constructor( + private readonly cache: ICacheAdapter, + ) { } public async hit(key: string, step: number, options: IRateLimiterOptions): Promise { const timestamp = Date.now() - const { period } = options + const { period, rate } = options - const [, , entries] = await Promise.all([ - this.cache.removeRangeByScoreFromSortedSet(key, 0, timestamp - period), - this.cache.addToSortedSet(key, { [`${timestamp}:${step}`]: timestamp.toString() }), - this.cache.getRangeFromSortedSet(key, 0, -1), - this.cache.setKeyExpiry(key, period), + const result = await this.cache.eval(SLIDING_WINDOW_RATE_LIMITER_LUA_SCRIPT, [key], [ + timestamp.toString(), + period.toString(), + step.toString(), + rate.toString(), ]) - const hits = entries.reduce((acc, timestampAndStep) => acc + Number(timestampAndStep.split(':')[1]), 0) + const isRateLimited = result === 1 || result === '1' - logger('hit count on %s bucket: %d', key, hits) + logger('hit on %s bucket: is rate limited? %s', key, isRateLimited) - return hits > options.rate + return isRateLimited } } diff --git a/test/unit/utils/sliding-window-rate-limiter.spec.ts b/test/unit/utils/sliding-window-rate-limiter.spec.ts index 87cb75a4..a864a93d 100644 --- a/test/unit/utils/sliding-window-rate-limiter.spec.ts +++ b/test/unit/utils/sliding-window-rate-limiter.spec.ts @@ -17,6 +17,7 @@ describe('SlidingWindowRateLimiter', () => { let getKeyStub: Sinon.SinonStub let hasKeyStub: Sinon.SinonStub let setKeyStub: Sinon.SinonStub + let evalStub: Sinon.SinonStub let sandbox: Sinon.SinonSandbox @@ -30,6 +31,7 @@ describe('SlidingWindowRateLimiter', () => { getKeyStub = sandbox.stub() hasKeyStub = sandbox.stub() setKeyStub = sandbox.stub() + evalStub = sandbox.stub() cache = { removeRangeByScoreFromSortedSet: removeRangeByScoreFromSortedSetStub, addToSortedSet: addToSortedSetStub, @@ -38,7 +40,10 @@ describe('SlidingWindowRateLimiter', () => { getKey: getKeyStub, hasKey: hasKeyStub, setKey: setKeyStub, + eval: evalStub, } as unknown as ICacheAdapter + + rateLimiter = new SlidingWindowRateLimiter(cache) }) @@ -48,20 +53,32 @@ describe('SlidingWindowRateLimiter', () => { }) it('returns true if rate limited', async () => { - const now = Date.now() - getRangeFromSortedSetStub.resolves([`${now}:6`, `${now}:4`, `${now}:1`]) + evalStub.resolves(1) const actualResult = await rateLimiter.hit('key', 1, { period: 60000, rate: 10 }) expect(actualResult).to.be.true + expect(evalStub).to.have.been.calledOnce + const args = evalStub.firstCall.args + expect(args[1]).to.deep.equal(['key']) + expect(args[2][1]).to.equal('60000') // period + expect(args[2][2]).to.equal('1') // step + expect(args[2][3]).to.equal('10') // max_rate }) it('returns false if not rate limited', async () => { - const now = Date.now() - getRangeFromSortedSetStub.resolves([`${now}:10`]) + evalStub.resolves(0) const actualResult = await rateLimiter.hit('key', 1, { period: 60000, rate: 10 }) expect(actualResult).to.be.false }) + + it('robustly handles string return types from Redis', async () => { + evalStub.resolves('1') + + const actualResult = await rateLimiter.hit('key', 1, { period: 60000, rate: 10 }) + + expect(actualResult).to.be.true + }) })