Skip to content

Commit

Permalink
Merge pull request #32594 from hashicorp/b-aws_ec2_traffic_mirror_fil…
Browse files Browse the repository at this point in the history
…ter_rule-panic

r/aws_ec2_traffic_mirror_filter_rule: Fix crash when updating `rule_number`
  • Loading branch information
ewbankkit authored Jul 19, 2023
2 parents 0a175e3 + 80821ed commit 36a8ecb
Show file tree
Hide file tree
Showing 4 changed files with 191 additions and 12 deletions.
3 changes: 3 additions & 0 deletions .changelog/32594.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
```release-note:enhancement
resource/aws_ec2_traffic_mirror_filter_rule: Fix crash when updating `rule_number`
```
154 changes: 154 additions & 0 deletions internal/service/ec2/sweep.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ func init() {
"aws_db_proxy",
"aws_directory_service_directory",
"aws_ec2_client_vpn_endpoint",
"aws_ec2_traffic_mirror_session",
"aws_ec2_transit_gateway_vpc_attachment",
"aws_eks_cluster",
"aws_elb",
Expand Down Expand Up @@ -259,6 +260,27 @@ func init() {
},
})

resource.AddTestSweepers("aws_ec2_traffic_mirror_filter", &resource.Sweeper{
Name: "aws_ec2_traffic_mirror_filter",
F: sweepTrafficMirrorFilters,
Dependencies: []string{
"aws_ec2_traffic_mirror_session",
},
})

resource.AddTestSweepers("aws_ec2_traffic_mirror_session", &resource.Sweeper{
Name: "aws_ec2_traffic_mirror_session",
F: sweepTrafficMirrorSessions,
})

resource.AddTestSweepers("aws_ec2_traffic_mirror_target", &resource.Sweeper{
Name: "aws_ec2_traffic_mirror_target",
F: sweepTrafficMirrorTargets,
Dependencies: []string{
"aws_ec2_traffic_mirror_session",
},
})

resource.AddTestSweepers("aws_ec2_transit_gateway_peering_attachment", &resource.Sweeper{
Name: "aws_ec2_transit_gateway_peering_attachment",
F: sweepTransitGatewayPeeringAttachments,
Expand Down Expand Up @@ -1741,6 +1763,138 @@ func sweepSubnets(region string) error {
return nil
}

func sweepTrafficMirrorFilters(region string) error {
ctx := sweep.Context(region)
client, err := sweep.SharedRegionalSweepClient(ctx, region)
if err != nil {
return fmt.Errorf("error getting client: %s", err)
}
conn := client.EC2Conn(ctx)
input := &ec2.DescribeTrafficMirrorFiltersInput{}
sweepResources := make([]sweep.Sweepable, 0)

err = conn.DescribeTrafficMirrorFiltersPagesWithContext(ctx, input, func(page *ec2.DescribeTrafficMirrorFiltersOutput, lastPage bool) bool {
if page == nil {
return !lastPage
}

for _, v := range page.TrafficMirrorFilters {
r := ResourceTrafficMirrorFilter()
d := r.Data(nil)
d.SetId(aws.StringValue(v.TrafficMirrorFilterId))

sweepResources = append(sweepResources, sweep.NewSweepResource(r, d, client))
}

return !lastPage
})

if sweep.SkipSweepError(err) {
log.Printf("[WARN] Skipping EC2 Traffic Mirror Filter sweep for %s: %s", region, err)
return nil
}

if err != nil {
return fmt.Errorf("error listing EC2 Traffic Mirror Filters (%s): %w", region, err)
}

err = sweep.SweepOrchestrator(ctx, sweepResources)

if err != nil {
return fmt.Errorf("error sweeping EC2 Traffic Mirror Filters (%s): %w", region, err)
}

return nil
}

func sweepTrafficMirrorSessions(region string) error {
ctx := sweep.Context(region)
client, err := sweep.SharedRegionalSweepClient(ctx, region)
if err != nil {
return fmt.Errorf("error getting client: %s", err)
}
conn := client.EC2Conn(ctx)
input := &ec2.DescribeTrafficMirrorSessionsInput{}
sweepResources := make([]sweep.Sweepable, 0)

err = conn.DescribeTrafficMirrorSessionsPagesWithContext(ctx, input, func(page *ec2.DescribeTrafficMirrorSessionsOutput, lastPage bool) bool {
if page == nil {
return !lastPage
}

for _, v := range page.TrafficMirrorSessions {
r := ResourceTrafficMirrorSession()
d := r.Data(nil)
d.SetId(aws.StringValue(v.TrafficMirrorSessionId))

sweepResources = append(sweepResources, sweep.NewSweepResource(r, d, client))
}

return !lastPage
})

if sweep.SkipSweepError(err) {
log.Printf("[WARN] Skipping EC2 Traffic Mirror Session sweep for %s: %s", region, err)
return nil
}

if err != nil {
return fmt.Errorf("error listing EC2 Traffic Mirror Sessions (%s): %w", region, err)
}

err = sweep.SweepOrchestrator(ctx, sweepResources)

if err != nil {
return fmt.Errorf("error sweeping EC2 Traffic Mirror Sessions (%s): %w", region, err)
}

return nil
}

func sweepTrafficMirrorTargets(region string) error {
ctx := sweep.Context(region)
client, err := sweep.SharedRegionalSweepClient(ctx, region)
if err != nil {
return fmt.Errorf("error getting client: %s", err)
}
conn := client.EC2Conn(ctx)
input := &ec2.DescribeTrafficMirrorTargetsInput{}
sweepResources := make([]sweep.Sweepable, 0)

err = conn.DescribeTrafficMirrorTargetsPagesWithContext(ctx, input, func(page *ec2.DescribeTrafficMirrorTargetsOutput, lastPage bool) bool {
if page == nil {
return !lastPage
}

for _, v := range page.TrafficMirrorTargets {
r := ResourceTrafficMirrorTarget()
d := r.Data(nil)
d.SetId(aws.StringValue(v.TrafficMirrorTargetId))

sweepResources = append(sweepResources, sweep.NewSweepResource(r, d, client))
}

return !lastPage
})

if sweep.SkipSweepError(err) {
log.Printf("[WARN] Skipping EC2 Traffic Mirror Target sweep for %s: %s", region, err)
return nil
}

if err != nil {
return fmt.Errorf("error listing EC2 Traffic Mirror Targets (%s): %w", region, err)
}

err = sweep.SweepOrchestrator(ctx, sweepResources)

if err != nil {
return fmt.Errorf("error sweeping EC2 Traffic Mirror Targets (%s): %w", region, err)
}

return nil
}

func sweepTransitGateways(region string) error {
ctx := sweep.Context(region)
client, err := sweep.SharedRegionalSweepClient(ctx, region)
Expand Down
10 changes: 5 additions & 5 deletions internal/service/ec2/vpc_traffic_mirror_filter_rule.go
Original file line number Diff line number Diff line change
Expand Up @@ -247,16 +247,16 @@ func resourceTrafficMirrorFilterRuleUpdate(ctx context.Context, d *schema.Resour
}
}

if d.HasChange("source_cidr_block") {
input.SourceCidrBlock = aws.String(d.Get("source_cidr_block").(string))
}

if d.HasChange("rule_action") {
input.RuleAction = aws.String(d.Get("rule_action").(string))
}

if d.HasChange("rule_number") {
input.RuleNumber = aws.Int64(int64(d.Get("rule_action").(int)))
input.RuleNumber = aws.Int64(int64(d.Get("rule_number").(int)))
}

if d.HasChange("source_cidr_block") {
input.SourceCidrBlock = aws.String(d.Get("source_cidr_block").(string))
}

if d.HasChange("source_port_range") {
Expand Down
36 changes: 29 additions & 7 deletions internal/service/ec2/vpc_traffic_mirror_filter_rule_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ func TestAccVPCTrafficMirrorFilterRule_basic(t *testing.T) {
resourceName := "aws_ec2_traffic_mirror_filter_rule.test"
dstCidr := "10.0.0.0/8"
srcCidr := "0.0.0.0/0"
ruleNum := 1
ruleNum1 := 1
ruleNum2 := 2
action := "accept"
direction := "ingress"
description := "test rule"
Expand All @@ -45,14 +46,14 @@ func TestAccVPCTrafficMirrorFilterRule_basic(t *testing.T) {
Steps: []resource.TestStep{
//create
{
Config: testAccVPCTrafficMirrorFilterRuleConfig_basic(dstCidr, srcCidr, action, direction, ruleNum),
Config: testAccVPCTrafficMirrorFilterRuleConfig_basic(dstCidr, srcCidr, action, direction, ruleNum1),
Check: resource.ComposeAggregateTestCheckFunc(
testAccCheckTrafficMirrorFilterRuleExists(ctx, resourceName),
acctest.MatchResourceAttrRegionalARN(resourceName, "arn", ec2.ServiceName, regexp.MustCompile(`traffic-mirror-filter-rule/tmfr-.+`)),
resource.TestMatchResourceAttr(resourceName, "traffic_mirror_filter_id", regexp.MustCompile("tmf-.*")),
resource.TestCheckResourceAttr(resourceName, "destination_cidr_block", dstCidr),
resource.TestCheckResourceAttr(resourceName, "rule_action", action),
resource.TestCheckResourceAttr(resourceName, "rule_number", strconv.Itoa(ruleNum)),
resource.TestCheckResourceAttr(resourceName, "rule_number", strconv.Itoa(ruleNum1)),
resource.TestCheckResourceAttr(resourceName, "source_cidr_block", srcCidr),
resource.TestCheckResourceAttr(resourceName, "traffic_direction", direction),
resource.TestCheckResourceAttr(resourceName, "description", ""),
Expand All @@ -63,13 +64,34 @@ func TestAccVPCTrafficMirrorFilterRule_basic(t *testing.T) {
},
// Add all optionals
{
Config: testAccVPCTrafficMirrorFilterRuleConfig_full(dstCidr, srcCidr, action, direction, description, ruleNum, srcPortFrom, srcPortTo, dstPortFrom, dstPortTo, protocol),
Config: testAccVPCTrafficMirrorFilterRuleConfig_full(dstCidr, srcCidr, action, direction, description, ruleNum1, srcPortFrom, srcPortTo, dstPortFrom, dstPortTo, protocol),
Check: resource.ComposeAggregateTestCheckFunc(
testAccCheckTrafficMirrorFilterRuleExists(ctx, resourceName),
resource.TestMatchResourceAttr(resourceName, "traffic_mirror_filter_id", regexp.MustCompile("tmf-.*")),
resource.TestCheckResourceAttr(resourceName, "destination_cidr_block", dstCidr),
resource.TestCheckResourceAttr(resourceName, "rule_action", action),
resource.TestCheckResourceAttr(resourceName, "rule_number", strconv.Itoa(ruleNum)),
resource.TestCheckResourceAttr(resourceName, "rule_number", strconv.Itoa(ruleNum1)),
resource.TestCheckResourceAttr(resourceName, "source_cidr_block", srcCidr),
resource.TestCheckResourceAttr(resourceName, "traffic_direction", direction),
resource.TestCheckResourceAttr(resourceName, "description", description),
resource.TestCheckResourceAttr(resourceName, "destination_port_range.#", "1"),
resource.TestCheckResourceAttr(resourceName, "destination_port_range.0.from_port", strconv.Itoa(dstPortFrom)),
resource.TestCheckResourceAttr(resourceName, "destination_port_range.0.to_port", strconv.Itoa(dstPortTo)),
resource.TestCheckResourceAttr(resourceName, "source_port_range.#", "1"),
resource.TestCheckResourceAttr(resourceName, "source_port_range.0.from_port", strconv.Itoa(srcPortFrom)),
resource.TestCheckResourceAttr(resourceName, "source_port_range.0.to_port", strconv.Itoa(srcPortTo)),
resource.TestCheckResourceAttr(resourceName, "protocol", strconv.Itoa(protocol)),
),
},
// Updates
{
Config: testAccVPCTrafficMirrorFilterRuleConfig_full(dstCidr, srcCidr, action, direction, description, ruleNum2, srcPortFrom, srcPortTo, dstPortFrom, dstPortTo, protocol),
Check: resource.ComposeAggregateTestCheckFunc(
testAccCheckTrafficMirrorFilterRuleExists(ctx, resourceName),
resource.TestMatchResourceAttr(resourceName, "traffic_mirror_filter_id", regexp.MustCompile("tmf-.*")),
resource.TestCheckResourceAttr(resourceName, "destination_cidr_block", dstCidr),
resource.TestCheckResourceAttr(resourceName, "rule_action", action),
resource.TestCheckResourceAttr(resourceName, "rule_number", strconv.Itoa(ruleNum2)),
resource.TestCheckResourceAttr(resourceName, "source_cidr_block", srcCidr),
resource.TestCheckResourceAttr(resourceName, "traffic_direction", direction),
resource.TestCheckResourceAttr(resourceName, "description", description),
Expand All @@ -84,13 +106,13 @@ func TestAccVPCTrafficMirrorFilterRule_basic(t *testing.T) {
},
// remove optionals
{
Config: testAccVPCTrafficMirrorFilterRuleConfig_basic(dstCidr, srcCidr, action, direction, ruleNum),
Config: testAccVPCTrafficMirrorFilterRuleConfig_basic(dstCidr, srcCidr, action, direction, ruleNum1),
Check: resource.ComposeAggregateTestCheckFunc(
testAccCheckTrafficMirrorFilterRuleExists(ctx, resourceName),
resource.TestMatchResourceAttr(resourceName, "traffic_mirror_filter_id", regexp.MustCompile("tmf-.*")),
resource.TestCheckResourceAttr(resourceName, "destination_cidr_block", dstCidr),
resource.TestCheckResourceAttr(resourceName, "rule_action", action),
resource.TestCheckResourceAttr(resourceName, "rule_number", strconv.Itoa(ruleNum)),
resource.TestCheckResourceAttr(resourceName, "rule_number", strconv.Itoa(ruleNum1)),
resource.TestCheckResourceAttr(resourceName, "source_cidr_block", srcCidr),
resource.TestCheckResourceAttr(resourceName, "traffic_direction", direction),
resource.TestCheckResourceAttr(resourceName, "description", ""),
Expand Down

0 comments on commit 36a8ecb

Please sign in to comment.