Skip to content

Commit

Permalink
Stop UTs from attempting to call to IMDS (#2513)
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinclark19a authored Feb 22, 2022
1 parent acd6301 commit 899a38e
Show file tree
Hide file tree
Showing 9 changed files with 94 additions and 74 deletions.
4 changes: 1 addition & 3 deletions azurelinuxagent/ga/update.py
Original file line number Diff line number Diff line change
Expand Up @@ -1225,9 +1225,7 @@ def _send_heartbeat_telemetry(self, protocol):
auto_update_enabled = 1 if conf.get_autoupdate_enabled() else 0
# Include VMSize in the heartbeat message because the kusto table does not have
# a separate column for it (or architecture).
# Temporarily disable vmsize because it is breaking UTs. TODO: Re-enable when this is fixed.
# vmsize = self._get_vm_size(protocol)
vmsize = "unknown"
vmsize = self._get_vm_size(protocol)

telemetry_msg = "{0};{1};{2};{3};{4};{5}".format(self._heartbeat_counter, self._heartbeat_id, dropped_packets,
self._heartbeat_update_goal_state_error_count,
Expand Down
1 change: 1 addition & 0 deletions tests/ga/test_exthandlers_exthandlerinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def setUp(self):

def tearDown(self):
self.mock_get_base_dir.stop()
super(ExtHandlerInstanceTestCase, self).tearDown()

def test_rm_ext_handler_dir_should_remove_the_extension_packages(self):
os.mkdir(self.extension_directory)
Expand Down
126 changes: 63 additions & 63 deletions tests/ga/test_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@
READONLY_FILE_GLOBS, ExtensionsSummary, AgentUpgradeType
from tests.protocol.mocks import mock_wire_protocol, MockHttpResponse
from tests.protocol.mockwiredata import DATA_FILE, DATA_FILE_MULTIPLE_EXT
from tests.tools import AgentTestCase, data_dir, DEFAULT, patch, load_bin_data, Mock, MagicMock, \
clear_singleton_instances, mock_sleep, skip_if_predicate_true
from tests.tools import AgentTestCase, AgentTestCaseWithGetVmSizeMock, data_dir, DEFAULT, patch, load_bin_data, Mock, MagicMock, \
clear_singleton_instances, mock_sleep
from tests.protocol import mockwiredata
from tests.protocol.HttpRequestPredicates import HttpRequestPredicates

Expand Down Expand Up @@ -156,13 +156,13 @@ def check_running(*val, **__):
type(update_handler).is_running = True


class UpdateTestCase(AgentTestCase):
class UpdateTestCase(AgentTestCaseWithGetVmSizeMock):
_test_suite_tmp_dir = None
_agent_zip_dir = None

@classmethod
def setUpClass(cls):
AgentTestCase.setUpClass()
super(UpdateTestCase, cls).setUpClass()
# copy data_dir/ga/WALinuxAgent-0.0.0.0.zip to _test_suite_tmp_dir/waagent-zip/WALinuxAgent-<AGENT_VERSION>.zip
sample_agent_zip = "WALinuxAgent-0.0.0.0.zip"
test_agent_zip = sample_agent_zip.replace("0.0.0.0", AGENT_VERSION)
Expand All @@ -175,7 +175,7 @@ def setUpClass(cls):

@classmethod
def tearDownClass(cls):
AgentTestCase.tearDownClass()
super(UpdateTestCase, cls).tearDownClass()
shutil.rmtree(UpdateTestCase._test_suite_tmp_dir)

@staticmethod
Expand Down Expand Up @@ -1555,62 +1555,6 @@ def test_update_happens_when_extensions_disabled(self, _):
self.update_handler._download_agent_if_upgrade_available = Mock(return_value=True)
self._test_run(invocations=0, calls=0, enable_updates=True, sleep_interval=(300,))

@patch("azurelinuxagent.common.logger.info")
@patch("azurelinuxagent.ga.update.add_event")
def test_telemetry_heartbeat_creates_event(self, patch_add_event, patch_info, *_):
update_handler = get_update_handler()
mock_protocol = WireProtocol("foo.bar")

update_handler.last_telemetry_heartbeat = datetime.utcnow() - timedelta(hours=1)
update_handler._send_heartbeat_telemetry(mock_protocol)
self.assertEqual(1, patch_add_event.call_count)
self.assertTrue(any(call_args[0] == "[HEARTBEAT] Agent {0} is running as the goal state agent {1}"
for call_args in patch_info.call_args), "The heartbeat was not written to the agent's log")


@skip_if_predicate_true(lambda: True, "Enable this test when VMSize bug hanging Uts is fixed.")
@patch("azurelinuxagent.ga.update.add_event")
@patch("azurelinuxagent.common.protocol.imds.ImdsClient")
def test_telemetry_heartbeat_retries_failed_vm_size_fetch(self, mock_imds_factory, patch_add_event, *_):

def validate_single_heartbeat_event_matches_vm_size(vm_size):
heartbeat_event_kwargs = [
kwargs for _, kwargs in patch_add_event.call_args_list
if kwargs.get('op', None) == WALAEventOperation.HeartBeat
]

self.assertEqual(1, len(heartbeat_event_kwargs), "Expected exactly one HeartBeat event, got {0}"\
.format(heartbeat_event_kwargs))

telemetry_message = heartbeat_event_kwargs[0].get("message", "")
self.assertTrue(telemetry_message.endswith(vm_size),
"Expected HeartBeat message ('{0}') to end with the test vmSize value, {1}."\
.format(telemetry_message, vm_size))

with mock_wire_protocol(mockwiredata.DATA_FILE) as mock_protocol:
update_handler = get_update_handler()
update_handler.protocol_util.get_protocol = Mock(return_value=mock_protocol)

# Zero out the _vm_size parameter for test resiliency
update_handler._vm_size = None

mock_imds_client = mock_imds_factory.return_value = Mock()

# First force a vmSize retrieval failure
mock_imds_client.get_compute.side_effect = HttpError(msg="HTTP Test Failure")
update_handler._last_telemetry_heartbeat = datetime.utcnow() - timedelta(hours=1)
update_handler._send_heartbeat_telemetry(mock_protocol)

validate_single_heartbeat_event_matches_vm_size("unknown")
patch_add_event.reset_mock()

# Now provide a vmSize
mock_imds_client.get_compute = lambda: ComputeInfo(vmSize="TestVmSizeValue")
update_handler._last_telemetry_heartbeat = datetime.utcnow() - timedelta(hours=1)
update_handler._send_heartbeat_telemetry(mock_protocol)

validate_single_heartbeat_event_matches_vm_size("TestVmSizeValue")

@staticmethod
def _get_test_ext_handler_instance(protocol, name="OSTCExtensions.ExampleHandlerLinux", version="1.0.0"):
eh = Extension(name=name)
Expand Down Expand Up @@ -2461,9 +2405,9 @@ def test_it_should_not_downgrade_below_daemon_version(self):
@patch('azurelinuxagent.ga.update.get_collect_logs_handler')
@patch('azurelinuxagent.ga.update.get_monitor_handler')
@patch('azurelinuxagent.ga.update.get_env_handler')
class MonitorThreadTest(AgentTestCase):
class MonitorThreadTest(AgentTestCaseWithGetVmSizeMock):
def setUp(self):
AgentTestCase.setUp(self)
super(MonitorThreadTest, self).setUp()
self.event_patch = patch('azurelinuxagent.common.event.add_event')
currentThread().setName("ExtHandler")
protocol = Mock()
Expand Down Expand Up @@ -2869,6 +2813,62 @@ def test_it_should_process_goal_state_only_on_new_goal_state(self):
self.assertEqual(3, exthandlers_handler.report_ext_handlers_status.call_count, "exthandlers_handler.report_ext_handlers_status() should have been called on a new goal state")
self.assertEqual(2, remote_access_handler.run.call_count, "remote_access_handler.run() should have been called on a new goal state")

class HeartbeatTestCase(AgentTestCase):

@patch("azurelinuxagent.common.logger.info")
@patch("azurelinuxagent.ga.update.add_event")
def test_telemetry_heartbeat_creates_event(self, patch_add_event, patch_info, *_):

with mock_wire_protocol(mockwiredata.DATA_FILE) as mock_protocol:
update_handler = get_update_handler()

update_handler.last_telemetry_heartbeat = datetime.utcnow() - timedelta(hours=1)
update_handler._send_heartbeat_telemetry(mock_protocol)
self.assertEqual(1, patch_add_event.call_count)
self.assertTrue(any(call_args[0] == "[HEARTBEAT] Agent {0} is running as the goal state agent {1}"
for call_args in patch_info.call_args), "The heartbeat was not written to the agent's log")

@patch("azurelinuxagent.ga.update.add_event")
@patch("azurelinuxagent.common.protocol.imds.ImdsClient")
def test_telemetry_heartbeat_retries_failed_vm_size_fetch(self, mock_imds_factory, patch_add_event, *_):

def validate_single_heartbeat_event_matches_vm_size(vm_size):
heartbeat_event_kwargs = [
kwargs for _, kwargs in patch_add_event.call_args_list
if kwargs.get('op', None) == WALAEventOperation.HeartBeat
]

self.assertEqual(1, len(heartbeat_event_kwargs), "Expected exactly one HeartBeat event, got {0}"\
.format(heartbeat_event_kwargs))

telemetry_message = heartbeat_event_kwargs[0].get("message", "")
self.assertTrue(telemetry_message.endswith(vm_size),
"Expected HeartBeat message ('{0}') to end with the test vmSize value, {1}."\
.format(telemetry_message, vm_size))

with mock_wire_protocol(mockwiredata.DATA_FILE) as mock_protocol:
update_handler = get_update_handler()
update_handler.protocol_util.get_protocol = Mock(return_value=mock_protocol)

# Zero out the _vm_size parameter for test resiliency
update_handler._vm_size = None

mock_imds_client = mock_imds_factory.return_value = Mock()

# First force a vmSize retrieval failure
mock_imds_client.get_compute.side_effect = HttpError(msg="HTTP Test Failure")
update_handler._last_telemetry_heartbeat = datetime.utcnow() - timedelta(hours=1)
update_handler._send_heartbeat_telemetry(mock_protocol)

validate_single_heartbeat_event_matches_vm_size("unknown")
patch_add_event.reset_mock()

# Now provide a vmSize
mock_imds_client.get_compute = lambda: ComputeInfo(vmSize="TestVmSizeValue")
update_handler._last_telemetry_heartbeat = datetime.utcnow() - timedelta(hours=1)
update_handler._send_heartbeat_telemetry(mock_protocol)

validate_single_heartbeat_event_matches_vm_size("TestVmSizeValue")

class GoalStateIntervalTestCase(AgentTestCase):
def test_initial_goal_state_period_should_default_to_goal_state_period(self):
Expand Down
6 changes: 6 additions & 0 deletions tests/protocol/mockwiredata.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from azurelinuxagent.common.utils.textutil import parse_doc, find, findall
from tests.protocol.HttpRequestPredicates import HttpRequestPredicates
from tests.tools import load_bin_data, load_data, MagicMock, Mock
from azurelinuxagent.common.protocol.imds import IMDS_ENDPOINT
from azurelinuxagent.common.exception import HttpError, ResourceGoneError
from azurelinuxagent.common.future import httpclient
from azurelinuxagent.common.utils.cryptutil import CryptUtil
Expand All @@ -37,6 +38,7 @@
"trans_prv": "wire/trans_prv",
"trans_cert": "wire/trans_cert",
"test_ext": "ext/sample_ext-1.3.0.zip",
"imds_info": "imds/valid.json",
"remote_access": None,
"in_vm_artifacts_profile": None,
"vm_settings": None,
Expand Down Expand Up @@ -161,6 +163,7 @@ def __init__(self, data_files=None):
self.in_vm_artifacts_profile = None
self.vm_settings = None
self.etag = None
self.imds_info = None

self.reload()

Expand All @@ -177,6 +180,7 @@ def reload(self):
self.ga_manifest = load_data(self.data_files.get("ga_manifest"))
self.trans_prv = load_data(self.data_files.get("trans_prv"))
self.trans_cert = load_data(self.data_files.get("trans_cert"))
self.imds_info = json.loads(load_data(self.data_files.get("imds_info")))
self.ext = load_bin_data(self.data_files.get("test_ext"))

vm_settings = self.data_files.get("vm_settings")
Expand Down Expand Up @@ -236,6 +240,8 @@ def mock_http_get(self, url, *_, **kwargs):
content = self.vm_settings
response_headers = [('ETag', self.etag)]
self.call_counts["vm_settings"] += 1
elif '{0}/metadata/compute'.format(IMDS_ENDPOINT) in url:
content = json.dumps(self.imds_info.get("compute", "{}"))

else:
# A stale GoalState results in a 400 from the HostPlugin
Expand Down
2 changes: 2 additions & 0 deletions tests/protocol/test_metadata_server_migration_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,5 +128,7 @@ def tearDown(self):
os.remove(path)
# pylint: enable=redefined-builtin

super(TestMetadataServerMigrationUtil, self).tearDown()

if __name__ == '__main__':
unittest.main()
2 changes: 2 additions & 0 deletions tests/protocol/test_protocol_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ def tearDown(self):
if os.path.exists(endpoint_path):
os.remove(endpoint_path)

super(TestProtocolUtil, self).tearDown()

def test_get_protocol_util_should_return_same_object_for_same_thread(self, _):
protocol_util1 = get_protocol_util()
protocol_util2 = get_protocol_util()
Expand Down
16 changes: 16 additions & 0 deletions tests/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,22 @@ def create_script(script_file, contents):
os.chmod(script_file, stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR)


class AgentTestCaseWithGetVmSizeMock(AgentTestCase):

def setUp(self):

self._get_vm_size_patch = patch('azurelinuxagent.ga.update.UpdateHandler._get_vm_size', return_value="unknown")
self._get_vm_size_patch.start()

super(AgentTestCaseWithGetVmSizeMock, self).setUp()

def tearDown(self):

if self._get_vm_size_patch:
self._get_vm_size_patch.stop()

super(AgentTestCaseWithGetVmSizeMock, self).tearDown()

def load_data(name):
"""Load test data"""
path = os.path.join(data_dir, name)
Expand Down
6 changes: 1 addition & 5 deletions tests/utils/test_archive.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the Apache License.
import os
import shutil
import tempfile
import zipfile
from datetime import datetime, timedelta
Expand All @@ -23,14 +22,11 @@

class TestArchive(AgentTestCase):
def setUp(self):
super(TestArchive, self).setUp()
prefix = "{0}_".format(self.__class__.__name__)

self.tmp_dir = tempfile.mkdtemp(prefix=prefix)

def tearDown(self):
if not debug and self.tmp_dir is not None:
shutil.rmtree(self.tmp_dir)

def _write_file(self, filename, contents=None):
full_name = os.path.join(self.tmp_dir, filename)
fileutil.mkdir(os.path.dirname(full_name))
Expand Down
5 changes: 2 additions & 3 deletions tests/utils/test_extension_process_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
# Requires Python 2.6+ and Openssl 1.0+
#
import os
import shutil
import subprocess
import tempfile

Expand All @@ -39,8 +38,8 @@ def setUp(self):
def tearDown(self):
self.stderr.close()
self.stdout.close()
if self.tmp_dir is not None:
shutil.rmtree(self.tmp_dir)

super(TestProcessUtils, self).tearDown()

def test_wait_for_process_completion_or_timeout_should_terminate_cleanly(self):
process = subprocess.Popen(
Expand Down

0 comments on commit 899a38e

Please sign in to comment.