Skip to content

Commit

Permalink
Add input validation for all polices and additional tests
Browse files Browse the repository at this point in the history
  • Loading branch information
neelp03 committed Oct 16, 2024
1 parent 8aac6cc commit d5b69fc
Show file tree
Hide file tree
Showing 3 changed files with 127 additions and 58 deletions.
13 changes: 12 additions & 1 deletion ratelimiter/sliding_window.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,17 @@ func (l *SlidingWindowLimiter) getMutex(key string) *keyMutex {

// Allow checks whether a request associated with the given key is allowed.
func (l *SlidingWindowLimiter) Allow(key string) (bool, error) {
// Input validation
if key == "" {
return false, errors.New("invalid key: key cannot be empty")
}
if len(key) > 256 {
return false, errors.New("invalid key: key length exceeds maximum allowed length")
}

Check warning on line 63 in ratelimiter/sliding_window.go

View check run for this annotation

Codecov / codecov/patch

ratelimiter/sliding_window.go#L62-L63

Added lines #L62 - L63 were not covered by tests
if !validKeyRegex.MatchString(key) {
return false, errors.New("invalid key: key contains invalid characters")
}

km := l.getMutex(key)
km.mu.Lock()
defer km.mu.Unlock()
Expand All @@ -75,10 +86,10 @@ func (l *SlidingWindowLimiter) Allow(key string) (bool, error) {
}

allowed := count <= int64(l.limit)

return allowed, nil
}


// startMutexCleanup runs a background goroutine to clean up unused mutexes.
func (l *SlidingWindowLimiter) startMutexCleanup() {
l.cleanupTicker = time.NewTicker(l.cleanupInterval)
Expand Down
132 changes: 78 additions & 54 deletions ratelimiter/sliding_window_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,39 @@ import (
"github.com/neelp03/throttlex/store"
)

// TestSlidingWindowLimiter tests the SlidingWindowLimiter with various scenarios.
func TestSlidingWindowLimiter(t *testing.T) {
// Initialize the MemoryStore
// TestSlidingWindowLimiterInvalidKeys checks edge cases with invalid key inputs.
func TestSlidingWindowLimiterInvalidKeys(t *testing.T) {
memStore := store.NewMemoryStore()
limiter, err := NewSlidingWindowLimiter(memStore, 5, time.Second*1)
if err != nil {
t.Errorf("Failed to create rate limiter: %v", err)
t.Fatalf("Failed to create SlidingWindowLimiter: %v", err)
}
key := "user1"

// Simulate 5 allowed requests
// Empty key
allowed, err := limiter.Allow("")
if err == nil || allowed {
t.Error("Expected error or disallowed access for empty key")
}

// Invalid key format
invalidKey := "invalid!key@format"
allowed, err = limiter.Allow(invalidKey)
if err == nil || allowed {
t.Error("Expected error or disallowed access for invalid key format")
}
}

// TestSlidingWindowLimiterHighFrequency tests frequent requests within the same window.
func TestSlidingWindowLimiterHighFrequency(t *testing.T) {
memStore := store.NewMemoryStore()
limiter, err := NewSlidingWindowLimiter(memStore, 5, time.Second*2)
if err != nil {
t.Fatalf("Failed to create SlidingWindowLimiter: %v", err)
}

key := "highFrequencyUser"

// Make 5 requests quickly
for i := 0; i < 5; i++ {
allowed, err := limiter.Allow(key)
if err != nil {
Expand All @@ -28,85 +50,87 @@ func TestSlidingWindowLimiter(t *testing.T) {
}
}

// 6th request should be blocked
// All subsequent requests should be blocked until window partially resets
allowed, err := limiter.Allow(key)
if err != nil {
t.Errorf("Unexpected error on 6th request: %v", err)
t.Errorf("Unexpected error on blocked request: %v", err)
}
if allowed {
t.Errorf("6th request should not be allowed")
t.Error("Request should be blocked as the rate limit has been reached")
}

// Wait for half the window to pass
time.Sleep(time.Millisecond * 500)
// Wait for half of the window duration to pass
time.Sleep(time.Second)

// 7th request should still be blocked
// Next request should still be blocked
allowed, err = limiter.Allow(key)
if err != nil {
t.Errorf("Unexpected error on 7th request: %v", err)
t.Errorf("Unexpected error on partially reset window: %v", err)
}
if allowed {
t.Errorf("7th request should not be allowed")
t.Error("Request should be blocked, as the partial window reset is not complete")
}

// Wait for the window to expire
time.Sleep(time.Millisecond * 600)
// Wait for the rest of the window to expire
time.Sleep(time.Second)

// Next request should be allowed after window resets
// Next request should be allowed after full window reset
allowed, err = limiter.Allow(key)
if err != nil {
t.Errorf("Unexpected error after window reset: %v", err)
}
if !allowed {
t.Errorf("Request after window reset should be allowed")
t.Error("Request after window reset should be allowed")
}
}

// TestSlidingWindowLimiterEdgeCases checks edge cases for invalid parameters.
func TestSlidingWindowLimiterEdgeCases(t *testing.T) {
// TestSlidingWindowLimiterVariableRequests simulates requests at different intervals.
func TestSlidingWindowLimiterVariableRequests(t *testing.T) {
memStore := store.NewMemoryStore()

// Test with negative limit
_, err := NewSlidingWindowLimiter(memStore, -5, time.Second*1)
if err == nil {
t.Error("Expected error with negative limit, but got none")
limiter, err := NewSlidingWindowLimiter(memStore, 3, time.Second*2)
if err != nil {
t.Fatalf("Failed to create SlidingWindowLimiter: %v", err)
}

// Test with zero window duration
_, err = NewSlidingWindowLimiter(memStore, 5, 0)
if err == nil {
t.Error("Expected error with zero window duration, but got none")
}
}
key := "variableUser"

// TestSlidingWindowLimiterMultipleClients simulates rate limiting for multiple clients.
func TestSlidingWindowLimiterMultipleClients(t *testing.T) {
memStore := store.NewMemoryStore()
limiter, err := NewSlidingWindowLimiter(memStore, 3, time.Second*1)
// First request - should be allowed
allowed, err := limiter.Allow(key)
if err != nil {
t.Errorf("Failed to create rate limiter: %v", err)
}

// Simulate requests for multiple clients
keys := []string{"client1", "client2", "client3"}
for _, key := range keys {
for i := 0; i < 3; i++ {
allowed, err := limiter.Allow(key)
if err != nil {
t.Errorf("Unexpected error for key %s on request %d: %v", key, i+1, err)
}
if !allowed {
t.Errorf("Request %d for key %s should be allowed", i+1, key)
}
}
t.Errorf("Unexpected error on 1st request: %v", err)
}
if !allowed {
t.Error("1st request should be allowed")
}

// 4th request should be blocked for each key
// Wait a short time and make two more requests within the limit
time.Sleep(time.Millisecond * 500)
for i := 0; i < 2; i++ {
allowed, err := limiter.Allow(key)
if err != nil {
t.Errorf("Unexpected error for key %s on 4th request: %v", key, err)
t.Errorf("Unexpected error on request %d: %v", i+2, err)
}
if allowed {
t.Errorf("4th request for key %s should not be allowed", key)
if !allowed {
t.Errorf("Request %d should be allowed", i+2)
}
}

// 4th request should be blocked
allowed, err = limiter.Allow(key)
if err != nil {
t.Errorf("Unexpected error on blocked request: %v", err)
}
if allowed {
t.Error("4th request should be blocked, limit reached")
}

// Wait for full window duration and reattempt
time.Sleep(time.Second * 2)
allowed, err = limiter.Allow(key)
if err != nil {
t.Errorf("Unexpected error after full window reset: %v", err)
}
if !allowed {
t.Error("Request after window reset should be allowed")
}
}
40 changes: 37 additions & 3 deletions ratelimiter/token_bucket_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,12 @@ import (

// TestTokenBucketLimiter tests the TokenBucketLimiter using the MemoryStore.
func TestTokenBucketLimiter(t *testing.T) {
// Initialize the MemoryStore
memStore := store.NewMemoryStore()
capacity := 5.0 // Maximum of 5 tokens
refillRate := 1.0 // Refill 1 token per second
limiter, err := NewTokenBucketLimiter(memStore, capacity, refillRate)
if err != nil {
t.Errorf("Failed to create rate limiter: %v", err)
t.Fatalf("Failed to create rate limiter: %v", err)
}
key := "user1"

Expand Down Expand Up @@ -95,7 +94,7 @@ func TestTokenBucketLimiterRefill(t *testing.T) {
refillRate := 1.0 // 1 token per second
limiter, err := NewTokenBucketLimiter(memStore, capacity, refillRate)
if err != nil {
t.Errorf("Failed to create rate limiter: %v", err)
t.Fatalf("Failed to create rate limiter: %v", err)
}
key := "user2"

Expand Down Expand Up @@ -140,3 +139,38 @@ func TestTokenBucketLimiterRefill(t *testing.T) {
t.Errorf("Request should be blocked after single token is consumed")
}
}

// TestTokenBucketLimiterMultipleClients tests multiple clients with separate token buckets.
func TestTokenBucketLimiterMultipleClients(t *testing.T) {
memStore := store.NewMemoryStore()
capacity := 2.0 // Maximum of 2 tokens per client
refillRate := 1.0 // Refill 1 token per second per client
limiter, err := NewTokenBucketLimiter(memStore, capacity, refillRate)
if err != nil {
t.Fatalf("Failed to create rate limiter: %v", err)
}

clients := []string{"client1", "client2", "client3"}

// Each client consumes 2 tokens initially
for _, client := range clients {
for i := 0; i < 2; i++ {
allowed, err := limiter.Allow(client)
if err != nil {
t.Errorf("Unexpected error for client %s on request %d: %v", client, i+1, err)
}
if !allowed {
t.Errorf("Request %d for client %s should be allowed", i+1, client)
}
}

// The next request for each client should be blocked
allowed, err := limiter.Allow(client)
if err != nil {
t.Errorf("Unexpected error for client %s on blocked request: %v", client, err)
}
if allowed {
t.Errorf("Request exceeding capacity should not be allowed for client %s", client)
}
}
}

0 comments on commit d5b69fc

Please sign in to comment.