Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add awsbase.ErrCodeEquals, AWS SDK for Go v2 variant of helper in v2/awsv1shim/tfawserr #524

Merged
merged 4 commits into from
Jun 22, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 0 additions & 17 deletions errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
package awsbase

import (
smithy "github.com/aws/smithy-go"
"github.com/hashicorp/aws-sdk-go-base/v2/internal/config"
"github.com/hashicorp/aws-sdk-go-base/v2/internal/errs"
)
Expand All @@ -24,19 +23,3 @@ type NoValidCredentialSourcesError = config.NoValidCredentialSourcesError
func IsNoValidCredentialSourcesError(err error) bool {
return errs.IsA[NoValidCredentialSourcesError](err)
}

// AWS SDK for Go v2 variants of helpers in v2/awsv1shim/tfawserr.

// ErrCodeEquals returns true if the error matches all these conditions:
// - err is of type smithy.APIError
// - Error.Code() equals one of the passed codes
func ErrCodeEquals(err error, codes ...string) bool {
if apiErr, ok := errs.As[smithy.APIError](err); ok {
for _, code := range codes {
if apiErr.ErrorCode() == code {
return true
}
}
}
return false
}
95 changes: 0 additions & 95 deletions errors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,6 @@ package awsbase
import (
"fmt"
"testing"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/sts/types"
smithy "github.com/aws/smithy-go"
)

func TestIsCannotAssumeRoleError(t *testing.T) {
Expand Down Expand Up @@ -87,94 +83,3 @@ func TestIsNoValidCredentialSourcesError(t *testing.T) {
})
}
}

func TestErrCodeEquals(t *testing.T) {
testCases := []struct {
Name string
Err error
Codes []string
Expected bool
}{
{
Name: "nil error",
},
{
Name: "Top-level CannotAssumeRoleError",
Err: CannotAssumeRoleError{},
},
{
Name: "Top-level smithy.GenericAPIError matching first code",
Err: &smithy.GenericAPIError{Code: "TestCode", Message: "TestMessage"},
Codes: []string{"TestCode"},
Expected: true,
},
{
Name: "Top-level smithy.GenericAPIError matching last code",
Err: &smithy.GenericAPIError{Code: "TestCode", Message: "TestMessage"},
Codes: []string{"NotMatching", "TestCode"},
Expected: true,
},
{
Name: "Top-level smithy.GenericAPIError no code",
Err: &smithy.GenericAPIError{Code: "TestCode", Message: "TestMessage"},
},
{
Name: "Top-level smithy.GenericAPIError non-matching codes",
Err: &smithy.GenericAPIError{Code: "TestCode", Message: "TestMessage"},
Codes: []string{"NotMatching", "AlsoNotMatching"},
},
{
Name: "Wrapped smithy.GenericAPIError matching first code",
Err: fmt.Errorf("test: %w", &smithy.GenericAPIError{Code: "TestCode", Message: "TestMessage"}),
Codes: []string{"TestCode"},
Expected: true,
},
{
Name: "Wrapped smithy.GenericAPIError matching last code",
Err: fmt.Errorf("test: %w", &smithy.GenericAPIError{Code: "TestCode", Message: "TestMessage"}),
Codes: []string{"NotMatching", "TestCode"},
Expected: true,
},
{
Name: "Wrapped smithy.GenericAPIError non-matching codes",
Err: fmt.Errorf("test: %w", &smithy.GenericAPIError{Code: "TestCode", Message: "TestMessage"}),
Codes: []string{"NotMatching", "AlsoNotMatching"},
},
{
Name: "Top-level sts ExpiredTokenException matching first code",
Err: &types.ExpiredTokenException{ErrorCodeOverride: aws.String("TestCode"), Message: aws.String("TestMessage")},
Codes: []string{"TestCode"},
Expected: true,
},
{
Name: "Top-level sts ExpiredTokenException matching last code",
Err: &types.ExpiredTokenException{ErrorCodeOverride: aws.String("TestCode"), Message: aws.String("TestMessage")},
Codes: []string{"NotMatching", "TestCode"},
Expected: true,
},
{
Name: "Wrapped sts ExpiredTokenException matching first code",
Err: fmt.Errorf("test: %w", &types.ExpiredTokenException{ErrorCodeOverride: aws.String("TestCode"), Message: aws.String("TestMessage")}),
Codes: []string{"TestCode"},
Expected: true,
},
{
Name: "Wrapped sts ExpiredTokenException matching last code",
Err: fmt.Errorf("test: %w", &types.ExpiredTokenException{ErrorCodeOverride: aws.String("TestCode"), Message: aws.String("TestMessage")}),
Codes: []string{"NotMatching", "TestCode"},
Expected: true,
},
}

for _, testCase := range testCases {
testCase := testCase

t.Run(testCase.Name, func(t *testing.T) {
got := ErrCodeEquals(testCase.Err, testCase.Codes...)

if got != testCase.Expected {
t.Errorf("got %t, expected %t", got, testCase.Expected)
}
})
}
}
23 changes: 23 additions & 0 deletions tfawserr/awserr.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0

package tfawserr

import (
smithy "github.com/aws/smithy-go"
"github.com/hashicorp/aws-sdk-go-base/v2/internal/errs"
)

// ErrCodeEquals returns true if the error matches all these conditions:
// - err is of type smithy.APIError
// - Error.Code() equals one of the passed codes
func ErrCodeEquals(err error, codes ...string) bool {
if apiErr, ok := errs.As[smithy.APIError](err); ok {
for _, code := range codes {
if apiErr.ErrorCode() == code {
return true
}
}
}
return false
}
105 changes: 105 additions & 0 deletions tfawserr/awserr_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0

package tfawserr

import (
"fmt"
"testing"

"github.com/aws/aws-sdk-go-v2/service/sts/types"
"github.com/aws/aws-sdk-go/aws"
Copy link
Contributor Author

@ewbankkit ewbankkit Jun 22, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be github.com/aws/aws-sdk-go-v2/aws to prevent the go.mod diff.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy and paste :(

smithy "github.com/aws/smithy-go"
awsbase "github.com/hashicorp/aws-sdk-go-base/v2"
)

func TestErrCodeEquals(t *testing.T) {
testCases := []struct {
Name string
Err error
Codes []string
Expected bool
}{
{
Name: "nil error",
},
{
Name: "Top-level CannotAssumeRoleError",
Err: awsbase.CannotAssumeRoleError{},
},
{
Name: "Top-level smithy.GenericAPIError matching first code",
Err: &smithy.GenericAPIError{Code: "TestCode", Message: "TestMessage"},
Codes: []string{"TestCode"},
Expected: true,
},
{
Name: "Top-level smithy.GenericAPIError matching last code",
Err: &smithy.GenericAPIError{Code: "TestCode", Message: "TestMessage"},
Codes: []string{"NotMatching", "TestCode"},
Expected: true,
},
{
Name: "Top-level smithy.GenericAPIError no code",
Err: &smithy.GenericAPIError{Code: "TestCode", Message: "TestMessage"},
},
{
Name: "Top-level smithy.GenericAPIError non-matching codes",
Err: &smithy.GenericAPIError{Code: "TestCode", Message: "TestMessage"},
Codes: []string{"NotMatching", "AlsoNotMatching"},
},
{
Name: "Wrapped smithy.GenericAPIError matching first code",
Err: fmt.Errorf("test: %w", &smithy.GenericAPIError{Code: "TestCode", Message: "TestMessage"}),
Codes: []string{"TestCode"},
Expected: true,
},
{
Name: "Wrapped smithy.GenericAPIError matching last code",
Err: fmt.Errorf("test: %w", &smithy.GenericAPIError{Code: "TestCode", Message: "TestMessage"}),
Codes: []string{"NotMatching", "TestCode"},
Expected: true,
},
{
Name: "Wrapped smithy.GenericAPIError non-matching codes",
Err: fmt.Errorf("test: %w", &smithy.GenericAPIError{Code: "TestCode", Message: "TestMessage"}),
Codes: []string{"NotMatching", "AlsoNotMatching"},
},
{
Name: "Top-level sts ExpiredTokenException matching first code",
Err: &types.ExpiredTokenException{ErrorCodeOverride: aws.String("TestCode"), Message: aws.String("TestMessage")},
Codes: []string{"TestCode"},
Expected: true,
},
{
Name: "Top-level sts ExpiredTokenException matching last code",
Err: &types.ExpiredTokenException{ErrorCodeOverride: aws.String("TestCode"), Message: aws.String("TestMessage")},
Codes: []string{"NotMatching", "TestCode"},
Expected: true,
},
{
Name: "Wrapped sts ExpiredTokenException matching first code",
Err: fmt.Errorf("test: %w", &types.ExpiredTokenException{ErrorCodeOverride: aws.String("TestCode"), Message: aws.String("TestMessage")}),
Codes: []string{"TestCode"},
Expected: true,
},
{
Name: "Wrapped sts ExpiredTokenException matching last code",
Err: fmt.Errorf("test: %w", &types.ExpiredTokenException{ErrorCodeOverride: aws.String("TestCode"), Message: aws.String("TestMessage")}),
Codes: []string{"NotMatching", "TestCode"},
Expected: true,
},
}

for _, testCase := range testCases {
testCase := testCase

t.Run(testCase.Name, func(t *testing.T) {
got := ErrCodeEquals(testCase.Err, testCase.Codes...)

if got != testCase.Expected {
t.Errorf("got %t, expected %t", got, testCase.Expected)
}
})
}
}