Skip to content

Commit

Permalink
Add custom backoff strategy option (#302)
Browse files Browse the repository at this point in the history
Signed-off-by: Andrew Harding <azdagron@gmail.com>
  • Loading branch information
azdagron authored Oct 4, 2024
1 parent 51299b0 commit bfb7d34
Show file tree
Hide file tree
Showing 5 changed files with 110 additions and 44 deletions.
47 changes: 34 additions & 13 deletions v2/workloadapi/backoff.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,30 +5,51 @@ import (
"time"
)

// backoff defines an linear backoff policy.
type backoff struct {
InitialDelay time.Duration
MaxDelay time.Duration
// BackoffStrategy provides backoff facilities.
type BackoffStrategy interface {
// NewBackoff returns a new backoff for the strategy. The returned
// Backoff is in the same state that it would be in after a call to
// Reset().
NewBackoff() Backoff
}

// Backoff provides backoff for a workload API operation.
type Backoff interface {
// Next returns the next backoff period.
Next() time.Duration

// Reset() resets the backoff.
Reset()
}

type defaultBackoffStrategy struct{}

func (defaultBackoffStrategy) NewBackoff() Backoff {
return newLinearBackoff()
}

// linearBackoff defines an linear backoff policy.
type linearBackoff struct {
initialDelay time.Duration
maxDelay time.Duration
n int
}

func newBackoff() *backoff {
return &backoff{
InitialDelay: time.Second,
MaxDelay: 30 * time.Second,
func newLinearBackoff() *linearBackoff {
return &linearBackoff{
initialDelay: time.Second,
maxDelay: 30 * time.Second,
n: 0,
}
}

// Duration returns the next wait period for the backoff. Not goroutine-safe.
func (b *backoff) Duration() time.Duration {
func (b *linearBackoff) Next() time.Duration {
backoff := float64(b.n) + 1
d := math.Min(b.InitialDelay.Seconds()*backoff, b.MaxDelay.Seconds())
d := math.Min(b.initialDelay.Seconds()*backoff, b.maxDelay.Seconds())
b.n++
return time.Duration(d) * time.Second
}

// Reset resets the backoff's state.
func (b *backoff) Reset() {
func (b *linearBackoff) Reset() {
b.n = 0
}
23 changes: 8 additions & 15 deletions v2/workloadapi/backoff_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,34 +7,27 @@ import (
"github.com/stretchr/testify/require"
)

func TestBackoff(t *testing.T) {
new := func() *backoff { //nolint:all
b := newBackoff()
b.InitialDelay = time.Second
b.MaxDelay = 30 * time.Second
return b
}

testUntilMax := func(t *testing.T, b *backoff) {
func TestLinearBackoff(t *testing.T) {
testUntilMax := func(t *testing.T, b *linearBackoff) {
for i := 1; i < 30; i++ {
require.Equal(t, time.Duration(i)*time.Second, b.Duration())
require.Equal(t, time.Duration(i)*time.Second, b.Next())
}
require.Equal(t, 30*time.Second, b.Duration())
require.Equal(t, 30*time.Second, b.Duration())
require.Equal(t, 30*time.Second, b.Duration())
require.Equal(t, 30*time.Second, b.Next())
require.Equal(t, 30*time.Second, b.Next())
require.Equal(t, 30*time.Second, b.Next())
}

t.Run("test max", func(t *testing.T) {
t.Parallel()

b := new()
b := newLinearBackoff()
testUntilMax(t, b)
})

t.Run("test reset", func(t *testing.T) {
t.Parallel()

b := new()
b := newLinearBackoff()
testUntilMax(t, b)

b.Reset()
Expand Down
19 changes: 10 additions & 9 deletions v2/workloadapi/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ func (c *Client) FetchX509Bundles(ctx context.Context) (*x509bundle.Set, error)
// WatchX509Bundles watches for changes to the X.509 bundles. The watcher receives
// the updated X.509 bundles.
func (c *Client) WatchX509Bundles(ctx context.Context, watcher X509BundleWatcher) error {
backoff := newBackoff()
backoff := c.config.backoffStrategy.NewBackoff()
for {
err := c.watchX509Bundles(ctx, watcher, backoff)
watcher.OnX509BundlesWatchError(err)
Expand Down Expand Up @@ -152,7 +152,7 @@ func (c *Client) FetchX509Context(ctx context.Context) (*X509Context, error) {
// WatchX509Context watches for updates to the X.509 context. The watcher
// receives the updated X.509 context.
func (c *Client) WatchX509Context(ctx context.Context, watcher X509ContextWatcher) error {
backoff := newBackoff()
backoff := c.config.backoffStrategy.NewBackoff()
for {
err := c.watchX509Context(ctx, watcher, backoff)
watcher.OnX509ContextWatchError(err)
Expand Down Expand Up @@ -224,7 +224,7 @@ func (c *Client) FetchJWTBundles(ctx context.Context) (*jwtbundle.Set, error) {
// WatchJWTBundles watches for changes to the JWT bundles. The watcher receives
// the updated JWT bundles.
func (c *Client) WatchJWTBundles(ctx context.Context, watcher JWTBundleWatcher) error {
backoff := newBackoff()
backoff := c.config.backoffStrategy.NewBackoff()
for {
err := c.watchJWTBundles(ctx, watcher, backoff)
watcher.OnJWTBundlesWatchError(err)
Expand Down Expand Up @@ -258,7 +258,7 @@ func (c *Client) newConn(ctx context.Context) (*grpc.ClientConn, error) {
return grpc.DialContext(ctx, c.config.address, c.config.dialOptions...) //nolint:staticcheck // preserve backcompat with WithDialOptions option
}

func (c *Client) handleWatchError(ctx context.Context, err error, backoff *backoff) error {
func (c *Client) handleWatchError(ctx context.Context, err error, backoff Backoff) error {
code := status.Code(err)
if code == codes.Canceled {
return err
Expand All @@ -270,7 +270,7 @@ func (c *Client) handleWatchError(ctx context.Context, err error, backoff *backo
}

c.config.log.Errorf("Failed to watch the Workload API: %v", err)
retryAfter := backoff.Duration()
retryAfter := backoff.Next()
c.config.log.Debugf("Retrying watch in %s", retryAfter)
select {
case <-time.After(retryAfter):
Expand All @@ -281,7 +281,7 @@ func (c *Client) handleWatchError(ctx context.Context, err error, backoff *backo
}
}

func (c *Client) watchX509Context(ctx context.Context, watcher X509ContextWatcher, backoff *backoff) error {
func (c *Client) watchX509Context(ctx context.Context, watcher X509ContextWatcher, backoff Backoff) error {
ctx, cancel := context.WithCancel(withHeader(ctx))
defer cancel()

Expand All @@ -308,7 +308,7 @@ func (c *Client) watchX509Context(ctx context.Context, watcher X509ContextWatche
}
}

func (c *Client) watchJWTBundles(ctx context.Context, watcher JWTBundleWatcher, backoff *backoff) error {
func (c *Client) watchJWTBundles(ctx context.Context, watcher JWTBundleWatcher, backoff Backoff) error {
ctx, cancel := context.WithCancel(withHeader(ctx))
defer cancel()

Expand All @@ -335,7 +335,7 @@ func (c *Client) watchJWTBundles(ctx context.Context, watcher JWTBundleWatcher,
}
}

func (c *Client) watchX509Bundles(ctx context.Context, watcher X509BundleWatcher, backoff *backoff) error {
func (c *Client) watchX509Bundles(ctx context.Context, watcher X509BundleWatcher, backoff Backoff) error {
ctx, cancel := context.WithCancel(withHeader(ctx))
defer cancel()

Expand Down Expand Up @@ -402,7 +402,8 @@ func withHeader(ctx context.Context) context.Context {

func defaultClientConfig() clientConfig {
return clientConfig{
log: logger.Null,
log: logger.Null,
backoffStrategy: defaultBackoffStrategy{},
}
}

Expand Down
48 changes: 45 additions & 3 deletions v2/workloadapi/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"crypto/x509"
"sync"
"sync/atomic"
"testing"
"time"

Expand Down Expand Up @@ -103,7 +104,10 @@ func TestFetchX509Bundles(t *testing.T) {
func TestWatchX509Bundles(t *testing.T) {
wl := fakeworkloadapi.New(t)
defer wl.Stop()
c, err := New(context.Background(), WithAddr(wl.Addr()))

backoffStrategy := &testBackoffStrategy{}

c, err := New(context.Background(), WithAddr(wl.Addr()), WithBackoffStrategy(backoffStrategy))
require.NoError(t, err)
defer c.Close()

Expand Down Expand Up @@ -149,6 +153,9 @@ func TestWatchX509Bundles(t *testing.T) {
wl.Stop()
tw.WaitForUpdates(1)
assert.Len(t, tw.Errors(), 2)

// Assert that there was the expected number of backoffs.
assert.Equal(t, 2, backoffStrategy.BackedOff())
}

func TestFetchX509Context(t *testing.T) {
Expand Down Expand Up @@ -213,7 +220,10 @@ func TestWatchX509Context(t *testing.T) {
federatedCA := test.NewCA(t, federatedTD)
wl := fakeworkloadapi.New(t)
defer wl.Stop()
c, err := New(context.Background(), WithAddr(wl.Addr()))

backoffStrategy := &testBackoffStrategy{}

c, err := New(context.Background(), WithAddr(wl.Addr()), WithBackoffStrategy(backoffStrategy))
require.NoError(t, err)
defer c.Close()

Expand Down Expand Up @@ -291,6 +301,9 @@ func TestWatchX509Context(t *testing.T) {

cancel()
wg.Wait()

// Assert that there was the expected number of backoffs.
assert.Equal(t, 2, backoffStrategy.BackedOff())
}

func TestFetchJWTSVID(t *testing.T) {
Expand Down Expand Up @@ -375,7 +388,10 @@ func TestFetchJWTBundles(t *testing.T) {
func TestWatchJWTBundles(t *testing.T) {
wl := fakeworkloadapi.New(t)
defer wl.Stop()
c, err := New(context.Background(), WithAddr(wl.Addr()))

backoffStrategy := &testBackoffStrategy{}

c, err := New(context.Background(), WithAddr(wl.Addr()), WithBackoffStrategy(backoffStrategy))
require.NoError(t, err)
defer c.Close()

Expand Down Expand Up @@ -421,6 +437,9 @@ func TestWatchJWTBundles(t *testing.T) {
wl.Stop()
tw.WaitForUpdates(1)
assert.Len(t, tw.Errors(), 2)

// Assert that there was the expected number of backoffs.
assert.Equal(t, 2, backoffStrategy.BackedOff())
}

func TestValidateJWTSVID(t *testing.T) {
Expand Down Expand Up @@ -605,3 +624,26 @@ func (w *testWatcher) WaitForUpdates(expectedNumUpdates int) {
}
}
}

type testBackoffStrategy struct {
backedOff int32
}

func (s *testBackoffStrategy) NewBackoff() Backoff {
return testBackoff{backedOff: &s.backedOff}
}

func (s *testBackoffStrategy) BackedOff() int {
return int(atomic.LoadInt32(&s.backedOff))
}

type testBackoff struct {
backedOff *int32
}

func (b testBackoff) Next() time.Duration {
atomic.AddInt32(b.backedOff, 1)
return time.Millisecond * 200
}

func (b testBackoff) Reset() {}
17 changes: 13 additions & 4 deletions v2/workloadapi/option.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,14 @@ func WithLogger(logger logger.Logger) ClientOption {
})
}

// WithBackoff provides a custom backoff strategy that replaces the
// default backoff strategy (linear backoff).
func WithBackoffStrategy(backoffStrategy BackoffStrategy) ClientOption {
return clientOption(func(c *clientConfig) {
c.backoffStrategy = backoffStrategy
})
}

// SourceOption are options that are shared among all option types.
type SourceOption interface {
configureX509Source(*x509SourceConfig)
Expand Down Expand Up @@ -81,10 +89,11 @@ type BundleSourceOption interface {
}

type clientConfig struct {
address string
namedPipeName string
dialOptions []grpc.DialOption
log logger.Logger
address string
namedPipeName string
dialOptions []grpc.DialOption
log logger.Logger
backoffStrategy BackoffStrategy
}

type clientOption func(*clientConfig)
Expand Down

0 comments on commit bfb7d34

Please sign in to comment.