Skip to content

Commit d75a1b8

Browse files
authored
grpc_retry backoff overflow (#747)
* grpc_retry backoff overflow * Add bounds to exponentBase2 and use it instead of math.Exp2 * Add a few tests * Add copyright to backoff_test
1 parent ed865db commit d75a1b8

File tree

2 files changed

+73
-1
lines changed

2 files changed

+73
-1
lines changed

interceptors/retry/backoff.go

+34-1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ package retry
55

66
import (
77
"context"
8+
"math"
89
"math/rand"
910
"time"
1011
)
@@ -24,8 +25,15 @@ func jitterUp(duration time.Duration, jitter float64) time.Duration {
2425
return time.Duration(float64(duration) * (1 + multiplier))
2526
}
2627

27-
// exponentBase2 computes 2^(a-1) where a >= 1. If a is 0, the result is 0.
28+
// exponentBase2 computes 2^(a-1) where a >= 1. If a is 0, the result is 1.
29+
// if a is greater than 62, the result is 2^62 to avoid overflowing int64
2830
func exponentBase2(a uint) uint {
31+
if a == 0 {
32+
return 1
33+
}
34+
if a > 62 {
35+
return 1 << 62
36+
}
2937
return (1 << a) >> 1
3038
}
3139

@@ -50,6 +58,31 @@ func BackoffExponential(scalar time.Duration) BackoffFunc {
5058
// BackoffExponential does, but adds jitter.
5159
func BackoffExponentialWithJitter(scalar time.Duration, jitterFraction float64) BackoffFunc {
5260
return func(ctx context.Context, attempt uint) time.Duration {
61+
exp := exponentBase2(attempt)
62+
dur := scalar * time.Duration(exp)
63+
// Check for overflow in duration multiplication
64+
if exp != 0 && dur/scalar != time.Duration(exp) {
65+
return time.Duration(math.MaxInt64)
66+
}
5367
return jitterUp(scalar*time.Duration(exponentBase2(attempt)), jitterFraction)
5468
}
5569
}
70+
71+
func BackoffExponentialWithJitterBounded(scalar time.Duration, jitterFrac float64, maxBound time.Duration) BackoffFunc {
72+
return func(ctx context.Context, attempt uint) time.Duration {
73+
exp := exponentBase2(attempt)
74+
dur := scalar * time.Duration(exp)
75+
// Check for overflow in duration multiplication
76+
if exp != 0 && dur/scalar != time.Duration(exp) {
77+
return maxBound
78+
}
79+
// Apply random jitter between -jitterFrac and +jitterFrac
80+
jitter := 1 + jitterFrac*(rand.Float64()*2-1)
81+
jitteredDuration := time.Duration(float64(dur) * jitter)
82+
// Check for overflow in jitter multiplication
83+
if float64(dur)*jitter > float64(math.MaxInt64) {
84+
return maxBound
85+
}
86+
return min(jitteredDuration, maxBound)
87+
}
88+
}

interceptors/retry/backoff_test.go

+39
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
// Copyright (c) The go-grpc-middleware Authors.
2+
// Licensed under the Apache License 2.0.
3+
package retry
4+
5+
import (
6+
"context"
7+
"testing"
8+
"time"
9+
)
10+
11+
func TestBackoffExponentialWithJitter(t *testing.T) {
12+
scalar := 100 * time.Millisecond
13+
jitterFrac := 0.10
14+
backoffFunc := BackoffExponentialWithJitter(scalar, jitterFrac)
15+
// use 64 so we are past number of attempts where exponentBase2 would overflow
16+
for i := 0; i < 64; i++ {
17+
waitFor := backoffFunc(nil, uint(i))
18+
if waitFor < 0 {
19+
t.Errorf("BackoffExponentialWithJitter(%d) = %d; want >= 0", i, waitFor)
20+
}
21+
}
22+
}
23+
24+
func TestBackoffExponentialWithJitterBounded(t *testing.T) {
25+
scalar := 100 * time.Millisecond
26+
jitterFrac := 0.10
27+
maxBound := 10 * time.Second
28+
backoff := BackoffExponentialWithJitterBounded(scalar, jitterFrac, maxBound)
29+
// use 64 so we are past number of attempts where exponentBase2 would overflow
30+
for i := 0; i < 64; i++ {
31+
waitFor := backoff(context.Background(), uint(i))
32+
if waitFor > maxBound {
33+
t.Fatalf("expected dur to be less than %v, got %v for %d", maxBound, waitFor, i)
34+
}
35+
if waitFor < 0 {
36+
t.Fatalf("expected dur to be greater than 0, got %v for %d", waitFor, i)
37+
}
38+
}
39+
}

0 commit comments

Comments
 (0)