Skip to content

Commit

Permalink
Merge pull request #8 from neelp03/tests/input-validation
Browse files Browse the repository at this point in the history
Add Input Validation and Additional Tests for Coverage
  • Loading branch information
neelp03 authored Oct 16, 2024
2 parents 6916bd4 + d5b69fc commit cf95d33
Show file tree
Hide file tree
Showing 5 changed files with 275 additions and 63 deletions.
29 changes: 24 additions & 5 deletions ratelimiter/fixed_window.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"errors"
"strconv"
"time"
"regexp"

"github.com/neelp03/throttlex/store"
)
Expand Down Expand Up @@ -46,6 +47,10 @@ func NewFixedWindowLimiter(store store.Store, limit int, window time.Duration) (
}, nil
}

// validKeyRegex is a compiled regular expression that matches valid keys.
// A valid key consists of alphanumeric characters, periods, underscores, and hyphens.
var validKeyRegex = regexp.MustCompile(`^[a-zA-Z0-9._-]+$`)

// Allow checks whether a request associated with the given key is allowed under the rate limit.
// It increments the count for the current window and determines if the request should be allowed.
//
Expand All @@ -55,17 +60,31 @@ func NewFixedWindowLimiter(store store.Store, limit int, window time.Duration) (
// Returns:
// - allowed: A boolean indicating whether the request is allowed (true) or should be rate-limited (false)
// - err: An error if there was a problem accessing the storage backend
func (l *FixedWindowLimiter) Allow(key string) (allowed bool, err error) {
// Generate a key for the current window
windowKey := l.getWindowKey(key)
func (l *FixedWindowLimiter) Allow(key string) (bool, error) {
// Input validation

// Increment the counter in the storage backend
// Check for empty key
if key == "" {
return false, errors.New("invalid key: key cannot be empty")
}

// Check for overly long key (256 characters max)
if len(key) > 256 {
return false, errors.New("invalid key: key length exceeds maximum allowed length")
}

// Check for valid key format (alphanumeric, ".", "_", "-")
if !validKeyRegex.MatchString(key) {
return false, errors.New("invalid key: key contains invalid characters")
}

// Proceed with rate limiting if input validation passes
windowKey := l.getWindowKey(key)
count, err := l.store.Increment(windowKey, l.window)
if err != nil {
return false, err
}

// Determine if the count exceeds the limit
if count > int64(l.limit) {
return false, nil // Rate limit exceeded
}
Expand Down
124 changes: 124 additions & 0 deletions ratelimiter/fixed_window_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,127 @@ func TestFixedWindowLimiter(t *testing.T) {
t.Error("Expected error when window duration is set to zero")
}
}

func TestFixedWindowLimiter_MultipleIPs(t *testing.T) {
memStore := store.NewMemoryStore()
limiter, err := NewFixedWindowLimiter(memStore, 10, time.Second*5)
if err != nil {
t.Fatalf("Failed to create FixedWindowLimiter: %v", err)
}

// Define a set of IPs to test
ips := []string{
"192.168.1.1", "192.168.1.2", "192.168.1.3",
"192.168.1.4", "192.168.1.5", "192.168.1.6",
"192.168.1.7", "192.168.1.8", "192.168.1.9",
"192.168.1.10",
}

// Allow 10 requests for each IP and ensure each IP is allowed up to the limit
for _, ip := range ips {
for i := 0; i < 10; i++ {
allowed, err := limiter.Allow(ip)
if err != nil {
t.Errorf("Unexpected error for IP %s: %v", ip, err)
}
if !allowed {
t.Errorf("Request %d should be allowed for IP %s", i+1, ip)
}
}
// 11th request should be blocked for each IP
allowed, err := limiter.Allow(ip)
if err != nil {
t.Errorf("Unexpected error for IP %s: %v", ip, err)
}
if allowed {
t.Errorf("11th request should not be allowed for IP %s", ip)
}
}

// Wait for the window to expire, then check again
time.Sleep(time.Second * 5)
for _, ip := range ips {
allowed, err := limiter.Allow(ip)
if err != nil {
t.Errorf("Unexpected error after window reset for IP %s: %v", ip, err)
}
if !allowed {
t.Errorf("Request after window reset should be allowed for IP %s", ip)
}
}

// Edge cases:
// 1. Check behavior with an empty IP string
allowed, err := limiter.Allow("")
if err == nil || allowed {
t.Errorf("Expected error or disallowed access for empty IP")
}

// 2. Check behavior with non-standard IP format
invalidIP := "invalidIP"
allowed, err = limiter.Allow(invalidIP)
if err != nil {
t.Errorf("Unexpected error for invalid IP %s: %v", invalidIP, err)
}
if !allowed {
t.Errorf("Expected first request to be allowed for invalid IP format %s", invalidIP)
}

// 3. Check large number of requests for a single IP beyond the threshold
singleIP := "192.168.2.1"
for i := 0; i < 20; i++ {
allowed, err := limiter.Allow(singleIP)
if err != nil {
t.Errorf("Unexpected error for IP %s: %v", singleIP, err)
}
if i < 10 && !allowed {
t.Errorf("Request %d should be allowed for IP %s", i+1, singleIP)
}
if i >= 10 && allowed {
t.Errorf("Request %d should be denied for IP %s after limit is reached", i+1, singleIP)
}
}
}

func TestFixedWindowLimiter_InvalidKey(t *testing.T) {
memStore := store.NewMemoryStore()
limiter, err := NewFixedWindowLimiter(memStore, 5, time.Second*1)
if err != nil {
t.Fatalf("Failed to create FixedWindowLimiter: %v", err)
}

// Test cases for invalid keys
testCases := []struct {
key string
name string
wantErr bool
}{
{"", "Empty Key", true},
{"invalidKey!", "Invalid Format Key", true},
{
"thisisaverylongkeythatshouldnotbeallowedbecauseitexceedsthe256characterlimitandweareusingitjustfortestpurposesbutintheenditshoulddefinitelytriggeranerrorsoheregoesthelongstringwiththemaximumlengthofcharactersallowedbytheapplicationaddedextracharacterstomakeitsureitisover256characters",
"Overly Long Key", true,
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
allowed, err := limiter.Allow(tc.key)
if tc.wantErr {
if err == nil {
t.Errorf("Expected error for %s but got none", tc.name)
}
if allowed {
t.Errorf("Expected disallowed access for %s, but got allowed", tc.name)
}
} else {
if err != nil {
t.Errorf("Unexpected error for %s: %v", tc.name, err)
}
if !allowed {
t.Errorf("Expected allowed access for %s, but got disallowed", tc.name)
}
}
})
}
}
13 changes: 12 additions & 1 deletion ratelimiter/sliding_window.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,17 @@ func (l *SlidingWindowLimiter) getMutex(key string) *keyMutex {

// Allow checks whether a request associated with the given key is allowed.
func (l *SlidingWindowLimiter) Allow(key string) (bool, error) {
// Input validation
if key == "" {
return false, errors.New("invalid key: key cannot be empty")
}
if len(key) > 256 {
return false, errors.New("invalid key: key length exceeds maximum allowed length")
}
if !validKeyRegex.MatchString(key) {
return false, errors.New("invalid key: key contains invalid characters")
}

km := l.getMutex(key)
km.mu.Lock()
defer km.mu.Unlock()
Expand All @@ -75,10 +86,10 @@ func (l *SlidingWindowLimiter) Allow(key string) (bool, error) {
}

allowed := count <= int64(l.limit)

return allowed, nil
}


// startMutexCleanup runs a background goroutine to clean up unused mutexes.
func (l *SlidingWindowLimiter) startMutexCleanup() {
l.cleanupTicker = time.NewTicker(l.cleanupInterval)
Expand Down
132 changes: 78 additions & 54 deletions ratelimiter/sliding_window_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,39 @@ import (
"github.com/neelp03/throttlex/store"
)

// TestSlidingWindowLimiter tests the SlidingWindowLimiter with various scenarios.
func TestSlidingWindowLimiter(t *testing.T) {
// Initialize the MemoryStore
// TestSlidingWindowLimiterInvalidKeys checks edge cases with invalid key inputs.
func TestSlidingWindowLimiterInvalidKeys(t *testing.T) {
memStore := store.NewMemoryStore()
limiter, err := NewSlidingWindowLimiter(memStore, 5, time.Second*1)
if err != nil {
t.Errorf("Failed to create rate limiter: %v", err)
t.Fatalf("Failed to create SlidingWindowLimiter: %v", err)
}
key := "user1"

// Simulate 5 allowed requests
// Empty key
allowed, err := limiter.Allow("")
if err == nil || allowed {
t.Error("Expected error or disallowed access for empty key")
}

// Invalid key format
invalidKey := "invalid!key@format"
allowed, err = limiter.Allow(invalidKey)
if err == nil || allowed {
t.Error("Expected error or disallowed access for invalid key format")
}
}

// TestSlidingWindowLimiterHighFrequency tests frequent requests within the same window.
func TestSlidingWindowLimiterHighFrequency(t *testing.T) {
memStore := store.NewMemoryStore()
limiter, err := NewSlidingWindowLimiter(memStore, 5, time.Second*2)
if err != nil {
t.Fatalf("Failed to create SlidingWindowLimiter: %v", err)
}

key := "highFrequencyUser"

// Make 5 requests quickly
for i := 0; i < 5; i++ {
allowed, err := limiter.Allow(key)
if err != nil {
Expand All @@ -28,85 +50,87 @@ func TestSlidingWindowLimiter(t *testing.T) {
}
}

// 6th request should be blocked
// All subsequent requests should be blocked until window partially resets
allowed, err := limiter.Allow(key)
if err != nil {
t.Errorf("Unexpected error on 6th request: %v", err)
t.Errorf("Unexpected error on blocked request: %v", err)
}
if allowed {
t.Errorf("6th request should not be allowed")
t.Error("Request should be blocked as the rate limit has been reached")
}

// Wait for half the window to pass
time.Sleep(time.Millisecond * 500)
// Wait for half of the window duration to pass
time.Sleep(time.Second)

// 7th request should still be blocked
// Next request should still be blocked
allowed, err = limiter.Allow(key)
if err != nil {
t.Errorf("Unexpected error on 7th request: %v", err)
t.Errorf("Unexpected error on partially reset window: %v", err)
}
if allowed {
t.Errorf("7th request should not be allowed")
t.Error("Request should be blocked, as the partial window reset is not complete")
}

// Wait for the window to expire
time.Sleep(time.Millisecond * 600)
// Wait for the rest of the window to expire
time.Sleep(time.Second)

// Next request should be allowed after window resets
// Next request should be allowed after full window reset
allowed, err = limiter.Allow(key)
if err != nil {
t.Errorf("Unexpected error after window reset: %v", err)
}
if !allowed {
t.Errorf("Request after window reset should be allowed")
t.Error("Request after window reset should be allowed")
}
}

// TestSlidingWindowLimiterEdgeCases checks edge cases for invalid parameters.
func TestSlidingWindowLimiterEdgeCases(t *testing.T) {
// TestSlidingWindowLimiterVariableRequests simulates requests at different intervals.
func TestSlidingWindowLimiterVariableRequests(t *testing.T) {
memStore := store.NewMemoryStore()

// Test with negative limit
_, err := NewSlidingWindowLimiter(memStore, -5, time.Second*1)
if err == nil {
t.Error("Expected error with negative limit, but got none")
limiter, err := NewSlidingWindowLimiter(memStore, 3, time.Second*2)
if err != nil {
t.Fatalf("Failed to create SlidingWindowLimiter: %v", err)
}

// Test with zero window duration
_, err = NewSlidingWindowLimiter(memStore, 5, 0)
if err == nil {
t.Error("Expected error with zero window duration, but got none")
}
}
key := "variableUser"

// TestSlidingWindowLimiterMultipleClients simulates rate limiting for multiple clients.
func TestSlidingWindowLimiterMultipleClients(t *testing.T) {
memStore := store.NewMemoryStore()
limiter, err := NewSlidingWindowLimiter(memStore, 3, time.Second*1)
// First request - should be allowed
allowed, err := limiter.Allow(key)
if err != nil {
t.Errorf("Failed to create rate limiter: %v", err)
}

// Simulate requests for multiple clients
keys := []string{"client1", "client2", "client3"}
for _, key := range keys {
for i := 0; i < 3; i++ {
allowed, err := limiter.Allow(key)
if err != nil {
t.Errorf("Unexpected error for key %s on request %d: %v", key, i+1, err)
}
if !allowed {
t.Errorf("Request %d for key %s should be allowed", i+1, key)
}
}
t.Errorf("Unexpected error on 1st request: %v", err)
}
if !allowed {
t.Error("1st request should be allowed")
}

// 4th request should be blocked for each key
// Wait a short time and make two more requests within the limit
time.Sleep(time.Millisecond * 500)
for i := 0; i < 2; i++ {
allowed, err := limiter.Allow(key)
if err != nil {
t.Errorf("Unexpected error for key %s on 4th request: %v", key, err)
t.Errorf("Unexpected error on request %d: %v", i+2, err)
}
if allowed {
t.Errorf("4th request for key %s should not be allowed", key)
if !allowed {
t.Errorf("Request %d should be allowed", i+2)
}
}

// 4th request should be blocked
allowed, err = limiter.Allow(key)
if err != nil {
t.Errorf("Unexpected error on blocked request: %v", err)
}
if allowed {
t.Error("4th request should be blocked, limit reached")
}

// Wait for full window duration and reattempt
time.Sleep(time.Second * 2)
allowed, err = limiter.Allow(key)
if err != nil {
t.Errorf("Unexpected error after full window reset: %v", err)
}
if !allowed {
t.Error("Request after window reset should be allowed")
}
}
Loading

0 comments on commit cf95d33

Please sign in to comment.