From 7c8592a7e05370b21cc02cc2edaee6b2355fd83a Mon Sep 17 00:00:00 2001 From: Benjamin Chibuzor-Orie Date: Fri, 15 Jan 2021 21:53:15 +0100 Subject: [PATCH] adds middleware for rate limiting (#1724) * adds middleware for rate limiting * added comment for InMemoryStore ShouldAllow * removed redundant mutex declaration * fixed lint issues * removed sleep from tests * improved coverage * refactor: renames Identifiers, includes default SourceFunc * Added last seen stats for visitor * uses http Constants for improved readdability adds default error handler * used other handler apart from default handler to mark custom error handler for rate limiting * split tests into separate blocks added an error pair to IdentifierExtractor Includes deny handler for explicitly denying requests * adds comments for exported members Extractor and ErrorHandler * makes cleanup implementation inhouse * Avoid race for cleanup due to non-atomic access to store.expiresIn * Use a dedicated producer for rate testing * tidy commit * refactors tests, implicitly tests lastSeen property on visitor switches NewRateLimiterMemoryStore constructor to Referential Functions style (Advised by @pafuent) * switches to mock of time module for time based tests tests are now fully deterministic * improved coverage * replaces Rob Pike referential options with more conventional struct configs makes cleanup asynchronous * blocks racy access to lastCleanup * Add benchmark tests for rate limiter * Add rate limiter with sharded memory store * Racy access to store.lastCleanup eliminated Merges in shiny sharded map implementation by @lammel * Remove RateLimiterShradedMemoryStore for now * Make fields for RateLimiterStoreConfig public for external configuration * Improve docs for RateLimiter usage * Fix ErrorHandler vs. DenyHandler usage for rate limiter * Simplify NewRateLimiterMemoryStore * improved coverage * updated errorHandler and denyHandler to use echo.HTTPError * Improve wording for error and comments * Remove duplicate lastSeen marking for Allow * Improve wording for comments * Add disclaimer on perf characteristics of memory store * changes Allow signature on rate limiter to return err too Co-authored-by: Roland Lammel --- .gitignore | 1 + go.mod | 1 + go.sum | 2 + middleware/rate_limiter.go | 268 ++++++++++++++++++ middleware/rate_limiter_test.go | 462 ++++++++++++++++++++++++++++++++ 5 files changed, 734 insertions(+) create mode 100644 middleware/rate_limiter.go create mode 100644 middleware/rate_limiter_test.go diff --git a/.gitignore b/.gitignore index dd74acca4..dbadf3bd0 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,4 @@ vendor .idea *.iml *.out +.vscode diff --git a/go.mod b/go.mod index 74c6a9abe..877117075 100644 --- a/go.mod +++ b/go.mod @@ -12,4 +12,5 @@ require ( golang.org/x/net v0.0.0-20200822124328-c89045814202 golang.org/x/sys v0.0.0-20200826173525-f9321e4c35a6 // indirect golang.org/x/text v0.3.3 // indirect + golang.org/x/time v0.0.0-20201208040808-7e3f01d25324 ) diff --git a/go.sum b/go.sum index 58c80c831..54ba24e67 100644 --- a/go.sum +++ b/go.sum @@ -46,6 +46,8 @@ golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3 h1:cokOdA+Jmi5PJGXLlLllQSgYigAEfHXJAERHVMaCc2k= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/time v0.0.0-20201208040808-7e3f01d25324 h1:Hir2P/De0WpUhtrKGGjvSb2YxUgyZ7EFOSLIcSSpiwE= +golang.org/x/time v0.0.0-20201208040808-7e3f01d25324/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/middleware/rate_limiter.go b/middleware/rate_limiter.go new file mode 100644 index 000000000..7d1abfcb9 --- /dev/null +++ b/middleware/rate_limiter.go @@ -0,0 +1,268 @@ +package middleware + +import ( + "net/http" + "sync" + "time" + + "github.com/labstack/echo/v4" + "golang.org/x/time/rate" +) + +type ( + // RateLimiterStore is the interface to be implemented by custom stores. + RateLimiterStore interface { + // Stores for the rate limiter have to implement the Allow method + Allow(identifier string) (bool, error) + } +) + +type ( + // RateLimiterConfig defines the configuration for the rate limiter + RateLimiterConfig struct { + Skipper Skipper + BeforeFunc BeforeFunc + // IdentifierExtractor uses echo.Context to extract the identifier for a visitor + IdentifierExtractor Extractor + // Store defines a store for the rate limiter + Store RateLimiterStore + // ErrorHandler provides a handler to be called when IdentifierExtractor returns an error + ErrorHandler func(context echo.Context, err error) error + // DenyHandler provides a handler to be called when RateLimiter denies access + DenyHandler func(context echo.Context, identifier string, err error) error + } + // Extractor is used to extract data from echo.Context + Extractor func(context echo.Context) (string, error) +) + +// errors +var ( + // ErrRateLimitExceeded denotes an error raised when rate limit is exceeded + ErrRateLimitExceeded = echo.NewHTTPError(http.StatusTooManyRequests, "rate limit exceeded") + // ErrExtractorError denotes an error raised when extractor function is unsuccessful + ErrExtractorError = echo.NewHTTPError(http.StatusForbidden, "error while extracting identifier") +) + +// DefaultRateLimiterConfig defines default values for RateLimiterConfig +var DefaultRateLimiterConfig = RateLimiterConfig{ + Skipper: DefaultSkipper, + IdentifierExtractor: func(ctx echo.Context) (string, error) { + id := ctx.RealIP() + return id, nil + }, + ErrorHandler: func(context echo.Context, err error) error { + return &echo.HTTPError{ + Code: ErrExtractorError.Code, + Message: ErrExtractorError.Message, + Internal: err, + } + }, + DenyHandler: func(context echo.Context, identifier string, err error) error { + return &echo.HTTPError{ + Code: ErrRateLimitExceeded.Code, + Message: ErrRateLimitExceeded.Message, + Internal: err, + } + }, +} + +/* +RateLimiter returns a rate limiting middleware + + e := echo.New() + + limiterStore := middleware.NewRateLimiterMemoryStore(20) + + e.GET("/rate-limited", func(c echo.Context) error { + return c.String(http.StatusOK, "test") + }, RateLimiter(limiterStore)) +*/ +func RateLimiter(store RateLimiterStore) echo.MiddlewareFunc { + config := DefaultRateLimiterConfig + config.Store = store + + return RateLimiterWithConfig(config) +} + +/* +RateLimiterWithConfig returns a rate limiting middleware + + e := echo.New() + + config := middleware.RateLimiterConfig{ + Skipper: DefaultSkipper, + Store: middleware.NewRateLimiterMemoryStore( + middleware.RateLimiterMemoryStoreConfig{Rate: 10, Burst: 30, ExpiresIn: 3 * time.Minute} + ) + IdentifierExtractor: func(ctx echo.Context) (string, error) { + id := ctx.RealIP() + return id, nil + }, + ErrorHandler: func(context echo.Context, err error) error { + return context.JSON(http.StatusTooManyRequests, nil) + }, + DenyHandler: func(context echo.Context, identifier string) error { + return context.JSON(http.StatusForbidden, nil) + }, + } + + e.GET("/rate-limited", func(c echo.Context) error { + return c.String(http.StatusOK, "test") + }, middleware.RateLimiterWithConfig(config)) +*/ +func RateLimiterWithConfig(config RateLimiterConfig) echo.MiddlewareFunc { + if config.Skipper == nil { + config.Skipper = DefaultRateLimiterConfig.Skipper + } + if config.IdentifierExtractor == nil { + config.IdentifierExtractor = DefaultRateLimiterConfig.IdentifierExtractor + } + if config.ErrorHandler == nil { + config.ErrorHandler = DefaultRateLimiterConfig.ErrorHandler + } + if config.DenyHandler == nil { + config.DenyHandler = DefaultRateLimiterConfig.DenyHandler + } + if config.Store == nil { + panic("Store configuration must be provided") + } + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + if config.Skipper(c) { + return next(c) + } + if config.BeforeFunc != nil { + config.BeforeFunc(c) + } + + identifier, err := config.IdentifierExtractor(c) + if err != nil { + c.Error(config.ErrorHandler(c, err)) + return nil + } + + if allow, err := config.Store.Allow(identifier); !allow { + c.Error(config.DenyHandler(c, identifier, err)) + return nil + } + return next(c) + } + } +} + +type ( + // RateLimiterMemoryStore is the built-in store implementation for RateLimiter + RateLimiterMemoryStore struct { + visitors map[string]*Visitor + mutex sync.Mutex + rate rate.Limit + burst int + expiresIn time.Duration + lastCleanup time.Time + } + // Visitor signifies a unique user's limiter details + Visitor struct { + *rate.Limiter + lastSeen time.Time + } +) + +/* +NewRateLimiterMemoryStore returns an instance of RateLimiterMemoryStore with +the provided rate (as req/s). Burst and ExpiresIn will be set to default values. + +Example (with 20 requests/sec): + + limiterStore := middleware.NewRateLimiterMemoryStore(20) + +*/ +func NewRateLimiterMemoryStore(rate rate.Limit) (store *RateLimiterMemoryStore) { + return NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{ + Rate: rate, + }) +} + +/* +NewRateLimiterMemoryStoreWithConfig returns an instance of RateLimiterMemoryStore +with the provided configuration. Rate must be provided. Burst will be set to the value of +the configured rate if not provided or set to 0. + +The build-in memory store is usually capable for modest loads. For higher loads other +store implementations should be considered. + +Characteristics: +* Concurrency above 100 parallel requests may causes measurable lock contention +* A high number of different IP addresses (above 16000) may be impacted by the internally used Go map +* A high number of requests from a single IP address may cause lock contention + +Example: + + limiterStore := middleware.NewRateLimiterMemoryStoreWithConfig( + middleware.RateLimiterMemoryStoreConfig{Rate: 50, Burst: 200, ExpiresIn: 5 * time.Minutes}, + ) +*/ +func NewRateLimiterMemoryStoreWithConfig(config RateLimiterMemoryStoreConfig) (store *RateLimiterMemoryStore) { + store = &RateLimiterMemoryStore{} + + store.rate = config.Rate + store.burst = config.Burst + store.expiresIn = config.ExpiresIn + if config.ExpiresIn == 0 { + store.expiresIn = DefaultRateLimiterMemoryStoreConfig.ExpiresIn + } + if config.Burst == 0 { + store.burst = int(config.Rate) + } + store.visitors = make(map[string]*Visitor) + store.lastCleanup = now() + return +} + +// RateLimiterMemoryStoreConfig represents configuration for RateLimiterMemoryStore +type RateLimiterMemoryStoreConfig struct { + Rate rate.Limit // Rate of requests allowed to pass as req/s + Burst int // Burst additionally allows a number of requests to pass when rate limit is reached + ExpiresIn time.Duration // ExpiresIn is the duration after that a rate limiter is cleaned up +} + +// DefaultRateLimiterMemoryStoreConfig provides default configuration values for RateLimiterMemoryStore +var DefaultRateLimiterMemoryStoreConfig = RateLimiterMemoryStoreConfig{ + ExpiresIn: 3 * time.Minute, +} + +// Allow implements RateLimiterStore.Allow +func (store *RateLimiterMemoryStore) Allow(identifier string) (bool, error) { + store.mutex.Lock() + limiter, exists := store.visitors[identifier] + if !exists { + limiter = new(Visitor) + limiter.Limiter = rate.NewLimiter(store.rate, store.burst) + store.visitors[identifier] = limiter + } + limiter.lastSeen = now() + if now().Sub(store.lastCleanup) > store.expiresIn { + store.cleanupStaleVisitors() + } + store.mutex.Unlock() + return limiter.AllowN(now(), 1), nil +} + +/* +cleanupStaleVisitors helps manage the size of the visitors map by removing stale records +of users who haven't visited again after the configured expiry time has elapsed +*/ +func (store *RateLimiterMemoryStore) cleanupStaleVisitors() { + for id, visitor := range store.visitors { + if now().Sub(visitor.lastSeen) > store.expiresIn { + delete(store.visitors, id) + } + } + store.lastCleanup = now() +} + +/* +actual time method which is mocked in test file +*/ +var now = func() time.Time { + return time.Now() +} diff --git a/middleware/rate_limiter_test.go b/middleware/rate_limiter_test.go new file mode 100644 index 000000000..2e57bf175 --- /dev/null +++ b/middleware/rate_limiter_test.go @@ -0,0 +1,462 @@ +package middleware + +import ( + "errors" + "fmt" + "math/rand" + "net/http" + "net/http/httptest" + "sync" + "testing" + "time" + + "github.com/labstack/echo/v4" + "github.com/labstack/gommon/random" + "github.com/stretchr/testify/assert" + "golang.org/x/time/rate" +) + +func TestRateLimiter(t *testing.T) { + e := echo.New() + + handler := func(c echo.Context) error { + return c.String(http.StatusOK, "test") + } + + var inMemoryStore = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3}) + + mw := RateLimiter(inMemoryStore) + + testCases := []struct { + id string + code int + }{ + {"127.0.0.1", http.StatusOK}, + {"127.0.0.1", http.StatusOK}, + {"127.0.0.1", http.StatusOK}, + {"127.0.0.1", http.StatusTooManyRequests}, + {"127.0.0.1", http.StatusTooManyRequests}, + {"127.0.0.1", http.StatusTooManyRequests}, + {"127.0.0.1", http.StatusTooManyRequests}, + } + + for _, tc := range testCases { + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Add(echo.HeaderXRealIP, tc.id) + + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + _ = mw(handler)(c) + assert.Equal(t, tc.code, rec.Code) + } +} + +func TestRateLimiter_panicBehaviour(t *testing.T) { + var inMemoryStore = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3}) + + assert.Panics(t, func() { + RateLimiter(nil) + }) + + assert.NotPanics(t, func() { + RateLimiter(inMemoryStore) + }) +} + +func TestRateLimiterWithConfig(t *testing.T) { + var inMemoryStore = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3}) + + e := echo.New() + + handler := func(c echo.Context) error { + return c.String(http.StatusOK, "test") + } + + mw := RateLimiterWithConfig(RateLimiterConfig{ + IdentifierExtractor: func(c echo.Context) (string, error) { + id := c.Request().Header.Get(echo.HeaderXRealIP) + if id == "" { + return "", errors.New("invalid identifier") + } + return id, nil + }, + DenyHandler: func(ctx echo.Context, identifier string, err error) error { + return ctx.JSON(http.StatusForbidden, nil) + }, + ErrorHandler: func(ctx echo.Context, err error) error { + return ctx.JSON(http.StatusBadRequest, nil) + }, + Store: inMemoryStore, + }) + + testCases := []struct { + id string + code int + }{ + {"127.0.0.1", http.StatusOK}, + {"127.0.0.1", http.StatusOK}, + {"127.0.0.1", http.StatusOK}, + {"127.0.0.1", http.StatusForbidden}, + {"", http.StatusBadRequest}, + {"127.0.0.1", http.StatusForbidden}, + {"127.0.0.1", http.StatusForbidden}, + } + + for _, tc := range testCases { + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Add(echo.HeaderXRealIP, tc.id) + + rec := httptest.NewRecorder() + + c := e.NewContext(req, rec) + + _ = mw(handler)(c) + + assert.Equal(t, tc.code, rec.Code) + } +} + +func TestRateLimiterWithConfig_defaultDenyHandler(t *testing.T) { + var inMemoryStore = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3}) + + e := echo.New() + + handler := func(c echo.Context) error { + return c.String(http.StatusOK, "test") + } + + mw := RateLimiterWithConfig(RateLimiterConfig{ + IdentifierExtractor: func(c echo.Context) (string, error) { + id := c.Request().Header.Get(echo.HeaderXRealIP) + if id == "" { + return "", errors.New("invalid identifier") + } + return id, nil + }, + Store: inMemoryStore, + }) + + testCases := []struct { + id string + code int + }{ + {"127.0.0.1", http.StatusOK}, + {"127.0.0.1", http.StatusOK}, + {"127.0.0.1", http.StatusOK}, + {"127.0.0.1", http.StatusTooManyRequests}, + {"", http.StatusForbidden}, + {"127.0.0.1", http.StatusTooManyRequests}, + {"127.0.0.1", http.StatusTooManyRequests}, + } + + for _, tc := range testCases { + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Add(echo.HeaderXRealIP, tc.id) + + rec := httptest.NewRecorder() + + c := e.NewContext(req, rec) + + _ = mw(handler)(c) + + assert.Equal(t, tc.code, rec.Code) + } +} + +func TestRateLimiterWithConfig_defaultConfig(t *testing.T) { + { + var inMemoryStore = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3}) + + e := echo.New() + + handler := func(c echo.Context) error { + return c.String(http.StatusOK, "test") + } + + mw := RateLimiterWithConfig(RateLimiterConfig{ + Store: inMemoryStore, + }) + + testCases := []struct { + id string + code int + }{ + {"127.0.0.1", http.StatusOK}, + {"127.0.0.1", http.StatusOK}, + {"127.0.0.1", http.StatusOK}, + {"127.0.0.1", http.StatusTooManyRequests}, + {"127.0.0.1", http.StatusTooManyRequests}, + {"127.0.0.1", http.StatusTooManyRequests}, + {"127.0.0.1", http.StatusTooManyRequests}, + } + + for _, tc := range testCases { + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Add(echo.HeaderXRealIP, tc.id) + + rec := httptest.NewRecorder() + + c := e.NewContext(req, rec) + + _ = mw(handler)(c) + + assert.Equal(t, tc.code, rec.Code) + } + } +} + +func TestRateLimiterWithConfig_skipper(t *testing.T) { + e := echo.New() + + var beforeFuncRan bool + handler := func(c echo.Context) error { + return c.String(http.StatusOK, "test") + } + var inMemoryStore = NewRateLimiterMemoryStore(5) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Add(echo.HeaderXRealIP, "127.0.0.1") + + rec := httptest.NewRecorder() + + c := e.NewContext(req, rec) + + mw := RateLimiterWithConfig(RateLimiterConfig{ + Skipper: func(c echo.Context) bool { + return true + }, + BeforeFunc: func(c echo.Context) { + beforeFuncRan = true + }, + Store: inMemoryStore, + IdentifierExtractor: func(ctx echo.Context) (string, error) { + return "127.0.0.1", nil + }, + }) + + _ = mw(handler)(c) + + assert.Equal(t, false, beforeFuncRan) +} + +func TestRateLimiterWithConfig_skipperNoSkip(t *testing.T) { + e := echo.New() + + var beforeFuncRan bool + handler := func(c echo.Context) error { + return c.String(http.StatusOK, "test") + } + var inMemoryStore = NewRateLimiterMemoryStore(5) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Add(echo.HeaderXRealIP, "127.0.0.1") + + rec := httptest.NewRecorder() + + c := e.NewContext(req, rec) + + mw := RateLimiterWithConfig(RateLimiterConfig{ + Skipper: func(c echo.Context) bool { + return false + }, + BeforeFunc: func(c echo.Context) { + beforeFuncRan = true + }, + Store: inMemoryStore, + IdentifierExtractor: func(ctx echo.Context) (string, error) { + return "127.0.0.1", nil + }, + }) + + _ = mw(handler)(c) + + assert.Equal(t, true, beforeFuncRan) +} + +func TestRateLimiterWithConfig_beforeFunc(t *testing.T) { + e := echo.New() + + handler := func(c echo.Context) error { + return c.String(http.StatusOK, "test") + } + + var beforeRan bool + var inMemoryStore = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3}) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Add(echo.HeaderXRealIP, "127.0.0.1") + + rec := httptest.NewRecorder() + + c := e.NewContext(req, rec) + + mw := RateLimiterWithConfig(RateLimiterConfig{ + BeforeFunc: func(c echo.Context) { + beforeRan = true + }, + Store: inMemoryStore, + IdentifierExtractor: func(ctx echo.Context) (string, error) { + return "127.0.0.1", nil + }, + }) + + _ = mw(handler)(c) + + assert.Equal(t, true, beforeRan) +} + +func TestRateLimiterMemoryStore_Allow(t *testing.T) { + var inMemoryStore = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3, ExpiresIn: 2 * time.Second}) + testCases := []struct { + id string + allowed bool + }{ + {"127.0.0.1", true}, // 0 ms + {"127.0.0.1", true}, // 220 ms burst #2 + {"127.0.0.1", true}, // 440 ms burst #3 + {"127.0.0.1", false}, // 660 ms block + {"127.0.0.1", false}, // 880 ms block + {"127.0.0.1", true}, // 1100 ms next second #1 + {"127.0.0.2", true}, // 1320 ms allow other ip + {"127.0.0.1", false}, // 1540 ms no burst + {"127.0.0.1", false}, // 1760 ms no burst + {"127.0.0.1", false}, // 1980 ms no burst + {"127.0.0.1", true}, // 2200 ms no burst + {"127.0.0.1", false}, // 2420 ms no burst + {"127.0.0.1", false}, // 2640 ms no burst + {"127.0.0.1", false}, // 2860 ms no burst + {"127.0.0.1", true}, // 3080 ms no burst + {"127.0.0.1", false}, // 3300 ms no burst + {"127.0.0.1", false}, // 3520 ms no burst + {"127.0.0.1", false}, // 3740 ms no burst + {"127.0.0.1", false}, // 3960 ms no burst + {"127.0.0.1", true}, // 4180 ms no burst + {"127.0.0.1", false}, // 4400 ms no burst + {"127.0.0.1", false}, // 4620 ms no burst + {"127.0.0.1", false}, // 4840 ms no burst + {"127.0.0.1", true}, // 5060 ms no burst + } + + for i, tc := range testCases { + t.Logf("Running testcase #%d => %v", i, time.Duration(i)*220*time.Millisecond) + now = func() time.Time { + return time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC).Add(time.Duration(i) * 220 * time.Millisecond) + } + allowed, _ := inMemoryStore.Allow(tc.id) + assert.Equal(t, tc.allowed, allowed) + } +} + +func TestRateLimiterMemoryStore_cleanupStaleVisitors(t *testing.T) { + var inMemoryStore = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3}) + now = func() time.Time { + return time.Now() + } + fmt.Println(now()) + inMemoryStore.visitors = map[string]*Visitor{ + "A": { + Limiter: rate.NewLimiter(1, 3), + lastSeen: now(), + }, + "B": { + Limiter: rate.NewLimiter(1, 3), + lastSeen: now().Add(-1 * time.Minute), + }, + "C": { + Limiter: rate.NewLimiter(1, 3), + lastSeen: now().Add(-5 * time.Minute), + }, + "D": { + Limiter: rate.NewLimiter(1, 3), + lastSeen: now().Add(-10 * time.Minute), + }, + } + + inMemoryStore.Allow("D") + inMemoryStore.cleanupStaleVisitors() + + var exists bool + + _, exists = inMemoryStore.visitors["A"] + assert.Equal(t, true, exists) + + _, exists = inMemoryStore.visitors["B"] + assert.Equal(t, true, exists) + + _, exists = inMemoryStore.visitors["C"] + assert.Equal(t, false, exists) + + _, exists = inMemoryStore.visitors["D"] + assert.Equal(t, true, exists) +} + +func TestNewRateLimiterMemoryStore(t *testing.T) { + testCases := []struct { + rate rate.Limit + burst int + expiresIn time.Duration + expectedExpiresIn time.Duration + }{ + {1, 3, 5 * time.Second, 5 * time.Second}, + {2, 4, 0, 3 * time.Minute}, + {1, 5, 10 * time.Minute, 10 * time.Minute}, + {3, 7, 0, 3 * time.Minute}, + } + + for _, tc := range testCases { + store := NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: tc.rate, Burst: tc.burst, ExpiresIn: tc.expiresIn}) + assert.Equal(t, tc.rate, store.rate) + assert.Equal(t, tc.burst, store.burst) + assert.Equal(t, tc.expectedExpiresIn, store.expiresIn) + } +} + +func generateAddressList(count int) []string { + addrs := make([]string, count) + for i := 0; i < count; i++ { + addrs[i] = random.String(15) + } + return addrs +} + +func run(wg *sync.WaitGroup, store RateLimiterStore, addrs []string, max int, b *testing.B) { + for i := 0; i < b.N; i++ { + store.Allow(addrs[rand.Intn(max)]) + } + wg.Done() +} + +func benchmarkStore(store RateLimiterStore, parallel int, max int, b *testing.B) { + addrs := generateAddressList(max) + wg := &sync.WaitGroup{} + for i := 0; i < parallel; i++ { + wg.Add(1) + go run(wg, store, addrs, max, b) + } + wg.Wait() +} + +const ( + testExpiresIn = 1000 * time.Millisecond +) + +func BenchmarkRateLimiterMemoryStore_1000(b *testing.B) { + var store = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 100, Burst: 200, ExpiresIn: testExpiresIn}) + benchmarkStore(store, 10, 1000, b) +} + +func BenchmarkRateLimiterMemoryStore_10000(b *testing.B) { + var store = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 100, Burst: 200, ExpiresIn: testExpiresIn}) + benchmarkStore(store, 10, 10000, b) +} + +func BenchmarkRateLimiterMemoryStore_100000(b *testing.B) { + var store = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 100, Burst: 200, ExpiresIn: testExpiresIn}) + benchmarkStore(store, 10, 100000, b) +} + +func BenchmarkRateLimiterMemoryStore_conc100_10000(b *testing.B) { + var store = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 100, Burst: 200, ExpiresIn: testExpiresIn}) + benchmarkStore(store, 100, 10000, b) +}