Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

V2: Implement MaxRetries #77

Merged
merged 2 commits into from
Sep 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 43 additions & 9 deletions aws_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,39 +3,49 @@ package awsbase
import (
"context"
"crypto/tls"
"errors"
"fmt"
"log"
"net"
"net/http"
"os"
"strings"
"time"

"github.com/aws/aws-sdk-go-v2/aws"
awsmiddleware "github.com/aws/aws-sdk-go-v2/aws/middleware"
"github.com/aws/aws-sdk-go-v2/aws/retry"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/feature/ec2/imds"
"github.com/aws/aws-sdk-go-v2/service/iam"
"github.com/aws/aws-sdk-go-v2/service/sts"
"github.com/aws/smithy-go/middleware"
"github.com/hashicorp/aws-sdk-go-base/internal/constants"
"github.com/hashicorp/aws-sdk-go-base/internal/endpoints"
"github.com/hashicorp/go-cleanhttp"
)

const (
// appendUserAgentEnvVar is a conventionally used environment variable
// containing additional HTTP User-Agent information.
// If present and its value is non-empty, it is directly appended to the
// User-Agent header for HTTP requests.
appendUserAgentEnvVar = "TF_APPEND_USER_AGENT"
)

func GetAwsConfig(ctx context.Context, c *Config) (aws.Config, error) {
credentialsProvider, err := getCredentialsProvider(ctx, c)
if err != nil {
return aws.Config{}, err
}

var retryer aws.Retryer
retryer = retry.NewStandard()
if c.MaxRetries != 0 {
retryer = retry.AddWithMaxAttempts(retryer, c.MaxRetries)
}
retryer = &networkErrorShortcutter{
Retryer: retryer,
}

loadOptions := append(
commonLoadOptions(c),
config.WithCredentialsProvider(credentialsProvider),
config.WithRetryer(func() aws.Retryer {
return retryer
}),
)
cfg, err := config.LoadDefaultConfig(ctx, loadOptions...)
if err != nil {
Expand All @@ -51,6 +61,30 @@ func GetAwsConfig(ctx context.Context, c *Config) (aws.Config, error) {
return cfg, nil
}

// networkErrorShortcutter is used to enable networking error shortcutting
type networkErrorShortcutter struct {
aws.Retryer
}

// We're misusing RetryDelay here, since this is the only function that takes the attempt count
func (r *networkErrorShortcutter) RetryDelay(attempt int, err error) (time.Duration, error) {
if attempt >= constants.MaxNetworkRetryCount {
var netOpErr *net.OpError
if errors.As(err, &netOpErr) {
// It's disappointing that we have to do string matching here, rather than being able to using `errors.Is()` or even strings exported by the Go `net` package
if strings.Contains(netOpErr.Error(), "no such host") || strings.Contains(netOpErr.Error(), "connection refused") {
log.Printf("[WARN] Disabling retries after next request due to networking issue: %s", err)
return 0, &retry.MaxAttemptsError{
Attempt: attempt,
Err: err,
}
}
}
}

return r.Retryer.RetryDelay(attempt, err)
}

func GetAwsAccountIDAndPartition(ctx context.Context, awsConfig aws.Config, skipCredsValidation, skipRequestingAccountId bool) (string, string, error) {
if !skipCredsValidation {
stsClient := sts.NewFromConfig(awsConfig)
Expand Down Expand Up @@ -102,7 +136,7 @@ func commonLoadOptions(c *Config) []func(*config.LoadOptions) error {
return stack.Build.Add(customUserAgentMiddleware(c), middleware.After)
})
}
if v := os.Getenv(appendUserAgentEnvVar); v != "" {
if v := os.Getenv(constants.AppendUserAgentEnvVar); v != "" {
log.Printf("[DEBUG] Using additional User-Agent Info: %s", v)
apiOptions = append(apiOptions, awsmiddleware.AddUserAgentKey(v))
}
Expand Down
239 changes: 237 additions & 2 deletions aws_config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,23 @@ import (
"errors"
"fmt"
"io/ioutil"
"net"
"os"
"reflect"
"runtime"
"strings"
"testing"
"time"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/aws/retry"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/service/sts"
"github.com/aws/smithy-go/middleware"
smithyhttp "github.com/aws/smithy-go/transport/http"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/hashicorp/aws-sdk-go-base/internal/constants"
"github.com/hashicorp/aws-sdk-go-base/mockdata"
"github.com/hashicorp/aws-sdk-go-base/servicemocks"
)
Expand Down Expand Up @@ -960,7 +964,7 @@ func TestUserAgentProducts(t *testing.T) {
},
Description: "customized User-Agent TF_APPEND_USER_AGENT",
EnvironmentVariables: map[string]string{
appendUserAgentEnvVar: "Last",
constants.AppendUserAgentEnvVar: "Last",
},
ExpectedUserAgent: awsSdkGoUserAgent() + " Last",
},
Expand Down Expand Up @@ -1003,7 +1007,7 @@ func TestUserAgentProducts(t *testing.T) {
},
Description: "customized User-Agent Products and TF_APPEND_USER_AGENT",
EnvironmentVariables: map[string]string{
appendUserAgentEnvVar: "Last",
constants.AppendUserAgentEnvVar: "Last",
},
ExpectedUserAgent: "first/1.0 second/1.2.3 (+https://www.example.com/) " + awsSdkGoUserAgent() + " Last",
},
Expand Down Expand Up @@ -1217,3 +1221,234 @@ func TestGetAwsConfigWithAccountIDAndPartition(t *testing.T) {
})
}
}

type mockRetryableError struct{ b bool }

func (m mockRetryableError) RetryableError() bool { return m.b }
func (m mockRetryableError) Error() string {
return fmt.Sprintf("mock retryable %t", m.b)
}

func TestRetryHandlers(t *testing.T) {
const maxRetries = 10

testcases := map[string]struct {
NextHandler func() middleware.FinalizeHandler
ExpectResults retry.AttemptResults
Err error
}{
"stops at maxRetries for retryable errors": {
NextHandler: func() middleware.FinalizeHandler {
num := 0
reqsErrs := make([]error, maxRetries)
for i := 0; i < maxRetries; i++ {
reqsErrs[i] = mockRetryableError{b: true}
}
return middleware.FinalizeHandlerFunc(func(ctx context.Context, in middleware.FinalizeInput) (out middleware.FinalizeOutput, metadata middleware.Metadata, err error) {
if num >= len(reqsErrs) {
err = fmt.Errorf("more requests than expected")
} else {
err = reqsErrs[num]
num++
}
return out, metadata, err
})
},
Err: fmt.Errorf("exceeded maximum number of attempts"),
ExpectResults: func() retry.AttemptResults {
results := retry.AttemptResults{
Results: make([]retry.AttemptResult, maxRetries),
}
for i := 0; i < maxRetries-1; i++ {
results.Results[i] = retry.AttemptResult{
Err: mockRetryableError{b: true},
Retryable: true,
Retried: true,
}
}
results.Results[maxRetries-1] = retry.AttemptResult{
Err: &retry.MaxAttemptsError{Attempt: maxRetries, Err: mockRetryableError{b: true}},
Retryable: true,
}
return results
}(),
},
"stops at MaxNetworkRetryCount for 'no such host' errors": {
NextHandler: func() middleware.FinalizeHandler {
num := 0
reqsErrs := make([]error, constants.MaxNetworkRetryCount)
for i := 0; i < constants.MaxNetworkRetryCount; i++ {
reqsErrs[i] = &net.OpError{Op: "dial", Err: errors.New("no such host")}
}
return middleware.FinalizeHandlerFunc(func(ctx context.Context, in middleware.FinalizeInput) (out middleware.FinalizeOutput, metadata middleware.Metadata, err error) {
if num >= len(reqsErrs) {
err = fmt.Errorf("more requests than expected")
} else {
err = reqsErrs[num]
num++
}
return out, metadata, err
})
},
Err: fmt.Errorf("exceeded maximum number of attempts"),
ExpectResults: func() retry.AttemptResults {
results := retry.AttemptResults{
Results: make([]retry.AttemptResult, constants.MaxNetworkRetryCount),
}
for i := 0; i < constants.MaxNetworkRetryCount-1; i++ {
results.Results[i] = retry.AttemptResult{
Err: &net.OpError{Op: "dial", Err: errors.New("no such host")},
Retryable: true,
Retried: true,
}
}
results.Results[constants.MaxNetworkRetryCount-1] = retry.AttemptResult{
Err: &retry.MaxAttemptsError{Attempt: constants.MaxNetworkRetryCount, Err: &net.OpError{Op: "dial", Err: errors.New("no such host")}},
Retryable: true,
}
return results
}(),
},
"stops at MaxNetworkRetryCount for 'connection refused' errors": {
NextHandler: func() middleware.FinalizeHandler {
num := 0
reqsErrs := make([]error, constants.MaxNetworkRetryCount)
for i := 0; i < constants.MaxNetworkRetryCount; i++ {
reqsErrs[i] = &net.OpError{Op: "dial", Err: errors.New("connection refused")}
}
return middleware.FinalizeHandlerFunc(func(ctx context.Context, in middleware.FinalizeInput) (out middleware.FinalizeOutput, metadata middleware.Metadata, err error) {
if num >= len(reqsErrs) {
err = fmt.Errorf("more requests than expected")
} else {
err = reqsErrs[num]
num++
}
return out, metadata, err
})
},
Err: fmt.Errorf("exceeded maximum number of attempts"),
ExpectResults: func() retry.AttemptResults {
results := retry.AttemptResults{
Results: make([]retry.AttemptResult, constants.MaxNetworkRetryCount),
}
for i := 0; i < constants.MaxNetworkRetryCount-1; i++ {
results.Results[i] = retry.AttemptResult{
Err: &net.OpError{Op: "dial", Err: errors.New("connection refused")},
Retryable: true,
Retried: true,
}
}
results.Results[constants.MaxNetworkRetryCount-1] = retry.AttemptResult{
Err: &retry.MaxAttemptsError{Attempt: constants.MaxNetworkRetryCount, Err: &net.OpError{Op: "dial", Err: errors.New("connection refused")}},
Retryable: true,
}
return results
}(),
},
"stops at maxRetries for other network errors": {
NextHandler: func() middleware.FinalizeHandler {
num := 0
reqsErrs := make([]error, maxRetries)
for i := 0; i < maxRetries; i++ {
reqsErrs[i] = &net.OpError{Op: "dial", Err: errors.New("other error")}
}
return middleware.FinalizeHandlerFunc(func(ctx context.Context, in middleware.FinalizeInput) (out middleware.FinalizeOutput, metadata middleware.Metadata, err error) {
if num >= len(reqsErrs) {
err = fmt.Errorf("more requests than expected")
} else {
err = reqsErrs[num]
num++
}
return out, metadata, err
})
},
Err: fmt.Errorf("exceeded maximum number of attempts"),
ExpectResults: func() retry.AttemptResults {
results := retry.AttemptResults{
Results: make([]retry.AttemptResult, maxRetries),
}
for i := 0; i < maxRetries-1; i++ {
results.Results[i] = retry.AttemptResult{
Err: &net.OpError{Op: "dial", Err: errors.New("other error")},
Retryable: true,
Retried: true,
}
}
results.Results[maxRetries-1] = retry.AttemptResult{
Err: &retry.MaxAttemptsError{Attempt: maxRetries, Err: &net.OpError{Op: "dial", Err: errors.New("other error")}},
Retryable: true,
}
return results
}(),
},
}

for name, testcase := range testcases {
testcase := testcase

t.Run(name, func(t *testing.T) {
oldEnv := servicemocks.InitSessionTestEnv()
defer servicemocks.PopEnv(oldEnv)

config := &Config{
AccessKey: servicemocks.MockStaticAccessKey,
Region: "us-east-1",
MaxRetries: maxRetries,
SecretKey: servicemocks.MockStaticSecretKey,
SkipCredsValidation: true,
DebugLogging: true,
}
awsConfig, err := GetAwsConfig(context.Background(), config)
if err != nil {
t.Fatalf("unexpected error from GetAwsConfig(): %s", err)
}
if awsConfig.Retryer == nil {
t.Fatal("No Retryer configured on awsConfig")
}

am := retry.NewAttemptMiddleware(&withNoDelay{
Retryer: awsConfig.Retryer(),
}, func(i interface{}) interface{} {
return i
})
_, metadata, err := am.HandleFinalize(context.Background(), middleware.FinalizeInput{Request: nil}, testcase.NextHandler())
if err != nil && testcase.Err == nil {
t.Errorf("expect no error, got %v", err)
} else if err == nil && testcase.Err != nil {
t.Errorf("expect error, got none")
} else if err != nil && testcase.Err != nil {
if !strings.Contains(err.Error(), testcase.Err.Error()) {
t.Errorf("expect %v, got %v", testcase.Err, err)
}
}

attemptResults, ok := retry.GetAttemptResults(metadata)
if !ok {
t.Fatalf("expected metadata to contain attempt results, got none")
}
if e, a := testcase.ExpectResults, attemptResults; !reflect.DeepEqual(e, a) {
t.Fatalf("expected %v, got %v", e, a)
}

for i, attempt := range attemptResults.Results {
_, ok := retry.GetAttemptResults(attempt.ResponseMetadata)
if ok {
t.Errorf("expect no attempt to include AttemptResults metadata, %v does, %#v", i, attempt)
}
}
})
}
}

type withNoDelay struct {
aws.Retryer
}

func (r *withNoDelay) RetryDelay(attempt int, err error) (time.Duration, error) {
delay, delayErr := r.Retryer.RetryDelay(attempt, err)
if delayErr != nil {
return delay, delayErr
}

return 0 * time.Second, nil
}
Loading