Skip to content

Commit

Permalink
Allow update instead for recreate for aws_security_group_rule.
Browse files Browse the repository at this point in the history
With this change, changes it cidr_blocks and ipv6_cidr_blocks will only
remove/add the cidr ranges that were removed/added in config, rather
than destroying the entire resource and recreating it.

It also changes the type of those attributes to sets to make the diffs
more readable.
  • Loading branch information
tmccombs committed Oct 20, 2021
1 parent 1b0103e commit a5f2c9d
Showing 1 changed file with 142 additions and 28 deletions.
170 changes: 142 additions & 28 deletions internal/service/ec2/security_group_rule.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,8 @@ func ResourceSecurityGroupRule() *schema.Resource {
},

"cidr_blocks": {
Type: schema.TypeList,
Type: schema.TypeSet,
Optional: true,
ForceNew: true,
Elem: &schema.Schema{
Type: schema.TypeString,
ValidateFunc: verify.ValidCIDRNetworkAddress,
Expand All @@ -103,9 +102,8 @@ func ResourceSecurityGroupRule() *schema.Resource {
},

"ipv6_cidr_blocks": {
Type: schema.TypeList,
Type: schema.TypeSet,
Optional: true,
ForceNew: true,
Elem: &schema.Schema{
Type: schema.TypeString,
ValidateFunc: verify.ValidCIDRNetworkAddress,
Expand Down Expand Up @@ -349,6 +347,12 @@ func resourceSecurityGroupRuleUpdate(d *schema.ResourceData, meta interface{}) e
}
}

if d.HasChange("cidr_blocks") || d.HasChange("ipv6_cidr_blocks") {
if err := resourceSecurityGroupRuleCidrUpdate(conn, d); err != nil {
return err
}
}

return resourceSecurityGroupRuleRead(d, meta)
}

Expand Down Expand Up @@ -641,7 +645,7 @@ func expandIPPerm(d *schema.ResourceData, sg *ec2.SecurityGroup) (*ec2.IpPermiss
}

if raw, ok := d.GetOk("cidr_blocks"); ok {
list := raw.([]interface{})
list := raw.(*schema.Set).List()
perm.IpRanges = make([]*ec2.IpRange, len(list))
for i, v := range list {
cidrIP, ok := v.(string)
Expand All @@ -657,7 +661,7 @@ func expandIPPerm(d *schema.ResourceData, sg *ec2.SecurityGroup) (*ec2.IpPermiss
}

if raw, ok := d.GetOk("ipv6_cidr_blocks"); ok {
list := raw.([]interface{})
list := raw.(*schema.Set).List()
perm.Ipv6Ranges = make([]*ec2.Ipv6Range, len(list))
for i, v := range list {
cidrIP, ok := v.(string)
Expand Down Expand Up @@ -691,22 +695,79 @@ func expandIPPerm(d *schema.ResourceData, sg *ec2.SecurityGroup) (*ec2.IpPermiss
return &perm, nil
}

func expandCidrIPPerm(d *schema.ResourceData, cidrBlocks *schema.Set, ipv6CidrBlocks *schema.Set) *ec2.IpPermission {
var perm ec2.IpPermission

protocol := ProtocolForValue(d.Get("protocol").(string))
perm.IpProtocol = aws.String(protocol)

if protocol != "-1" {
perm.FromPort = aws.Int64(int64(d.Get("from_port").(int)))
perm.ToPort = aws.Int64(int64(d.Get("to_port").(int)))
}

description := d.Get("description").(string)

if cidrBlocks.Len() > 0 {
list := cidrBlocks.List()
perm.IpRanges = make([]*ec2.IpRange, len(list))
for i, v := range list {
cidrIP := v.(string)
perm.IpRanges[i] = &ec2.IpRange{CidrIp: aws.String(cidrIP)}
if description != "" {
perm.IpRanges[i].Description = aws.String(description)
}
}
}

if ipv6CidrBlocks.Len() > 0 {
list := ipv6CidrBlocks.List()
perm.Ipv6Ranges = make([]*ec2.Ipv6Range, len(list))
for i, v := range list {
cidrIP := v.(string)
perm.Ipv6Ranges[i] = &ec2.Ipv6Range{CidrIpv6: aws.String(cidrIP)}
if description != "" {
perm.Ipv6Ranges[i].Description = aws.String(description)
}
}
}

return &perm
}

// Get the sets of removed and added items in a set of
func getSetChange(d *schema.ResourceData, name string) (removed *schema.Set, added *schema.Set) {
o, n := d.GetChange(name)
if o == nil {
o = new(schema.Set)
}
if n == nil {
n = new(schema.Set)
}
old := o.(*schema.Set)
new_ := n.(*schema.Set)
removed = old.Difference(new_)
added = new_.Difference(old)

return removed, added
}

func setFromIPPerm(d *schema.ResourceData, sg *ec2.SecurityGroup, rule *ec2.IpPermission) {
isVPC := aws.StringValue(sg.VpcId) != ""

d.Set("from_port", rule.FromPort)
d.Set("to_port", rule.ToPort)
d.Set("protocol", rule.IpProtocol)

var cb []string
cb := &schema.Set{F: schema.HashString}
for _, c := range rule.IpRanges {
cb = append(cb, *c.CidrIp)
cb.Add(*c.CidrIp)
}
d.Set("cidr_blocks", cb)

var ipv6 []string
ipv6 := &schema.Set{F: schema.HashString}
for _, ip := range rule.Ipv6Ranges {
ipv6 = append(ipv6, *ip.CidrIpv6)
ipv6.Add(*ip.CidrIpv6)
}
d.Set("ipv6_cidr_blocks", ipv6)

Expand Down Expand Up @@ -741,16 +802,14 @@ func setFromIPPerm(d *schema.ResourceData, sg *ec2.SecurityGroup, rule *ec2.IpPe

func descriptionFromIPPerm(d *schema.ResourceData, rule *ec2.IpPermission) string {
// probe IpRanges
cidrIps := make(map[string]bool)
var cidrIps *schema.Set
if raw, ok := d.GetOk("cidr_blocks"); ok {
for _, v := range raw.([]interface{}) {
cidrIps[v.(string)] = true
}
cidrIps = raw.(*schema.Set)
}

if len(cidrIps) > 0 {
if cidrIps != nil && cidrIps.Len() > 0 {
for _, c := range rule.IpRanges {
if _, ok := cidrIps[*c.CidrIp]; !ok {
if !cidrIps.Contains(*c.CidrIp) {
continue
}

Expand All @@ -761,16 +820,14 @@ func descriptionFromIPPerm(d *schema.ResourceData, rule *ec2.IpPermission) strin
}

// probe Ipv6Ranges
cidrIpv6s := make(map[string]bool)
var cidrIpv6s *schema.Set
if raw, ok := d.GetOk("ipv6_cidr_blocks"); ok {
for _, v := range raw.([]interface{}) {
cidrIpv6s[v.(string)] = true
}
cidrIpv6s = raw.(*schema.Set)
}

if len(cidrIpv6s) > 0 {
if cidrIpv6s != nil && cidrIpv6s.Len() > 0 {
for _, ip := range rule.Ipv6Ranges {
if _, ok := cidrIpv6s[*ip.CidrIpv6]; !ok {
if !cidrIpv6s.Contains(*ip.CidrIpv6) {
continue
}

Expand Down Expand Up @@ -897,6 +954,63 @@ func resourceSecurityGroupRuleDescriptionUpdate(conn *ec2.EC2, d *schema.Resourc
return nil
}

func resourceSecurityGroupRuleCidrUpdate(conn *ec2.EC2, d *schema.ResourceData) error {
var err error
sg_id := d.Get("security_group_id").(string)

removed, added := getSetChange(d, "cidr_blocks")
ipv6Removed, ipv6Added := getSetChange(d, "ipv6_cidr_blocks")

removePerm := expandCidrIPPerm(d, removed, ipv6Removed)
addPerm := expandCidrIPPerm(d, added, ipv6Added)

conns.GlobalMutexKV.Lock(sg_id)
defer conns.GlobalMutexKV.Unlock(sg_id)

ruleType := d.Get("type").(string)
log.Printf("[DEBUG] Revoking rules (%s) from security group %s:\n%s", ruleType, sg_id, removePerm)
switch ruleType {
case "ingress":
req := &ec2.RevokeSecurityGroupIngressInput{
GroupId: aws.String(sg_id),
IpPermissions: []*ec2.IpPermission{removePerm},
}
_, err = conn.RevokeSecurityGroupIngress(req)
case "egress":
req := &ec2.RevokeSecurityGroupEgressInput{
GroupId: aws.String(sg_id),
IpPermissions: []*ec2.IpPermission{removePerm},
}
_, err = conn.RevokeSecurityGroupEgress(req)
}
if err != nil {
return fmt.Errorf("Error revoking security group %s rules: %s", sg_id, err)
}

log.Printf("[DEBUG] Adding rules (%s) for security group %s:\n%s", ruleType, sg_id, addPerm)
switch ruleType {
case "ingress":
req := &ec2.AuthorizeSecurityGroupIngressInput{
GroupId: aws.String(sg_id),
IpPermissions: []*ec2.IpPermission{addPerm},
}

_, err = conn.AuthorizeSecurityGroupIngress(req)
case "egress":
req := &ec2.AuthorizeSecurityGroupEgressInput{
GroupId: aws.String(sg_id),
IpPermissions: []*ec2.IpPermission{addPerm},
}

_, err = conn.AuthorizeSecurityGroupEgress(req)
}
if err != nil {
return fmt.Errorf("Error adding security group %s rules: %s", sg_id, err)
}

return nil
}

// validateSecurityGroupRuleImportString does minimal validation of import string without going to AWS
func validateSecurityGroupRuleImportString(importStr string) ([]string, error) {
// example: sg-09a093729ef9382a6_ingress_tcp_8000_8000_10.0.3.0/24
Expand Down Expand Up @@ -981,9 +1095,9 @@ func populateSecurityGroupRuleFromImport(d *schema.ResourceData, importParts []s
d.Set("to_port", toPort)

d.Set("self", false)
var cidrs []string
cidrs := schema.Set{F: schema.HashString}
var prefixList []string
var ipv6cidrs []string
ipv6cidrs := schema.Set{F: schema.HashString}
for _, source := range sources {
if source == "self" {
d.Set("self", true)
Expand All @@ -992,13 +1106,13 @@ func populateSecurityGroupRuleFromImport(d *schema.ResourceData, importParts []s
} else if strings.Contains(source, "pl-") {
prefixList = append(prefixList, source)
} else if strings.Contains(source, ":") {
ipv6cidrs = append(ipv6cidrs, source)
ipv6cidrs.Add(source)
} else {
cidrs = append(cidrs, source)
cidrs.Add(source)
}
}
d.Set("ipv6_cidr_blocks", ipv6cidrs)
d.Set("cidr_blocks", cidrs)
d.Set("ipv6_cidr_blocks", &ipv6cidrs)
d.Set("cidr_blocks", &cidrs)
d.Set("prefix_list_ids", prefixList)

return nil
Expand Down

0 comments on commit a5f2c9d

Please sign in to comment.