Skip to content

Commit

Permalink
add necessary plumbing to implement per server ip based rate limiting (
Browse files Browse the repository at this point in the history
  • Loading branch information
dhiaayachi authored May 23, 2023
1 parent 304d641 commit f526dfd
Show file tree
Hide file tree
Showing 7 changed files with 113 additions and 81 deletions.
45 changes: 26 additions & 19 deletions agent/consul/rate/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"context"
"errors"
"fmt"
"github.com/hashicorp/consul/agent/metadata"
"net"
"reflect"
"sync/atomic"
Expand Down Expand Up @@ -153,14 +154,14 @@ type RequestLimitsHandler interface {
Allow(op Operation) error
UpdateConfig(cfg HandlerConfig)
UpdateIPConfig(cfg IPLimitConfig)
Register(leaderStatusProvider LeaderStatusProvider)
Register(serversStatusProvider ServersStatusProvider)
}

// Handler enforces rate limits for incoming RPCs.
type Handler struct {
globalCfg *atomic.Pointer[HandlerConfig]
ipCfg *atomic.Pointer[IPLimitConfig]
leaderStatusProvider LeaderStatusProvider
globalCfg *atomic.Pointer[HandlerConfig]
ipCfg *atomic.Pointer[IPLimitConfig]
serversStatusProvider ServersStatusProvider

limiter multilimiter.RateLimiter

Expand All @@ -186,13 +187,14 @@ type HandlerConfig struct {
GlobalLimitConfig GlobalLimitConfig
}

//go:generate mockery --name LeaderStatusProvider --inpackage --filename mock_LeaderStatusProvider_test.go
type LeaderStatusProvider interface {
//go:generate mockery --name ServersStatusProvider --inpackage --filename mock_ServersStatusProvider_test.go
type ServersStatusProvider interface {
// IsLeader is used to determine whether the operation is being performed
// against the cluster leader, such that if it can _only_ be performed by
// the leader (e.g. write operations) we don't tell clients to retry against
// a different server.
IsLeader() bool
IsServer(addr string) bool
}

func isInfRate(cfg multilimiter.LimiterConfig) bool {
Expand Down Expand Up @@ -237,19 +239,19 @@ func (h *Handler) Run(ctx context.Context) {
// because of an exhausted rate-limit.
func (h *Handler) Allow(op Operation) error {

if h.leaderStatusProvider == nil {
h.logger.Error("leaderStatusProvider required to be set via Register(). bailing on rate limiter")
if h.serversStatusProvider == nil {
h.logger.Error("serversStatusProvider required to be set via Register(). bailing on rate limiter")
return nil
// TODO: panic and make sure to use the server's recovery handler
// panic("leaderStatusProvider required to be set via Register(..)")
// panic("serversStatusProvider required to be set via Register(..)")
}

cfg := h.globalCfg.Load()
if cfg.GlobalLimitConfig.Mode == ModeDisabled {
return nil
}

allow, throttledLimits := h.allowAllLimits(h.limits(op))
allow, throttledLimits := h.allowAllLimits(h.limits(op), h.serversStatusProvider.IsServer(string(metadata.GetIP(op.SourceAddr))))

if !allow {
for _, l := range throttledLimits {
Expand Down Expand Up @@ -277,7 +279,7 @@ func (h *Handler) Allow(op Operation) error {
})

if enforced {
if h.leaderStatusProvider.IsLeader() && op.Type == OperationTypeWrite {
if h.serversStatusProvider.IsLeader() && op.Type == OperationTypeWrite {
return ErrRetryLater
}
return ErrRetryElsewhere
Expand Down Expand Up @@ -305,17 +307,18 @@ func (h *Handler) UpdateConfig(cfg HandlerConfig) {

}

func (h *Handler) Register(leaderStatusProvider LeaderStatusProvider) {
h.leaderStatusProvider = leaderStatusProvider
func (h *Handler) Register(serversStatusProvider ServersStatusProvider) {
h.serversStatusProvider = serversStatusProvider
}

type limit struct {
mode Mode
ent multilimiter.LimitedEntity
desc string
mode Mode
ent multilimiter.LimitedEntity
desc string
applyOnServer bool
}

func (h *Handler) allowAllLimits(limits []limit) (bool, []limit) {
func (h *Handler) allowAllLimits(limits []limit, isServer bool) (bool, []limit) {
allow := true
throttledLimits := make([]limit, 0)

Expand All @@ -324,6 +327,10 @@ func (h *Handler) allowAllLimits(limits []limit) (bool, []limit) {
continue
}

if isServer && !l.applyOnServer {
continue
}

if !h.limiter.Allow(l.ent) {
throttledLimits = append(throttledLimits, l)
allow = false
Expand Down Expand Up @@ -358,7 +365,7 @@ func (h *Handler) globalLimit(op Operation) *limit {
}
cfg := h.globalCfg.Load()

lim := &limit{mode: cfg.GlobalLimitConfig.Mode}
lim := &limit{mode: cfg.GlobalLimitConfig.Mode, applyOnServer: true}
switch op.Type {
case OperationTypeRead:
lim.desc = "global/read"
Expand Down Expand Up @@ -409,4 +416,4 @@ func (nullRequestLimitsHandler) Run(_ context.Context) {}

func (nullRequestLimitsHandler) UpdateConfig(_ HandlerConfig) {}

func (nullRequestLimitsHandler) Register(_ LeaderStatusProvider) {}
func (nullRequestLimitsHandler) Register(_ ServersStatusProvider) {}
27 changes: 7 additions & 20 deletions agent/consul/rate/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,22 +19,6 @@ import (
"github.com/hashicorp/consul/agent/consul/multilimiter"
)

//
// Revisit test when handler.go:189 TODO implemented
//
// func TestHandler_Allow_PanicsWhenLeaderStatusProviderNotRegistered(t *testing.T) {
// defer func() {
// err := recover()
// if err == nil {
// t.Fatal("Run should panic")
// }
// }()

// handler := NewHandler(HandlerConfig{}, hclog.NewNullLogger())
// handler.Allow(Operation{})
// // intentionally skip handler.Register(...)
// }

func TestHandler(t *testing.T) {
var (
rpcName = "Foo.Bar"
Expand All @@ -50,6 +34,7 @@ func TestHandler(t *testing.T) {
globalMode Mode
checks []limitCheck
isLeader bool
isServer bool
expectErr error
expectLog bool
expectMetric bool
Expand Down Expand Up @@ -230,8 +215,9 @@ func TestHandler(t *testing.T) {
limiter.On("Allow", mock.Anything).Return(c.allow)
}

leaderStatusProvider := NewMockLeaderStatusProvider(t)
leaderStatusProvider.On("IsLeader").Return(tc.isLeader).Maybe()
serversStatusProvider := NewMockServersStatusProvider(t)
serversStatusProvider.On("IsLeader").Return(tc.isLeader).Maybe()
serversStatusProvider.On("IsServer", mock.Anything).Return(tc.isServer).Maybe()

var output bytes.Buffer
logger := hclog.NewInterceptLogger(&hclog.LoggerOptions{
Expand All @@ -252,7 +238,7 @@ func TestHandler(t *testing.T) {
limiter,
logger,
)
handler.Register(leaderStatusProvider)
handler.Register(serversStatusProvider)

require.Equal(t, tc.expectErr, handler.Allow(tc.op))

Expand Down Expand Up @@ -426,8 +412,9 @@ func TestAllow(t *testing.T) {
}
mockRateLimiter.On("UpdateConfig", mock.Anything, mock.Anything).Return()
logger := hclog.NewNullLogger()
delegate := NewMockLeaderStatusProvider(t)
delegate := NewMockServersStatusProvider(t)
delegate.On("IsLeader").Return(true).Maybe()
delegate.On("IsServer", mock.Anything).Return(false).Maybe()
handler := NewHandlerWithLimiter(*tc.cfg, mockRateLimiter, logger)
handler.Register(delegate)
addr := net.TCPAddrFromAddrPort(netip.MustParseAddrPort("127.0.0.1:1234"))
Expand Down
39 changes: 0 additions & 39 deletions agent/consul/rate/mock_LeaderStatusProvider_test.go

This file was deleted.

6 changes: 3 additions & 3 deletions agent/consul/rate/mock_RequestLimitsHandler.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

53 changes: 53 additions & 0 deletions agent/consul/rate/mock_ServersStatusProvider_test.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

14 changes: 14 additions & 0 deletions agent/consul/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -1660,6 +1660,20 @@ func (s *Server) IsLeader() bool {
return s.raft.State() == raft.Leader
}

// IsServer checks if this addr is of a server
func (s *Server) IsServer(addr string) bool {
for _, s := range s.raft.GetConfiguration().Configuration().Servers {
a, err := net.ResolveTCPAddr("tcp", string(s.Address))
if err != nil {
continue
}
if string(metadata.GetIP(a)) == addr {
return true
}
}
return false
}

// LeaderLastContact returns the time of last contact by a leader.
// This only makes sense if we are currently a follower.
func (s *Server) LeaderLastContact() time.Time {
Expand Down
10 changes: 10 additions & 0 deletions agent/metadata/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -221,3 +221,13 @@ func AddFeatureFlags(tags map[string]string, flags ...string) {
tags[featureFlagPrefix+flag] = "1"
}
}

func GetIP(addr net.Addr) []byte {
switch a := addr.(type) {
case *net.UDPAddr:
return []byte(a.IP.String())
case *net.TCPAddr:
return []byte(a.IP.String())
}
return []byte{}
}

0 comments on commit f526dfd

Please sign in to comment.