Skip to content

Commit

Permalink
Moves SSO testing to shared config tests in aws-sdk-go-base
Browse files Browse the repository at this point in the history
  • Loading branch information
gdavison committed Nov 7, 2023
1 parent de1a5d3 commit 6263d90
Showing 1 changed file with 97 additions and 185 deletions.
282 changes: 97 additions & 185 deletions internal/backend/remote-state/s3/backend_complete_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
configtesting "github.com/hashicorp/aws-sdk-go-base/v2/configtesting"
"github.com/hashicorp/aws-sdk-go-base/v2/mockdata"
"github.com/hashicorp/aws-sdk-go-base/v2/servicemocks"
"github.com/hashicorp/terraform/internal/configs/hcl2shim"
Expand Down Expand Up @@ -2010,222 +2011,133 @@ web_identity_token_file = no-such-file
}
}

func TestBackendConfig_Authentication_SSO(t *testing.T) {
const ssoSessionName = "test-sso-session"

testCases := map[string]struct {
config map[string]any
SharedConfigurationFile string
SetSharedConfigurationFile bool
ExpectedCredentialsValue aws.Credentials
ExpectedDiags tfdiags.Diagnostics
MockStsEndpoints []*servicemocks.MockEndpoint
}{
"shared configuration file": {
config: map[string]any{},
SharedConfigurationFile: fmt.Sprintf(`
[default]
sso_session = %s
sso_account_id = 123456789012
sso_role_name = testRole
region = us-east-1
[sso-session test-sso-session]
sso_region = us-east-1
sso_start_url = https://d-123456789a.awsapps.com/start
sso_registration_scopes = sso:account:access
`, ssoSessionName),
SetSharedConfigurationFile: true,
ExpectedCredentialsValue: mockdata.MockSsoCredentials,
MockStsEndpoints: []*servicemocks.MockEndpoint{
servicemocks.MockStsGetCallerIdentityValidEndpoint,
},
},
}

for name, tc := range testCases {
tc := tc

t.Run(name, func(t *testing.T) {
servicemocks.InitSessionTestEnv(t)
var _ configtesting.TestDriver = &testDriver{}

ctx := context.TODO()

// Populate required fields
tc.config["region"] = "us-east-1"
tc.config["bucket"] = "bucket"
tc.config["key"] = "key"

err := servicemocks.SsoTestSetup(t, ssoSessionName)
if err != nil {
t.Fatalf("setup: %s", err)
}
type testDriver struct {
mode configtesting.TestMode
}

endpoints := map[string]any{}
func (t *testDriver) Init(mode configtesting.TestMode) {
t.mode = mode
}

closeSso, ssoEndpoint := servicemocks.SsoCredentialsApiMock()
defer closeSso()
endpoints["sso"] = ssoEndpoint
func (t testDriver) TestCase() configtesting.TestCaseDriver {
return &testCaseDriver{
mode: t.mode,
}
}

ts := servicemocks.MockAwsApiServer("STS", tc.MockStsEndpoints)
defer ts.Close()
endpoints["sts"] = ts.URL
var _ configtesting.TestCaseDriver = &testCaseDriver{}

tempdir, err := os.MkdirTemp("", "temp")
if err != nil {
t.Fatalf("error creating temp dir: %s", err)
}
defer os.Remove(tempdir)
t.Setenv("TMPDIR", tempdir)
type testCaseDriver struct {
mode configtesting.TestMode
config configurer
}

if tc.SharedConfigurationFile != "" {
file, err := os.CreateTemp("", "aws-sdk-go-base-shared-configuration-file")
func (d *testCaseDriver) Configuration() configtesting.Configurer {
return d.configuration()
}

if err != nil {
t.Fatalf("unexpected error creating temporary shared configuration file: %s", err)
}
func (d *testCaseDriver) configuration() *configurer {
if d.config == nil {
d.config = make(configurer, 0)
}
return &d.config
}

defer os.Remove(file.Name())
func (d *testCaseDriver) Setup(t *testing.T) {
ts := servicemocks.MockAwsApiServer("STS", []*servicemocks.MockEndpoint{
servicemocks.MockStsGetCallerIdentityValidEndpoint,
})
t.Cleanup(func() {
ts.Close()
})
d.config.AddEndpoint("sts", ts.URL)
}

err = os.WriteFile(file.Name(), []byte(tc.SharedConfigurationFile), 0600)
func (d testCaseDriver) Apply(ctx context.Context, t *testing.T) (context.Context, configtesting.Thing) {
t.Helper()

if err != nil {
t.Fatalf("unexpected error writing shared configuration file: %s", err)
}
// Populate required fields
d.config.SetRegion("us-east-1")
d.config.setBucket("bucket")
d.config.setKey("key")
if d.mode == configtesting.TestModeLocal {
d.config.SetSkipCredsValidation(true)
d.config.SetSkipRequestingAccountId(true)
}

tc.config["shared_config_files"] = []any{file.Name()}
}
b, diags := configureBackend(t, map[string]any(d.config))

tc.config["skip_credentials_validation"] = true
var expected tfdiags.Diagnostics

tc.config["endpoints"] = endpoints
if diff := cmp.Diff(diags, expected, cmp.Comparer(diagnosticComparer)); diff != "" {
t.Errorf("unexpected diagnostics difference: %s", diff)
}

b, diags := configureBackend(t, tc.config)
return ctx, thing(b.awsConfig)
}

if diff := cmp.Diff(diags, tc.ExpectedDiags, cmp.Comparer(diagnosticComparer)); diff != "" {
t.Errorf("unexpected diagnostics difference: %s", diff)
}
if diags.HasErrors() {
return
}
var _ configtesting.Configurer = &configurer{}

credentials, err := b.awsConfig.Credentials.Retrieve(ctx)
if err != nil {
t.Fatalf("Error when requesting credentials: %s", err)
}
type configurer map[string]any

if diff := cmp.Diff(credentials, tc.ExpectedCredentialsValue, cmpopts.IgnoreFields(aws.Credentials{}, "Expires")); diff != "" {
t.Fatalf("unexpected credentials: (- got, + expected)\n%s", diff)
}
})
func (c configurer) AddEndpoint(k, v string) {
if endpoints, ok := c["endpoints"]; ok {
m := endpoints.(map[string]any)
m[k] = v
} else {
c["endpoints"] = map[string]any{
k: v,
}
}
}

func TestBackendConfig_Authentication_LegacySSO(t *testing.T) {
const ssoStartUrl = "https://d-123456789a.awsapps.com/start"

testCases := map[string]struct {
config map[string]any
SharedConfigurationFile string
SetSharedConfigurationFile bool
ExpectedCredentialsValue aws.Credentials
ExpectedDiags tfdiags.Diagnostics
MockStsEndpoints []*servicemocks.MockEndpoint
}{
"shared configuration file": {
config: map[string]any{},
SharedConfigurationFile: fmt.Sprintf(`
[default]
sso_start_url = %s
sso_region = us-east-1
sso_account_id = 123456789012
sso_role_name = testRole
region = us-east-1
`, ssoStartUrl),
SetSharedConfigurationFile: true,
ExpectedCredentialsValue: mockdata.MockSsoCredentials,
MockStsEndpoints: []*servicemocks.MockEndpoint{
servicemocks.MockStsGetCallerIdentityValidEndpoint,
},
},
func (c configurer) AddSharedConfigFile(f string) {
x := c["shared_config_files"]
if x == nil {
c["shared_config_files"] = []any{f}
} else {
files := x.([]any)
files = append(files, f)
c["shared_config_files"] = files
}
}

for name, tc := range testCases {
tc := tc

t.Run(name, func(t *testing.T) {
servicemocks.InitSessionTestEnv(t)

ctx := context.TODO()

// Populate required fields
tc.config["region"] = "us-east-1"
tc.config["bucket"] = "bucket"
tc.config["key"] = "key"

err := servicemocks.SsoTestSetup(t, ssoStartUrl)
if err != nil {
t.Fatalf("setup: %s", err)
}

endpoints := map[string]any{}

closeSso, ssoEndpoint := servicemocks.SsoCredentialsApiMock()
defer closeSso()
endpoints["sso"] = ssoEndpoint

ts := servicemocks.MockAwsApiServer("STS", tc.MockStsEndpoints)
defer ts.Close()
endpoints["sts"] = ts.URL

tempdir, err := os.MkdirTemp("", "temp")
if err != nil {
t.Fatalf("error creating temp dir: %s", err)
}
defer os.Remove(tempdir)
t.Setenv("TMPDIR", tempdir)

if tc.SharedConfigurationFile != "" {
file, err := os.CreateTemp("", "aws-sdk-go-base-shared-configuration-file")

if err != nil {
t.Fatalf("unexpected error creating temporary shared configuration file: %s", err)
}

defer os.Remove(file.Name())
func (c configurer) setBucket(s string) {
c["bucket"] = s
}

err = os.WriteFile(file.Name(), []byte(tc.SharedConfigurationFile), 0600)
func (c configurer) setKey(s string) {
c["key"] = s
}

if err != nil {
t.Fatalf("unexpected error writing shared configuration file: %s", err)
}
func (c configurer) SetRegion(s string) {
c["region"] = s
}

tc.config["shared_config_files"] = []any{file.Name()}
}
func (c configurer) SetSkipCredsValidation(b bool) {
c["skip_credentials_validation"] = b
}

tc.config["skip_credentials_validation"] = true
func (c configurer) SetSkipRequestingAccountId(b bool) {
c["skip_requesting_account_id"] = b
}

tc.config["endpoints"] = endpoints
var _ configtesting.Thing = thing{}

b, diags := configureBackend(t, tc.config)
type thing aws.Config

if diff := cmp.Diff(diags, tc.ExpectedDiags, cmp.Comparer(diagnosticComparer)); diff != "" {
t.Errorf("unexpected diagnostics difference: %s", diff)
}
if diags.HasErrors() {
return
}
func (t thing) GetCredentials() aws.CredentialsProvider {
return t.Credentials
}

credentials, err := b.awsConfig.Credentials.Retrieve(ctx)
if err != nil {
t.Fatalf("Error when requesting credentials: %s", err)
}
func TestBackendConfig_Authentication_SSO(t *testing.T) {
configtesting.SSO(t, &testDriver{})
}

if diff := cmp.Diff(credentials, tc.ExpectedCredentialsValue, cmpopts.IgnoreFields(aws.Credentials{}, "Expires")); diff != "" {
t.Fatalf("unexpected credentials: (- got, + expected)\n%s", diff)
}
})
}
func TestBackendConfig_Authentication_LegacySSO(t *testing.T) {
configtesting.LegacySSO(t, &testDriver{})
}

func TestBackendConfig_Region(t *testing.T) {
Expand Down

0 comments on commit 6263d90

Please sign in to comment.