diff --git a/internal/backend/remote-state/s3/backend_complete_test.go b/internal/backend/remote-state/s3/backend_complete_test.go index ab131b597371..1846ac40fd66 100644 --- a/internal/backend/remote-state/s3/backend_complete_test.go +++ b/internal/backend/remote-state/s3/backend_complete_test.go @@ -14,6 +14,7 @@ import ( "github.com/aws/aws-sdk-go-v2/aws" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" + configtesting "github.com/hashicorp/aws-sdk-go-base/v2/configtesting" "github.com/hashicorp/aws-sdk-go-base/v2/mockdata" "github.com/hashicorp/aws-sdk-go-base/v2/servicemocks" "github.com/hashicorp/terraform/internal/configs/hcl2shim" @@ -2010,222 +2011,133 @@ web_identity_token_file = no-such-file } } -func TestBackendConfig_Authentication_SSO(t *testing.T) { - const ssoSessionName = "test-sso-session" - - testCases := map[string]struct { - config map[string]any - SharedConfigurationFile string - SetSharedConfigurationFile bool - ExpectedCredentialsValue aws.Credentials - ExpectedDiags tfdiags.Diagnostics - MockStsEndpoints []*servicemocks.MockEndpoint - }{ - "shared configuration file": { - config: map[string]any{}, - SharedConfigurationFile: fmt.Sprintf(` -[default] -sso_session = %s -sso_account_id = 123456789012 -sso_role_name = testRole -region = us-east-1 - -[sso-session test-sso-session] -sso_region = us-east-1 -sso_start_url = https://d-123456789a.awsapps.com/start -sso_registration_scopes = sso:account:access -`, ssoSessionName), - SetSharedConfigurationFile: true, - ExpectedCredentialsValue: mockdata.MockSsoCredentials, - MockStsEndpoints: []*servicemocks.MockEndpoint{ - servicemocks.MockStsGetCallerIdentityValidEndpoint, - }, - }, - } - - for name, tc := range testCases { - tc := tc - - t.Run(name, func(t *testing.T) { - servicemocks.InitSessionTestEnv(t) +var _ configtesting.TestDriver = &testDriver{} - ctx := context.TODO() - - // Populate required fields - tc.config["region"] = "us-east-1" - tc.config["bucket"] = "bucket" - tc.config["key"] = "key" - - err := servicemocks.SsoTestSetup(t, ssoSessionName) - if err != nil { - t.Fatalf("setup: %s", err) - } +type testDriver struct { + mode configtesting.TestMode +} - endpoints := map[string]any{} +func (t *testDriver) Init(mode configtesting.TestMode) { + t.mode = mode +} - closeSso, ssoEndpoint := servicemocks.SsoCredentialsApiMock() - defer closeSso() - endpoints["sso"] = ssoEndpoint +func (t testDriver) TestCase() configtesting.TestCaseDriver { + return &testCaseDriver{ + mode: t.mode, + } +} - ts := servicemocks.MockAwsApiServer("STS", tc.MockStsEndpoints) - defer ts.Close() - endpoints["sts"] = ts.URL +var _ configtesting.TestCaseDriver = &testCaseDriver{} - tempdir, err := os.MkdirTemp("", "temp") - if err != nil { - t.Fatalf("error creating temp dir: %s", err) - } - defer os.Remove(tempdir) - t.Setenv("TMPDIR", tempdir) +type testCaseDriver struct { + mode configtesting.TestMode + config configurer +} - if tc.SharedConfigurationFile != "" { - file, err := os.CreateTemp("", "aws-sdk-go-base-shared-configuration-file") +func (d *testCaseDriver) Configuration() configtesting.Configurer { + return d.configuration() +} - if err != nil { - t.Fatalf("unexpected error creating temporary shared configuration file: %s", err) - } +func (d *testCaseDriver) configuration() *configurer { + if d.config == nil { + d.config = make(configurer, 0) + } + return &d.config +} - defer os.Remove(file.Name()) +func (d *testCaseDriver) Setup(t *testing.T) { + ts := servicemocks.MockAwsApiServer("STS", []*servicemocks.MockEndpoint{ + servicemocks.MockStsGetCallerIdentityValidEndpoint, + }) + t.Cleanup(func() { + ts.Close() + }) + d.config.AddEndpoint("sts", ts.URL) +} - err = os.WriteFile(file.Name(), []byte(tc.SharedConfigurationFile), 0600) +func (d testCaseDriver) Apply(ctx context.Context, t *testing.T) (context.Context, configtesting.Thing) { + t.Helper() - if err != nil { - t.Fatalf("unexpected error writing shared configuration file: %s", err) - } + // Populate required fields + d.config.SetRegion("us-east-1") + d.config.setBucket("bucket") + d.config.setKey("key") + if d.mode == configtesting.TestModeLocal { + d.config.SetSkipCredsValidation(true) + d.config.SetSkipRequestingAccountId(true) + } - tc.config["shared_config_files"] = []any{file.Name()} - } + b, diags := configureBackend(t, map[string]any(d.config)) - tc.config["skip_credentials_validation"] = true + var expected tfdiags.Diagnostics - tc.config["endpoints"] = endpoints + if diff := cmp.Diff(diags, expected, cmp.Comparer(diagnosticComparer)); diff != "" { + t.Errorf("unexpected diagnostics difference: %s", diff) + } - b, diags := configureBackend(t, tc.config) + return ctx, thing(b.awsConfig) +} - if diff := cmp.Diff(diags, tc.ExpectedDiags, cmp.Comparer(diagnosticComparer)); diff != "" { - t.Errorf("unexpected diagnostics difference: %s", diff) - } - if diags.HasErrors() { - return - } +var _ configtesting.Configurer = &configurer{} - credentials, err := b.awsConfig.Credentials.Retrieve(ctx) - if err != nil { - t.Fatalf("Error when requesting credentials: %s", err) - } +type configurer map[string]any - if diff := cmp.Diff(credentials, tc.ExpectedCredentialsValue, cmpopts.IgnoreFields(aws.Credentials{}, "Expires")); diff != "" { - t.Fatalf("unexpected credentials: (- got, + expected)\n%s", diff) - } - }) +func (c configurer) AddEndpoint(k, v string) { + if endpoints, ok := c["endpoints"]; ok { + m := endpoints.(map[string]any) + m[k] = v + } else { + c["endpoints"] = map[string]any{ + k: v, + } } } -func TestBackendConfig_Authentication_LegacySSO(t *testing.T) { - const ssoStartUrl = "https://d-123456789a.awsapps.com/start" - - testCases := map[string]struct { - config map[string]any - SharedConfigurationFile string - SetSharedConfigurationFile bool - ExpectedCredentialsValue aws.Credentials - ExpectedDiags tfdiags.Diagnostics - MockStsEndpoints []*servicemocks.MockEndpoint - }{ - "shared configuration file": { - config: map[string]any{}, - SharedConfigurationFile: fmt.Sprintf(` -[default] -sso_start_url = %s -sso_region = us-east-1 -sso_account_id = 123456789012 -sso_role_name = testRole -region = us-east-1 -`, ssoStartUrl), - SetSharedConfigurationFile: true, - ExpectedCredentialsValue: mockdata.MockSsoCredentials, - MockStsEndpoints: []*servicemocks.MockEndpoint{ - servicemocks.MockStsGetCallerIdentityValidEndpoint, - }, - }, +func (c configurer) AddSharedConfigFile(f string) { + x := c["shared_config_files"] + if x == nil { + c["shared_config_files"] = []any{f} + } else { + files := x.([]any) + files = append(files, f) + c["shared_config_files"] = files } +} - for name, tc := range testCases { - tc := tc - - t.Run(name, func(t *testing.T) { - servicemocks.InitSessionTestEnv(t) - - ctx := context.TODO() - - // Populate required fields - tc.config["region"] = "us-east-1" - tc.config["bucket"] = "bucket" - tc.config["key"] = "key" - - err := servicemocks.SsoTestSetup(t, ssoStartUrl) - if err != nil { - t.Fatalf("setup: %s", err) - } - - endpoints := map[string]any{} - - closeSso, ssoEndpoint := servicemocks.SsoCredentialsApiMock() - defer closeSso() - endpoints["sso"] = ssoEndpoint - - ts := servicemocks.MockAwsApiServer("STS", tc.MockStsEndpoints) - defer ts.Close() - endpoints["sts"] = ts.URL - - tempdir, err := os.MkdirTemp("", "temp") - if err != nil { - t.Fatalf("error creating temp dir: %s", err) - } - defer os.Remove(tempdir) - t.Setenv("TMPDIR", tempdir) - - if tc.SharedConfigurationFile != "" { - file, err := os.CreateTemp("", "aws-sdk-go-base-shared-configuration-file") - - if err != nil { - t.Fatalf("unexpected error creating temporary shared configuration file: %s", err) - } - - defer os.Remove(file.Name()) +func (c configurer) setBucket(s string) { + c["bucket"] = s +} - err = os.WriteFile(file.Name(), []byte(tc.SharedConfigurationFile), 0600) +func (c configurer) setKey(s string) { + c["key"] = s +} - if err != nil { - t.Fatalf("unexpected error writing shared configuration file: %s", err) - } +func (c configurer) SetRegion(s string) { + c["region"] = s +} - tc.config["shared_config_files"] = []any{file.Name()} - } +func (c configurer) SetSkipCredsValidation(b bool) { + c["skip_credentials_validation"] = b +} - tc.config["skip_credentials_validation"] = true +func (c configurer) SetSkipRequestingAccountId(b bool) { + c["skip_requesting_account_id"] = b +} - tc.config["endpoints"] = endpoints +var _ configtesting.Thing = thing{} - b, diags := configureBackend(t, tc.config) +type thing aws.Config - if diff := cmp.Diff(diags, tc.ExpectedDiags, cmp.Comparer(diagnosticComparer)); diff != "" { - t.Errorf("unexpected diagnostics difference: %s", diff) - } - if diags.HasErrors() { - return - } +func (t thing) GetCredentials() aws.CredentialsProvider { + return t.Credentials +} - credentials, err := b.awsConfig.Credentials.Retrieve(ctx) - if err != nil { - t.Fatalf("Error when requesting credentials: %s", err) - } +func TestBackendConfig_Authentication_SSO(t *testing.T) { + configtesting.SSO(t, &testDriver{}) +} - if diff := cmp.Diff(credentials, tc.ExpectedCredentialsValue, cmpopts.IgnoreFields(aws.Credentials{}, "Expires")); diff != "" { - t.Fatalf("unexpected credentials: (- got, + expected)\n%s", diff) - } - }) - } +func TestBackendConfig_Authentication_LegacySSO(t *testing.T) { + configtesting.LegacySSO(t, &testDriver{}) } func TestBackendConfig_Region(t *testing.T) {