diff --git a/internal/terraform/aws/default.go b/internal/terraform/aws/default.go new file mode 100644 index 0000000..377da8f --- /dev/null +++ b/internal/terraform/aws/default.go @@ -0,0 +1,65 @@ +package aws + +import ( + "os" + + "github.com/carboniferio/carbonifer/internal/terraform/tfrefs" + "github.com/carboniferio/carbonifer/internal/utils" + tfjson "github.com/hashicorp/terraform-json" + log "github.com/sirupsen/logrus" + + "github.com/aws/aws-sdk-go/aws/ec2metadata" + "github.com/aws/aws-sdk-go/aws/session" +) + +func GetDefaults(awsConfig *tfjson.ProviderConfig, tfPlan *tfjson.Plan, terraformRefs *tfrefs.References) { + log.Debugf("Reading provider config %v", awsConfig.Name) + + region := getDefaultRegion(awsConfig, tfPlan) + if region != nil { + terraformRefs.ProviderConfigs["region"] = region.(string) + } +} + +func getDefaultRegion(awsConfig *tfjson.ProviderConfig, tfPlan *tfjson.Plan) interface{} { + var region interface{} + regionExpr := awsConfig.Expressions["region"] + if regionExpr != nil { + var err error + region, err = utils.GetValueOfExpression(regionExpr, tfPlan) + if err != nil { + log.Fatalf("Error getting region from provider config %v", err) + } + } + if region == nil { + if os.Getenv("AWS_DEFAULT_REGION") != "" { + region = os.Getenv("AWS_DEFAULT_REGION") + } + } + if region == nil { + if os.Getenv("AWS_REGION") != "" { + region = os.Getenv("AWS_REGION") + } + } + + // Check AWS Config file + if region == nil { + sess, err := session.NewSession() + if err != nil { + log.Fatalf("Error getting region from AWS config file %v", err) + } + if *sess.Config.Region != "" { + region = *sess.Config.Region + } + } + + // Check EC2 Instance Metadata + if region == nil { + sess := session.Must(session.NewSession()) + svc := ec2metadata.New(sess) + if svc.Available() { + region, _ = svc.Region() + } + } + return region +} diff --git a/internal/terraform/aws/default_test.go b/internal/terraform/aws/default_test.go new file mode 100644 index 0000000..66006f1 --- /dev/null +++ b/internal/terraform/aws/default_test.go @@ -0,0 +1,118 @@ +package aws + +import ( + "io" + "os" + "path/filepath" + "testing" + + tfjson "github.com/hashicorp/terraform-json" + "github.com/stretchr/testify/assert" +) + +func Test_getDefaultRegion_providerConstant(t *testing.T) { + awsConfigs := &tfjson.ProviderConfig{ + Name: "aws", + Expressions: map[string]*tfjson.Expression{ + "region": { + ExpressionData: &tfjson.ExpressionData{ + ConstantValue: "test1", + }, + }, + }, + } + + tfPlan := &tfjson.Plan{} + + region := getDefaultRegion(awsConfigs, tfPlan) + assert.Equal(t, "test1", region) + +} + +func Test_getDefaultRegion_providerVariable(t *testing.T) { + awsConfigs := &tfjson.ProviderConfig{ + Name: "aws", + Expressions: map[string]*tfjson.Expression{ + "region": { + ExpressionData: &tfjson.ExpressionData{ + References: []string{"var.region"}, + }, + }, + }, + } + + tfPlan := &tfjson.Plan{ + Variables: map[string]*tfjson.PlanVariable{ + "region": { + Value: "test2", + }, + }, + } + + region := getDefaultRegion(awsConfigs, tfPlan) + assert.Equal(t, "test2", region) + +} + +func Test_getDefaultRegion_EnvVar(t *testing.T) { + awsConfigs := &tfjson.ProviderConfig{ + Name: "aws", + Expressions: map[string]*tfjson.Expression{}, + } + + tfPlan := &tfjson.Plan{} + + t.Setenv("AWS_REGION", "test3") + + region := getDefaultRegion(awsConfigs, tfPlan) + assert.Equal(t, "test3", region) + +} + +func Test_getDefaultRegion_EnvDefaultVar(t *testing.T) { + awsConfigs := &tfjson.ProviderConfig{ + Name: "aws", + Expressions: map[string]*tfjson.Expression{}, + } + + tfPlan := &tfjson.Plan{} + + t.Setenv("AWS_DEFAULT_REGION", "test4") + + region := getDefaultRegion(awsConfigs, tfPlan) + assert.Equal(t, "test4", region) + +} + +func Test_getDefaultRegion_AWSConfigFile(t *testing.T) { + // Create a temporary directory + tmpDir := t.TempDir() + + // Create AWS config file + awsConfigFile := filepath.Join(tmpDir, "config") + f, err := os.Create(awsConfigFile) + if err != nil { + t.Fatal(err) + } + defer f.Close() + + _, err = io.WriteString(f, "[default]\nregion = region_from_config_file\n") + if err != nil { + t.Fatal(err) + } + + // Set the AWS_SDK_LOAD_CONFIG environment variable + t.Setenv("AWS_SDK_LOAD_CONFIG", "1") + t.Setenv("AWS_CONFIG_FILE", awsConfigFile) + + awsConfigs := &tfjson.ProviderConfig{ + Name: "aws", + Expressions: map[string]*tfjson.Expression{}, + } + + tfPlan := &tfjson.Plan{} + + region := getDefaultRegion(awsConfigs, tfPlan) + assert.Equal(t, "region_from_config_file", region) + +} diff --git a/internal/terraform/resources.go b/internal/terraform/resources.go index 376ee24..00aa32d 100644 --- a/internal/terraform/resources.go +++ b/internal/terraform/resources.go @@ -2,7 +2,6 @@ package terraform import ( "encoding/json" - "os" "strings" "github.com/carboniferio/carbonifer/internal/resources" @@ -66,20 +65,7 @@ func GetResources(tfPlan *tfjson.Plan) (map[string]resources.Resource, error) { // Get default values for provider, resConfig := range tfPlan.Config.ProviderConfigs { if provider == "aws" { - log.Debugf("Reading provider config %v", resConfig.Name) - // TODO #58 Improve way we get default regions (env var, profile...) - var region interface{} - regionExpr := resConfig.Expressions["region"] - if regionExpr != nil { - region = regionExpr.ConstantValue - } else { - if os.Getenv("AWS_REGION") != "" { - region = os.Getenv("AWS_REGION") - } - } - if region != nil { - terraformRefs.ProviderConfigs["region"] = region.(string) - } + aws.GetDefaults(resConfig, tfPlan, &terraformRefs) } } diff --git a/internal/utils/expressions.go b/internal/utils/expressions.go new file mode 100644 index 0000000..de07dfb --- /dev/null +++ b/internal/utils/expressions.go @@ -0,0 +1,28 @@ +package utils + +import ( + "errors" + "fmt" + "strings" + + tfjson "github.com/hashicorp/terraform-json" +) + +func GetValueOfExpression(expression *tfjson.Expression, tfPlan *tfjson.Plan) (interface{}, error) { + if fmt.Sprintf("%T", expression.ConstantValue) != "*tfjson.unknownConstantValue" && expression.ConstantValue != nil { + // It's a known value, return it as is + return expression.ConstantValue, nil + } + + // Constant value is not set or unknown, look up references + for _, reference := range expression.References { + ref := strings.TrimPrefix(reference, "var.") + if val, ok := tfPlan.Variables[ref]; ok { + return val.Value, nil + } + } + + // No variables were found + return nil, errors.New("no value found for expression") + +}