diff --git a/internal/service/ec2/security_group_rule.go b/internal/service/ec2/security_group_rule.go index b8b1b6f67df7..25c382142110 100644 --- a/internal/service/ec2/security_group_rule.go +++ b/internal/service/ec2/security_group_rule.go @@ -92,9 +92,8 @@ func ResourceSecurityGroupRule() *schema.Resource { }, "cidr_blocks": { - Type: schema.TypeList, + Type: schema.TypeSet, Optional: true, - ForceNew: true, Elem: &schema.Schema{ Type: schema.TypeString, ValidateFunc: verify.ValidCIDRNetworkAddress, @@ -103,9 +102,8 @@ func ResourceSecurityGroupRule() *schema.Resource { }, "ipv6_cidr_blocks": { - Type: schema.TypeList, + Type: schema.TypeSet, Optional: true, - ForceNew: true, Elem: &schema.Schema{ Type: schema.TypeString, ValidateFunc: verify.ValidCIDRNetworkAddress, @@ -349,6 +347,12 @@ func resourceSecurityGroupRuleUpdate(d *schema.ResourceData, meta interface{}) e } } + if d.HasChange("cidr_blocks") || d.HasChange("ipv6_cidr_blocks") { + if err := resourceSecurityGroupRuleCidrUpdate(conn, d); err != nil { + return err + } + } + return resourceSecurityGroupRuleRead(d, meta) } @@ -641,7 +645,7 @@ func expandIPPerm(d *schema.ResourceData, sg *ec2.SecurityGroup) (*ec2.IpPermiss } if raw, ok := d.GetOk("cidr_blocks"); ok { - list := raw.([]interface{}) + list := raw.(*schema.Set).List() perm.IpRanges = make([]*ec2.IpRange, len(list)) for i, v := range list { cidrIP, ok := v.(string) @@ -657,7 +661,7 @@ func expandIPPerm(d *schema.ResourceData, sg *ec2.SecurityGroup) (*ec2.IpPermiss } if raw, ok := d.GetOk("ipv6_cidr_blocks"); ok { - list := raw.([]interface{}) + list := raw.(*schema.Set).List() perm.Ipv6Ranges = make([]*ec2.Ipv6Range, len(list)) for i, v := range list { cidrIP, ok := v.(string) @@ -691,6 +695,63 @@ func expandIPPerm(d *schema.ResourceData, sg *ec2.SecurityGroup) (*ec2.IpPermiss return &perm, nil } +func expandCidrIPPerm(d *schema.ResourceData, cidrBlocks *schema.Set, ipv6CidrBlocks *schema.Set) *ec2.IpPermission { + var perm ec2.IpPermission + + protocol := ProtocolForValue(d.Get("protocol").(string)) + perm.IpProtocol = aws.String(protocol) + + if protocol != "-1" { + perm.FromPort = aws.Int64(int64(d.Get("from_port").(int))) + perm.ToPort = aws.Int64(int64(d.Get("to_port").(int))) + } + + description := d.Get("description").(string) + + if cidrBlocks.Len() > 0 { + list := cidrBlocks.List() + perm.IpRanges = make([]*ec2.IpRange, len(list)) + for i, v := range list { + cidrIP := v.(string) + perm.IpRanges[i] = &ec2.IpRange{CidrIp: aws.String(cidrIP)} + if description != "" { + perm.IpRanges[i].Description = aws.String(description) + } + } + } + + if ipv6CidrBlocks.Len() > 0 { + list := ipv6CidrBlocks.List() + perm.Ipv6Ranges = make([]*ec2.Ipv6Range, len(list)) + for i, v := range list { + cidrIP := v.(string) + perm.Ipv6Ranges[i] = &ec2.Ipv6Range{CidrIpv6: aws.String(cidrIP)} + if description != "" { + perm.Ipv6Ranges[i].Description = aws.String(description) + } + } + } + + return &perm +} + +// Get the sets of removed and added items in a set of +func getSetChange(d *schema.ResourceData, name string) (removed *schema.Set, added *schema.Set) { + o, n := d.GetChange(name) + if o == nil { + o = new(schema.Set) + } + if n == nil { + n = new(schema.Set) + } + old := o.(*schema.Set) + new_ := n.(*schema.Set) + removed = old.Difference(new_) + added = new_.Difference(old) + + return removed, added +} + func setFromIPPerm(d *schema.ResourceData, sg *ec2.SecurityGroup, rule *ec2.IpPermission) { isVPC := aws.StringValue(sg.VpcId) != "" @@ -698,15 +759,15 @@ func setFromIPPerm(d *schema.ResourceData, sg *ec2.SecurityGroup, rule *ec2.IpPe d.Set("to_port", rule.ToPort) d.Set("protocol", rule.IpProtocol) - var cb []string + cb := &schema.Set{F: schema.HashString} for _, c := range rule.IpRanges { - cb = append(cb, *c.CidrIp) + cb.Add(*c.CidrIp) } d.Set("cidr_blocks", cb) - var ipv6 []string + ipv6 := &schema.Set{F: schema.HashString} for _, ip := range rule.Ipv6Ranges { - ipv6 = append(ipv6, *ip.CidrIpv6) + ipv6.Add(*ip.CidrIpv6) } d.Set("ipv6_cidr_blocks", ipv6) @@ -741,16 +802,14 @@ func setFromIPPerm(d *schema.ResourceData, sg *ec2.SecurityGroup, rule *ec2.IpPe func descriptionFromIPPerm(d *schema.ResourceData, rule *ec2.IpPermission) string { // probe IpRanges - cidrIps := make(map[string]bool) + var cidrIps *schema.Set if raw, ok := d.GetOk("cidr_blocks"); ok { - for _, v := range raw.([]interface{}) { - cidrIps[v.(string)] = true - } + cidrIps = raw.(*schema.Set) } - if len(cidrIps) > 0 { + if cidrIps != nil && cidrIps.Len() > 0 { for _, c := range rule.IpRanges { - if _, ok := cidrIps[*c.CidrIp]; !ok { + if !cidrIps.Contains(*c.CidrIp) { continue } @@ -761,16 +820,14 @@ func descriptionFromIPPerm(d *schema.ResourceData, rule *ec2.IpPermission) strin } // probe Ipv6Ranges - cidrIpv6s := make(map[string]bool) + var cidrIpv6s *schema.Set if raw, ok := d.GetOk("ipv6_cidr_blocks"); ok { - for _, v := range raw.([]interface{}) { - cidrIpv6s[v.(string)] = true - } + cidrIpv6s = raw.(*schema.Set) } - if len(cidrIpv6s) > 0 { + if cidrIpv6s != nil && cidrIpv6s.Len() > 0 { for _, ip := range rule.Ipv6Ranges { - if _, ok := cidrIpv6s[*ip.CidrIpv6]; !ok { + if !cidrIpv6s.Contains(*ip.CidrIpv6) { continue } @@ -897,6 +954,63 @@ func resourceSecurityGroupRuleDescriptionUpdate(conn *ec2.EC2, d *schema.Resourc return nil } +func resourceSecurityGroupRuleCidrUpdate(conn *ec2.EC2, d *schema.ResourceData) error { + var err error + sg_id := d.Get("security_group_id").(string) + + removed, added := getSetChange(d, "cidr_blocks") + ipv6Removed, ipv6Added := getSetChange(d, "ipv6_cidr_blocks") + + removePerm := expandCidrIPPerm(d, removed, ipv6Removed) + addPerm := expandCidrIPPerm(d, added, ipv6Added) + + conns.GlobalMutexKV.Lock(sg_id) + defer conns.GlobalMutexKV.Unlock(sg_id) + + ruleType := d.Get("type").(string) + log.Printf("[DEBUG] Revoking rules (%s) from security group %s:\n%s", ruleType, sg_id, removePerm) + switch ruleType { + case "ingress": + req := &ec2.RevokeSecurityGroupIngressInput{ + GroupId: aws.String(sg_id), + IpPermissions: []*ec2.IpPermission{removePerm}, + } + _, err = conn.RevokeSecurityGroupIngress(req) + case "egress": + req := &ec2.RevokeSecurityGroupEgressInput{ + GroupId: aws.String(sg_id), + IpPermissions: []*ec2.IpPermission{removePerm}, + } + _, err = conn.RevokeSecurityGroupEgress(req) + } + if err != nil { + return fmt.Errorf("Error revoking security group %s rules: %s", sg_id, err) + } + + log.Printf("[DEBUG] Adding rules (%s) for security group %s:\n%s", ruleType, sg_id, addPerm) + switch ruleType { + case "ingress": + req := &ec2.AuthorizeSecurityGroupIngressInput{ + GroupId: aws.String(sg_id), + IpPermissions: []*ec2.IpPermission{addPerm}, + } + + _, err = conn.AuthorizeSecurityGroupIngress(req) + case "egress": + req := &ec2.AuthorizeSecurityGroupEgressInput{ + GroupId: aws.String(sg_id), + IpPermissions: []*ec2.IpPermission{addPerm}, + } + + _, err = conn.AuthorizeSecurityGroupEgress(req) + } + if err != nil { + return fmt.Errorf("Error adding security group %s rules: %s", sg_id, err) + } + + return nil +} + // validateSecurityGroupRuleImportString does minimal validation of import string without going to AWS func validateSecurityGroupRuleImportString(importStr string) ([]string, error) { // example: sg-09a093729ef9382a6_ingress_tcp_8000_8000_10.0.3.0/24 @@ -981,9 +1095,9 @@ func populateSecurityGroupRuleFromImport(d *schema.ResourceData, importParts []s d.Set("to_port", toPort) d.Set("self", false) - var cidrs []string + cidrs := schema.Set{F: schema.HashString} var prefixList []string - var ipv6cidrs []string + ipv6cidrs := schema.Set{F: schema.HashString} for _, source := range sources { if source == "self" { d.Set("self", true) @@ -992,13 +1106,13 @@ func populateSecurityGroupRuleFromImport(d *schema.ResourceData, importParts []s } else if strings.Contains(source, "pl-") { prefixList = append(prefixList, source) } else if strings.Contains(source, ":") { - ipv6cidrs = append(ipv6cidrs, source) + ipv6cidrs.Add(source) } else { - cidrs = append(cidrs, source) + cidrs.Add(source) } } - d.Set("ipv6_cidr_blocks", ipv6cidrs) - d.Set("cidr_blocks", cidrs) + d.Set("ipv6_cidr_blocks", &ipv6cidrs) + d.Set("cidr_blocks", &cidrs) d.Set("prefix_list_ids", prefixList) return nil