Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

http: Replace ConcurrencyLimiter with IDLocker #2925

Merged
merged 10 commits into from
Apr 5, 2023
6 changes: 1 addition & 5 deletions app/inithttp.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,11 +115,7 @@ func (app *App) initHTTP(ctx context.Context) error {
// add auth info to request logs
logRequestAuth,

conReqLimit{
perIntKey: 1,
perService: 2,
perUser: 3,
}.Middleware,
LimitConcurrencyByAuthSource,

wrapGzip,
}
Expand Down
42 changes: 42 additions & 0 deletions app/limitconcurrencybyauthsource.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package app

import (
"net/http"
"time"

"github.com/target/goalert/ctxlock"
"github.com/target/goalert/permission"
"github.com/target/goalert/util/errutil"
)

// LimitConcurrencyByAuthSource limits the number of concurrent requests
// per auth source. MaxHeld is 1, so only one request can be processed at a
// time per source (e.g., session key, integration key, etc).
//
// Note: This is per source/ID combo, so only multiple requests via the SAME
// integration key would get queued. Separate keys go in separate buckets.
func LimitConcurrencyByAuthSource(next http.Handler) http.Handler {
limit := ctxlock.NewIDLocker[permission.SourceInfo](ctxlock.Config{
MaxHeld: 1,
MaxWait: 100,
Timeout: 20 * time.Second,
})

return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
ctx := req.Context()

src := permission.Source(ctx)
if src == nil {
// Any unknown source gets put into a single bucket.
src = &permission.SourceInfo{}
}

err := limit.Lock(ctx, *src)
if errutil.HTTPError(ctx, w, err) {
return
}
defer limit.Unlock(*src)

next.ServeHTTP(w, req)
})
}
68 changes: 0 additions & 68 deletions app/middlewarereqlimit.go

This file was deleted.

7 changes: 7 additions & 0 deletions auth/basic/identityprovider.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"github.com/pkg/errors"
"github.com/target/goalert/auth"
"github.com/target/goalert/config"
"github.com/target/goalert/util/errutil"
"github.com/target/goalert/util/log"
"github.com/target/goalert/validation/validate"
)
Expand Down Expand Up @@ -45,6 +46,12 @@ func (p *Provider) ExtractIdentity(route *auth.RouteInfo, w http.ResponseWriter,
}
ctx = log.WithField(ctx, "username", username)

err = p.lim.Lock(ctx, username)
if errutil.HTTPError(ctx, w, err) {
return nil, err
}
defer p.lim.Unlock(username)

_, err = p.b.Validate(ctx, username, password)
if err != nil {
log.Debug(ctx, errors.Wrap(err, "basic login"))
Expand Down
9 changes: 8 additions & 1 deletion auth/basic/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,21 @@ package basic

import (
"context"

"github.com/target/goalert/ctxlock"
)

// Provider implements the auth.IdentityProvider interface.
type Provider struct {
b *Store

lim *ctxlock.IDLocker[string]
}

// NewProvider creates a new Provider with the associated config.
func NewProvider(ctx context.Context, store *Store) (*Provider, error) {
return &Provider{b: store}, nil
return &Provider{
b: store,
lim: ctxlock.NewIDLocker[string](ctxlock.Config{MaxHeld: 1}),
}, nil
}
63 changes: 63 additions & 0 deletions ctxlock/idlocker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package ctxlock_test

import (
"context"
"sync/atomic"
"testing"
"time"

Expand Down Expand Up @@ -212,3 +213,65 @@ func TestIDLocker_Unlock_Abandoned(t *testing.T) {
l.Unlock("foo") // original lock
assert.Panics(t, func() { l.Unlock("foo") }, "unlocking an empty queue should panic")
}

func BenchmarkIDLocker_Sequential(b *testing.B) {
l := ctxlock.NewIDLocker[struct{}](ctxlock.Config{})
ctx := context.Background()
for i := 0; i < b.N; i++ {
err := l.Lock(ctx, struct{}{})
if err != nil {
b.Fatal(err)
}
l.Unlock(struct{}{})
}
}

func BenchmarkIDLocker_Sequential_Cardinality(b *testing.B) {
l := ctxlock.NewIDLocker[int64](ctxlock.Config{})
ctx := context.Background()
var n int64
for i := 0; i < b.N; i++ {
err := l.Lock(ctx, n)
if err != nil {
b.Fatal(err)
}
n++
if n > 100 {
l.Unlock(n - 100)
}
}
}

func BenchmarkIDLocker_Concurrent(b *testing.B) {
l := ctxlock.NewIDLocker[struct{}](ctxlock.Config{MaxWait: -1})
ctx := context.Background()

b.SetParallelism(1000)
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
err := l.Lock(ctx, struct{}{})
require.NoError(b, err)
l.Unlock(struct{}{})
}
})
}

func BenchmarkIDLocker_Concurrent_Cardinality(b *testing.B) {
l := ctxlock.NewIDLocker[int64](ctxlock.Config{MaxWait: 1})
ctx := context.Background()

b.SetParallelism(1000)
var n int64
b.RunParallel(func(pb *testing.PB) {
id := atomic.AddInt64(&n, 1)
ch := make(chan error, 1)
for pb.Next() {
err := l.Lock(ctx, id)
require.NoError(b, err)
go func() { ch <- l.Lock(ctx, id) }()
l.Unlock(id)
require.NoError(b, <-ch)
l.Unlock(id)
}
})
}
61 changes: 33 additions & 28 deletions util/errutil/httperror.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,20 @@ package errutil
import (
"context"
"database/sql"
"errors"
"net/http"

"github.com/pkg/errors"
"github.com/target/goalert/ctxlock"
"github.com/target/goalert/permission"
"github.com/target/goalert/util/log"
"github.com/target/goalert/util/sqlutil"
"github.com/target/goalert/validation"
)

func isCtxCause(err error) bool {
func isCancel(err error) bool {
if errors.Is(err, context.Canceled) {
return true
}
if errors.Is(err, context.DeadlineExceeded) {
return true
}
if errors.Is(err, sql.ErrTxDone) {
return true
}
Expand Down Expand Up @@ -51,35 +49,42 @@ func HTTPError(ctx context.Context, w http.ResponseWriter, err error) bool {
}

err = MapDBError(err)
if permission.IsUnauthorized(err) {
log.Debug(ctx, err)
switch {
case errors.Is(err, ctxlock.ErrQueueFull), errors.Is(err, ctxlock.ErrTimeout):
// Either the queue is full or the lock timed out. Either way
// we are waiting on concurrent requests for this source, so
// send them back with a 429 because we are rate limiting them
// due to being at/beyond capacity.
//
// Because of the way the lock works, we can guarantee that
// we will process one request at a time (per source), but we
// may have to wait for a previous request to finish before we
// can start processing the next one.
//
// This means only concurrent requests (per process) have the
// possibility to be rate limited, and not sequential requests,
// even in the worst case scenario.
http.Error(w, http.StatusText(http.StatusTooManyRequests), http.StatusTooManyRequests)
case isCancel(err):
// Client disconnected, send 400 back so logs reflect that this
// was a client-side problem.
http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
case permission.IsUnauthorized(err):
http.Error(w, unwrapAll(err).Error(), http.StatusUnauthorized)
return true
}
if permission.IsPermissionError(err) {
log.Debug(ctx, err)
case permission.IsPermissionError(err):
http.Error(w, unwrapAll(err).Error(), http.StatusForbidden)
return true
}
if validation.IsClientError(err) {
log.Debug(ctx, err)
case validation.IsClientError(err):
http.Error(w, unwrapAll(err).Error(), http.StatusBadRequest)
return true
}
if IsLimitError(err) {
log.Debug(ctx, err)
case IsLimitError(err):
http.Error(w, unwrapAll(err).Error(), http.StatusConflict)
return true
}

if ctx.Err() != nil && isCtxCause(err) {
// context timed out or was canceled
log.Debug(ctx, err)
case errors.Is(err, context.DeadlineExceeded):
// Timeout
http.Error(w, http.StatusText(http.StatusGatewayTimeout), http.StatusGatewayTimeout)
return true
default:
// For all other unexpected errors, log the error and send a 500.
log.Log(ctx, err)
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
}

log.Log(ctx, err)
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return true
}