Skip to content

Commit

Permalink
Adds retry shortcutting for V2
Browse files Browse the repository at this point in the history
  • Loading branch information
gdavison committed Sep 16, 2021
1 parent e55054b commit 1fc56c5
Show file tree
Hide file tree
Showing 3 changed files with 282 additions and 6 deletions.
41 changes: 41 additions & 0 deletions aws_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,18 @@ package awsbase
import (
"context"
"crypto/tls"
"errors"
"fmt"
"log"
"net"
"net/http"
"os"
"strings"
"time"

"github.com/aws/aws-sdk-go-v2/aws"
awsmiddleware "github.com/aws/aws-sdk-go-v2/aws/middleware"
"github.com/aws/aws-sdk-go-v2/aws/retry"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/feature/ec2/imds"
"github.com/aws/aws-sdk-go-v2/service/iam"
Expand All @@ -26,9 +31,21 @@ func GetAwsConfig(ctx context.Context, c *Config) (aws.Config, error) {
return aws.Config{}, err
}

var retryer aws.Retryer
retryer = retry.NewStandard()
if c.MaxRetries != 0 {
retryer = retry.AddWithMaxAttempts(retryer, c.MaxRetries)
}
retryer = &networkErrorShortcutter{
Retryer: retryer,
}

loadOptions := append(
commonLoadOptions(c),
config.WithCredentialsProvider(credentialsProvider),
config.WithRetryer(func() aws.Retryer {
return retryer
}),
)
cfg, err := config.LoadDefaultConfig(ctx, loadOptions...)
if err != nil {
Expand All @@ -44,6 +61,30 @@ func GetAwsConfig(ctx context.Context, c *Config) (aws.Config, error) {
return cfg, nil
}

// networkErrorShortcutter is used to enable networking error shortcutting
type networkErrorShortcutter struct {
aws.Retryer
}

// We're misusing RetryDelay here, since this is the only function that takes the attempt count
func (r *networkErrorShortcutter) RetryDelay(attempt int, err error) (time.Duration, error) {
if attempt >= constants.MaxNetworkRetryCount {
var netOpErr *net.OpError
if errors.As(err, &netOpErr) {
// It's disappointing that we have to do string matching here, rather than being able to using `errors.Is()` or even strings exported by the Go `net` package
if strings.Contains(netOpErr.Error(), "no such host") || strings.Contains(netOpErr.Error(), "connection refused") {
log.Printf("[WARN] Disabling retries after next request due to networking issue: %s", err)
return 0, &retry.MaxAttemptsError{
Attempt: attempt,
Err: err,
}
}
}
}

return r.Retryer.RetryDelay(attempt, err)
}

func GetAwsAccountIDAndPartition(ctx context.Context, awsConfig aws.Config, skipCredsValidation, skipRequestingAccountId bool) (string, string, error) {
if !skipCredsValidation {
stsClient := sts.NewFromConfig(awsConfig)
Expand Down
234 changes: 234 additions & 0 deletions aws_config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,16 @@ import (
"errors"
"fmt"
"io/ioutil"
"net"
"os"
"reflect"
"runtime"
"strings"
"testing"
"time"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/aws/retry"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/service/sts"
"github.com/aws/smithy-go/middleware"
Expand Down Expand Up @@ -1223,3 +1226,234 @@ func TestGetAwsConfigWithAccountIDAndPartition(t *testing.T) {
})
}
}

type mockRetryableError struct{ b bool }

func (m mockRetryableError) RetryableError() bool { return m.b }
func (m mockRetryableError) Error() string {
return fmt.Sprintf("mock retryable %t", m.b)
}

func TestRetryHandlers(t *testing.T) {
const maxRetries = 10

testcases := map[string]struct {
NextHandler func() middleware.FinalizeHandler
ExpectResults retry.AttemptResults
Err error
}{
"stops at maxRetries for retryable errors": {
NextHandler: func() middleware.FinalizeHandler {
num := 0
reqsErrs := make([]error, maxRetries)
for i := 0; i < maxRetries; i++ {
reqsErrs[i] = mockRetryableError{b: true}
}
return middleware.FinalizeHandlerFunc(func(ctx context.Context, in middleware.FinalizeInput) (out middleware.FinalizeOutput, metadata middleware.Metadata, err error) {
if num >= len(reqsErrs) {
err = fmt.Errorf("more requests than expected")
} else {
err = reqsErrs[num]
num++
}
return out, metadata, err
})
},
Err: fmt.Errorf("exceeded maximum number of attempts"),
ExpectResults: func() retry.AttemptResults {
results := retry.AttemptResults{
Results: make([]retry.AttemptResult, maxRetries),
}
for i := 0; i < maxRetries-1; i++ {
results.Results[i] = retry.AttemptResult{
Err: mockRetryableError{b: true},
Retryable: true,
Retried: true,
}
}
results.Results[maxRetries-1] = retry.AttemptResult{
Err: &retry.MaxAttemptsError{Attempt: maxRetries, Err: mockRetryableError{b: true}},
Retryable: true,
}
return results
}(),
},
"stops at MaxNetworkRetryCount for 'no such host' errors": {
NextHandler: func() middleware.FinalizeHandler {
num := 0
reqsErrs := make([]error, constants.MaxNetworkRetryCount)
for i := 0; i < constants.MaxNetworkRetryCount; i++ {
reqsErrs[i] = &net.OpError{Op: "dial", Err: errors.New("no such host")}
}
return middleware.FinalizeHandlerFunc(func(ctx context.Context, in middleware.FinalizeInput) (out middleware.FinalizeOutput, metadata middleware.Metadata, err error) {
if num >= len(reqsErrs) {
err = fmt.Errorf("more requests than expected")
} else {
err = reqsErrs[num]
num++
}
return out, metadata, err
})
},
Err: fmt.Errorf("exceeded maximum number of attempts"),
ExpectResults: func() retry.AttemptResults {
results := retry.AttemptResults{
Results: make([]retry.AttemptResult, constants.MaxNetworkRetryCount),
}
for i := 0; i < constants.MaxNetworkRetryCount-1; i++ {
results.Results[i] = retry.AttemptResult{
Err: &net.OpError{Op: "dial", Err: errors.New("no such host")},
Retryable: true,
Retried: true,
}
}
results.Results[constants.MaxNetworkRetryCount-1] = retry.AttemptResult{
Err: &retry.MaxAttemptsError{Attempt: constants.MaxNetworkRetryCount, Err: &net.OpError{Op: "dial", Err: errors.New("no such host")}},
Retryable: true,
}
return results
}(),
},
"stops at MaxNetworkRetryCount for 'connection refused' errors": {
NextHandler: func() middleware.FinalizeHandler {
num := 0
reqsErrs := make([]error, constants.MaxNetworkRetryCount)
for i := 0; i < constants.MaxNetworkRetryCount; i++ {
reqsErrs[i] = &net.OpError{Op: "dial", Err: errors.New("connection refused")}
}
return middleware.FinalizeHandlerFunc(func(ctx context.Context, in middleware.FinalizeInput) (out middleware.FinalizeOutput, metadata middleware.Metadata, err error) {
if num >= len(reqsErrs) {
err = fmt.Errorf("more requests than expected")
} else {
err = reqsErrs[num]
num++
}
return out, metadata, err
})
},
Err: fmt.Errorf("exceeded maximum number of attempts"),
ExpectResults: func() retry.AttemptResults {
results := retry.AttemptResults{
Results: make([]retry.AttemptResult, constants.MaxNetworkRetryCount),
}
for i := 0; i < constants.MaxNetworkRetryCount-1; i++ {
results.Results[i] = retry.AttemptResult{
Err: &net.OpError{Op: "dial", Err: errors.New("connection refused")},
Retryable: true,
Retried: true,
}
}
results.Results[constants.MaxNetworkRetryCount-1] = retry.AttemptResult{
Err: &retry.MaxAttemptsError{Attempt: constants.MaxNetworkRetryCount, Err: &net.OpError{Op: "dial", Err: errors.New("connection refused")}},
Retryable: true,
}
return results
}(),
},
"stops at maxRetries for other network errors": {
NextHandler: func() middleware.FinalizeHandler {
num := 0
reqsErrs := make([]error, maxRetries)
for i := 0; i < maxRetries; i++ {
reqsErrs[i] = &net.OpError{Op: "dial", Err: errors.New("other error")}
}
return middleware.FinalizeHandlerFunc(func(ctx context.Context, in middleware.FinalizeInput) (out middleware.FinalizeOutput, metadata middleware.Metadata, err error) {
if num >= len(reqsErrs) {
err = fmt.Errorf("more requests than expected")
} else {
err = reqsErrs[num]
num++
}
return out, metadata, err
})
},
Err: fmt.Errorf("exceeded maximum number of attempts"),
ExpectResults: func() retry.AttemptResults {
results := retry.AttemptResults{
Results: make([]retry.AttemptResult, maxRetries),
}
for i := 0; i < maxRetries-1; i++ {
results.Results[i] = retry.AttemptResult{
Err: &net.OpError{Op: "dial", Err: errors.New("other error")},
Retryable: true,
Retried: true,
}
}
results.Results[maxRetries-1] = retry.AttemptResult{
Err: &retry.MaxAttemptsError{Attempt: maxRetries, Err: &net.OpError{Op: "dial", Err: errors.New("other error")}},
Retryable: true,
}
return results
}(),
},
}

for name, testcase := range testcases {
testcase := testcase

t.Run(name, func(t *testing.T) {
oldEnv := awsmocks.InitSessionTestEnv()
defer awsmocks.PopEnv(oldEnv)

config := &Config{
AccessKey: awsmocks.MockStaticAccessKey,
Region: "us-east-1",
MaxRetries: maxRetries,
SecretKey: awsmocks.MockStaticSecretKey,
SkipCredsValidation: true,
DebugLogging: true,
}
awsConfig, err := GetAwsConfig(context.Background(), config)
if err != nil {
t.Fatalf("unexpected error from GetAwsConfig(): %s", err)
}
if awsConfig.Retryer == nil {
t.Fatal("No Retryer configured on awsConfig")
}

am := retry.NewAttemptMiddleware(&withNoDelay{
Retryer: awsConfig.Retryer(),
}, func(i interface{}) interface{} {
return i
})
_, metadata, err := am.HandleFinalize(context.Background(), middleware.FinalizeInput{Request: nil}, testcase.NextHandler())
if err != nil && testcase.Err == nil {
t.Errorf("expect no error, got %v", err)
} else if err == nil && testcase.Err != nil {
t.Errorf("expect error, got none")
} else if err != nil && testcase.Err != nil {
if !strings.Contains(err.Error(), testcase.Err.Error()) {
t.Errorf("expect %v, got %v", testcase.Err, err)
}
}

attemptResults, ok := retry.GetAttemptResults(metadata)
if !ok {
t.Fatalf("expected metadata to contain attempt results, got none")
}
if e, a := testcase.ExpectResults, attemptResults; !reflect.DeepEqual(e, a) {
t.Fatalf("expected %v, got %v", e, a)
}

for i, attempt := range attemptResults.Results {
_, ok := retry.GetAttemptResults(attempt.ResponseMetadata)
if ok {
t.Errorf("expect no attempt to include AttemptResults metadata, %v does, %#v", i, attempt)
}
}
})
}
}

type withNoDelay struct {
aws.Retryer
}

func (r *withNoDelay) RetryDelay(attempt int, err error) (time.Duration, error) {
delay, delayErr := r.Retryer.RetryDelay(attempt, err)
if delayErr != nil {
return delay, delayErr
}

return 0 * time.Second, nil
}
13 changes: 7 additions & 6 deletions awsv1shim/session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"io/ioutil"
"net"
"os"
"runtime"
"testing"
Expand Down Expand Up @@ -1148,42 +1149,42 @@ func TestSessionRetryHandlers(t *testing.T) {
{
Description: "send request no such host failed under MaxNetworkRetryCount",
RetryCount: constants.MaxNetworkRetryCount - 1,
Error: awserr.New(request.ErrCodeRequestError, "send request failed", errors.New("no such host")),
Error: awserr.New(request.ErrCodeRequestError, "send request failed", &net.OpError{Op: "dial", Err: errors.New("no such host")}),
ExpectedRetryableValue: true,
ExpectRetryToBeAttempted: true,
},
{
Description: "send request no such host failed over MaxNetworkRetryCount",
RetryCount: constants.MaxNetworkRetryCount,
Error: awserr.New(request.ErrCodeRequestError, "send request failed", errors.New("no such host")),
Error: awserr.New(request.ErrCodeRequestError, "send request failed", &net.OpError{Op: "dial", Err: errors.New("no such host")}),
ExpectedRetryableValue: false,
ExpectRetryToBeAttempted: false,
},
{
Description: "send request connection refused failed under MaxNetworkRetryCount",
RetryCount: constants.MaxNetworkRetryCount - 1,
Error: awserr.New(request.ErrCodeRequestError, "send request failed", errors.New("connection refused")),
Error: awserr.New(request.ErrCodeRequestError, "send request failed", &net.OpError{Op: "dial", Err: errors.New("connection refused")}),
ExpectedRetryableValue: true,
ExpectRetryToBeAttempted: true,
},
{
Description: "send request connection refused failed over MaxNetworkRetryCount",
RetryCount: constants.MaxNetworkRetryCount,
Error: awserr.New(request.ErrCodeRequestError, "send request failed", errors.New("connection refused")),
Error: awserr.New(request.ErrCodeRequestError, "send request failed", &net.OpError{Op: "dial", Err: errors.New("connection refused")}),
ExpectedRetryableValue: false,
ExpectRetryToBeAttempted: false,
},
{
Description: "send request other error failed under MaxNetworkRetryCount",
RetryCount: constants.MaxNetworkRetryCount - 1,
Error: awserr.New(request.ErrCodeRequestError, "send request failed", errors.New("other error")),
Error: awserr.New(request.ErrCodeRequestError, "send request failed", &net.OpError{Op: "dial", Err: errors.New("other error")}),
ExpectedRetryableValue: true,
ExpectRetryToBeAttempted: true,
},
{
Description: "send request other error failed over MaxNetworkRetryCount",
RetryCount: constants.MaxNetworkRetryCount,
Error: awserr.New(request.ErrCodeRequestError, "send request failed", errors.New("other error")),
Error: awserr.New(request.ErrCodeRequestError, "send request failed", &net.OpError{Op: "dial", Err: errors.New("other error")}),
ExpectedRetryableValue: true,
ExpectRetryToBeAttempted: true,
},
Expand Down

0 comments on commit 1fc56c5

Please sign in to comment.