From 140898c01ec3eef91c2d04e92d363757008a3d1a Mon Sep 17 00:00:00 2001 From: James Kwon Date: Thu, 27 Jul 2023 16:49:20 -0400 Subject: [PATCH] Refactor db subnet groups --- aws/aws.go | 2 +- aws/rds_subnet_group.go | 39 +++++------- aws/rds_subnet_group_test.go | 108 ++++++++++++++++++++++------------ aws/rds_subnet_group_types.go | 12 ++-- 4 files changed, 91 insertions(+), 70 deletions(-) diff --git a/aws/aws.go b/aws/aws.go index 3da6b14b..210ff5af 100644 --- a/aws/aws.go +++ b/aws/aws.go @@ -941,7 +941,7 @@ func GetAllResources(targetRegions []string, excludeAfter time.Time, resourceTyp // Note: the `DescribeDBSubnetGroups` API response does not contain any information // about when the subnet group was created, so we cannot apply the `excludeAfter` filter - subnetGroups, err := getAllRdsDbSubnetGroups(cloudNukeSession, configObj) + subnetGroups, err := dbSubnetGroups.getAll(configObj) if err != nil { ge := report.GeneralError{ Error: err, diff --git a/aws/rds_subnet_group.go b/aws/rds_subnet_group.go index 8a566603..05c4aa96 100644 --- a/aws/rds_subnet_group.go +++ b/aws/rds_subnet_group.go @@ -7,7 +7,6 @@ import ( "github.com/aws/aws-sdk-go/aws" awsgo "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/rds" "github.com/gruntwork-io/cloud-nuke/logging" "github.com/gruntwork-io/cloud-nuke/report" @@ -16,10 +15,10 @@ import ( commonTelemetry "github.com/gruntwork-io/go-commons/telemetry" ) -func waitUntilRdsDbSubnetGroupDeleted(svc *rds.RDS, name *string) error { +func (dsg DBSubnetGroups) waitUntilRdsDbSubnetGroupDeleted(name *string) error { // wait up to 15 minutes for i := 0; i < 90; i++ { - _, err := svc.DescribeDBSubnetGroups(&rds.DescribeDBSubnetGroupsInput{DBSubnetGroupName: name}) + _, err := dsg.Client.DescribeDBSubnetGroups(&rds.DescribeDBSubnetGroupsInput{DBSubnetGroupName: name}) if err != nil { if awsErr, isAwsErr := err.(awserr.Error); isAwsErr && awsErr.Code() == rds.ErrCodeDBSubnetGroupNotFoundFault { return nil @@ -35,23 +34,15 @@ func waitUntilRdsDbSubnetGroupDeleted(svc *rds.RDS, name *string) error { return RdsDeleteError{name: *name} } -func shouldIncludeDbSubnetGroup(subnetGroup *rds.DBSubnetGroup, configObj config.Config) bool { - return config.ShouldInclude( - aws.StringValue(subnetGroup.DBSubnetGroupName), - configObj.DBSubnetGroups.IncludeRule.NamesRegExp, - configObj.DBSubnetGroups.ExcludeRule.NamesRegExp, - ) -} - -func getAllRdsDbSubnetGroups(session *session.Session, configObj config.Config) ([]*string, error) { - svc := rds.New(session) - +func (dsg DBSubnetGroups) getAll(configObj config.Config) ([]*string, error) { var names []*string - err := svc.DescribeDBSubnetGroupsPages( + err := dsg.Client.DescribeDBSubnetGroupsPages( &rds.DescribeDBSubnetGroupsInput{}, func(page *rds.DescribeDBSubnetGroupsOutput, lastPage bool) bool { for _, subnetGroup := range page.DBSubnetGroups { - if shouldIncludeDbSubnetGroup(subnetGroup, configObj) { + if configObj.DBSubnetGroups.ShouldInclude(config.ResourceValue{ + Name: subnetGroup.DBSubnetGroupName, + }) { names = append(names, subnetGroup.DBSubnetGroupName) } } @@ -65,19 +56,17 @@ func getAllRdsDbSubnetGroups(session *session.Session, configObj config.Config) return names, nil } -func nukeAllRdsDbSubnetGroups(session *session.Session, names []*string) error { - svc := rds.New(session) - +func (dsg DBSubnetGroups) nukeAll(names []*string) error { if len(names) == 0 { - logging.Logger.Debugf("No DB Subnet groups in region %s", *session.Config.Region) + logging.Logger.Debugf("No DB Subnet groups in region %s", dsg.Region) return nil } - logging.Logger.Debugf("Deleting all DB Subnet groups in region %s", *session.Config.Region) + logging.Logger.Debugf("Deleting all DB Subnet groups in region %s", dsg.Region) deletedNames := []*string{} for _, name := range names { - _, err := svc.DeleteDBSubnetGroup(&rds.DeleteDBSubnetGroupInput{ + _, err := dsg.Client.DeleteDBSubnetGroup(&rds.DeleteDBSubnetGroupInput{ DBSubnetGroupName: name, }) @@ -94,7 +83,7 @@ func nukeAllRdsDbSubnetGroups(session *session.Session, names []*string) error { telemetry.TrackEvent(commonTelemetry.EventContext{ EventName: "Error Nuking RDS DB subnet group", }, map[string]interface{}{ - "region": *session.Config.Region, + "region": dsg.Region, }) } else { deletedNames = append(deletedNames, name) @@ -105,7 +94,7 @@ func nukeAllRdsDbSubnetGroups(session *session.Session, names []*string) error { if len(deletedNames) > 0 { for _, name := range deletedNames { - err := waitUntilRdsDbSubnetGroupDeleted(svc, name) + err := dsg.waitUntilRdsDbSubnetGroupDeleted(name) if err != nil { logging.Logger.Errorf("[Failed] %s", err) return errors.WithStackTrace(err) @@ -113,6 +102,6 @@ func nukeAllRdsDbSubnetGroups(session *session.Session, names []*string) error { } } - logging.Logger.Debugf("[OK] %d RDS DB subnet group(s) nuked in %s", len(deletedNames), *session.Config.Region) + logging.Logger.Debugf("[OK] %d RDS DB subnet group(s) nuked in %s", len(deletedNames), dsg.Region) return nil } diff --git a/aws/rds_subnet_group_test.go b/aws/rds_subnet_group_test.go index f7d52ef1..8804b0ba 100644 --- a/aws/rds_subnet_group_test.go +++ b/aws/rds_subnet_group_test.go @@ -1,59 +1,91 @@ package aws import ( - "fmt" - "strings" - "testing" - awsgo "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/rds" + "github.com/aws/aws-sdk-go/service/rds/rdsiface" "github.com/gruntwork-io/cloud-nuke/config" "github.com/gruntwork-io/cloud-nuke/telemetry" - "github.com/gruntwork-io/cloud-nuke/util" - "github.com/gruntwork-io/go-commons/errors" - "github.com/gruntwork-io/terratest/modules/aws" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "regexp" + "testing" ) -func createTestRDSSubnetGroup(t *testing.T, session *session.Session, name string) { - t.Logf("Creating RDS subnet group in region %s", awsgo.StringValue(session.Config.Region)) - - defaultVpc := aws.GetDefaultVpc(t, *session.Config.Region) - defaultAzSubnets := aws.GetDefaultSubnetIDsForVpc(t, *defaultVpc) - var subnetIds []*string - for _, subnet := range defaultAzSubnets { - subnetIds = append(subnetIds, awsgo.String(subnet)) - } +type mockedDBSubnetGroups struct { + rdsiface.RDSAPI + DescribeDBSubnetGroupsOutput rds.DescribeDBSubnetGroupsOutput + DeleteDBSubnetGroupOutput rds.DeleteDBSubnetGroupOutput +} - svc := rds.New(session) - _, err := svc.CreateDBSubnetGroup(&rds.CreateDBSubnetGroupInput{ - DBSubnetGroupName: awsgo.String(name), - DBSubnetGroupDescription: awsgo.String(fmt.Sprintf("Test DB subnet for %s", t.Name())), - SubnetIds: subnetIds, - }) +func (m mockedDBSubnetGroups) DescribeDBSubnetGroups(*rds.DescribeDBSubnetGroupsInput) (*rds.DescribeDBSubnetGroupsOutput, error) { + return &m.DescribeDBSubnetGroupsOutput, nil +} - require.NoError(t, err) +func (m mockedDBSubnetGroups) DeleteDBSubnetGroup(*rds.DeleteDBSubnetGroupInput) (*rds.DeleteDBSubnetGroupOutput, error) { + return &m.DeleteDBSubnetGroupOutput, nil } -func TestNukeRDSSubnetGroup(t *testing.T) { +func TestDBSubnetGroups_GetAll(t *testing.T) { telemetry.InitTelemetry("cloud-nuke", "") t.Parallel() - region, err := getRandomRegion() - require.NoError(t, errors.WithStackTrace(err)) + testName1 := "test-db-subnet-group1" + testName2 := "test-db-subnet-group2" + dsg := DBSubnetGroups{ + Client: mockedDBSubnetGroups{ + DescribeDBSubnetGroupsOutput: rds.DescribeDBSubnetGroupsOutput{ + DBSubnetGroups: []*rds.DBSubnetGroup{ + { + DBSubnetGroupName: awsgo.String(testName1), + }, + { + DBSubnetGroupName: awsgo.String(testName2), + }, + }, + }, + }, + } + + tests := map[string]struct { + configObj config.ResourceType + expected []string + }{ + "emptyFilter": { + configObj: config.ResourceType{}, + expected: []string{testName1, testName2}, + }, + "nameExclusionFilter": { + configObj: config.ResourceType{ + ExcludeRule: config.FilterRule{ + NamesRegExp: []config.Expression{{ + RE: *regexp.MustCompile(testName1), + }}}, + }, + expected: []string{testName2}, + }, + } + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + names, err := dsg.getAll(config.Config{ + DBSubnetGroups: tc.configObj, + }) + require.NoError(t, err) + require.Equal(t, tc.expected, awsgo.StringValueSlice(names)) + }) + } - session, err := session.NewSession(&awsgo.Config{ - Region: awsgo.String(region)}, - ) +} - subnetGroupName := "cloud-nuke-test-" + util.UniqueID() - createTestRDSSubnetGroup(t, session, subnetGroupName) +func TestDBSubnetGroups_NukeAll(t *testing.T) { + telemetry.InitTelemetry("cloud-nuke", "") + t.Parallel() + + dsg := DBSubnetGroups{ + Client: mockedDBSubnetGroups{ + DeleteDBSubnetGroupOutput: rds.DeleteDBSubnetGroupOutput{}, + }, + } - defer func() { - nukeAllRdsDbSubnetGroups(session, []*string{&subnetGroupName}) - subnetGroupNames, _ := getAllRdsDbSubnetGroups(session, config.Config{}) - assert.NotContains(t, awsgo.StringValueSlice(subnetGroupNames), strings.ToLower(subnetGroupName)) - }() + err := dsg.nukeAll([]*string{awsgo.String("test")}) + require.NoError(t, err) } diff --git a/aws/rds_subnet_group_types.go b/aws/rds_subnet_group_types.go index bc47e8b0..4692f84e 100644 --- a/aws/rds_subnet_group_types.go +++ b/aws/rds_subnet_group_types.go @@ -13,23 +13,23 @@ type DBSubnetGroups struct { InstanceNames []string } -func (instance DBSubnetGroups) ResourceName() string { +func (dsg DBSubnetGroups) ResourceName() string { return "rds-subnet-group" } // ResourceIdentifiers - The instance names of the rds db instances -func (instance DBSubnetGroups) ResourceIdentifiers() []string { - return instance.InstanceNames +func (dsg DBSubnetGroups) ResourceIdentifiers() []string { + return dsg.InstanceNames } -func (instance DBSubnetGroups) MaxBatchSize() int { +func (dsg DBSubnetGroups) MaxBatchSize() int { // Tentative batch size to ensure AWS doesn't throttle return 49 } // Nuke - nuke 'em all!!! -func (instance DBSubnetGroups) Nuke(session *session.Session, identifiers []string) error { - if err := nukeAllRdsDbSubnetGroups(session, awsgo.StringSlice(identifiers)); err != nil { +func (dsg DBSubnetGroups) Nuke(session *session.Session, identifiers []string) error { + if err := dsg.nukeAll(awsgo.StringSlice(identifiers)); err != nil { return errors.WithStackTrace(err) }