Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

provider: Convert and enforce most data sources to AWS Go SDK pointer conversion functions during conditionals #17718

Merged
merged 1 commit into from
Mar 5, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions .semgrep.yml
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,41 @@ rules:
- pattern-not: '*$LHS2 = *$RHS'
severity: WARNING

- id: prefer-aws-go-sdk-pointer-conversion-conditional
languages: [go]
message: Prefer AWS Go SDK pointer conversion functions for dereferencing during conditionals, e.g. aws.StringValue()
paths:
exclude:
- aws/cloudfront_distribution_configuration_structure.go
- aws/cloudfront_distribution_configuration_structure_test.go
- aws/config.go
- aws/data_source_aws_route*
- aws/ecs_task_definition_equivalency.go
- aws/opsworks_layers.go
- aws/resource*
- aws/structure.go
- aws/internal/generators/
- aws/internal/keyvaluetags/
- aws/internal/naming/
- awsproviderlint/vendor/
include:
- aws/
patterns:
- pattern-either:
- pattern: '$LHS == *$RHS'
- pattern: '$LHS != *$RHS'
- pattern: '$LHS > *$RHS'
- pattern: '$LHS < *$RHS'
- pattern: '$LHS >= *$RHS'
- pattern: '$LHS <= *$RHS'
- pattern: '*$LHS == $RHS'
- pattern: '*$LHS != $RHS'
- pattern: '*$LHS > $RHS'
- pattern: '*$LHS < $RHS'
- pattern: '*$LHS >= $RHS'
- pattern: '*$LHS <= $RHS'
severity: WARNING

- id: aws-go-sdk-pointer-conversion-ResourceData-SetId
fix: d.SetId(aws.StringValue($VALUE))
languages: [go]
Expand Down
12 changes: 6 additions & 6 deletions aws/data_source_aws_acm_certificate.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ func dataSourceAwsAcmCertificateRead(d *schema.ResourceData, meta interface{}) e
log.Printf("[DEBUG] Reading ACM Certificate: %s", params)
err := conn.ListCertificatesPages(params, func(page *acm.ListCertificatesOutput, lastPage bool) bool {
for _, cert := range page.CertificateSummaryList {
if *cert.DomainName == target {
if aws.StringValue(cert.DomainName) == target {
arns = append(arns, cert.CertificateArn)
}
}
Expand Down Expand Up @@ -125,7 +125,7 @@ func dataSourceAwsAcmCertificateRead(d *schema.ResourceData, meta interface{}) e

if filterTypesOk {
for _, certType := range typesStrings {
if *certificate.Type == *certType {
if aws.StringValue(certificate.Type) == aws.StringValue(certType) {
// We do not have a candidate certificate
if matchedCertificate == nil {
matchedCertificate = certificate
Expand Down Expand Up @@ -185,18 +185,18 @@ func dataSourceAwsAcmCertificateRead(d *schema.ResourceData, meta interface{}) e
}

func mostRecentAcmCertificate(i, j *acm.CertificateDetail) (*acm.CertificateDetail, error) {
if *i.Status != *j.Status {
if aws.StringValue(i.Status) != aws.StringValue(j.Status) {
return nil, fmt.Errorf("most_recent filtering on different ACM certificate statues is not supported")
}
// Cover IMPORTED and ISSUED AMAZON_ISSUED certificates
if *i.Status == acm.CertificateStatusIssued {
if (*i.NotBefore).After(*j.NotBefore) {
if aws.StringValue(i.Status) == acm.CertificateStatusIssued {
if aws.TimeValue(i.NotBefore).After(aws.TimeValue(j.NotBefore)) {
return i, nil
}
return j, nil
}
// Cover non-ISSUED AMAZON_ISSUED certificates
if (*i.CreatedAt).After(*j.CreatedAt) {
if aws.TimeValue(i.CreatedAt).After(aws.TimeValue(j.CreatedAt)) {
return i, nil
}
return j, nil
Expand Down
7 changes: 4 additions & 3 deletions aws/data_source_aws_ami_ids.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,13 +82,14 @@ func dataSourceAwsAmiIdsRead(d *schema.ResourceData, meta interface{}) error {
// Check for a very rare case where the response would include no
// image name. No name means nothing to attempt a match against,
// therefore we are skipping such image.
if image.Name == nil || *image.Name == "" {
name := aws.StringValue(image.Name)
if name == "" {
log.Printf("[WARN] Unable to find AMI name to match against "+
"for image ID %q owned by %q, nothing to do.",
*image.ImageId, *image.OwnerId)
aws.StringValue(image.ImageId), aws.StringValue(image.OwnerId))
continue
}
if r.MatchString(*image.Name) {
if r.MatchString(name) {
filteredImages = append(filteredImages, image)
}
}
Expand Down
2 changes: 1 addition & 1 deletion aws/data_source_aws_api_gateway_rest_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ func dataSourceAwsApiGatewayRestApiRead(d *schema.ResourceData, meta interface{}

err = conn.GetResourcesPages(resourceParams, func(page *apigateway.GetResourcesOutput, lastPage bool) bool {
for _, item := range page.Items {
if *item.Path == "/" {
if aws.StringValue(item.Path) == "/" {
d.Set("root_resource_id", item.Id)
return false
}
Expand Down
2 changes: 1 addition & 1 deletion aws/data_source_aws_cloudfront_cache_policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ func dataSourceAwsCloudFrontCachePolicyFindByName(d *schema.ResourceData, conn *
}

for _, policySummary := range resp.CachePolicyList.Items {
if *policySummary.CachePolicy.CachePolicyConfig.Name == d.Get("name").(string) {
if aws.StringValue(policySummary.CachePolicy.CachePolicyConfig.Name) == d.Get("name").(string) {
cachePolicy = policySummary.CachePolicy
break
}
Expand Down
2 changes: 1 addition & 1 deletion aws/data_source_aws_cloudfront_origin_request_policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ func dataSourceAwsCloudFrontOriginRequestPolicyFindByName(d *schema.ResourceData
}

for _, policySummary := range resp.OriginRequestPolicyList.Items {
if *policySummary.OriginRequestPolicy.OriginRequestPolicyConfig.Name == d.Get("name").(string) {
if aws.StringValue(policySummary.OriginRequestPolicy.OriginRequestPolicyConfig.Name) == d.Get("name").(string) {
originRequestPolicy = policySummary.OriginRequestPolicy
break
}
Expand Down
5 changes: 3 additions & 2 deletions aws/data_source_aws_ebs_default_kms_key_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"fmt"
"testing"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/ec2"
"github.com/hashicorp/terraform-plugin-sdk/v2/helper/resource"
"github.com/hashicorp/terraform-plugin-sdk/v2/terraform"
Expand Down Expand Up @@ -44,8 +45,8 @@ func testAccCheckDataSourceAwsEBSDefaultKmsKey(n string) resource.TestCheckFunc

attr := rs.Primary.Attributes["key_arn"]

if attr != *actual.KmsKeyId {
return fmt.Errorf("EBS default KMS key is not the expected value (%v)", actual.KmsKeyId)
if attr != aws.StringValue(actual.KmsKeyId) {
return fmt.Errorf("EBS default KMS key is not the expected value (%s)", aws.StringValue(actual.KmsKeyId))
}

return nil
Expand Down
5 changes: 3 additions & 2 deletions aws/data_source_aws_ebs_encryption_by_default_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"strconv"
"testing"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/ec2"
"github.com/hashicorp/terraform-plugin-sdk/v2/helper/resource"
"github.com/hashicorp/terraform-plugin-sdk/v2/terraform"
Expand Down Expand Up @@ -45,8 +46,8 @@ func testAccCheckDataSourceAwsEBSEncryptionByDefault(n string) resource.TestChec

attr, _ := strconv.ParseBool(rs.Primary.Attributes["enabled"])

if attr != *actual.EbsEncryptionByDefault {
return fmt.Errorf("EBS encryption by default is not in expected state (%t)", *actual.EbsEncryptionByDefault)
if attr != aws.BoolValue(actual.EbsEncryptionByDefault) {
return fmt.Errorf("EBS encryption by default is not in expected state (%t)", aws.BoolValue(actual.EbsEncryptionByDefault))
}

return nil
Expand Down
2 changes: 1 addition & 1 deletion aws/data_source_aws_elasticache_cluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ func dataSourceAwsElastiCacheClusterRead(d *schema.ResourceData, meta interface{
d.Set("availability_zone", cluster.PreferredAvailabilityZone)

if cluster.NotificationConfiguration != nil {
if *cluster.NotificationConfiguration.TopicStatus == "active" {
if aws.StringValue(cluster.NotificationConfiguration.TopicStatus) == "active" {
d.Set("notification_topic_arn", cluster.NotificationConfiguration.TopicArn)
}
}
Expand Down
2 changes: 1 addition & 1 deletion aws/data_source_aws_iam_server_certificate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ func TestResourceSortByExpirationDate(t *testing.T) {
},
}
sort.Sort(certificateByExpiration(certs))
if *certs[0].ServerCertificateName != "latest" {
if aws.StringValue(certs[0].ServerCertificateName) != "latest" {
t.Fatalf("Expected first item to be %q, but was %q", "latest", *certs[0].ServerCertificateName)
}
}
Expand Down
10 changes: 5 additions & 5 deletions aws/data_source_aws_instance.go
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ func dataSourceAwsInstanceRead(d *schema.ResourceData, meta interface{}) error {
// loop through reservations, and remove terminated instances, populate instance slice
for _, res := range resp.Reservations {
for _, instance := range res.Instances {
if instance.State != nil && *instance.State.Name != "terminated" {
if instance.State != nil && aws.StringValue(instance.State.Name) != ec2.InstanceStateNameTerminated {
filteredInstances = append(filteredInstances, instance)
}
}
Expand All @@ -408,13 +408,13 @@ func dataSourceAwsInstanceRead(d *schema.ResourceData, meta interface{}) error {
instance = filteredInstances[0]
}

log.Printf("[DEBUG] aws_instance - Single Instance ID found: %s", *instance.InstanceId)
log.Printf("[DEBUG] aws_instance - Single Instance ID found: %s", aws.StringValue(instance.InstanceId))
if err := instanceDescriptionAttributes(d, instance, conn, ignoreTagsConfig); err != nil {
return err
}

if d.Get("get_password_data").(bool) {
passwordData, err := getAwsEc2InstancePasswordData(*instance.InstanceId, conn)
passwordData, err := getAwsEc2InstancePasswordData(aws.StringValue(instance.InstanceId), conn)
if err != nil {
return err
}
Expand Down Expand Up @@ -475,7 +475,7 @@ func instanceDescriptionAttributes(d *schema.ResourceData, instance *ec2.Instanc
// iterate through network interfaces, and set subnet, network_interface, public_addr
if len(instance.NetworkInterfaces) > 0 {
for _, ni := range instance.NetworkInterfaces {
if *ni.Attachment.DeviceIndex == 0 {
if aws.Int64Value(ni.Attachment.DeviceIndex) == 0 {
d.Set("subnet_id", ni.SubnetId)
d.Set("network_interface_id", ni.NetworkInterfaceId)
d.Set("associate_public_ip_address", ni.Association != nil)
Expand All @@ -497,7 +497,7 @@ func instanceDescriptionAttributes(d *schema.ResourceData, instance *ec2.Instanc
}

d.Set("ebs_optimized", instance.EbsOptimized)
if instance.SubnetId != nil && *instance.SubnetId != "" {
if aws.StringValue(instance.SubnetId) != "" {
d.Set("source_dest_check", instance.SourceDestCheck)
}

Expand Down
2 changes: 1 addition & 1 deletion aws/data_source_aws_kms_alias.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ func dataSourceAwsKmsAliasRead(d *schema.ResourceData, meta interface{}) error {
log.Printf("[DEBUG] Reading KMS Alias: %s", params)
err := conn.ListAliasesPages(params, func(page *kms.ListAliasesOutput, lastPage bool) bool {
for _, entity := range page.Aliases {
if *entity.AliasName == target {
if aws.StringValue(entity.AliasName) == target {
alias = entity
return false
}
Expand Down
2 changes: 1 addition & 1 deletion aws/data_source_aws_lb_listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ func dataSourceAwsLbListenerRead(d *schema.ResourceData, meta interface{}) error
return fmt.Errorf("no listener exists for load balancer: %s", lbArn)
}
for _, listener := range resp.Listeners {
if *listener.Port == int64(port.(int)) {
if aws.Int64Value(listener.Port) == int64(port.(int)) {
//log.Printf("[DEBUG] get listener arn for %s:%s: %s", lbArn, port, *listener.Port)
d.SetId(aws.StringValue(listener.ListenerArn))
return resourceAwsLbListenerRead(d, meta)
Expand Down
2 changes: 1 addition & 1 deletion aws/data_source_aws_nat_gateway.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ func dataSourceAwsNatGatewayRead(d *schema.ResourceData, meta interface{}) error
}

for _, address := range ngw.NatGatewayAddresses {
if *address.AllocationId != "" {
if aws.StringValue(address.AllocationId) != "" {
d.Set("allocation_id", address.AllocationId)
d.Set("network_interface_id", address.NetworkInterfaceId)
d.Set("private_ip", address.PrivateIp)
Expand Down
4 changes: 2 additions & 2 deletions aws/data_source_aws_sfn_state_machine.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ func dataSourceAwsSfnStateMachineRead(d *schema.ResourceData, meta interface{})

err := conn.ListStateMachinesPages(params, func(page *sfn.ListStateMachinesOutput, lastPage bool) bool {
for _, sm := range page.StateMachines {
if *sm.Name == target {
arns = append(arns, *sm.StateMachineArn)
if aws.StringValue(sm.Name) == target {
arns = append(arns, aws.StringValue(sm.StateMachineArn))
}
}
return true
Expand Down
2 changes: 1 addition & 1 deletion aws/data_source_aws_ssm_patch_baseline.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ func dataAwsSsmPatchBaselineRead(d *schema.ResourceData, meta interface{}) error
var filteredBaselines []*ssm.PatchBaselineIdentity
if v, ok := d.GetOk("operating_system"); ok {
for _, baseline := range resp.BaselineIdentities {
if v.(string) == *baseline.OperatingSystem {
if v.(string) == aws.StringValue(baseline.OperatingSystem) {
filteredBaselines = append(filteredBaselines, baseline)
}
}
Expand Down
2 changes: 1 addition & 1 deletion aws/data_source_aws_subnet.go
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ func dataSourceAwsSubnetRead(d *schema.ResourceData, meta interface{}) error {
d.Set("default_for_az", subnet.DefaultForAz)

for _, a := range subnet.Ipv6CidrBlockAssociationSet {
if *a.Ipv6CidrBlockState.State == "associated" { //we can only ever have 1 IPv6 block associated at once
if a.Ipv6CidrBlockState != nil && aws.StringValue(a.Ipv6CidrBlockState.State) == ec2.VpcCidrBlockStateCodeAssociated { //we can only ever have 1 IPv6 block associated at once
d.Set("ipv6_cidr_block_association_id", a.AssociationId)
d.Set("ipv6_cidr_block", a.Ipv6CidrBlock)
}
Expand Down
4 changes: 2 additions & 2 deletions aws/data_source_aws_workspaces_image_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,10 @@ func testAccCheckWorkspacesImageExists(n string, image *workspaces.WorkspaceImag
if err != nil {
return fmt.Errorf("Failed describe workspaces images: %w", err)
}
if len(resp.Images) == 0 {
if resp == nil || len(resp.Images) == 0 || resp.Images[0] == nil {
return fmt.Errorf("Workspace image %s was not found", rs.Primary.ID)
}
if *resp.Images[0].ImageId != rs.Primary.ID {
if aws.StringValue(resp.Images[0].ImageId) != rs.Primary.ID {
return fmt.Errorf("Workspace image ID mismatch - existing: %q, state: %q", *resp.Images[0].ImageId, rs.Primary.ID)
}

Expand Down