diff --git a/src/lib/rate-limiter.test.ts b/src/lib/rate-limiter.test.ts new file mode 100644 index 0000000..14a0246 --- /dev/null +++ b/src/lib/rate-limiter.test.ts @@ -0,0 +1,56 @@ +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' + +import { RateLimiter } from './rate-limiter' + +describe('RateLimiter.throttle', () => { + beforeEach(() => { + vi.useFakeTimers() + vi.setSystemTime(new Date(0)) + }) + + afterEach(() => { + vi.useRealTimers() + }) + + it('returns null when the window is saturated for medium priority', async () => { + const limiter = new RateLimiter({ + maxRequests: 1, + windowMs: 1000, + retryDelay: 10, + maxRetries: 2 + }) + const fn = vi.fn(async () => 'ok') + + await limiter.throttle('key', fn, 'medium') + const result = await limiter.throttle('key', fn, 'medium') + + expect(result).toBeNull() + expect(fn).toHaveBeenCalledTimes(1) + }) + + it('bounds high-priority retries without recursion when the window is saturated', async () => { + const limiter = new RateLimiter({ + maxRequests: 1, + windowMs: 1000, + retryDelay: 10, + maxRetries: 3 + }) + const fn = vi.fn(async () => 'ok') + + await limiter.throttle('key', fn, 'high') + + const spy = vi.spyOn(limiter, 'throttle') + let resolved: unknown = 'pending' + const pending = limiter.throttle('key', fn, 'high').then(result => { + resolved = result + return result + }) + + await vi.advanceTimersByTimeAsync(30) + await pending + + expect(resolved).toBeNull() + expect(fn).toHaveBeenCalledTimes(1) + expect(spy).toHaveBeenCalledTimes(1) + }) +}) diff --git a/src/lib/rate-limiter.ts b/src/lib/rate-limiter.ts index 3f17ebe..e761def 100644 --- a/src/lib/rate-limiter.ts +++ b/src/lib/rate-limiter.ts @@ -2,6 +2,7 @@ interface RateLimitConfig { maxRequests: number windowMs: number retryDelay: number + maxRetries?: number } interface RequestRecord { @@ -9,14 +10,15 @@ interface RequestRecord { count: number } -class RateLimiter { +export class RateLimiter { private requests: Map = new Map() private config: RateLimitConfig constructor(config: RateLimitConfig = { maxRequests: 5, windowMs: 60000, - retryDelay: 2000 + retryDelay: 2000, + maxRetries: 3 }) { this.config = config } @@ -26,13 +28,13 @@ class RateLimiter { fn: () => Promise, priority: 'low' | 'medium' | 'high' = 'medium' ): Promise { - const maxHighPriorityRetries = 5 - let retryCount = 0 - let record: RequestRecord | undefined + const maxRetries = this.config.maxRetries ?? 3 + let attempts = 0 while (true) { const now = Date.now() - record = this.requests.get(key) + const record = this.requests.get(key) + let isLimited = false if (record) { const timeElapsed = now - record.timestamp @@ -40,17 +42,10 @@ class RateLimiter { if (timeElapsed < this.config.windowMs) { if (record.count >= this.config.maxRequests) { console.warn(`Rate limit exceeded for ${key}. Try again in ${Math.ceil((this.config.windowMs - timeElapsed) / 1000)}s`) - - if (priority === 'high' && retryCount < maxHighPriorityRetries) { - retryCount++ - await new Promise(resolve => setTimeout(resolve, this.config.retryDelay)) - continue - } - - return null + isLimited = true + } else { + record.count++ } - - record.count++ } else { this.requests.set(key, { timestamp: now, count: 1 }) } @@ -58,26 +53,35 @@ class RateLimiter { this.requests.set(key, { timestamp: now, count: 1 }) } - break - } + this.cleanup() - this.cleanup() - - try { - return await fn() - } catch (error) { - if (error instanceof Error && ( - error.message.includes('502') || - error.message.includes('Bad Gateway') || - error.message.includes('429') || - error.message.includes('rate limit') - )) { - console.error(`Gateway error for ${key}:`, error.message) - if (record) { - record.count = this.config.maxRequests + if (isLimited) { + if (priority === 'high' && attempts < maxRetries) { + attempts += 1 + await new Promise(resolve => setTimeout(resolve, this.config.retryDelay)) + continue } + + return null + } + + try { + return await fn() + } catch (error) { + if (error instanceof Error && ( + error.message.includes('502') || + error.message.includes('Bad Gateway') || + error.message.includes('429') || + error.message.includes('rate limit') + )) { + console.error(`Gateway error for ${key}:`, error.message) + const updatedRecord = this.requests.get(key) + if (updatedRecord) { + updatedRecord.count = this.config.maxRequests + } + } + throw error } - throw error } }