Skip to content

Commit

Permalink
Code review feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
narrieta@microsoft committed Sep 26, 2024
1 parent 99fd28f commit f80d169
Show file tree
Hide file tree
Showing 6 changed files with 290 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,7 @@ def _cleanup_metadata_protocol_certificates():

def _reset_firewall_rules():
"""
Removes MetadataServer firewall rule so IMDS can be used. Enables
WireServer firewall rule based on if firewall is configured to be on.
Removes MetadataServer firewall rule so IMDS can be used.
"""
try:
_remove_firewall(dst_ip=_KNOWN_METADATASERVER_IP, uid=os.getuid(), wait=_get_firewall_will_wait())
Expand Down
2 changes: 1 addition & 1 deletion azurelinuxagent/ga/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def _operation(self):
self._report(event.warn, "An error occurred while setting up the firewall: {0}", ustr(e))

def _report(self, report_function, message, *args):
# Report the first 6 messages, then stop reporting for 12 hours
# Report the first 3 messages, then stop reporting for 12 hours
if datetime.datetime.now() < self._report_after:
return

Expand Down
32 changes: 27 additions & 5 deletions azurelinuxagent/ga/firewall_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from azurelinuxagent.common import logger
from azurelinuxagent.common import event
from azurelinuxagent.common.event import add_event, WALAEventOperation
from azurelinuxagent.common.event import WALAEventOperation
from azurelinuxagent.common.utils import shellutil

from azurelinuxagent.common.future import ustr
Expand Down Expand Up @@ -283,9 +283,7 @@ def __init__(self, wire_server_address):
except Exception as exception:
if isinstance(exception, OSError) and exception.errno == errno.ENOENT: # pylint: disable=no-member
raise FirewallManagerNotAvailableError("iptables is not available")
message = "Unable to determine version of iptables; will not use -w option. --version output: {0}".format(ustr(exception))
logger.warn(message)
add_event(op=WALAEventOperation.Firewall, is_success=False, message=message, log_event=False)
event.warn(WALAEventOperation.Firewall, "Unable to determine version of iptables; will not use -w option. --version output: {0}", ustr(exception))
use_wait_option = False

if use_wait_option:
Expand Down Expand Up @@ -338,6 +336,16 @@ class FirewallCmd(_FirewallManagerMultipleRules):
"""
FirewallManager based on the firewalld command-line tool.
"""
def __init__(self, wire_server_address):
super(FirewallCmd, self).__init__(wire_server_address)

try:
self._version = shellutil.run_command(["firewall-cmd", "--version"]).strip()
except Exception as exception:
if isinstance(exception, OSError) and exception.errno == errno.ENOENT: # pylint: disable=no-member
raise FirewallManagerNotAvailableError("nft is not available")
self._version = "unknown"

def _get_state_command(self):
return ["firewall-cmd", "--permanent", "--direct", "--get-all-passthroughs"]

Expand Down Expand Up @@ -369,12 +377,26 @@ class NfTables(FirewallManager):
"""
FirewallManager based on the nft command-line tool.
"""
def __init__(self, wire_server_address):
super(NfTables, self).__init__(wire_server_address)

try:
self._version = shellutil.run_command(["nft", "--version"]).strip()
except Exception as exception:
if isinstance(exception, OSError) and exception.errno == errno.ENOENT: # pylint: disable=no-member
raise FirewallManagerNotAvailableError("nft is not available")
self._version = "unknown"

@property
def version(self):
return self._version

def setup(self):
shellutil.run_command(["nft", "add", "table", "ip", "walinuxagent"])
shellutil.run_command(["nft", "add", "chain", "ip", "walinuxagent", "output", "{", "type", "filter", "hook", "output", "priority", "0", ";", "policy", "accept", ";", "}"])
shellutil.run_command([
"nft", "add", "rule", "ip", "walinuxagent", "output", "ip", "daddr", self._wire_server_address,
"tcp", " dport", "!=", "53",
"tcp", "dport", "!=", "53",
"skuid", "!=", str(os.getuid()),
"ct", "state", "invalid,new", "counter", "drop"])

Expand Down
5 changes: 4 additions & 1 deletion azurelinuxagent/ga/persist_firewall_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,10 @@ def setup(self):
# setup permanent firewalld rules
firewall_manager = FirewallCmd(self._dst_ip)

firewall_manager.remove_legacy_rule()
try:
firewall_manager.remove_legacy_rule()
except Exception as error:
event.error(WALAEventOperation.Firewall, "Unable to remove legacy firewall rule. Error: {0}", ustr(error))

try:
if firewall_manager.check():
Expand Down
116 changes: 113 additions & 3 deletions tests/ga/test_firewall_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,56 @@
#
# Requires Python 2.6+ and Openssl 1.0+
#
import contextlib
import os
import unittest

from azurelinuxagent.ga.firewall_manager import IpTables, FirewallCmd
from tests.lib.tools import AgentTestCase
from tests.lib.mock_firewall_command import MockIpTables, MockFirewallCmd
from azurelinuxagent.common.utils import shellutil
from azurelinuxagent.ga.firewall_manager import FirewallManager, IpTables, FirewallCmd, NfTables, FirewallStateError, FirewallManagerNotAvailableError
from tests.lib.tools import AgentTestCase, patch
from tests.lib.mock_firewall_command import MockIpTables, MockFirewallCmd, MockNft


@contextlib.contextmanager
def firewall_command_exists_mock(iptables_exist=True, firewallcmd_exist=True, nft_exists=True):
"""
Mocks the shellutil.run_command method to fake calls to the iptables/firewall-cmd/nft commands. If ech of those commands should exists,
the call is faked to return success. Otherwise, the call is faked to invoke a non-existing command.
"""
commands = {
"iptables": iptables_exist,
"firewall-cmd": firewallcmd_exist,
"nft": nft_exists
}

original_run_command = shellutil.run_command

def mock_run_command(command, *args, **kwargs):
command_exists = commands.get(command[0])
if command_exists is not None:
command = ['sh', '-c', "exit 0"] if command_exists else ["fake-command-that-does-not-exist"]
return original_run_command(command, *args, **kwargs)

with patch("azurelinuxagent.ga.firewall_manager.shellutil.run_command", side_effect=mock_run_command) as patcher:
yield patcher


class TestFirewallManager(AgentTestCase):
def test_create_should_prefer_iptables_when_both_iptables_and_nftables_exist(self):
with firewall_command_exists_mock(iptables_exist=True, nft_exists=True):
firewall = FirewallManager.create('168.63.129.16')
self.assertIsInstance(firewall, IpTables)

def test_create_should_use_nftables_when_iptables_does_not_exist(self):
with firewall_command_exists_mock(iptables_exist=False, nft_exists=True):
firewall = FirewallManager.create('168.63.129.16')
self.assertIsInstance(firewall, NfTables)

def test_create_should_raise_FirewallManagerNotAvailableError_when_both_iptables_and_nftables_do_not_exist(self):
with firewall_command_exists_mock(iptables_exist=False, nft_exists=False):
with self.assertRaises(FirewallManagerNotAvailableError):
FirewallManager.create('168.63.129.16')


class _TestFirewallCommand(AgentTestCase):
"""
Expand Down Expand Up @@ -104,6 +149,11 @@ def _test_remove_legacy_rule_should_delete_the_legacy_rule(self, firewall_cmd_ty


class TestIpTables(_TestFirewallCommand):
def test_it_should_raise_FirewallManagerNotAvailableError_when_the_command_is_not_available(self):
with firewall_command_exists_mock(iptables_exist=False):
with self.assertRaises(FirewallManagerNotAvailableError):
IpTables('168.63.129.16')

def test_setup_should_set_all_the_firewall_rules(self):
self._test_setup_should_set_all_the_firewall_rules(IpTables, MockIpTables)

Expand Down Expand Up @@ -135,6 +185,11 @@ def test_it_should_not_use_the_wait_option_on_iptables_versions_less_than_1_4_21


class TestFirewallCmd(_TestFirewallCommand):
def test_it_should_raise_FirewallManagerNotAvailableError_when_the_command_is_not_available(self):
with firewall_command_exists_mock(firewallcmd_exist=False):
with self.assertRaises(FirewallManagerNotAvailableError):
FirewallCmd('168.63.129.16')

def test_setup_should_set_all_the_firewall_rules(self):
self._test_setup_should_set_all_the_firewall_rules(FirewallCmd, MockFirewallCmd)

Expand All @@ -151,5 +206,60 @@ def test_remove_legacy_rule_should_delete_the_legacy_rule(self):
self._test_remove_legacy_rule_should_delete_the_legacy_rule(FirewallCmd, MockFirewallCmd)


class TestNft(AgentTestCase):
def test_it_should_raise_FirewallManagerNotAvailableError_when_the_command_is_not_available(self):
with firewall_command_exists_mock(nft_exists=False):
with self.assertRaises(FirewallManagerNotAvailableError):
NfTables('168.63.129.16')

def test_setup_should_set_the_walinuxagent_table(self):
with MockNft() as mock_nft:
firewall = NfTables('168.63.129.16')
firewall.setup()

self.assertEqual(
[
mock_nft.get_add_command("table"),
mock_nft.get_add_command("chain"),
mock_nft.get_add_command("rule"),
],
mock_nft.call_list,
"Expected exactly 3 calls, to the add the walinuxagent table, output chain, and wireserver rule")

def test_remove_should_delete_the_walinuxagent_table(self):
with MockNft() as mock_nft:
firewall = NfTables('168.63.129.16')
firewall.remove()

self.assertEqual([mock_nft.get_delete_command()], mock_nft.call_list, "Expected a call to delete the walinuxagent table")

def test_check_should_verify_all_rules(self):
with MockNft() as mock_nft:
_, walinuxagent_table = mock_nft.get_return_value(mock_nft.get_list_command("table"))

firewall = NfTables('168.63.129.16')

# Remove the clause for DNS and verify check() fails
stdout = walinuxagent_table.replace('{ "match": {"op": "!=", "left": { "payload": { "protocol": "tcp", "field": "dport" } }, "right": 53}},', '')
mock_nft.set_return_value("list", "table", (0, stdout))
with self.assertRaises(FirewallStateError) as context:
firewall.check()
self.assertIn("['No expression excludes the DNS port']", str(context.exception), "Expected an error message indicating the DNS port is not excluded")

# Remove the clause for root and verify check() fails
stdout = walinuxagent_table.replace('{ "match": {"op": "!=", "left": { "meta": { "key": "skuid" } }, "right": ' + str(os.getuid()) + '}},', '')
mock_nft.set_return_value("list", "table", (0, stdout))
with self.assertRaises(FirewallStateError) as context:
firewall.check()
self.assertIn('["No expression excludes the Agent\'s UID"]', str(context.exception), "Expected an error message indicating the Agent's UID is not excluded")

# Remove the "drop" clause and verify check() fails
stdout = walinuxagent_table.replace('{ "drop": null }', '{ "accept": null }')
mock_nft.set_return_value("list", "table", (0, stdout))
with self.assertRaises(FirewallStateError) as context:
firewall.check()
self.assertIn("['The drop action is missing']", str(context.exception), "Expected an error message indicating the Agent's UID is not excluded")


if __name__ == '__main__':
unittest.main()
144 changes: 144 additions & 0 deletions tests/lib/mock_firewall_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,3 +231,147 @@ def get_drop_command(option):
@staticmethod
def get_legacy_command(option):
return "firewall-cmd --permanent --direct {0} ipv4 -t security -I OUTPUT -d 168.63.129.16 -p tcp --destination-port 53 -j ACCEPT".format(option)


class MockNft(object):
"""
Intercepts calls to shellutil.run_command and mocks the behavior of the nft command-line utility using a pre-defined set of return values.
"""
def __init__(self):
self._call_list = []
self._original_run_command = shellutil.run_command
self._run_command_patcher = patch("azurelinuxagent.ga.firewall_manager.shellutil.run_command", side_effect=self._mock_run_command)
#
# Return values for each nft command-line indexed by command name ("add", "delete", "list"). Each item is a (exit_code, stdout) tuple.
# These default values indicate success, and can be overridden with the set_*_return_values() methods.
#
self._return_values = {
"add": {
"table": (0, ''), # nft add table ip walinuxagent
"chain": (0, ''), # nft add chain ip walinuxagent output { type filter hook output priority 0 ; policy accept ; }
"rule": (0, ''), # nft add rule ip walinuxagent output ip daddr 168.63.129.16 tcp dport != 53 skuid != 0 ct state invalid,new counter drop
},
"delete": {
"table": (0, ''), # nft delete table walinuxagent
},
"list": {
"tables": (0, # nft --json list tables
'''
{
"nftables": [
{ "metainfo": { "version": "1.0.2", "release_name": "Lester Gooch", "json_schema_version": 1 } },
{ "table": { "family": "ip", "name": "walinuxagent", "handle": 2 } }
]
}
'''),
"table": (0, # nft --json list table walinuxagent
'''
{
"nftables": [
{ "metainfo": { "version": "1.0.2", "release_name": "Lester Gooch", "json_schema_version": 1 } },
{ "table": { "family": "ip", "name": "walinuxagent", "handle": 2 } },
{ "chain": { "family": "ip", "table": "walinuxagent", "name": "output", "handle": 1, "type": "filter", "hook": "output", "prio": 0, "policy": "accept" } },
{
"rule": {
"family": "ip", "table": "walinuxagent", "chain": "output", "handle": 2,
"expr": [
{ "match": {"op": "==", "left": { "payload": { "protocol": "ip", "field": "daddr" } }, "right": "168.63.129.16"}},
{ "match": {"op": "!=", "left": { "payload": { "protocol": "tcp", "field": "dport" } }, "right": 53}},
{ "match": {"op": "!=", "left": { "meta": { "key": "skuid" } }, "right": ''' + str(os.getuid()) +'''}},
{ "match": {"op": "in", "left": { "ct": { "key": "state" } }, "right": [ "invalid", "new" ]}},
{ "counter": {"packets": 0, "bytes": 0}},
{ "drop": null }
]
}
}
]
}
''')
}
}

def __enter__(self):
self._run_command_patcher.start()
return self

def __exit__(self, exc_type, exc_value, exc_traceback):
self._run_command_patcher.stop()

def _mock_run_command(self, command, *args, **kwargs):
if command[0] == 'nft':
command_string = " ".join(command)
exit_code, stdout = self.get_return_value(command_string)
script = \
"""
cat << ..
{0}
..
exit {1}
""".format(stdout, exit_code)
command = ['sh', '-c', script]
self._call_list.append(command_string)
return self._original_run_command(command, *args, **kwargs)

@property
def call_list(self):
"""
Returns the list of commands that were executed by the mock
"""
return self._call_list

def set_return_value(self, command, target, return_value):
"""
Changes the return values for the mocked command
"""
self._return_values[command][target] = return_value

def get_return_value(self, command):
"""
Possible commands are:
nft add table ip walinuxagent
nft add chain ip walinuxagent output { type filter hook output priority 0 ; policy accept ; }
nft add rule ip walinuxagent output ip daddr 168.63.129.16 tcp dport != 53 skuid != 0 ct state invalid,new counter drop
nft delete table walinuxagent
nft --json list tables
nft --json list table walinuxagent
"""
r = r"nft add (?P<target>table|chain|rule)" + \
r"(ip walinuxagent output " + \
r"(\{ type filter hook output priority 0 ; policy accept ; })" + \
r"|" + \
r"(ip daddr 168.63.129.16 tcp dport != 53 skuid != \d+ ct state invalid,new counter drop)" + \
r")?"
match = re.match(r, command)
if match is not None:
target = match.group("target")
return self._return_values["add"][target]
if command == "nft delete table walinuxagent":
return self._return_values["delete"]["table"]
match = re.match(r"nft --json list (?P<target>tables|table)( walinuxagent)?", command)
if match is not None:
target = match.group("target")
return self._return_values["list"][target]
raise Exception("Unexpected command: {0}".format(command))

@staticmethod
def get_add_command(target):
if target == "table":
return "nft add table ip walinuxagent"
if target == "chain":
return "nft add chain ip walinuxagent output { type filter hook output priority 0 ; policy accept ; }"
if target == "rule":
return "nft add rule ip walinuxagent output ip daddr 168.63.129.16 tcp dport != 53 skuid != {0} ct state invalid,new counter drop".format(os.getuid())
raise Exception("Unexpected command target: {0}".format(target))

@staticmethod
def get_delete_command():
return "nft delete table walinuxagent"

@staticmethod
def get_list_command(target):
if target == "tables":
return "nft --json list tables"
if target == "table":
return "nft --json list table walinuxagent"
raise Exception("Unexpected command target: {0}".format(target))

0 comments on commit f80d169

Please sign in to comment.