Skip to content

Commit

Permalink
Add 'awsbase.ErrCodeEquals', AWS SDK for Go v2 variants of helper in …
Browse files Browse the repository at this point in the history
…v2/awsv1shim/tfawserr.
  • Loading branch information
ewbankkit committed Jun 22, 2023
1 parent 82b6e9e commit 1c3bdec
Show file tree
Hide file tree
Showing 3 changed files with 221 additions and 6 deletions.
26 changes: 20 additions & 6 deletions errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,39 @@
package awsbase

import (
"errors"

smithy "github.com/aws/smithy-go"
"github.com/hashicorp/aws-sdk-go-base/v2/internal/config"
"github.com/hashicorp/aws-sdk-go-base/v2/internal/errs"
)

// CannotAssumeRoleError occurs when AssumeRole cannot complete.
type CannotAssumeRoleError = config.CannotAssumeRoleError

// IsCannotAssumeRoleError returns true if the error contains the CannotAssumeRoleError type.
func IsCannotAssumeRoleError(err error) bool {
var e CannotAssumeRoleError
return errors.As(err, &e)
return errs.IsA[CannotAssumeRoleError](err)
}

// NoValidCredentialSourcesError occurs when all credential lookup methods have been exhausted without results.
type NoValidCredentialSourcesError = config.NoValidCredentialSourcesError

// IsNoValidCredentialSourcesError returns true if the error contains the NoValidCredentialSourcesError type.
func IsNoValidCredentialSourcesError(err error) bool {
var e NoValidCredentialSourcesError
return errors.As(err, &e)
return errs.IsA[NoValidCredentialSourcesError](err)
}

// AWS SDK for Go v2 variants of helpers in v2/awsv1shim/tfawserr.

// ErrCodeEquals returns true if the error matches all these conditions:
// - err is of type smithy.APIError
// - Error.Code() equals one of the passed codes
func ErrCodeEquals(err error, codes ...string) bool {
if apiErr, ok := errs.As[smithy.APIError](err); ok {
for _, code := range codes {
if apiErr.ErrorCode() == code {
return true
}
}
}
return false
}
180 changes: 180 additions & 0 deletions errors_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0

package awsbase

import (
"fmt"
"testing"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/sts/types"
smithy "github.com/aws/smithy-go"
)

func TestIsCannotAssumeRoleError(t *testing.T) {
testCases := []struct {
Name string
Err error
Expected bool
}{
{
Name: "nil error",
},
{
Name: "Top-level NoValidCredentialSourcesError",
Err: NoValidCredentialSourcesError{},
},
{
Name: "Top-level CannotAssumeRoleError",
Err: CannotAssumeRoleError{},
Expected: true,
},
{
Name: "Nested CannotAssumeRoleError",
Err: fmt.Errorf("test: %w", CannotAssumeRoleError{}),
Expected: true,
},
}

for _, testCase := range testCases {
testCase := testCase

t.Run(testCase.Name, func(t *testing.T) {
got := IsCannotAssumeRoleError(testCase.Err)

if got != testCase.Expected {
t.Errorf("got %t, expected %t", got, testCase.Expected)
}
})
}
}

func TestIsNoValidCredentialSourcesError(t *testing.T) {
testCases := []struct {
Name string
Err error
Expected bool
}{
{
Name: "nil error",
},
{
Name: "Top-level CannotAssumeRoleError",
Err: CannotAssumeRoleError{},
},
{
Name: "Top-level NoValidCredentialSourcesError",
Err: NoValidCredentialSourcesError{},
Expected: true,
},
{
Name: "Nested NoValidCredentialSourcesError",
Err: fmt.Errorf("test: %w", NoValidCredentialSourcesError{}),
Expected: true,
},
}

for _, testCase := range testCases {
testCase := testCase

t.Run(testCase.Name, func(t *testing.T) {
got := IsNoValidCredentialSourcesError(testCase.Err)

if got != testCase.Expected {
t.Errorf("got %t, expected %t", got, testCase.Expected)
}
})
}
}

func TestErrCodeEquals(t *testing.T) {
testCases := []struct {
Name string
Err error
Codes []string
Expected bool
}{
{
Name: "nil error",
},
{
Name: "Top-level CannotAssumeRoleError",
Err: CannotAssumeRoleError{},
},
{
Name: "Top-level smithy.GenericAPIError matching first code",
Err: &smithy.GenericAPIError{Code: "TestCode", Message: "TestMessage"},
Codes: []string{"TestCode"},
Expected: true,
},
{
Name: "Top-level smithy.GenericAPIError matching last code",
Err: &smithy.GenericAPIError{Code: "TestCode", Message: "TestMessage"},
Codes: []string{"NotMatching", "TestCode"},
Expected: true,
},
{
Name: "Top-level smithy.GenericAPIError no code",
Err: &smithy.GenericAPIError{Code: "TestCode", Message: "TestMessage"},
},
{
Name: "Top-level smithy.GenericAPIError non-matching codes",
Err: &smithy.GenericAPIError{Code: "TestCode", Message: "TestMessage"},
Codes: []string{"NotMatching", "AlsoNotMatching"},
},
{
Name: "Wrapped smithy.GenericAPIError matching first code",
Err: fmt.Errorf("test: %w", &smithy.GenericAPIError{Code: "TestCode", Message: "TestMessage"}),
Codes: []string{"TestCode"},
Expected: true,
},
{
Name: "Wrapped smithy.GenericAPIError matching last code",
Err: fmt.Errorf("test: %w", &smithy.GenericAPIError{Code: "TestCode", Message: "TestMessage"}),
Codes: []string{"NotMatching", "TestCode"},
Expected: true,
},
{
Name: "Wrapped smithy.GenericAPIError non-matching codes",
Err: fmt.Errorf("test: %w", &smithy.GenericAPIError{Code: "TestCode", Message: "TestMessage"}),
Codes: []string{"NotMatching", "AlsoNotMatching"},
},
{
Name: "Top-level sts ExpiredTokenException matching first code",
Err: &types.ExpiredTokenException{ErrorCodeOverride: aws.String("TestCode"), Message: aws.String("TestMessage")},
Codes: []string{"TestCode"},
Expected: true,
},
{
Name: "Top-level sts ExpiredTokenException matching last code",
Err: &types.ExpiredTokenException{ErrorCodeOverride: aws.String("TestCode"), Message: aws.String("TestMessage")},
Codes: []string{"NotMatching", "TestCode"},
Expected: true,
},
{
Name: "Wrapped sts ExpiredTokenException matching first code",
Err: fmt.Errorf("test: %w", &types.ExpiredTokenException{ErrorCodeOverride: aws.String("TestCode"), Message: aws.String("TestMessage")}),
Codes: []string{"TestCode"},
Expected: true,
},
{
Name: "Wrapped sts ExpiredTokenException matching last code",
Err: fmt.Errorf("test: %w", &types.ExpiredTokenException{ErrorCodeOverride: aws.String("TestCode"), Message: aws.String("TestMessage")}),
Codes: []string{"NotMatching", "TestCode"},
Expected: true,
},
}

for _, testCase := range testCases {
testCase := testCase

t.Run(testCase.Name, func(t *testing.T) {
got := ErrCodeEquals(testCase.Err, testCase.Codes...)

if got != testCase.Expected {
t.Errorf("got %t, expected %t", got, testCase.Expected)
}
})
}
}
21 changes: 21 additions & 0 deletions internal/errs/errs.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0

package errs

import (
"errors"
)

// IsA indicates whether an error matches an error type.
func IsA[T error](err error) bool {
_, ok := As[T](err)
return ok
}

// As is equivalent to errors.As(), but returns the value in-line.
func As[T error](err error) (T, bool) {
var as T
ok := errors.As(err, &as)
return as, ok
}

0 comments on commit 1c3bdec

Please sign in to comment.