From 35c5af47e83de1c8540f8e5b90b6f13321eb4259 Mon Sep 17 00:00:00 2001 From: Graham Davison Date: Wed, 15 Sep 2021 10:47:19 -0700 Subject: [PATCH] Adds retry shortcutting for V2 --- aws_config.go | 41 +++++++ aws_config_test.go | 234 ++++++++++++++++++++++++++++++++++++++ awsv1shim/session_test.go | 13 ++- 3 files changed, 282 insertions(+), 6 deletions(-) diff --git a/aws_config.go b/aws_config.go index e0ad1543..875dbf68 100644 --- a/aws_config.go +++ b/aws_config.go @@ -3,13 +3,18 @@ package awsbase import ( "context" "crypto/tls" + "errors" "fmt" "log" + "net" "net/http" "os" + "strings" + "time" "github.com/aws/aws-sdk-go-v2/aws" awsmiddleware "github.com/aws/aws-sdk-go-v2/aws/middleware" + "github.com/aws/aws-sdk-go-v2/aws/retry" "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/feature/ec2/imds" "github.com/aws/aws-sdk-go-v2/service/iam" @@ -26,9 +31,21 @@ func GetAwsConfig(ctx context.Context, c *Config) (aws.Config, error) { return aws.Config{}, err } + var retryer aws.Retryer + retryer = retry.NewStandard() + if c.MaxRetries != 0 { + retryer = retry.AddWithMaxAttempts(retryer, c.MaxRetries) + } + retryer = &networkErrorShortcutter{ + Retryer: retryer, + } + loadOptions := append( commonLoadOptions(c), config.WithCredentialsProvider(credentialsProvider), + config.WithRetryer(func() aws.Retryer { + return retryer + }), ) cfg, err := config.LoadDefaultConfig(ctx, loadOptions...) if err != nil { @@ -44,6 +61,30 @@ func GetAwsConfig(ctx context.Context, c *Config) (aws.Config, error) { return cfg, nil } +// networkErrorShortcutter is used to enable networking error shortcutting +type networkErrorShortcutter struct { + aws.Retryer +} + +// We're misusing RetryDelay here, since this is the only function that takes the attempt count +func (r *networkErrorShortcutter) RetryDelay(attempt int, err error) (time.Duration, error) { + if attempt >= constants.MaxNetworkRetryCount { + var netOpErr *net.OpError + if errors.As(err, &netOpErr) { + // It's disappointing that we have to do string matching here, rather than being able to using `errors.Is()` or even strings exported by the Go `net` package + if strings.Contains(netOpErr.Error(), "no such host") || strings.Contains(netOpErr.Error(), "connection refused") { + log.Printf("[WARN] Disabling retries after next request due to networking issue: %s", err) + return 0, &retry.MaxAttemptsError{ + Attempt: attempt, + Err: err, + } + } + } + } + + return r.Retryer.RetryDelay(attempt, err) +} + func GetAwsAccountIDAndPartition(ctx context.Context, awsConfig aws.Config, skipCredsValidation, skipRequestingAccountId bool) (string, string, error) { if !skipCredsValidation { stsClient := sts.NewFromConfig(awsConfig) diff --git a/aws_config_test.go b/aws_config_test.go index b55a4cf6..65a67cbb 100644 --- a/aws_config_test.go +++ b/aws_config_test.go @@ -5,13 +5,16 @@ import ( "errors" "fmt" "io/ioutil" + "net" "os" "reflect" "runtime" "strings" "testing" + "time" "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/aws/retry" "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/service/sts" "github.com/aws/smithy-go/middleware" @@ -1218,3 +1221,234 @@ func TestGetAwsConfigWithAccountIDAndPartition(t *testing.T) { }) } } + +type mockRetryableError struct{ b bool } + +func (m mockRetryableError) RetryableError() bool { return m.b } +func (m mockRetryableError) Error() string { + return fmt.Sprintf("mock retryable %t", m.b) +} + +func TestRetryHandlers(t *testing.T) { + const maxRetries = 10 + + testcases := map[string]struct { + NextHandler func() middleware.FinalizeHandler + ExpectResults retry.AttemptResults + Err error + }{ + "stops at maxRetries for retryable errors": { + NextHandler: func() middleware.FinalizeHandler { + num := 0 + reqsErrs := make([]error, maxRetries) + for i := 0; i < maxRetries; i++ { + reqsErrs[i] = mockRetryableError{b: true} + } + return middleware.FinalizeHandlerFunc(func(ctx context.Context, in middleware.FinalizeInput) (out middleware.FinalizeOutput, metadata middleware.Metadata, err error) { + if num >= len(reqsErrs) { + err = fmt.Errorf("more requests than expected") + } else { + err = reqsErrs[num] + num++ + } + return out, metadata, err + }) + }, + Err: fmt.Errorf("exceeded maximum number of attempts"), + ExpectResults: func() retry.AttemptResults { + results := retry.AttemptResults{ + Results: make([]retry.AttemptResult, maxRetries), + } + for i := 0; i < maxRetries-1; i++ { + results.Results[i] = retry.AttemptResult{ + Err: mockRetryableError{b: true}, + Retryable: true, + Retried: true, + } + } + results.Results[maxRetries-1] = retry.AttemptResult{ + Err: &retry.MaxAttemptsError{Attempt: maxRetries, Err: mockRetryableError{b: true}}, + Retryable: true, + } + return results + }(), + }, + "stops at MaxNetworkRetryCount for 'no such host' errors": { + NextHandler: func() middleware.FinalizeHandler { + num := 0 + reqsErrs := make([]error, constants.MaxNetworkRetryCount) + for i := 0; i < constants.MaxNetworkRetryCount; i++ { + reqsErrs[i] = &net.OpError{Op: "dial", Err: errors.New("no such host")} + } + return middleware.FinalizeHandlerFunc(func(ctx context.Context, in middleware.FinalizeInput) (out middleware.FinalizeOutput, metadata middleware.Metadata, err error) { + if num >= len(reqsErrs) { + err = fmt.Errorf("more requests than expected") + } else { + err = reqsErrs[num] + num++ + } + return out, metadata, err + }) + }, + Err: fmt.Errorf("exceeded maximum number of attempts"), + ExpectResults: func() retry.AttemptResults { + results := retry.AttemptResults{ + Results: make([]retry.AttemptResult, constants.MaxNetworkRetryCount), + } + for i := 0; i < constants.MaxNetworkRetryCount-1; i++ { + results.Results[i] = retry.AttemptResult{ + Err: &net.OpError{Op: "dial", Err: errors.New("no such host")}, + Retryable: true, + Retried: true, + } + } + results.Results[constants.MaxNetworkRetryCount-1] = retry.AttemptResult{ + Err: &retry.MaxAttemptsError{Attempt: constants.MaxNetworkRetryCount, Err: &net.OpError{Op: "dial", Err: errors.New("no such host")}}, + Retryable: true, + } + return results + }(), + }, + "stops at MaxNetworkRetryCount for 'connection refused' errors": { + NextHandler: func() middleware.FinalizeHandler { + num := 0 + reqsErrs := make([]error, constants.MaxNetworkRetryCount) + for i := 0; i < constants.MaxNetworkRetryCount; i++ { + reqsErrs[i] = &net.OpError{Op: "dial", Err: errors.New("connection refused")} + } + return middleware.FinalizeHandlerFunc(func(ctx context.Context, in middleware.FinalizeInput) (out middleware.FinalizeOutput, metadata middleware.Metadata, err error) { + if num >= len(reqsErrs) { + err = fmt.Errorf("more requests than expected") + } else { + err = reqsErrs[num] + num++ + } + return out, metadata, err + }) + }, + Err: fmt.Errorf("exceeded maximum number of attempts"), + ExpectResults: func() retry.AttemptResults { + results := retry.AttemptResults{ + Results: make([]retry.AttemptResult, constants.MaxNetworkRetryCount), + } + for i := 0; i < constants.MaxNetworkRetryCount-1; i++ { + results.Results[i] = retry.AttemptResult{ + Err: &net.OpError{Op: "dial", Err: errors.New("connection refused")}, + Retryable: true, + Retried: true, + } + } + results.Results[constants.MaxNetworkRetryCount-1] = retry.AttemptResult{ + Err: &retry.MaxAttemptsError{Attempt: constants.MaxNetworkRetryCount, Err: &net.OpError{Op: "dial", Err: errors.New("connection refused")}}, + Retryable: true, + } + return results + }(), + }, + "stops at maxRetries for other network errors": { + NextHandler: func() middleware.FinalizeHandler { + num := 0 + reqsErrs := make([]error, maxRetries) + for i := 0; i < maxRetries; i++ { + reqsErrs[i] = &net.OpError{Op: "dial", Err: errors.New("other error")} + } + return middleware.FinalizeHandlerFunc(func(ctx context.Context, in middleware.FinalizeInput) (out middleware.FinalizeOutput, metadata middleware.Metadata, err error) { + if num >= len(reqsErrs) { + err = fmt.Errorf("more requests than expected") + } else { + err = reqsErrs[num] + num++ + } + return out, metadata, err + }) + }, + Err: fmt.Errorf("exceeded maximum number of attempts"), + ExpectResults: func() retry.AttemptResults { + results := retry.AttemptResults{ + Results: make([]retry.AttemptResult, maxRetries), + } + for i := 0; i < maxRetries-1; i++ { + results.Results[i] = retry.AttemptResult{ + Err: &net.OpError{Op: "dial", Err: errors.New("other error")}, + Retryable: true, + Retried: true, + } + } + results.Results[maxRetries-1] = retry.AttemptResult{ + Err: &retry.MaxAttemptsError{Attempt: maxRetries, Err: &net.OpError{Op: "dial", Err: errors.New("other error")}}, + Retryable: true, + } + return results + }(), + }, + } + + for name, testcase := range testcases { + testcase := testcase + + t.Run(name, func(t *testing.T) { + oldEnv := servicemocks.InitSessionTestEnv() + defer servicemocks.PopEnv(oldEnv) + + config := &Config{ + AccessKey: servicemocks.MockStaticAccessKey, + Region: "us-east-1", + MaxRetries: maxRetries, + SecretKey: servicemocks.MockStaticSecretKey, + SkipCredsValidation: true, + DebugLogging: true, + } + awsConfig, err := GetAwsConfig(context.Background(), config) + if err != nil { + t.Fatalf("unexpected error from GetAwsConfig(): %s", err) + } + if awsConfig.Retryer == nil { + t.Fatal("No Retryer configured on awsConfig") + } + + am := retry.NewAttemptMiddleware(&withNoDelay{ + Retryer: awsConfig.Retryer(), + }, func(i interface{}) interface{} { + return i + }) + _, metadata, err := am.HandleFinalize(context.Background(), middleware.FinalizeInput{Request: nil}, testcase.NextHandler()) + if err != nil && testcase.Err == nil { + t.Errorf("expect no error, got %v", err) + } else if err == nil && testcase.Err != nil { + t.Errorf("expect error, got none") + } else if err != nil && testcase.Err != nil { + if !strings.Contains(err.Error(), testcase.Err.Error()) { + t.Errorf("expect %v, got %v", testcase.Err, err) + } + } + + attemptResults, ok := retry.GetAttemptResults(metadata) + if !ok { + t.Fatalf("expected metadata to contain attempt results, got none") + } + if e, a := testcase.ExpectResults, attemptResults; !reflect.DeepEqual(e, a) { + t.Fatalf("expected %v, got %v", e, a) + } + + for i, attempt := range attemptResults.Results { + _, ok := retry.GetAttemptResults(attempt.ResponseMetadata) + if ok { + t.Errorf("expect no attempt to include AttemptResults metadata, %v does, %#v", i, attempt) + } + } + }) + } +} + +type withNoDelay struct { + aws.Retryer +} + +func (r *withNoDelay) RetryDelay(attempt int, err error) (time.Duration, error) { + delay, delayErr := r.Retryer.RetryDelay(attempt, err) + if delayErr != nil { + return delay, delayErr + } + + return 0 * time.Second, nil +} diff --git a/awsv1shim/session_test.go b/awsv1shim/session_test.go index f2a29845..be185505 100644 --- a/awsv1shim/session_test.go +++ b/awsv1shim/session_test.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "io/ioutil" + "net" "os" "runtime" "testing" @@ -1149,42 +1150,42 @@ func TestSessionRetryHandlers(t *testing.T) { { Description: "send request no such host failed under MaxNetworkRetryCount", RetryCount: constants.MaxNetworkRetryCount - 1, - Error: awserr.New(request.ErrCodeRequestError, "send request failed", errors.New("no such host")), + Error: awserr.New(request.ErrCodeRequestError, "send request failed", &net.OpError{Op: "dial", Err: errors.New("no such host")}), ExpectedRetryableValue: true, ExpectRetryToBeAttempted: true, }, { Description: "send request no such host failed over MaxNetworkRetryCount", RetryCount: constants.MaxNetworkRetryCount, - Error: awserr.New(request.ErrCodeRequestError, "send request failed", errors.New("no such host")), + Error: awserr.New(request.ErrCodeRequestError, "send request failed", &net.OpError{Op: "dial", Err: errors.New("no such host")}), ExpectedRetryableValue: false, ExpectRetryToBeAttempted: false, }, { Description: "send request connection refused failed under MaxNetworkRetryCount", RetryCount: constants.MaxNetworkRetryCount - 1, - Error: awserr.New(request.ErrCodeRequestError, "send request failed", errors.New("connection refused")), + Error: awserr.New(request.ErrCodeRequestError, "send request failed", &net.OpError{Op: "dial", Err: errors.New("connection refused")}), ExpectedRetryableValue: true, ExpectRetryToBeAttempted: true, }, { Description: "send request connection refused failed over MaxNetworkRetryCount", RetryCount: constants.MaxNetworkRetryCount, - Error: awserr.New(request.ErrCodeRequestError, "send request failed", errors.New("connection refused")), + Error: awserr.New(request.ErrCodeRequestError, "send request failed", &net.OpError{Op: "dial", Err: errors.New("connection refused")}), ExpectedRetryableValue: false, ExpectRetryToBeAttempted: false, }, { Description: "send request other error failed under MaxNetworkRetryCount", RetryCount: constants.MaxNetworkRetryCount - 1, - Error: awserr.New(request.ErrCodeRequestError, "send request failed", errors.New("other error")), + Error: awserr.New(request.ErrCodeRequestError, "send request failed", &net.OpError{Op: "dial", Err: errors.New("other error")}), ExpectedRetryableValue: true, ExpectRetryToBeAttempted: true, }, { Description: "send request other error failed over MaxNetworkRetryCount", RetryCount: constants.MaxNetworkRetryCount, - Error: awserr.New(request.ErrCodeRequestError, "send request failed", errors.New("other error")), + Error: awserr.New(request.ErrCodeRequestError, "send request failed", &net.OpError{Op: "dial", Err: errors.New("other error")}), ExpectedRetryableValue: true, ExpectRetryToBeAttempted: true, },