Skip to content

Commit

Permalink
feat: Add inbound rate limiting
Browse files Browse the repository at this point in the history
Closes: #59
Signed-off-by: Michael Gasch <mgasch@vmware.com>
  • Loading branch information
Michael Gasch committed Oct 11, 2021
1 parent 74757a6 commit 60db3c3
Show file tree
Hide file tree
Showing 6 changed files with 149 additions and 0 deletions.
1 change: 1 addition & 0 deletions v2/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ require (
go.uber.org/multierr v1.1.0 // indirect
go.uber.org/zap v1.10.0
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e
golang.org/x/time v0.0.0-20210723032227-1f47c861a9ac
gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f // indirect
gopkg.in/yaml.v2 v2.3.0 // indirect
)
2 changes: 2 additions & 0 deletions v2/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ go.uber.org/zap v1.10.0 h1:ORx85nbTijNz8ljznvCMR1ZBIPKFn3jQrag10X2AsuM=
go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q=
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e h1:vcxGaoTs7kV8m5Np9uUNQin4BrLOthgV7252N8V+FwY=
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/time v0.0.0-20210723032227-1f47c861a9ac h1:7zkz7BUtwNFFqcowJ+RIgu2MaV/MapERkDIy+mwPyjs=
golang.org/x/time v0.0.0-20210723032227-1f47c861a9ac/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
Expand Down
10 changes: 10 additions & 0 deletions v2/protocol/http/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -277,3 +277,13 @@ func WithIsRetriableFunc(isRetriable IsRetriable) Option {
return nil
}
}

func WithRateLimiter(rl RateLimiter) Option {
return func(p *Protocol) error {
if p == nil {
return fmt.Errorf("http OPTIONS handler func can not set nil protocol")
}
p.limiter = rl
return nil
}
}
18 changes: 18 additions & 0 deletions v2/protocol/http/protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"io"
"net/http"
"net/url"
"strconv"
"sync"
"sync/atomic"
"time"
Expand Down Expand Up @@ -86,6 +87,7 @@ type Protocol struct {
server *http.Server
handlerRegistered bool
middleware []Middleware
limiter RateLimiter

isRetriableFunc IsRetriable
}
Expand Down Expand Up @@ -115,6 +117,10 @@ func New(opts ...Option) (*Protocol, error) {
p.isRetriableFunc = defaultIsRetriableFunc
}

if p.limiter == nil {
p.limiter = noOpLimiter{}
}

return p, nil
}

Expand Down Expand Up @@ -277,6 +283,18 @@ func (p *Protocol) Respond(ctx context.Context) (binding.Message, protocol.Respo
// ServeHTTP implements http.Handler.
// Blocks until ResponseFn is invoked.
func (p *Protocol) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
// always apply limiter first
ok, reset, err := p.limiter.Take(context.TODO(), req)
if err != nil {
p.incoming <- msgErr{msg: nil, err: fmt.Errorf("acquire rate limit token: %v", err)}
rw.WriteHeader(http.StatusInternalServerError)
}

if !ok {
rw.Header().Add("Retry-After", strconv.Itoa(int(reset)))
http.Error(rw, "limit exceeded", 429)
}

// Filter the GET style methods:
switch req.Method {
case http.MethodOptions:
Expand Down
34 changes: 34 additions & 0 deletions v2/protocol/http/protocol_rate.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/*
Copyright 2021 The CloudEvents Authors
SPDX-License-Identifier: Apache-2.0
*/

package http

import (
"context"
"net/http"
)

type RateLimiter interface {
// Take attempts to take one token from the rate limiter for the specified
// request. It returns ok when this operation was successful. In case ok is
// false, reset will indicate the time in seconds when it is safe to perform
// another attempt. An error is returned when this operation failed, e.g. due to
// a backend error.
Take(ctx context.Context, r *http.Request) (ok bool, reset uint64, err error)
// Close terminates rate limiter and cleans up any data structures or
// connections that may remain open. After a store is stopped, Take() should
// always return zero values.
Close(ctx context.Context) error
}

type noOpLimiter struct{}

func (n noOpLimiter) Take(ctx context.Context, r *http.Request) (bool, uint64, error) {
return true, 0, nil
}

func (n noOpLimiter) Close(ctx context.Context) error {
return nil
}
84 changes: 84 additions & 0 deletions v2/protocol/http/protocol_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"errors"
"net/http"
"net/http/httptest"
"strconv"
"testing"
"time"

Expand All @@ -19,6 +20,7 @@ import (
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/time/rate"
)

func TestNew(t *testing.T) {
Expand Down Expand Up @@ -262,6 +264,88 @@ func ReceiveTest(t *testing.T, p *Protocol, ctx context.Context, rec *httptest.R
}
}

func TestServeHTTP_ReceiveWithLimiter(t *testing.T) {
testCases := map[string]struct {
limiter RateLimiter
delay time.Duration // client send
wantCodes []int // status codes
}{
// limiter disabled
"no limit, 5 requests, no delay, 200,200,200,200,200": {
limiter: nil,
delay: 0,
wantCodes: []int{200, 200, 200, 200, 200},
},
// reject all
"0rps limit, 5 requests, no delay, 429,429,429,429": {
limiter: newRateLimiterTest(0),
delay: time.Millisecond * 500,
wantCodes: []int{429, 429, 429, 429},
},
"10rps limit, 5 requests, no delay, 200,200,200,200,200": {
limiter: newRateLimiterTest(10),
delay: 0,
wantCodes: []int{200, 200, 200, 200, 200},
},
"1rps limit, 5 requests, 100ms delay, 200,429,429,429,429": {
limiter: newRateLimiterTest(1),
delay: time.Millisecond * 100,
wantCodes: []int{200, 429, 429, 429, 429},
},
"2rps limit, 4 requests, 0.5s delay, 200,200,200,200": {
limiter: newRateLimiterTest(2),
delay: time.Millisecond * 500,
wantCodes: []int{200, 200, 200, 200},
},
}

for n, tc := range testCases {
t.Run(n, func(t *testing.T) {
p, err := New(WithRateLimiter(tc.limiter))
require.NoError(t, err, "create protocol")

for i := range tc.wantCodes {
time.Sleep(tc.delay)

rw := httptest.NewRecorder()
req := httptest.NewRequest("POST", "http://unittest", nil)

go p.ServeHTTP(rw, req)
_, _ = p.Receive(context.Background())
res := rw.Result()
require.Equal(t, tc.wantCodes[i], res.StatusCode)

if res.StatusCode == 429 {
require.Equal(t, res.Header.Get("Retry-After"), strconv.Itoa(2))
}
}
})
}
}

type rateLimiterTest struct {
limiter *rate.Limiter
}

func newRateLimiterTest(rps float64) RateLimiter {
rl := rateLimiterTest{
limiter: rate.NewLimiter(rate.Limit(rps), int(rps)),
}

return &rl
}

func (rl *rateLimiterTest) Take(_ context.Context, _ *http.Request) (bool, uint64, error) {
if !rl.limiter.Allow() {
return false, 2, nil
}
return true, 0, nil
}

func (rl *rateLimiterTest) Close(_ context.Context) error {
return nil
}

type roundTripperTest struct {
statusCodes []int
requestCount int
Expand Down

0 comments on commit 60db3c3

Please sign in to comment.