diff --git a/cmd/backend/backend_main.go b/cmd/backend/backend_main.go index 82ea948..048acda 100644 --- a/cmd/backend/backend_main.go +++ b/cmd/backend/backend_main.go @@ -183,7 +183,14 @@ func main() { bMetrics := backend.CreateBackendMetrics(metricsRegistry) rMetrics := ratelimit.CreateRatelimiterMetrics(metricsRegistry) - rateLimiter := ratelimit.NewWindowRateLimiter(cfg.Backend.RateLimiter.RateWindow, cfg.Backend.RateLimiter.BucketDuration, cfg.Backend.RateLimiter.MaxRequestsPerWindow, cfg.Backend.RateLimiter.MaxRequestsPerBucket, rMetrics) + rateLimiter := ratelimit.NewWindowRateLimiter( + cfg.Backend.RateLimiter.RateWindow, + cfg.Backend.RateLimiter.BucketDuration, + cfg.Backend.RateLimiter.MaxIPRequestsPerWindow, + cfg.Backend.RateLimiter.MaxIPRequestsPerBucket, + cfg.Backend.RateLimiter.MaxURIRequestsPerWindow, + cfg.Backend.RateLimiter.MaxURIRequestsPerBucket, + rMetrics) rateLimiter.Start() var llmResponder responder.Responder diff --git a/config/backend-config.yaml b/config/backend-config.yaml index 1815dee..6184f89 100644 --- a/config/backend-config.yaml +++ b/config/backend-config.yaml @@ -86,9 +86,10 @@ whois_manager: ratelimiter: rate_window: 1h bucket_duration: 1m - max_requests_per_window: 500 - max_requests_per_bucket: 50 - + max_ip_requests_per_window: 500 + max_ip_requests_per_bucket: 50 + max_uri_requests_per_window: 1500 + max_uri_requests_per_bucket: 200 ai: # Whether to enable the responder. enable_responder: 1 diff --git a/config/database.sql b/config/database.sql index 610693e..7cd17e5 100644 --- a/config/database.sql +++ b/config/database.sql @@ -304,7 +304,7 @@ CREATE TABLE whois ( -- These need to be kept in sync with pkg/util/constants/shared_constants.go CREATE TYPE IP_EVENT_TYPE AS ENUM ('UNKNOWN', 'TRAFFIC_CLASS', 'HOSTED_MALWARE', 'SENT_MALWARE', 'RATELIMITED', 'HOST_C2'); -CREATE TYPE IP_EVENT_SUB_TYPE AS ENUM ('UNKNOWN', 'NONE', 'MALWARE_NEW', 'MALWARE_OLD', 'RATE_WINDOW', 'RATE_BUCKET', 'TC_SCANNED', 'TC_ATTACKED', 'TC_RECONNED', 'TC_BRUTEFORCED', 'TC_CRAWLED', 'TC_MALICIOUS'); +CREATE TYPE IP_EVENT_SUB_TYPE AS ENUM ('UNKNOWN', 'NONE', 'MALWARE_NEW', 'MALWARE_OLD', 'IP_RATE_WINDOW', 'IP_RATE_BUCKET', 'URI_RATE_WINDOW', 'URI_RATE_BUCKET', 'TC_SCANNED', 'TC_ATTACKED', 'TC_RECONNED', 'TC_BRUTEFORCED', 'TC_CRAWLED', 'TC_MALICIOUS'); CREATE TYPE IP_EVENT_SOURCE AS ENUM ('OTHER', 'VT', 'RULE', 'BACKEND', 'ANALYSIS', 'WHOIS', 'AI'); CREATE TYPE IP_EVENT_REF_TYPE AS ENUM ('UNKNOWN', 'NONE', 'REQUEST_ID', 'RULE_ID', 'CONTENT_ID', 'VT_ANALYSIS_ID', 'DOWNLOAD_ID', 'REQUEST_DESCRIPTION_ID', 'REQUEST_SOURCE_IP', 'SESSION_ID', 'APP_ID'); CREATE TABLE ip_event ( diff --git a/pkg/backend/backend.go b/pkg/backend/backend.go index 4a6d222..b459a27 100644 --- a/pkg/backend/backend.go +++ b/pkg/backend/backend.go @@ -731,12 +731,19 @@ func (s *BackendServer) HandleProbe(ctx context.Context, req *backend_service.Ha } switch err { - case ratelimit.ErrBucketLimitExceeded: - evt.Subtype = constants.IpEventSubTypeRateBucket - s.metrics.rateLimiterRejects.WithLabelValues(RatelimiterRejectReasonBucket).Add(1) - case ratelimit.ErrWindowLimitExceeded: - evt.Subtype = constants.IpEventSubTypeRateWindow - s.metrics.rateLimiterRejects.WithLabelValues(RatelimiterRejectReasonWindow).Add(1) + case ratelimit.ErrIPBucketLimitExceeded: + evt.Subtype = constants.IpEventSubTypeRateIPBucket + s.metrics.rateLimiterRejects.WithLabelValues(RatelimiterRejectReasonIPBucket).Add(1) + case ratelimit.ErrIPWindowLimitExceeded: + evt.Subtype = constants.IpEventSubTypeRateIPWindow + s.metrics.rateLimiterRejects.WithLabelValues(RatelimiterRejectReasonIPWindow).Add(1) + case ratelimit.ErrURIBucketLimitExceeded: + evt.Subtype = constants.IpEventSubTypeRateURIBucket + s.metrics.rateLimiterRejects.WithLabelValues(RatelimiterRejectReasonURIBucket).Add(1) + case ratelimit.ErrURIWindowLimitExceeded: + evt.Subtype = constants.IpEventSubTypeRateURIWindow + s.metrics.rateLimiterRejects.WithLabelValues(RatelimiterRejectReasonURIWindow).Add(1) + default: slog.Error("error happened in ratelimiter", slog.String("error", err.Error())) } diff --git a/pkg/backend/config.go b/pkg/backend/config.go index 79dbca6..6de4e56 100644 --- a/pkg/backend/config.go +++ b/pkg/backend/config.go @@ -41,10 +41,12 @@ type Config struct { MaxDownloadSizeMB int `fig:"max_download_size_mb" default:"200"` } `fig:"downloader"` RateLimiter struct { - RateWindow time.Duration `fig:"rate_window" default:"1h"` - BucketDuration time.Duration `fig:"bucket_duration" default:"1m"` - MaxRequestsPerWindow int `fig:"max_requests_per_window" default:"1000"` - MaxRequestsPerBucket int `fig:"max_requests_per_bucket" default:"50"` + RateWindow time.Duration `fig:"rate_window" default:"1h"` + BucketDuration time.Duration `fig:"bucket_duration" default:"1m"` + MaxIPRequestsPerWindow int `fig:"max_ip_requests_per_window" default:"1000"` + MaxIPRequestsPerBucket int `fig:"max_ip_requests_per_bucket" default:"50"` + MaxURIRequestsPerWindow int `fig:"max_uri_requests_per_window" default:"2000"` + MaxURIRequestsPerBucket int `fig:"max_uri_requests_per_bucket" default:"100"` } `fig:"ratelimiter"` Advanced struct { @@ -110,7 +112,7 @@ type Config struct { LLMCompletionTimeout time.Duration `fig:"llm_completion_timeout" default:"1m"` LLMConcurrentRequests int `fig:"llm_concurrent_requests" default:"5"` MaxInputCharacters int `fig:"max_input_characters" default:"4096"` - Triage struct { + Triage struct { Enable bool `fig:"enable"` LogFile string `fig:"log_file" default:"triage.log" ` LogLevel string `fig:"log_level" default:"debug" ` diff --git a/pkg/backend/metrics.go b/pkg/backend/metrics.go index ba89357..8e4d97f 100644 --- a/pkg/backend/metrics.go +++ b/pkg/backend/metrics.go @@ -23,8 +23,10 @@ import ( ) var ( - RatelimiterRejectReasonWindow = "window" - RatelimiterRejectReasonBucket = "bucket" + RatelimiterRejectReasonIPWindow = "ip_window" + RatelimiterRejectReasonIPBucket = "ip_bucket" + RatelimiterRejectReasonURIWindow = "uri_window" + RatelimiterRejectReasonURIBucket = "uri_bucket" ) type BackendMetrics struct { diff --git a/pkg/backend/ratelimit/metrics.go b/pkg/backend/ratelimit/metrics.go index f506bc1..16ee46e 100644 --- a/pkg/backend/ratelimit/metrics.go +++ b/pkg/backend/ratelimit/metrics.go @@ -14,7 +14,6 @@ // You should have received a copy of the GNU General Public License along // with this program; if not, write to the Free Software Foundation, Inc., // 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA -// package ratelimit import ( @@ -22,19 +21,26 @@ import ( ) type RatelimiterMetrics struct { - rateBucketsGauge prometheus.Gauge + ipRateBucketsGauge prometheus.Gauge + uriRateBucketsGauge prometheus.Gauge } // Register Metrics func CreateRatelimiterMetrics(reg prometheus.Registerer) *RatelimiterMetrics { m := &RatelimiterMetrics{ - rateBucketsGauge: prometheus.NewGauge( + ipRateBucketsGauge: prometheus.NewGauge( + prometheus.GaugeOpts{ + Name: "lophiid_backend_ratelimit_ip_buckets_gauge", + Help: "The amount of active IP ratelimit buckets"}, + ), + uriRateBucketsGauge: prometheus.NewGauge( prometheus.GaugeOpts{ - Name: "lophiid_backend_ratelimit_buckets_gauge", - Help: "The amount of active ratelimit buckets"}, + Name: "lophiid_backend_ratelimit_uri_buckets_gauge", + Help: "The amount of active URI ratelimit buckets"}, ), } - reg.MustRegister(m.rateBucketsGauge) + reg.MustRegister(m.ipRateBucketsGauge) + reg.MustRegister(m.uriRateBucketsGauge) return m } diff --git a/pkg/backend/ratelimit/ratelimit.go b/pkg/backend/ratelimit/ratelimit.go index 86ec7ab..20c9f33 100644 --- a/pkg/backend/ratelimit/ratelimit.go +++ b/pkg/backend/ratelimit/ratelimit.go @@ -26,8 +26,10 @@ import ( ) var ( - ErrBucketLimitExceeded = errors.New("bucket limit exceeded") - ErrWindowLimitExceeded = errors.New("window limit exceeded") + ErrIPBucketLimitExceeded = errors.New("IP bucket limit exceeded") + ErrIPWindowLimitExceeded = errors.New("IP window limit exceeded") + ErrURIBucketLimitExceeded = errors.New("URI bucket limit exceeded") + ErrURIWindowLimitExceeded = errors.New("URI window limit exceeded") ) type RateLimiter interface { @@ -46,28 +48,35 @@ type RateLimiter interface { // // Requires Start() to be called before usage. type WindowRateLimiter struct { - MaxRequestsPerWindow int - MaxRequestPerBucket int - RateWindow time.Duration - BucketDuration time.Duration - NumberBuckets int - RateBuckets map[string][]int - Metrics *RatelimiterMetrics - rateMu sync.Mutex - bgChan chan bool + MaxIPRequestsPerWindow int + MaxIPRequestPerBucket int + MaxURIRequestsPerWindow int + MaxURIRequestPerBucket int + RateWindow time.Duration + BucketDuration time.Duration + NumberBuckets int + IPRateBuckets map[string][]int + URIRateBuckets map[string][]int + Metrics *RatelimiterMetrics + rateIPMu sync.Mutex + rateURIMu sync.Mutex + bgChan chan bool } -func NewWindowRateLimiter(rateWindow time.Duration, bucketDuration time.Duration, maxRequestsPerWindow int, maxRequestPerBucket int, metrics *RatelimiterMetrics) *WindowRateLimiter { +func NewWindowRateLimiter(rateWindow time.Duration, bucketDuration time.Duration, maxIpRequestsPerWindow int, maxIpRequestPerBucket int, maxUriRequestsPerWindow int, maxUriRequestPerBucket int, metrics *RatelimiterMetrics) *WindowRateLimiter { slog.Info("Creating ratelimiter", slog.String("window_size", rateWindow.String()), slog.String("bucket_size", bucketDuration.String())) return &WindowRateLimiter{ - BucketDuration: bucketDuration, - MaxRequestPerBucket: maxRequestPerBucket, - MaxRequestsPerWindow: maxRequestsPerWindow, - RateWindow: rateWindow, - RateBuckets: make(map[string][]int), - NumberBuckets: int(rateWindow / bucketDuration), - Metrics: metrics, - bgChan: make(chan bool), + BucketDuration: bucketDuration, + MaxIPRequestPerBucket: maxIpRequestPerBucket, + MaxIPRequestsPerWindow: maxIpRequestsPerWindow, + MaxURIRequestPerBucket: maxUriRequestPerBucket, + MaxURIRequestsPerWindow: maxUriRequestsPerWindow, + RateWindow: rateWindow, + IPRateBuckets: make(map[string][]int), + URIRateBuckets: make(map[string][]int), + NumberBuckets: int(rateWindow / bucketDuration), + Metrics: metrics, + bgChan: make(chan bool), } } @@ -104,18 +113,30 @@ func GetSumOfWindow(window []int) int { // bucket while removing windows where all buckets are 0 (basically no traffic // seen). func (r *WindowRateLimiter) Tick() { - r.rateMu.Lock() - defer r.rateMu.Unlock() + r.rateIPMu.Lock() + for k := range r.IPRateBuckets { + r.IPRateBuckets[k] = r.IPRateBuckets[k][1:] + r.IPRateBuckets[k] = append(r.IPRateBuckets[k], 0) - for k := range r.RateBuckets { - r.RateBuckets[k] = r.RateBuckets[k][1:] - r.RateBuckets[k] = append(r.RateBuckets[k], 0) + if GetSumOfWindow(r.IPRateBuckets[k]) == 0 { + delete(r.IPRateBuckets, k) + } + } + r.rateIPMu.Unlock() - if GetSumOfWindow(r.RateBuckets[k]) == 0 { - delete(r.RateBuckets, k) + r.rateURIMu.Lock() + for k := range r.URIRateBuckets { + r.URIRateBuckets[k] = r.URIRateBuckets[k][1:] + r.URIRateBuckets[k] = append(r.URIRateBuckets[k], 0) + + if GetSumOfWindow(r.URIRateBuckets[k]) == 0 { + delete(r.URIRateBuckets, k) } } - r.Metrics.rateBucketsGauge.Set(float64(len(r.RateBuckets))) + r.rateURIMu.Unlock() + + r.Metrics.ipRateBucketsGauge.Set(float64(len(r.IPRateBuckets))) + r.Metrics.uriRateBucketsGauge.Set(float64(len(r.URIRateBuckets))) } // AllowRequest will return true if a request is allowed because the total @@ -123,33 +144,74 @@ func (r *WindowRateLimiter) Tick() { // then an error is returned with the reason why. // Requires that Start() has been called before usage. func (r *WindowRateLimiter) AllowRequest(req *models.Request) (bool, error) { - rKey := fmt.Sprintf("%s-%d-%s", req.HoneypotIP, req.Port, req.SourceIP) + ret, err := r.allowRequestForIP(req) + if !ret { + return ret, err + } + + return r.allowRequestForURI(req) +} - r.rateMu.Lock() - defer r.rateMu.Unlock() +func (r *WindowRateLimiter) allowRequestForIP(req *models.Request) (bool, error) { - _, ok := r.RateBuckets[rKey] + ipRateKey := fmt.Sprintf("%s-%d-%s", req.HoneypotIP, req.Port, req.SourceIP) + r.rateIPMu.Lock() + defer r.rateIPMu.Unlock() + + _, ok := r.IPRateBuckets[ipRateKey] // If the key is not present then this IP has no recent requests logged so we // create the buckets. if !ok { - r.RateBuckets[rKey] = make([]int, r.NumberBuckets) - r.RateBuckets[rKey][r.NumberBuckets-1] = 1 + r.IPRateBuckets[ipRateKey] = make([]int, r.NumberBuckets) + r.IPRateBuckets[ipRateKey][r.NumberBuckets-1] = 1 + return true, nil + } + + // Check how many requests there have been in this window. + if GetSumOfWindow(r.IPRateBuckets[ipRateKey]) >= r.MaxIPRequestsPerWindow { + r.IPRateBuckets[ipRateKey][r.NumberBuckets-1] += 1 + return false, ErrIPWindowLimitExceeded + } + + // Check if the bucket limit is not already exceeded. + if r.IPRateBuckets[ipRateKey][r.NumberBuckets-1] >= r.MaxIPRequestPerBucket { + r.IPRateBuckets[ipRateKey][r.NumberBuckets-1] += 1 + return false, ErrIPBucketLimitExceeded + } + + r.IPRateBuckets[ipRateKey][r.NumberBuckets-1] += 1 + + return true, nil +} + +func (r *WindowRateLimiter) allowRequestForURI(req *models.Request) (bool, error) { + uriRateKey := req.BaseHash + + r.rateURIMu.Lock() + defer r.rateURIMu.Unlock() + + _, ok := r.URIRateBuckets[uriRateKey] + // If the key is not present then this URI has no recent requests logged so we + // create the buckets. + if !ok { + r.URIRateBuckets[uriRateKey] = make([]int, r.NumberBuckets) + r.URIRateBuckets[uriRateKey][r.NumberBuckets-1] = 1 return true, nil } // Check how many requests there have been in this window. - if GetSumOfWindow(r.RateBuckets[rKey]) >= r.MaxRequestsPerWindow { - r.RateBuckets[rKey][r.NumberBuckets-1] += 1 - return false, ErrWindowLimitExceeded + if GetSumOfWindow(r.URIRateBuckets[uriRateKey]) >= r.MaxURIRequestsPerWindow { + r.URIRateBuckets[uriRateKey][r.NumberBuckets-1] += 1 + return false, ErrURIWindowLimitExceeded } // Check if the bucket limit is not already exceeded. - if r.RateBuckets[rKey][r.NumberBuckets-1] >= r.MaxRequestPerBucket { - r.RateBuckets[rKey][r.NumberBuckets-1] += 1 - return false, ErrBucketLimitExceeded + if r.URIRateBuckets[uriRateKey][r.NumberBuckets-1] >= r.MaxURIRequestPerBucket { + r.URIRateBuckets[uriRateKey][r.NumberBuckets-1] += 1 + return false, ErrURIBucketLimitExceeded } - r.RateBuckets[rKey][r.NumberBuckets-1] += 1 + r.URIRateBuckets[uriRateKey][r.NumberBuckets-1] += 1 return true, nil } diff --git a/pkg/backend/ratelimit/ratelimit_test.go b/pkg/backend/ratelimit/ratelimit_test.go index 22924d6..d2d2ad8 100644 --- a/pkg/backend/ratelimit/ratelimit_test.go +++ b/pkg/backend/ratelimit/ratelimit_test.go @@ -17,6 +17,7 @@ package ratelimit import ( + "fmt" "lophiid/pkg/database/models" "testing" "time" @@ -28,8 +29,11 @@ import ( func TestRateLimitOk(t *testing.T) { testRateWindow := time.Second * 5 testBucketDuration := time.Second - testMaxRequestsPerWindow := 4 - testMaxRequestPerBucket := 2 + testMaxIpRequestsPerWindow := 4 + testMaxIpRequestPerBucket := 2 + + testMaxUriRequestsPerWindow := 6 + testMaxUriRequestPerBucket := 6 req := models.Request{ HoneypotIP: "1.1.1.1", @@ -39,9 +43,9 @@ func TestRateLimitOk(t *testing.T) { } reg := prometheus.NewRegistry() rMetrics := CreateRatelimiterMetrics(reg) - r := NewWindowRateLimiter(testRateWindow, testBucketDuration, testMaxRequestsPerWindow, testMaxRequestPerBucket, rMetrics) + r := NewWindowRateLimiter(testRateWindow, testBucketDuration, testMaxIpRequestsPerWindow, testMaxIpRequestPerBucket, testMaxUriRequestsPerWindow, testMaxUriRequestPerBucket, rMetrics) - if testutil.ToFloat64(rMetrics.rateBucketsGauge) != 0 { + if testutil.ToFloat64(rMetrics.ipRateBucketsGauge) != 0 { t.Errorf("rateBucketsGauge should be 0 at the start") } @@ -61,7 +65,7 @@ func TestRateLimitOk(t *testing.T) { t.Errorf("request is allowed but it should be rejected") } - if err != ErrBucketLimitExceeded { + if err != ErrIPBucketLimitExceeded { t.Errorf("expected bucket exceeded, got unexpected error %v", err) } @@ -78,11 +82,11 @@ func TestRateLimitOk(t *testing.T) { t.Errorf("request exceeds window limit and should be rejected") } - if err != ErrWindowLimitExceeded { + if err != ErrIPWindowLimitExceeded { t.Errorf("expected ErrWindowLimitExceeded but got %v", err) } - m := testutil.ToFloat64(rMetrics.rateBucketsGauge) + m := testutil.ToFloat64(rMetrics.ipRateBucketsGauge) if m != 1 { t.Errorf("rateBucketsGauge should be 1, is %f", m) } @@ -95,7 +99,7 @@ func TestRateLimitOk(t *testing.T) { r.Tick() // Check if the RateBucket entry is indeed removed. - m = testutil.ToFloat64(rMetrics.rateBucketsGauge) + m = testutil.ToFloat64(rMetrics.ipRateBucketsGauge) if m != 0 { t.Errorf("rateBucketsGauge should be 0 after reset") } @@ -104,3 +108,230 @@ func TestRateLimitOk(t *testing.T) { t.Errorf("unexpected error %v", err) } } + +func TestAllowRequestForIP(t *testing.T) { + tests := []struct { + name string + req *models.Request + requestCount int + expectedAllow bool + expectedError error + setupRequests int // number of requests to make before the actual test + newBucket bool // whether this should create a new bucket + }{ + { + name: "first request for IP creates bucket", + req: &models.Request{ + HoneypotIP: "10.0.0.1", + SourceIP: "192.168.1.1", + Port: 8080, + }, + requestCount: 1, + expectedAllow: true, + expectedError: nil, + newBucket: true, + }, + { + name: "request within limits", + req: &models.Request{ + HoneypotIP: "10.0.0.1", + SourceIP: "192.168.1.2", + Port: 8080, + }, + requestCount: 2, + expectedAllow: true, + expectedError: nil, + }, + { + name: "bucket limit exceeded", + req: &models.Request{ + HoneypotIP: "10.0.0.1", + SourceIP: "192.168.1.3", + Port: 8080, + }, + requestCount: 3, + setupRequests: 2, // make 2 requests first to reach the limit + expectedAllow: false, + expectedError: ErrIPBucketLimitExceeded, + }, + { + name: "window limit exceeded", + req: &models.Request{ + HoneypotIP: "10.0.0.1", + SourceIP: "192.168.1.4", + Port: 8080, + }, + requestCount: 5, + setupRequests: 4, // make 4 requests first to reach the window limit + expectedAllow: false, + expectedError: ErrIPWindowLimitExceeded, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create a new rate limiter for each test + r := NewWindowRateLimiter( + time.Second*5, // window + time.Second, // bucket duration + 4, // max requests per window + 2, // max requests per bucket + 10, // max URI requests per window (not used in this test) + 5, // max URI requests per bucket (not used in this test) + CreateRatelimiterMetrics(prometheus.NewRegistry()), + ) + + // Perform setup requests if needed + for i := 0; i < tt.setupRequests; i++ { + r.allowRequestForIP(tt.req) + } + + // Record initial bucket state if we're testing new bucket creation + if tt.newBucket { + ipRateKey := fmt.Sprintf("%s-%d-%s", tt.req.HoneypotIP, tt.req.Port, tt.req.SourceIP) + if _, exists := r.IPRateBuckets[ipRateKey]; exists { + t.Errorf("bucket should not exist before first request") + } + } + + // Perform the test request + allowed, err := r.allowRequestForIP(tt.req) + + // Verify results + if allowed != tt.expectedAllow { + t.Errorf("allowRequestForIP() allowed = %v, want %v", allowed, tt.expectedAllow) + } + + if err != tt.expectedError { + t.Errorf("allowRequestForIP() error = %v, want %v", err, tt.expectedError) + } + + // Verify bucket creation if applicable + if tt.newBucket { + ipRateKey := fmt.Sprintf("%s-%d-%s", tt.req.HoneypotIP, tt.req.Port, tt.req.SourceIP) + if _, exists := r.IPRateBuckets[ipRateKey]; !exists { + t.Errorf("bucket should exist after first request") + } + } + }) + } +} + +func TestAllowRequestForURI(t *testing.T) { + tests := []struct { + name string + req *models.Request + requestCount int + expectedAllow bool + expectedError error + setupRequests int // number of requests to make before the actual test + newBucket bool // whether this should create a new bucket + }{ + { + name: "first request for URI creates bucket", + req: &models.Request{ + BaseHash: "hash1", + }, + requestCount: 1, + expectedAllow: true, + expectedError: nil, + newBucket: true, + }, + { + name: "request within limits", + req: &models.Request{ + BaseHash: "hash2", + }, + requestCount: 2, + expectedAllow: true, + expectedError: nil, + }, + { + name: "bucket limit exceeded", + req: &models.Request{ + BaseHash: "hash3", + }, + requestCount: 3, + setupRequests: 5, // make 5 requests first to reach the bucket limit + expectedAllow: false, + expectedError: ErrURIBucketLimitExceeded, + }, + { + name: "window limit exceeded", + req: &models.Request{ + BaseHash: "hash4", + }, + requestCount: 7, + setupRequests: 6, // make 6 requests first to reach the window limit + expectedAllow: false, + expectedError: ErrURIWindowLimitExceeded, + }, + { + name: "different URIs don't interfere", + req: &models.Request{ + BaseHash: "hash5", + }, + requestCount: 1, + expectedAllow: true, + expectedError: nil, + newBucket: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create a new rate limiter for each test + r := NewWindowRateLimiter( + time.Second*5, // window + time.Second, // bucket duration + 10, // max IP requests per window (not used in this test) + 5, // max IP requests per bucket (not used in this test) + 6, // max URI requests per window + 5, // max URI requests per bucket + CreateRatelimiterMetrics(prometheus.NewRegistry()), + ) + + // Perform setup requests if needed + for i := 0; i < tt.setupRequests; i++ { + if tt.name == "window limit exceeded" { + // For window limit test, directly set up the buckets + if i == 0 { + r.URIRateBuckets[tt.req.BaseHash] = make([]int, r.NumberBuckets) + } + // Spread requests across buckets to reach window limit + r.URIRateBuckets[tt.req.BaseHash][i%r.NumberBuckets] = 2 + } else { + r.allowRequestForURI(tt.req) + } + } + + // Record initial bucket state if we're testing new bucket creation + if tt.newBucket { + uriRateKey := tt.req.BaseHash + if _, exists := r.URIRateBuckets[uriRateKey]; exists { + t.Errorf("bucket should not exist before first request") + } + } + + // Perform the test request + allowed, err := r.allowRequestForURI(tt.req) + + // Verify results + if allowed != tt.expectedAllow { + t.Errorf("allowRequestForURI() allowed = %v, want %v", allowed, tt.expectedAllow) + } + + if err != tt.expectedError { + t.Errorf("allowRequestForURI() error = %v, want %v", err, tt.expectedError) + } + + // Verify bucket creation if applicable + if tt.newBucket { + uriRateKey := tt.req.BaseHash + if _, exists := r.URIRateBuckets[uriRateKey]; !exists { + t.Errorf("bucket should exist after first request") + } + } + }) + } +} diff --git a/pkg/util/constants/shared_constants.go b/pkg/util/constants/shared_constants.go index 47582b9..de9e014 100644 --- a/pkg/util/constants/shared_constants.go +++ b/pkg/util/constants/shared_constants.go @@ -33,8 +33,10 @@ const ( IpEventSubTypeMalwareNew = "MALWARE_NEW" IpEventSubTypeMalwareOld = "MALWARE_OLD" - IpEventSubTypeRateWindow = "RATE_WINDOW" - IpEventSubTypeRateBucket = "RATE_BUCKET" + IpEventSubTypeRateIPWindow = "IP_RATE_WINDOW" + IpEventSubTypeRateIPBucket = "IP_RATE_BUCKET" + IpEventSubTypeRateURIWindow = "URI_RATE_WINDOW" + IpEventSubTypeRateURIBucket = "URI_RATE_BUCKET" IpEventSubTypeTrafficClassScanned = "TC_SCANNED" IpEventSubTypeTrafficClassAttacked = "TC_ATTACKED"