Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

resource/aws_security_group_rule: Prevent crash when reading rules from groups containing an ALL/-1 protocol rule #6419

Merged
merged 1 commit into from
Nov 10, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 25 additions & 3 deletions aws/resource_aws_security_group_rule.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,13 @@ func resourceAwsSecurityGroupRule() *schema.Resource {
Type: schema.TypeInt,
Required: true,
ForceNew: true,
DiffSuppressFunc: func(k, old, new string, d *schema.ResourceData) bool {
protocol := protocolForValue(d.Get("protocol").(string))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in the future it'd be good for protocolForValue to return a *string so we can easily detect this is invalid?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

protocolForValue() returns the lowercased input string if it cannot find a match. Since we're specifically checking for a known value (should probably be a constant 🤔 ), I don't think it matters too much in this context since we do not have the opportunity to throw an error (other than panic()) within a DiffSuppressFunc. Definitely open to chatting about this further if I'm missing something (🥁). 😄

if protocol == "-1" && old == "0" && new == "65535" {
return true
}
return false
},
},

"protocol": {
Expand Down Expand Up @@ -430,21 +437,24 @@ func (b ByGroupPair) Less(i, j int) bool {
func findRuleMatch(p *ec2.IpPermission, rules []*ec2.IpPermission, isVPC bool) *ec2.IpPermission {
var rule *ec2.IpPermission
for _, r := range rules {
if r.ToPort != nil && *p.ToPort != *r.ToPort {
if p.ToPort != nil && r.ToPort != nil && *p.ToPort != *r.ToPort {
continue
}

if r.FromPort != nil && *p.FromPort != *r.FromPort {
if p.FromPort != nil && r.FromPort != nil && *p.FromPort != *r.FromPort {
continue
}

if r.IpProtocol != nil && *p.IpProtocol != *r.IpProtocol {
if p.IpProtocol != nil && r.IpProtocol != nil && *p.IpProtocol != *r.IpProtocol {
continue
}

remaining := len(p.IpRanges)
for _, ip := range p.IpRanges {
for _, rip := range r.IpRanges {
if ip.CidrIp == nil || rip.CidrIp == nil {
continue
}
if *ip.CidrIp == *rip.CidrIp {
remaining--
}
Expand All @@ -458,6 +468,9 @@ func findRuleMatch(p *ec2.IpPermission, rules []*ec2.IpPermission, isVPC bool) *
remaining = len(p.Ipv6Ranges)
for _, ipv6 := range p.Ipv6Ranges {
for _, ipv6ip := range r.Ipv6Ranges {
if ipv6.CidrIpv6 == nil || ipv6ip.CidrIpv6 == nil {
continue
}
if *ipv6.CidrIpv6 == *ipv6ip.CidrIpv6 {
remaining--
}
Expand All @@ -471,6 +484,9 @@ func findRuleMatch(p *ec2.IpPermission, rules []*ec2.IpPermission, isVPC bool) *
remaining = len(p.PrefixListIds)
for _, pl := range p.PrefixListIds {
for _, rpl := range r.PrefixListIds {
if pl.PrefixListId == nil || rpl.PrefixListId == nil {
continue
}
if *pl.PrefixListId == *rpl.PrefixListId {
remaining--
}
Expand All @@ -485,10 +501,16 @@ func findRuleMatch(p *ec2.IpPermission, rules []*ec2.IpPermission, isVPC bool) *
for _, ip := range p.UserIdGroupPairs {
for _, rip := range r.UserIdGroupPairs {
if isVPC {
if ip.GroupId == nil || rip.GroupId == nil {
continue
}
if *ip.GroupId == *rip.GroupId {
remaining--
}
} else {
if ip.GroupName == nil || rip.GroupName == nil {
continue
}
if *ip.GroupName == *rip.GroupName {
remaining--
}
Expand Down
175 changes: 172 additions & 3 deletions aws/resource_aws_security_group_rule_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -843,6 +843,111 @@ func TestAccAWSSecurityGroupRule_Description_AllPorts(t *testing.T) {
})
}

// Reference: https://github.com/terraform-providers/terraform-provider-aws/issues/6416
func TestAccAWSSecurityGroupRule_Description_AllPorts_ToPort65535(t *testing.T) {
var group ec2.SecurityGroup
rName := acctest.RandomWithPrefix("tf-acc-test")
securityGroupResourceName := "aws_security_group.test"
resourceName := "aws_security_group_rule.test"

rule1 := ec2.IpPermission{
IpProtocol: aws.String("-1"),
IpRanges: []*ec2.IpRange{
{CidrIp: aws.String("0.0.0.0/0"), Description: aws.String("description1")},
},
}

rule2 := ec2.IpPermission{
IpProtocol: aws.String("-1"),
IpRanges: []*ec2.IpRange{
{CidrIp: aws.String("0.0.0.0/0"), Description: aws.String("description2")},
},
}

resource.ParallelTest(t, resource.TestCase{
PreCheck: func() { testAccPreCheck(t) },
Providers: testAccProviders,
CheckDestroy: testAccCheckAWSSecurityGroupRuleDestroy,
Steps: []resource.TestStep{
{
Config: testAccAWSSecurityGroupRuleConfigDescriptionAllPortsToPort65535(rName, "description1"),
Check: resource.ComposeTestCheckFunc(
testAccCheckAWSSecurityGroupRuleExists(securityGroupResourceName, &group),
testAccCheckAWSSecurityGroupRuleAttributes(resourceName, &group, &rule1, "ingress"),
resource.TestCheckResourceAttr(resourceName, "description", "description1"),
resource.TestCheckResourceAttr(resourceName, "from_port", "0"),
resource.TestCheckResourceAttr(resourceName, "protocol", "-1"),
resource.TestCheckResourceAttr(resourceName, "to_port", "65535"),
),
},
{
ResourceName: resourceName,
ImportState: true,
ImportStateIdFunc: testAccAWSSecurityGroupRuleImportStateIdFunc(resourceName),
ImportStateVerify: true,
},
{
Config: testAccAWSSecurityGroupRuleConfigDescriptionAllPorts(rName, "description2"),
Check: resource.ComposeTestCheckFunc(
testAccCheckAWSSecurityGroupRuleExists(securityGroupResourceName, &group),
testAccCheckAWSSecurityGroupRuleAttributes(resourceName, &group, &rule2, "ingress"),
resource.TestCheckResourceAttr(resourceName, "description", "description2"),
resource.TestCheckResourceAttr(resourceName, "from_port", "0"),
resource.TestCheckResourceAttr(resourceName, "protocol", "-1"),
resource.TestCheckResourceAttr(resourceName, "to_port", "0"),
),
},
},
})
}

// Reference: https://github.com/terraform-providers/terraform-provider-aws/issues/6416
func TestAccAWSSecurityGroupRule_MultipleRuleSearching_AllProtocolCrash(t *testing.T) {
var group ec2.SecurityGroup
rName := acctest.RandomWithPrefix("tf-acc-test")
securityGroupResourceName := "aws_security_group.test"
resourceName1 := "aws_security_group_rule.test1"
resourceName2 := "aws_security_group_rule.test2"

rule1 := ec2.IpPermission{
IpProtocol: aws.String("-1"),
IpRanges: []*ec2.IpRange{
{CidrIp: aws.String("10.0.0.0/8")},
},
}

rule2 := ec2.IpPermission{
FromPort: aws.Int64(443),
ToPort: aws.Int64(443),
IpProtocol: aws.String("tcp"),
IpRanges: []*ec2.IpRange{
{CidrIp: aws.String("172.168.0.0/16")},
},
}

resource.ParallelTest(t, resource.TestCase{
PreCheck: func() { testAccPreCheck(t) },
Providers: testAccProviders,
CheckDestroy: testAccCheckAWSSecurityGroupRuleDestroy,
Steps: []resource.TestStep{
{
Config: testAccAWSSecurityGroupRuleConfigMultipleRuleSearchingAllProtocolCrash(rName),
Check: resource.ComposeTestCheckFunc(
testAccCheckAWSSecurityGroupRuleExists(securityGroupResourceName, &group),
testAccCheckAWSSecurityGroupRuleAttributes(resourceName1, &group, &rule1, "ingress"),
testAccCheckAWSSecurityGroupRuleAttributes(resourceName2, &group, &rule2, "ingress"),
resource.TestCheckResourceAttr(resourceName1, "from_port", "0"),
resource.TestCheckResourceAttr(resourceName1, "protocol", "-1"),
resource.TestCheckResourceAttr(resourceName1, "to_port", "65535"),
resource.TestCheckResourceAttr(resourceName2, "from_port", "443"),
resource.TestCheckResourceAttr(resourceName2, "protocol", "tcp"),
resource.TestCheckResourceAttr(resourceName2, "to_port", "443"),
),
},
},
})
}

func TestAccAWSSecurityGroupRule_MultiDescription(t *testing.T) {
var group ec2.SecurityGroup
var nat ec2.SecurityGroup
Expand Down Expand Up @@ -1114,21 +1219,24 @@ func testAccCheckAWSSecurityGroupRuleAttributes(n string, group *ec2.SecurityGro
}

for _, r := range rules {
if r.ToPort != nil && *p.ToPort != *r.ToPort {
if p.ToPort != nil && r.ToPort != nil && *p.ToPort != *r.ToPort {
continue
}

if r.FromPort != nil && *p.FromPort != *r.FromPort {
if p.FromPort != nil && r.FromPort != nil && *p.FromPort != *r.FromPort {
continue
}

if r.IpProtocol != nil && *p.IpProtocol != *r.IpProtocol {
if p.IpProtocol != nil && r.IpProtocol != nil && *p.IpProtocol != *r.IpProtocol {
continue
}

remaining := len(p.IpRanges)
for _, ip := range p.IpRanges {
for _, rip := range r.IpRanges {
if ip.CidrIp == nil || rip.CidrIp == nil {
continue
}
if *ip.CidrIp == *rip.CidrIp {
remaining--
}
Expand All @@ -1142,6 +1250,9 @@ func testAccCheckAWSSecurityGroupRuleAttributes(n string, group *ec2.SecurityGro
remaining = len(p.Ipv6Ranges)
for _, ip := range p.Ipv6Ranges {
for _, rip := range r.Ipv6Ranges {
if ip.CidrIpv6 == nil || rip.CidrIpv6 == nil {
continue
}
if *ip.CidrIpv6 == *rip.CidrIpv6 {
remaining--
}
Expand All @@ -1155,6 +1266,9 @@ func testAccCheckAWSSecurityGroupRuleAttributes(n string, group *ec2.SecurityGro
remaining = len(p.UserIdGroupPairs)
for _, ip := range p.UserIdGroupPairs {
for _, rip := range r.UserIdGroupPairs {
if ip.GroupId == nil || rip.GroupId == nil {
continue
}
if *ip.GroupId == *rip.GroupId {
remaining--
}
Expand All @@ -1168,6 +1282,9 @@ func testAccCheckAWSSecurityGroupRuleAttributes(n string, group *ec2.SecurityGro
remaining = len(p.PrefixListIds)
for _, pip := range p.PrefixListIds {
for _, rpip := range r.PrefixListIds {
if pip.PrefixListId == nil || rpip.PrefixListId == nil {
continue
}
if *pip.PrefixListId == *rpip.PrefixListId {
remaining--
}
Expand Down Expand Up @@ -1805,6 +1922,58 @@ resource "aws_security_group_rule" "test" {
`, rName, description)
}

func testAccAWSSecurityGroupRuleConfigDescriptionAllPortsToPort65535(rName, description string) string {
return fmt.Sprintf(`
resource "aws_security_group" "test" {
name = %q

tags {
Name = "tf-acc-test-ec2-security-group-rule"
}
}

resource "aws_security_group_rule" "test" {
cidr_blocks = ["0.0.0.0/0"]
description = %q
from_port = 0
protocol = -1
security_group_id = "${aws_security_group.test.id}"
to_port = 65535
type = "ingress"
}
`, rName, description)
}

func testAccAWSSecurityGroupRuleConfigMultipleRuleSearchingAllProtocolCrash(rName string) string {
return fmt.Sprintf(`
resource "aws_security_group" "test" {
name = %q

tags {
Name = "tf-acc-test-ec2-security-group-rule"
}
}

resource "aws_security_group_rule" "test1" {
cidr_blocks = ["10.0.0.0/8"]
from_port = 0
protocol = -1
security_group_id = "${aws_security_group.test.id}"
to_port = 65535
type = "ingress"
}

resource "aws_security_group_rule" "test2" {
cidr_blocks = ["172.168.0.0/16"]
from_port = 443
protocol = "tcp"
security_group_id = "${aws_security_group.test.id}"
to_port = 443
type = "ingress"
}
`, rName)
}

var testAccAWSSecurityGroupRuleRace = func() string {
var b bytes.Buffer
iterations := 50
Expand Down