diff --git a/helper/awsutil/error.go b/helper/awsutil/error.go index f0b7f763bb5f..67e0dfd361e3 100644 --- a/helper/awsutil/error.go +++ b/helper/awsutil/error.go @@ -2,6 +2,7 @@ package awsutil import ( awsRequest "github.com/aws/aws-sdk-go/aws/request" + multierror "github.com/hashicorp/go-multierror" "github.com/hashicorp/vault/logical" ) @@ -16,3 +17,15 @@ func CheckAWSError(err error) error { } return nil } + +// AppendLogicalError checks if the given error is a known AWS error we modify, +// and if so then returns a go-multierror, appending the original and the +// logical error. +// If the error is not an AWS error, or not an error we wish to modify, then +// return the original error. +func AppendLogicalError(err error) error { + if awserr := CheckAWSError(err); awserr != nil { + err = multierror.Append(err, awserr) + } + return err +} diff --git a/helper/awsutil/error_test.go b/helper/awsutil/error_test.go index a8f1dff2f33f..0c2945f87606 100644 --- a/helper/awsutil/error_test.go +++ b/helper/awsutil/error_test.go @@ -5,6 +5,7 @@ import ( "testing" "github.com/aws/aws-sdk-go/aws/awserr" + multierror "github.com/hashicorp/go-multierror" "github.com/hashicorp/vault/logical" ) @@ -50,3 +51,44 @@ func Test_CheckAWSError(t *testing.T) { }) } } + +func Test_AppendLogicalError(t *testing.T) { + awsErr := awserr.New("Throttling", "", nil) + testCases := []struct { + Name string + Err error + Expected error + }{ + { + Name: "Something not checked", + Err: fmt.Errorf("something"), + Expected: fmt.Errorf("something"), + }, + { + Name: "Upstream throttle error", + Err: awsErr, + Expected: multierror.Append(awsErr, logical.ErrUpstreamRateLimited), + }, + { + Name: "Nil", + }, + } + + for _, tc := range testCases { + t.Run(tc.Name, func(t *testing.T) { + err := AppendLogicalError(tc.Err) + if err == nil && tc.Expected != nil { + t.Fatalf("expected non-nil error (%#v), got nil", tc.Expected) + } + if err != nil && tc.Expected == nil { + t.Fatalf("expected nil error, got (%#v)", err) + } + if err == nil && tc.Expected == nil { + return + } + if err.Error() != tc.Expected.Error() { + t.Fatalf("expected error (%#v), got (%#v)", tc.Expected.Error(), err.Error()) + } + }) + } +}