diff --git a/cmd/root.go b/cmd/root.go index d8c94f9..b05dea9 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -25,9 +25,9 @@ var ( "CCC-Taxonomy": { Strikes.SQLFeatures, Strikes.AutomatedBackups, + Strikes.MultiRegion, // Strikes.VerticalScaling, // Strikes.Replication, - // Strikes.MultiRegion, // Strikes.BackupRecovery, // Strikes.Encryption, // Strikes.RBAC, diff --git a/example-config.yml b/example-config.yml index 66597a0..436075b 100644 --- a/example-config.yml +++ b/example-config.yml @@ -11,6 +11,7 @@ raids: config: instance_identifier: unique-id-name database: test + primary_region: us-east-1 host: localhost password: password port: 3306 diff --git a/strikes/AutomatedBackups.go b/strikes/AutomatedBackups.go index 967f1f3..a4ae071 100644 --- a/strikes/AutomatedBackups.go +++ b/strikes/AutomatedBackups.go @@ -23,7 +23,7 @@ func (a *Strikes) AutomatedBackups() (strikeName string, result raidengine.Strik Movements: make(map[string]raidengine.MovementResult), } - // Movement + // Get Configuration cfg, err := getAWSConfig() if err != nil { result.Message = err.Error() @@ -37,10 +37,10 @@ func (a *Strikes) AutomatedBackups() (strikeName string, result raidengine.Strik return } - autmatedBackupsMovement := checkRDSAutomatedBackupMovement(cfg) - result.Movements["CheckForDBInstanceAutomatedBackups"] = autmatedBackupsMovement - if !autmatedBackupsMovement.Passed { - result.Message = autmatedBackupsMovement.Message + automatedBackupsMovement := checkRDSAutomatedBackupMovement(cfg) + result.Movements["CheckForDBInstanceAutomatedBackups"] = automatedBackupsMovement + if !automatedBackupsMovement.Passed { + result.Message = automatedBackupsMovement.Message return } @@ -49,31 +49,6 @@ func (a *Strikes) AutomatedBackups() (strikeName string, result raidengine.Strik return } -func checkRDSInstanceMovement(cfg aws.Config) (result raidengine.MovementResult) { - // check if the instance is available - result = raidengine.MovementResult{ - Description: "Check if the instance is available/exists", - Function: utils.CallerPath(0), - } - - rdsClient := rds.NewFromConfig(cfg) - identifier, _ := getDBInstanceIdentifier() - - input := &rds.DescribeDBInstancesInput{ - DBInstanceIdentifier: aws.String(identifier), - } - - instances, err := rdsClient.DescribeDBInstances(context.TODO(), input) - if err != nil { - // Handle error - result.Message = err.Error() - result.Passed = false - return - } - result.Passed = len(instances.DBInstances) > 0 - return -} - func checkRDSAutomatedBackupMovement(cfg aws.Config) (result raidengine.MovementResult) { result = raidengine.MovementResult{ @@ -82,10 +57,10 @@ func checkRDSAutomatedBackupMovement(cfg aws.Config) (result raidengine.Movement } rdsClient := rds.NewFromConfig(cfg) - identifier, _ := getDBInstanceIdentifier() + instanceIdentifier, _ := getHostDBInstanceIdentifier() input := &rds.DescribeDBInstanceAutomatedBackupsInput{ - DBInstanceIdentifier: aws.String(identifier), + DBInstanceIdentifier: aws.String(instanceIdentifier), } backups, err := rdsClient.DescribeDBInstanceAutomatedBackups(context.TODO(), input) diff --git a/strikes/MultiRegion.go b/strikes/MultiRegion.go new file mode 100644 index 0000000..ac2adba --- /dev/null +++ b/strikes/MultiRegion.go @@ -0,0 +1,98 @@ +package strikes + +import ( + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/privateerproj/privateer-sdk/raidengine" + "github.com/privateerproj/privateer-sdk/utils" +) + +func (a *Strikes) MultiRegion() (strikeName string, result raidengine.StrikeResult) { + strikeName = "MultiRegion" + result = raidengine.StrikeResult{ + Passed: false, + Description: "Check if AWS RDS instance has multi region. This strike only checks for a read replica in a seperate region", + DocsURL: "https://www.github.com/krumIO/raid-rds", + ControlID: "CCC-Taxonomy-1", + Movements: make(map[string]raidengine.MovementResult), + } + + // Get Configuration + cfg, err := getAWSConfig() + if err != nil { + result.Message = err.Error() + return + } + + rdsInstanceMovement := checkRDSInstanceMovement(cfg) + result.Movements["CheckForDBInstance"] = rdsInstanceMovement + if !rdsInstanceMovement.Passed { + result.Message = rdsInstanceMovement.Message + return + } + + multiRegionMovement := checkRDSMultiRegionMovement(cfg) + result.Movements["CheckForMultiRegionDBInstances"] = multiRegionMovement + if !multiRegionMovement.Passed { + result.Message = multiRegionMovement.Message + return + } + + result.Passed = true + result.Message = "Completed Successfully" + + return +} + +func checkRDSMultiRegionMovement(cfg aws.Config) (result raidengine.MovementResult) { + + result = raidengine.MovementResult{ + Description: "Check if the instance has multi region enabled", + Function: utils.CallerPath(0), + } + instanceIdentifier, _ := getHostDBInstanceIdentifier() + + instance, _ := getRDSInstanceFromIdentifier(cfg, instanceIdentifier) + + // get read replicas from the instance + readReplicas := instance.DBInstances[0].ReadReplicaDBInstanceIdentifiers + + if len(readReplicas) == 0 { + result.Passed = false + result.Message = "Multi Region instances not found" + return + } + + hostRDSRegion, _ := getHostRDSRegion() + + // loop over the read replicas and check if they are in a different region + for _, replica := range readReplicas { + // we are getting the instance identifier the read replicas + // get instance from the replica identifier + replicaInstance, err := getRDSInstanceFromIdentifier(cfg, replica) + + if err != nil { + result.Passed = false + result.Message = err.Error() + return + } + + if len(replicaInstance.DBInstances) == 0 { + result.Passed = false + result.Message = "Cannot access the replica instance " + replica + return + } + + // check if replica region matches the host region + az := *replicaInstance.DBInstances[0].AvailabilityZone + // db instance doesnt contain the region so we need to remove the last character from the az + if az[:len(az)-1] == hostRDSRegion { + result.Passed = false + result.Message = "Multi Region instances not found" + return + } + } + + result.Passed = true + return + +} diff --git a/strikes/MultiRegion_test.go b/strikes/MultiRegion_test.go new file mode 100644 index 0000000..34b6c67 --- /dev/null +++ b/strikes/MultiRegion_test.go @@ -0,0 +1,32 @@ +package strikes + +import ( + "encoding/json" + "fmt" + "testing" + + "github.com/spf13/viper" +) + +func TestMultiRegion(t *testing.T) { + viper.AddConfigPath("../") + viper.SetConfigName("config") + viper.SetConfigType("yaml") + err := viper.ReadInConfig() + + if err != nil { + fmt.Println("Config file not found...") + return + } + + strikes := Strikes{} + strikeName, result := strikes.MultiRegion() + + fmt.Println(strikeName) + b, err := json.MarshalIndent(result, "", " ") + if err != nil { + fmt.Println(err) + } + fmt.Print(string(b)) + fmt.Println() +} diff --git a/strikes/common.go b/strikes/common.go index 0530851..22f3520 100644 --- a/strikes/common.go +++ b/strikes/common.go @@ -7,6 +7,7 @@ import ( "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/credentials" + "github.com/aws/aws-sdk-go-v2/service/rds" hclog "github.com/hashicorp/go-hclog" "github.com/privateerproj/privateer-sdk/raidengine" "github.com/privateerproj/privateer-sdk/utils" @@ -32,17 +33,25 @@ func getDBConfig() (string, error) { return "", errors.New("database url must be set in the config file") } -func getDBInstanceIdentifier() (string, error) { +func getHostDBInstanceIdentifier() (string, error) { if viper.IsSet("raids.RDS.aws.config.instance_identifier") { return viper.GetString("raids.RDS.aws.config.instance_identifier"), nil } return "", errors.New("database instance identifier must be set in the config file") } +func getHostRDSRegion() (string, error) { + if viper.IsSet("raids.RDS.aws.config.primary_region") { + return viper.GetString("raids.RDS.aws.config.primary_region"), nil + } + return "", errors.New("database instance identifier must be set in the config file") +} + func getAWSConfig() (cfg aws.Config, err error) { if viper.IsSet("raids.RDS.aws.creds") && viper.IsSet("raids.RDS.aws.creds.aws_access_key") && - viper.IsSet("raids.RDS.aws.creds.aws_secret_key") { + viper.IsSet("raids.RDS.aws.creds.aws_secret_key") && + viper.IsSet("raids.RDS.aws.creds.aws_region") { access_key := viper.GetString("raids.RDS.aws.creds.aws_access_key") secret_key := viper.GetString("raids.RDS.aws.creds.aws_secret_key") @@ -68,3 +77,34 @@ func connectToDb() (result raidengine.MovementResult) { result.Passed = true return } + +func checkRDSInstanceMovement(cfg aws.Config) (result raidengine.MovementResult) { + // check if the instance is available + result = raidengine.MovementResult{ + Description: "Check if the instance is available/exists", + Function: utils.CallerPath(0), + } + + instanceIdentifier, _ := getHostDBInstanceIdentifier() + + instance, err := getRDSInstanceFromIdentifier(cfg, instanceIdentifier) + if err != nil { + // Handle error + result.Message = err.Error() + result.Passed = false + return + } + result.Passed = len(instance.DBInstances) > 0 + return +} + +func getRDSInstanceFromIdentifier(cfg aws.Config, identifier string) (instance *rds.DescribeDBInstancesOutput, err error) { + rdsClient := rds.NewFromConfig(cfg) + + input := &rds.DescribeDBInstancesInput{ + DBInstanceIdentifier: aws.String(identifier), + } + + instance, err = rdsClient.DescribeDBInstances(context.TODO(), input) + return +}