diff --git a/rand/rand.go b/rand/rand.go new file mode 100644 index 000000000..4dc176d7d --- /dev/null +++ b/rand/rand.go @@ -0,0 +1,31 @@ +package rand + +import ( + "crypto/rand" + "fmt" + "io" + "math/big" +) + +func init() { + Reader = rand.Reader +} + +// Reader provides a random reader that can reset during testing. +var Reader io.Reader + +// Int63n returns a int64 between zero and value of max, read from an io.Reader source. +func Int63n(reader io.Reader, max int64) (int64, error) { + bi, err := rand.Int(reader, big.NewInt(max)) + if err != nil { + return 0, fmt.Errorf("failed to read random value, %w", err) + } + + return bi.Int64(), nil +} + +// CryptoRandInt63n returns a random int64 between zero and value of max +//obtained from the crypto rand source. +func CryptoRandInt63n(max int64) (int64, error) { + return Int63n(Reader, max) +} diff --git a/time/time.go b/time/time.go index 277adb136..0a3fbe48c 100644 --- a/time/time.go +++ b/time/time.go @@ -71,11 +71,3 @@ func SleepWithContext(ctx context.Context, dur time.Duration) error { return nil } - -// DurationMin compares two time.Duration input values and returns the minimum time.Duration value -func DurationMin(a, b time.Duration) time.Duration { - if a < b { - return a - } - return b -} diff --git a/waiter/logger.go b/waiter/logger.go new file mode 100644 index 000000000..4853ace01 --- /dev/null +++ b/waiter/logger.go @@ -0,0 +1,35 @@ +package waiter + +import ( + "context" + "fmt" + "github.com/awslabs/smithy-go/logging" + "github.com/awslabs/smithy-go/middleware" +) + +// Logger is the Logger middleware used by the waiter to log an attempt +type Logger struct { + // Attempt is the current attempt to be logged + Attempt int64 +} + +// ID representing the Logger middleware +func (*Logger) ID() string { + return "WaiterLogger" +} + +// HandleInitialize performs handling of request in initialize stack step +func (m *Logger) HandleInitialize(ctx context.Context, in middleware.InitializeInput, next middleware.InitializeHandler) ( + out middleware.InitializeOutput, metadata middleware.Metadata, err error, +) { + logger := middleware.GetLogger(ctx) + + logger.Logf(logging.Debug, fmt.Sprintf("attempting waiter request, attempt count: %d", m.Attempt)) + + return next.HandleInitialize(ctx, in) +} + +// AddLogger is helper util to add waiter logger after `SetLogger` middleware in +func (m Logger) AddLogger(stack *middleware.Stack) error { + return stack.Initialize.Insert(&m, "SetLogger", middleware.After) +} diff --git a/waiter/waiter.go b/waiter/waiter.go index a61bb1d2b..05b71cd30 100644 --- a/waiter/waiter.go +++ b/waiter/waiter.go @@ -2,52 +2,38 @@ package waiter import ( "fmt" + "github.com/awslabs/smithy-go/rand" "math" - "math/rand" "time" - - smithytime "github.com/awslabs/smithy-go/time" ) // ComputeDelay computes delay between waiter attempts. The function takes in a current attempt count, -// minimum delay, maximum delay, and remaining wait time for waiter as input. +// minimum delay, maximum delay, and remaining wait time for waiter as input. The inputs minDelay and maxDelay +// must always be greater than 0, along with minDelay lesser than or equal to maxDelay. // // Returns the computed delay and if next attempt count is possible within the given input time constraints. // Note that the zeroth attempt results in no delay. -func ComputeDelay(attempt int64, minDelay, maxDelay, remainingTime time.Duration) (delay time.Duration, done bool, err error) { - // validation - if minDelay > maxDelay { - return 0, true, fmt.Errorf("maximum delay must be greater than minimum delay") - } - +func ComputeDelay(attempt int64, minDelay, maxDelay, remainingTime time.Duration) (delay time.Duration, err error) { // zeroth attempt, no delay if attempt <= 0 { - return 0, true, nil + return 0, nil } // remainingTime is zero or less, no delay if remainingTime <= 0 { - return 0, true, nil + return 0, nil } - // as we use log, ensure min delay and maxdelay are atleast 1 ns - if minDelay < 1 { - minDelay = 1 + // validate min delay is greater than 0 + if minDelay == 0 { + return 0, fmt.Errorf("minDelay must be greater than zero when computing Delay") } - // if max delay is less than 1 ns, return 0 as delay - if maxDelay < 1 { - return 0, true, nil + // validate max delay is greater than 0 + if maxDelay == 0 { + return 0, fmt.Errorf("maxDelay must be greater than zero when computing Delay") } - // check if this is the last attempt possible and compute delay accordingly - defer func() { - if remainingTime-delay <= minDelay { - delay = remainingTime - minDelay - done = true - } - }() - // Get attempt ceiling to prevent integer overflow. attemptCeiling := (math.Log(float64(maxDelay/minDelay)) / math.Log(2)) + 1 @@ -55,16 +41,26 @@ func ComputeDelay(attempt int64, minDelay, maxDelay, remainingTime time.Duration delay = maxDelay } else { // Compute exponential delay based on attempt. - // [0.0, 1.0) * 2 ^ attempt-1 ri := 1 << uint64(attempt-1) // compute delay - delay = smithytime.DurationMin(maxDelay, minDelay*time.Duration(ri)) + delay = minDelay * time.Duration(ri) } if delay != minDelay { // randomize to get jitter between min delay and delay value - delay = time.Duration(rand.Int63n(int64(delay-minDelay))) + minDelay + // [0.0, 1.0) * [minDelay, delay] + d, err := rand.CryptoRandInt63n(int64(delay - minDelay)) + if err != nil { + return 0, fmt.Errorf("error computing retry jitter, %w", err) + } + + delay = time.Duration(d) + minDelay + } + + // check if this is the last attempt possible and compute delay accordingly + if remainingTime-delay <= minDelay { + delay = remainingTime - minDelay } - return delay, done, nil + return delay, nil } diff --git a/waiter/waiter_test.go b/waiter/waiter_test.go index ed35b8f2c..949c50dfd 100644 --- a/waiter/waiter_test.go +++ b/waiter/waiter_test.go @@ -25,19 +25,18 @@ func TestComputeDelay(t *testing.T) { expectedMinAttempts: 8, }, "zero minDelay": { - totalAttempts: 3, - minDelay: 0, - maxDelay: 120 * time.Second, - maxWaitTime: 300 * time.Second, - expectedMaxDelays: []time.Duration{1, 1, 1}, - expectedMinAttempts: 3, + totalAttempts: 3, + minDelay: 0, + maxDelay: 120 * time.Second, + maxWaitTime: 300 * time.Second, + expectedError: "minDelay must be greater than zero", }, "zero maxDelay": { totalAttempts: 3, minDelay: 10 * time.Second, maxDelay: 0, maxWaitTime: 300 * time.Second, - expectedError: "maximum delay must be greater than minimum delay", + expectedError: "maxDelay must be greater than zero", }, "zero remaining time": { totalAttempts: 3, @@ -47,6 +46,14 @@ func TestComputeDelay(t *testing.T) { expectedMaxDelays: []time.Duration{0}, expectedMinAttempts: 1, }, + "max wait time is less than min delay": { + totalAttempts: 3, + minDelay: 10 * time.Second, + maxDelay: 20 * time.Second, + maxWaitTime: 5 * time.Second, + expectedMaxDelays: []time.Duration{0}, + expectedMinAttempts: 1, + }, "large minDelay": { totalAttempts: 80, minDelay: 150 * time.Minute, @@ -88,6 +95,10 @@ func TestComputeDelay(t *testing.T) { if e, a := expectedDelay*time.Second, delays[i]; e < a { t.Fatalf("attempt %d : expected delay to be less than %v, got %v", i+1, e, a) } + + if e, a := c.minDelay, delays[i]; e > a && c.maxWaitTime > c.minDelay { + t.Fatalf("attempt %d : expected delay to be more than %v, got %v", i+1, e, a) + } } }) } @@ -105,17 +116,17 @@ func mockwait(maxAttempts int64, minDelay, maxDelay, maxWaitTime time.Duration) break } - delay, done, err := ComputeDelay(attempt, minDelay, maxDelay, remainingTime) + delay, err := ComputeDelay(attempt, minDelay, maxDelay, remainingTime) if err != nil { return delays, err } delays = append(delays, delay) - if done { - break - } remainingTime -= delay + if remainingTime < minDelay { + break + } } return delays, nil