Skip to content

Commit

Permalink
Refactor IPTable Rules (#2697)
Browse files Browse the repository at this point in the history
Co-authored-by: Joseph Chen <chenjoez@amazon.com>
  • Loading branch information
jchen6585 and Joseph Chen authored Dec 12, 2023
1 parent b0ad571 commit 21c4bd7
Show file tree
Hide file tree
Showing 11 changed files with 304 additions and 179 deletions.
5 changes: 5 additions & 0 deletions pkg/ipamd/ipamd.go
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,11 @@ func (c *IPAMContext) nodeInit() error {
if err != nil {
return errors.Wrap(err, "ipamd init: failed to set up host network")
}
err = c.networkClient.CleanUpStaleAWSChains(c.enableIPv4, c.enableIPv6)
if err != nil {
// We should not error if clean up fails since these chains don't affect the rules
log.Debugf("Failed to clean up stale AWS chains: %v", err)
}

metadataResult, err := c.awsClient.DescribeAllENIs()
if err != nil {
Expand Down
3 changes: 3 additions & 0 deletions pkg/ipamd/ipamd_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ func TestNodeInit(t *testing.T) {
m.awsutils.EXPECT().GetVPCIPv4CIDRs().AnyTimes().Return(cidrs, nil)
m.awsutils.EXPECT().GetPrimaryENImac().Return("")
m.network.EXPECT().SetupHostNetwork(cidrs, "", &primaryIP, false, true, false).Return(nil)
m.network.EXPECT().CleanUpStaleAWSChains(true, false).Return(nil)
m.awsutils.EXPECT().GetPrimaryENI().AnyTimes().Return(primaryENIid)
m.awsutils.EXPECT().RefreshSGIDs(gomock.Any()).AnyTimes().Return(nil)

Expand Down Expand Up @@ -234,6 +235,7 @@ func TestNodeInitwithPDenabledIPv4Mode(t *testing.T) {
m.awsutils.EXPECT().GetVPCIPv4CIDRs().AnyTimes().Return(cidrs, nil)
m.awsutils.EXPECT().GetPrimaryENImac().Return("")
m.network.EXPECT().SetupHostNetwork(cidrs, "", &primaryIP, false, true, false).Return(nil)
m.network.EXPECT().CleanUpStaleAWSChains(true, false).Return(nil)
m.awsutils.EXPECT().GetPrimaryENI().AnyTimes().Return(primaryENIid)
m.awsutils.EXPECT().RefreshSGIDs(gomock.Any()).AnyTimes().Return(nil)

Expand Down Expand Up @@ -308,6 +310,7 @@ func TestNodeInitwithPDenabledIPv6Mode(t *testing.T) {

primaryIP := net.ParseIP(ipaddr01)
m.network.EXPECT().SetupHostNetwork(cidrs, eni1.MAC, &primaryIP, false, false, true).Return(nil)
m.network.EXPECT().CleanUpStaleAWSChains(false, true).Return(nil)
m.awsutils.EXPECT().GetIPv6PrefixesFromEC2(eni1.ENIID).AnyTimes().Return(eni1.IPv6Prefixes, nil)
m.awsutils.EXPECT().GetPrimaryENI().AnyTimes().Return(primaryENIid)
m.awsutils.EXPECT().GetPrimaryENImac().Return(eni1.MAC)
Expand Down
6 changes: 6 additions & 0 deletions pkg/iptableswrapper/iptables.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ type IPTablesIface interface {
ClearChain(table, chain string) error
DeleteChain(table, chain string) error
ListChains(table string) ([]string, error)
ChainExists(table, chain string) (bool, error)
HasRandomFully() bool
}

Expand Down Expand Up @@ -98,6 +99,11 @@ func (i ipTables) ListChains(table string) ([]string, error) {
return i.ipt.ListChains(table)
}

// ChainExists implements IPTablesIface interface by calling iptables package
func (i ipTables) ChainExists(table, chain string) (bool, error) {
return i.ipt.ChainExists(table, chain)
}

// HasRandomFully implements IPTablesIface interface by calling iptables package
func (i ipTables) HasRandomFully() bool {
return i.ipt.HasRandomFully()
Expand Down
8 changes: 8 additions & 0 deletions pkg/iptableswrapper/mocks/iptables_maps.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,14 @@ func (ipt *MockIptables) ListChains(table string) ([]string, error) {
return chains, nil
}

func (ipt *MockIptables) ChainExists(table, chain string) (bool, error) {
_, ok := ipt.DataplaneState[table][chain]
if ok {
return true, nil
}
return false, nil
}

func (ipt *MockIptables) HasRandomFully() bool {
// TODO: Work out how to write a test case for this
return true
Expand Down
15 changes: 15 additions & 0 deletions pkg/iptableswrapper/mocks/iptables_mocks.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

14 changes: 14 additions & 0 deletions pkg/networkutils/mocks/network_mocks.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

160 changes: 102 additions & 58 deletions pkg/networkutils/network.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ type NetworkAPIs interface {
SetupENINetwork(eniIP string, mac string, deviceNumber int, subnetCIDR string) error
// UpdateHostIptablesRules updates the nat table iptables rules on the host
UpdateHostIptablesRules(vpcCIDRs []string, primaryMAC string, primaryAddr *net.IP, v4Enabled bool, v6Enabled bool) error
CleanUpStaleAWSChains(v4Enabled, v6Enabled bool) error
UseExternalSNAT() bool
GetExcludeSNATCIDRs() []string
GetExternalServiceCIDRs() []string
Expand Down Expand Up @@ -375,6 +376,51 @@ func (n *linuxNetwork) UpdateHostIptablesRules(vpcCIDRs []string, primaryMAC str
return n.updateHostIptablesRules(vpcCIDRs, primaryMAC, primaryAddr, v4Enabled, v6Enabled)
}

func (n *linuxNetwork) CleanUpStaleAWSChains(v4Enabled, v6Enabled bool) error {
ipProtocol := iptables.ProtocolIPv4
if v6Enabled {
ipProtocol = iptables.ProtocolIPv6
}

ipt, err := n.newIptables(ipProtocol)
if err != nil {
return errors.Wrap(err, "stale chain cleanup: failed to create iptables")
}

exists, err := ipt.ChainExists("nat", "AWS-SNAT-CHAIN-1")
if err != nil {
return errors.Wrap(err, "stale chain cleanup: failed to check if AWS-SNAT-CHAIN-1 exists")
}

if exists {
existingChains, err := ipt.ListChains("nat")
if err != nil {
return errors.Wrap(err, "stale chain cleanup: failed to list iptables nat chains")
}

for _, chain := range existingChains {
if !strings.HasPrefix(chain, "AWS-CONNMARK-CHAIN") && !strings.HasPrefix(chain, "AWS-SNAT-CHAIN") {
continue
}
parsedChain := strings.Split(chain, "-")
chainNum, err := strconv.Atoi(parsedChain[len(parsedChain)-1])
if err != nil {
return errors.Wrap(err, "stale chain cleanup: failed to convert string to int")
}
// Chains 1 --> x (0 indexed) will be stale
if chainNum > 0 {
// No need to clear the chain since computeStaleIptablesRules cleans up all rules already
log.Infof("Deleting stale chain: %s", chain)
err := ipt.DeleteChain("nat", chain)
if err != nil {
return errors.Wrapf(err, "stale chain cleanup: failed to delete chain %s", chain)
}
}
}
}
return nil
}

func (n *linuxNetwork) updateHostIptablesRules(vpcCIDRs []string, primaryMAC string, primaryAddr *net.IP, v4Enabled bool,
v6Enabled bool) error {
primaryIntf, err := findPrimaryInterfaceName(primaryMAC)
Expand Down Expand Up @@ -434,15 +480,13 @@ func (n *linuxNetwork) buildIptablesSNATRules(vpcCIDRs []string, primaryAddr *ne
log.Debugf("Total CIDRs to program - %d", len(allCIDRs))
// build IPTABLES chain for SNAT of non-VPC outbound traffic and excluded CIDRs
var chains []string
for i := 0; i <= len(allCIDRs); i++ {
chain := fmt.Sprintf("AWS-SNAT-CHAIN-%d", i)
log.Debugf("Setup Host Network: iptables -N %s -t nat", chain)
if err := ipt.NewChain("nat", chain); err != nil && !containChainExistErr(err) {
log.Errorf("ipt.NewChain error for chain [%s]: %v", chain, err)
return []iptablesRule{}, errors.Wrapf(err, "host network setup: failed to add chain")
}
chains = append(chains, chain)
chain := "AWS-SNAT-CHAIN-0"
log.Debugf("Setup Host Network: iptables -N %s -t nat", chain)
if err := ipt.NewChain("nat", chain); err != nil && !containChainExistErr(err) {
log.Errorf("ipt.NewChain error for chain [%s]: %v", chain, err)
return []iptablesRule{}, errors.Wrapf(err, "host network setup: failed to add chain")
}
chains = append(chains, chain)

// build SNAT rules for outbound non-VPC traffic
var iptableRules []iptablesRule
Expand All @@ -456,23 +500,20 @@ func (n *linuxNetwork) buildIptablesSNATRules(vpcCIDRs []string, primaryAddr *ne
"-m", "comment", "--comment", "AWS SNAT CHAIN", "-j", "AWS-SNAT-CHAIN-0",
}})

for i, cidr := range allCIDRs {
curChain := chains[i]
curName := fmt.Sprintf("[%d] AWS-SNAT-CHAIN", i)
nextChain := chains[i+1]
for _, cidr := range allCIDRs {
comment := "AWS SNAT CHAIN"
if cidr.isExclusion {
comment += " EXCLUSION"
}
log.Debugf("Setup Host Network: iptables -A %s ! -d %s -t nat -j %s", curChain, cidr, nextChain)
log.Debugf("Setup Host Network: iptables -A %s -d %s -t nat -j %s", chain, cidr, "RETURN")

iptableRules = append(iptableRules, iptablesRule{
name: curName,
name: chain,
shouldExist: !n.useExternalSNAT,
table: "nat",
chain: curChain,
chain: chain,
rule: []string{
"!", "-d", cidr.cidr, "-m", "comment", "--comment", comment, "-j", nextChain,
"-d", cidr.cidr, "-m", "comment", "--comment", comment, "-j", "RETURN",
}})
}

Expand All @@ -494,22 +535,21 @@ func (n *linuxNetwork) buildIptablesSNATRules(vpcCIDRs []string, primaryAddr *ne
}
}

lastChain := chains[len(chains)-1]
iptableRules = append(iptableRules, iptablesRule{
name: "last SNAT rule for non-VPC outbound traffic",
shouldExist: !n.useExternalSNAT,
table: "nat",
chain: lastChain,
rule: snatRule,
})

snatStaleRules, err := computeStaleIptablesRules(ipt, "nat", "AWS-SNAT-CHAIN", iptableRules, chains)
if err != nil {
return []iptablesRule{}, err
}

iptableRules = append(iptableRules, snatStaleRules...)

iptableRules = append(iptableRules, iptablesRule{
name: "last SNAT rule for non-VPC outbound traffic",
shouldExist: !n.useExternalSNAT,
table: "nat",
chain: chain,
rule: snatRule,
})

iptableRules = append(iptableRules, iptablesRule{
name: "connmark for primary ENI",
shouldExist: n.nodePortSupportEnabled,
Expand Down Expand Up @@ -556,16 +596,15 @@ func (n *linuxNetwork) buildIptablesConnmarkRules(vpcCIDRs []string, ipt iptable
excludeCIDRs := sets.NewString(n.excludeSNATCIDRs...)

log.Debugf("Total CIDRs to exempt from connmark rules - %d", len(allCIDRs))

var chains []string
for i := 0; i <= len(allCIDRs); i++ {
chain := fmt.Sprintf("AWS-CONNMARK-CHAIN-%d", i)
log.Debugf("Setup Host Network: iptables -N %s -t nat", chain)
if err := ipt.NewChain("nat", chain); err != nil && !containChainExistErr(err) {
log.Errorf("ipt.NewChain error for chain [%s]: %v", chain, err)
return []iptablesRule{}, errors.Wrapf(err, "host network setup: failed to add chain")
}
chains = append(chains, chain)
chain := "AWS-CONNMARK-CHAIN-0"
log.Debugf("Setup Host Network: iptables -N %s -t nat", chain)
if err := ipt.NewChain("nat", chain); err != nil && !containChainExistErr(err) {
log.Errorf("ipt.NewChain error for chain [%s]: %v", chain, err)
return []iptablesRule{}, errors.Wrapf(err, "host network setup: failed to add chain")
}
chains = append(chains, chain)

var iptableRules []iptablesRule
log.Debugf("Setup Host Network: iptables -t nat -A PREROUTING -i %s+ -m comment --comment \"AWS, outbound connections\" -j AWS-CONNMARK-CHAIN-0", n.vethPrefix)
Expand All @@ -590,37 +629,23 @@ func (n *linuxNetwork) buildIptablesConnmarkRules(vpcCIDRs []string, ipt iptable
"-j", "AWS-CONNMARK-CHAIN-0",
}})

for i, cidr := range allCIDRs {
curChain := chains[i]
curName := fmt.Sprintf("[%d] AWS-SNAT-CHAIN", i)
nextChain := chains[i+1]
for _, cidr := range allCIDRs {
comment := "AWS CONNMARK CHAIN, VPC CIDR"
if excludeCIDRs.Has(cidr) {
comment = "AWS CONNMARK CHAIN, EXCLUDED CIDR"
}
log.Debugf("Setup Host Network: iptables -A %s ! -d %s -t nat -j %s", curChain, cidr, nextChain)
log.Debugf("Setup Host Network: iptables -A %s -d %s -t nat -j %s", chain, cidr, "RETURN")

iptableRules = append(iptableRules, iptablesRule{
name: curName,
name: chain,
shouldExist: !n.useExternalSNAT,
table: "nat",
chain: curChain,
chain: chain,
rule: []string{
"!", "-d", cidr, "-m", "comment", "--comment", comment, "-j", nextChain,
"-d", cidr, "-m", "comment", "--comment", comment, "-j", "RETURN",
}})
}

iptableRules = append(iptableRules, iptablesRule{
name: "connmark rule for external outbound traffic",
shouldExist: !n.useExternalSNAT,
table: "nat",
chain: chains[len(chains)-1],
rule: []string{
"-m", "comment", "--comment", "AWS, CONNMARK", "-j", "CONNMARK",
"--set-xmark", fmt.Sprintf("%#x/%#x", n.mainENIMark, n.mainENIMark),
},
})

// Force delete existing restore mark rule so that the subsequent rule gets added to the end
iptableRules = append(iptableRules, iptablesRule{
name: "connmark to fwmark copy",
Expand Down Expand Up @@ -652,14 +677,24 @@ func (n *linuxNetwork) buildIptablesConnmarkRules(vpcCIDRs []string, ipt iptable
}
iptableRules = append(iptableRules, connmarkStaleRules...)

iptableRules = append(iptableRules, iptablesRule{
name: "connmark rule for external outbound traffic",
shouldExist: !n.useExternalSNAT,
table: "nat",
chain: chain,
rule: []string{
"-m", "comment", "--comment", "AWS, CONNMARK", "-j", "CONNMARK",
"--set-xmark", fmt.Sprintf("%#x/%#x", n.mainENIMark, n.mainENIMark),
},
})

log.Debugf("iptableRules: %v", iptableRules)
return iptableRules, nil
}

func (n *linuxNetwork) updateIptablesRules(iptableRules []iptablesRule, ipt iptableswrapper.IPTablesIface) error {
for _, rule := range iptableRules {
log.Debugf("execute iptable rule : %s", rule.name)

exists, err := ipt.Exists(rule.table, rule.chain, rule.rule...)
log.Debugf("rule %v exists %v, err %v", rule, exists, err)
if err != nil {
Expand All @@ -668,10 +703,19 @@ func (n *linuxNetwork) updateIptablesRules(iptableRules []iptablesRule, ipt ipta
}

if !exists && rule.shouldExist {
err = ipt.Append(rule.table, rule.chain, rule.rule...)
if err != nil {
log.Errorf("host network setup: failed to add %v, %v", rule, err)
return errors.Wrapf(err, "host network setup: failed to add %v", rule)
if rule.name == "AWS-CONNMARK-CHAIN-0" || rule.name == "AWS-SNAT-CHAIN-0" {
// All CIDR rules must go before the SNAT/Mark rule
err = ipt.Insert(rule.table, rule.chain, 1, rule.rule...)
if err != nil {
log.Errorf("host network setup: failed to insert %v, %v", rule, err)
return errors.Wrapf(err, "host network setup: failed to add %v", rule)
}
} else {
err = ipt.Append(rule.table, rule.chain, rule.rule...)
if err != nil {
log.Errorf("host network setup: failed to add %v, %v", rule, err)
return errors.Wrapf(err, "host network setup: failed to add %v", rule)
}
}
} else if exists && !rule.shouldExist {
err = ipt.Delete(rule.table, rule.chain, rule.rule...)
Expand Down Expand Up @@ -726,7 +770,7 @@ func computeStaleIptablesRules(ipt iptableswrapper.IPTablesIface, table, chainPr
return []iptablesRule{}, errors.Wrapf(err, "host network setup: failed to list rules from table %s with chain prefix %s", table, chainPrefix)
}
activeChains := sets.NewString(chains...)
log.Debugf("Setup Host Network: computing stale iptables rules for %s table with chain prefix %s")
log.Debugf("Setup Host Network: computing stale iptables rules for %s table with chain prefix %s", table, chainPrefix)
for _, staleRule := range existingRules {
if len(staleRule.rule) == 0 && activeChains.Has(staleRule.chain) {
log.Debugf("Setup Host Network: active chain found: %s", staleRule.chain)
Expand Down
Loading

0 comments on commit 21c4bd7

Please sign in to comment.