From d18634727a62b0418a62499bec67fb4970169798 Mon Sep 17 00:00:00 2001 From: nam Date: Fri, 6 Mar 2020 13:18:59 -0800 Subject: [PATCH 1/5] Enable WireClient tests --- azurelinuxagent/common/protocol/wire.py | 5 - tests/protocol/mock_wire_protocol.py | 47 +-- tests/protocol/mockwiredata.py | 23 +- tests/protocol/test_hostplugin.py | 282 +++++++-------- tests/protocol/test_wire.py | 459 ++++++++++++------------ 5 files changed, 422 insertions(+), 394 deletions(-) diff --git a/azurelinuxagent/common/protocol/wire.py b/azurelinuxagent/common/protocol/wire.py index 8bd0951fd3..81a5c10039 100644 --- a/azurelinuxagent/common/protocol/wire.py +++ b/azurelinuxagent/common/protocol/wire.py @@ -520,11 +520,6 @@ def __init__(self, endpoint): logger.info("Wire server endpoint:{0}", endpoint) self._endpoint = endpoint self._goal_state = None - self._hosting_env = None - self._shared_conf = None - self._remote_access = None - self._certs = None - self._ext_conf = None self._host_plugin = None self.status_blob = StatusBlob(self) self.goal_state_flusher = StateFlusher(conf.get_lib_dir()) diff --git a/tests/protocol/mock_wire_protocol.py b/tests/protocol/mock_wire_protocol.py index cafd738606..d2657493af 100644 --- a/tests/protocol/mock_wire_protocol.py +++ b/tests/protocol/mock_wire_protocol.py @@ -18,7 +18,7 @@ import contextlib import re from azurelinuxagent.common.protocol.wire import WireProtocol -from azurelinuxagent.common.utils.restutil import KNOWN_WIRESERVER_IP, http_get +from azurelinuxagent.common.utils.restutil import KNOWN_WIRESERVER_IP, http_request, DEFAULT_RETRIES, RETRY_CODES, DELAY_IN_SECONDS from tests.tools import patch from tests.protocol.mockwiredata import WireProtocolData @@ -28,21 +28,22 @@ def create(mock_wire_data_file): Creates a mock WireProtocol object that will return the data specified by 'mock_wire_data_file' (which must follow the structure of the data files defined in tests/protocol/mockwiredata.py). - NOTE: This function creates mocks for azurelinuxagent.common.utils.restutil.http_get and + NOTE: This function creates mocks for azurelinuxagent.common.utils.restutil.http_request and azurelinuxagent.common.protocol.wire.CryptUtil. These mocks can be stopped using the methods stop_mock_http_get() and stop_mock_crypt_util(). The return value is an instance of WireProtocol augmented with these properties/methods: * mock_wire_data - the WireProtocolData constructed from the mock_wire_data_file parameter. - * stop_mock_http_get() - stops the mock for restutil.http_get + * stop_mock_http_request() - stops the mock for restutil.http_request * stop_mock_crypt_util() - stops the mock for CrypUtil + * stop() - stops both mocks """ - def stop_mock_http_get(): - if stop_mock_http_get.mock is not None: - stop_mock_http_get.mock.stop() - stop_mock_http_get.mock = None - stop_mock_http_get.mock = None + def stop_mock_http_request(): + if stop_mock_http_request.mock is not None: + stop_mock_http_request.mock.stop() + stop_mock_http_request.mock = None + stop_mock_http_request.mock = None def stop_mock_crypt_util(): if stop_mock_crypt_util.mock is not None: @@ -50,26 +51,33 @@ def stop_mock_crypt_util(): stop_mock_crypt_util.mock = None stop_mock_crypt_util.mock = None + def stop(): + stop_mock_crypt_util() + stop_mock_http_request() + protocol = WireProtocol(KNOWN_WIRESERVER_IP) protocol.mock_wire_data = WireProtocolData(mock_wire_data_file) - protocol.stop_mock_http_get = stop_mock_http_get + protocol.stop_mock_http_request = stop_mock_http_request protocol.stop_mock_crypt_util = stop_mock_crypt_util + protocol.stop = stop try: - # To minimize the impact of mocking restutil.http_get we only use the mock data for requests + # To minimize the impact of mocking restutil.http_request we only use the mock data for requests # to the wireserver or requests starting with "mock-goal-state" - original_http_get = http_get - mock_data_re = re.compile(r'https?://(mock-goal-state|{0}).*'.format(KNOWN_WIRESERVER_IP.replace(r'.', r'\.')), re.IGNORECASE) - def mock_http_get(url, *args, **kwargs): - if mock_data_re.match(url) is None: - return original_http_get(url, *args, **kwargs) - return protocol.mock_wire_data.mock_http_get(url, *args, **kwargs) + original_http_request = http_request + + def mock_http_request(method, url, data, headers=None, use_proxy=False, max_retry=DEFAULT_RETRIES, retry_codes=RETRY_CODES, retry_delay=DELAY_IN_SECONDS): + if method == 'GET' and mock_data_re.match(url) is not None: + return protocol.mock_wire_data.mock_http_get(url, headers, use_proxy, max_retry, retry_codes, retry_delay) + elif method == 'POST': + return protocol.mock_wire_data.mock_http_post(url, data, headers, use_proxy, max_retry, retry_codes, retry_delay) + return original_http_request(method, url, data, headers, use_proxy, max_retry, retry_codes, retry_delay) - p = patch("azurelinuxagent.common.utils.restutil.http_get", side_effect=mock_http_get) + p = patch("azurelinuxagent.common.utils.restutil.http_request", side_effect=mock_http_request) p.start() - stop_mock_http_get.mock = p + stop_mock_http_request.mock = p p = patch("azurelinuxagent.common.protocol.wire.CryptUtil", side_effect=protocol.mock_wire_data.mock_crypt_util) p.start() @@ -80,5 +88,4 @@ def mock_http_get(url, *args, **kwargs): yield protocol finally: - protocol.stop_mock_crypt_util() - protocol.stop_mock_http_get() + protocol.stop() diff --git a/tests/protocol/mockwiredata.py b/tests/protocol/mockwiredata.py index 5c8833eead..cdf1b9af3c 100644 --- a/tests/protocol/mockwiredata.py +++ b/tests/protocol/mockwiredata.py @@ -86,6 +86,7 @@ def __init__(self, data_files=DATA_FILE): self.call_counts = { "comp=versions": 0, "/versions": 0, + "/HealthService": 0, "goalstate": 0, "hostingenvuri": 0, "sharedconfiguri": 0, @@ -135,13 +136,10 @@ def mock_http_get(self, url, *args, **kwargs): resp = MagicMock() resp.status = httpclient.OK - # wire server versions - if "comp=versions" in url: + if "comp=versions" in url: # wire server versions content = self.version_info self.call_counts["comp=versions"] += 1 - - # HostPlugin versions - elif "/versions" in url: + elif "/versions" in url: # HostPlugin versions content = '["2015-09-01"]' self.call_counts["/versions"] += 1 elif "goalstate" in url: @@ -201,6 +199,21 @@ def mock_http_get(self, url, *args, **kwargs): resp.read = Mock(return_value=content.encode("utf-8")) return resp + def mock_http_post(self, url, *args, **kwargs): + content = None + + resp = MagicMock() + resp.status = httpclient.OK + + if url.endswith('/HealthService'): + self.call_counts['/HealthService'] += 1 + content = '' + else: + raise Exception("Bad url {0}".format(url)) + + resp.read = Mock(return_value=content.encode("utf-8")) + return resp + def mock_crypt_util(self, *args, **kw): #Partially patch instance method of class CryptUtil cryptutil = CryptUtil(*args, **kw) diff --git a/tests/protocol/test_hostplugin.py b/tests/protocol/test_hostplugin.py index 58c7e25d1b..496ea91aa1 100644 --- a/tests/protocol/test_hostplugin.py +++ b/tests/protocol/test_hostplugin.py @@ -33,7 +33,7 @@ from azurelinuxagent.common.utils import restutil from tests.protocol import mock_wire_protocol from tests.protocol.mockwiredata import WireProtocolData, DATA_FILE, DATA_FILE_NO_EXT -from tests.protocol.test_wire import MockResponse +from tests.protocol.test_wire import MockResponse as TestWireMockResponse from tests.tools import AgentTestCase, PY_VERSION_MAJOR, Mock, patch if sys.version_info[0] == 3: @@ -403,6 +403,7 @@ def test_validate_http_request(self): with mock_wire_protocol.create(DATA_FILE) as protocol: test_goal_state = protocol.client._goal_state + plugin = protocol.client.get_host_plugin() status_blob = protocol.client.status_blob status_blob.data = faux_status @@ -417,9 +418,6 @@ def test_validate_http_request(self): with patch.object(restutil, "http_request") as patch_http: patch_http.return_value = Mock(status=httpclient.OK) - protocol.client.get_goal_state = Mock(return_value=test_goal_state) - plugin = protocol.client.get_host_plugin() - with patch.object(plugin, 'get_api_versions') as patch_api: patch_api.return_value = API_VERSION plugin.put_vm_status(status_blob, sas_url, block_blob_type) @@ -564,151 +562,153 @@ def test_validate_get_extension_artifacts(self): self.assertTrue(k in actual_headers) self.assertEqual(expected_headers[k], actual_headers[k]) - @patch("azurelinuxagent.common.utils.restutil.http_get") - def test_health(self, patch_http_get): + def test_health(self): host_plugin = self._init_host() - patch_http_get.return_value = MockResponse('', 200) - result = host_plugin.get_health() - self.assertEqual(1, patch_http_get.call_count) - self.assertTrue(result) - - patch_http_get.return_value = MockResponse('', 500) - result = host_plugin.get_health() - self.assertFalse(result) - - patch_http_get.side_effect = IOError('client IO error') - try: - host_plugin.get_health() - self.fail('IO error expected to be raised') - except IOError: - # expected - pass - - @patch("azurelinuxagent.common.utils.restutil.http_get", - return_value=MockResponse(status_code=200, body=b'')) - @patch("azurelinuxagent.common.protocol.healthservice.HealthService.report_host_plugin_versions") - def test_ensure_health_service_called(self, patch_http_get, patch_report_versions): + with patch("azurelinuxagent.common.utils.restutil.http_get") as patch_http_get: + patch_http_get.return_value = MockResponse('', 200) + result = host_plugin.get_health() + self.assertEqual(1, patch_http_get.call_count) + self.assertTrue(result) + + patch_http_get.return_value = MockResponse('', 500) + result = host_plugin.get_health() + self.assertFalse(result) + + patch_http_get.side_effect = IOError('client IO error') + try: + host_plugin.get_health() + self.fail('IO error expected to be raised') + except IOError: + # expected + pass + + def test_ensure_health_service_called(self): host_plugin = self._init_host() - host_plugin.get_api_versions() - self.assertEqual(1, patch_http_get.call_count) - self.assertEqual(1, patch_report_versions.call_count) + with patch("azurelinuxagent.common.utils.restutil.http_get", return_value=TestWireMockResponse(status_code=200, body=b'')) as patch_http_get: + with patch("azurelinuxagent.common.protocol.healthservice.HealthService.report_host_plugin_versions") as patch_report_versions: + host_plugin.get_api_versions() + self.assertEqual(1, patch_http_get.call_count) + self.assertEqual(1, patch_report_versions.call_count) - @patch("azurelinuxagent.common.utils.restutil.http_get") - @patch("azurelinuxagent.common.utils.restutil.http_post") - @patch("azurelinuxagent.common.utils.restutil.http_put") - def test_put_status_healthy_signal(self, patch_http_put, patch_http_post, patch_http_get): + def test_put_status_healthy_signal(self): host_plugin = self._init_host() - status_blob = self._init_status_blob() - # get_api_versions - patch_http_get.return_value = MockResponse(api_versions, 200) - # put status blob - patch_http_put.return_value = MockResponse(None, 201) - - host_plugin.put_vm_status(status_blob=status_blob, sas_url=sas_url) - self.assertEqual(1, patch_http_get.call_count) - self.assertEqual(hostplugin_versions_url, patch_http_get.call_args[0][0]) - - self.assertEqual(2, patch_http_put.call_count) - self.assertEqual(hostplugin_status_url, patch_http_put.call_args_list[0][0][0]) - self.assertEqual(hostplugin_status_url, patch_http_put.call_args_list[1][0][0]) - - self.assertEqual(2, patch_http_post.call_count) - - # signal for /versions - self.assertEqual(health_service_url, patch_http_post.call_args_list[0][0][0]) - jstr = patch_http_post.call_args_list[0][0][1] - obj = json.loads(jstr) - self.assertEqual(1, len(obj['Observations'])) - self.assertTrue(obj['Observations'][0]['IsHealthy']) - self.assertEqual('GuestAgentPluginVersions', obj['Observations'][0]['ObservationName']) - - # signal for /status - self.assertEqual(health_service_url, patch_http_post.call_args_list[1][0][0]) - jstr = patch_http_post.call_args_list[1][0][1] - obj = json.loads(jstr) - self.assertEqual(1, len(obj['Observations'])) - self.assertTrue(obj['Observations'][0]['IsHealthy']) - self.assertEqual('GuestAgentPluginStatus', obj['Observations'][0]['ObservationName']) - - @patch("azurelinuxagent.common.utils.restutil.http_get") - @patch("azurelinuxagent.common.utils.restutil.http_post") - @patch("azurelinuxagent.common.utils.restutil.http_put") - def test_put_status_unhealthy_signal_transient(self, patch_http_put, patch_http_post, patch_http_get): + + with patch("azurelinuxagent.common.utils.restutil.http_get") as patch_http_get: + with patch("azurelinuxagent.common.utils.restutil.http_post") as patch_http_post: + with patch("azurelinuxagent.common.utils.restutil.http_put") as patch_http_put: + status_blob = self._init_status_blob() + # get_api_versions + patch_http_get.return_value = MockResponse(api_versions, 200) + # put status blob + patch_http_put.return_value = MockResponse(None, 201) + + host_plugin.put_vm_status(status_blob=status_blob, sas_url=sas_url) + self.assertEqual(1, patch_http_get.call_count) + self.assertEqual(hostplugin_versions_url, patch_http_get.call_args[0][0]) + + self.assertEqual(2, patch_http_put.call_count) + self.assertEqual(hostplugin_status_url, patch_http_put.call_args_list[0][0][0]) + self.assertEqual(hostplugin_status_url, patch_http_put.call_args_list[1][0][0]) + + self.assertEqual(2, patch_http_post.call_count) + + # signal for /versions + self.assertEqual(health_service_url, patch_http_post.call_args_list[0][0][0]) + jstr = patch_http_post.call_args_list[0][0][1] + obj = json.loads(jstr) + self.assertEqual(1, len(obj['Observations'])) + self.assertTrue(obj['Observations'][0]['IsHealthy']) + self.assertEqual('GuestAgentPluginVersions', obj['Observations'][0]['ObservationName']) + + # signal for /status + self.assertEqual(health_service_url, patch_http_post.call_args_list[1][0][0]) + jstr = patch_http_post.call_args_list[1][0][1] + obj = json.loads(jstr) + self.assertEqual(1, len(obj['Observations'])) + self.assertTrue(obj['Observations'][0]['IsHealthy']) + self.assertEqual('GuestAgentPluginStatus', obj['Observations'][0]['ObservationName']) + + def test_put_status_unhealthy_signal_transient(self): host_plugin = self._init_host() - status_blob = self._init_status_blob() - # get_api_versions - patch_http_get.return_value = MockResponse(api_versions, 200) - # put status blob - patch_http_put.return_value = MockResponse(None, 500) - - with self.assertRaises(HttpError): - host_plugin.put_vm_status(status_blob=status_blob, sas_url=sas_url) - - self.assertEqual(1, patch_http_get.call_count) - self.assertEqual(hostplugin_versions_url, patch_http_get.call_args[0][0]) - - self.assertEqual(1, patch_http_put.call_count) - self.assertEqual(hostplugin_status_url, patch_http_put.call_args[0][0]) - - self.assertEqual(2, patch_http_post.call_count) - - # signal for /versions - self.assertEqual(health_service_url, patch_http_post.call_args_list[0][0][0]) - jstr = patch_http_post.call_args_list[0][0][1] - obj = json.loads(jstr) - self.assertEqual(1, len(obj['Observations'])) - self.assertTrue(obj['Observations'][0]['IsHealthy']) - self.assertEqual('GuestAgentPluginVersions', obj['Observations'][0]['ObservationName']) - - # signal for /status - self.assertEqual(health_service_url, patch_http_post.call_args_list[1][0][0]) - jstr = patch_http_post.call_args_list[1][0][1] - obj = json.loads(jstr) - self.assertEqual(1, len(obj['Observations'])) - self.assertTrue(obj['Observations'][0]['IsHealthy']) - self.assertEqual('GuestAgentPluginStatus', obj['Observations'][0]['ObservationName']) - - @patch("azurelinuxagent.common.utils.restutil.http_get") - @patch("azurelinuxagent.common.utils.restutil.http_post") - @patch("azurelinuxagent.common.utils.restutil.http_put") - def test_put_status_unhealthy_signal_permanent(self, patch_http_put, patch_http_post, patch_http_get): + + with patch("azurelinuxagent.common.utils.restutil.http_get") as patch_http_get: + with patch("azurelinuxagent.common.utils.restutil.http_post") as patch_http_post: + with patch("azurelinuxagent.common.utils.restutil.http_put") as patch_http_put: + status_blob = self._init_status_blob() + # get_api_versions + patch_http_get.return_value = MockResponse(api_versions, 200) + # put status blob + patch_http_put.return_value = MockResponse(None, 500) + + with self.assertRaises(HttpError): + host_plugin.put_vm_status(status_blob=status_blob, sas_url=sas_url) + + self.assertEqual(1, patch_http_get.call_count) + self.assertEqual(hostplugin_versions_url, patch_http_get.call_args[0][0]) + + self.assertEqual(1, patch_http_put.call_count) + self.assertEqual(hostplugin_status_url, patch_http_put.call_args[0][0]) + + self.assertEqual(2, patch_http_post.call_count) + + # signal for /versions + self.assertEqual(health_service_url, patch_http_post.call_args_list[0][0][0]) + jstr = patch_http_post.call_args_list[0][0][1] + obj = json.loads(jstr) + self.assertEqual(1, len(obj['Observations'])) + self.assertTrue(obj['Observations'][0]['IsHealthy']) + self.assertEqual('GuestAgentPluginVersions', obj['Observations'][0]['ObservationName']) + + # signal for /status + self.assertEqual(health_service_url, patch_http_post.call_args_list[1][0][0]) + jstr = patch_http_post.call_args_list[1][0][1] + obj = json.loads(jstr) + self.assertEqual(1, len(obj['Observations'])) + self.assertTrue(obj['Observations'][0]['IsHealthy']) + self.assertEqual('GuestAgentPluginStatus', obj['Observations'][0]['ObservationName']) + + def test_put_status_unhealthy_signal_permanent(self): host_plugin = self._init_host() - status_blob = self._init_status_blob() - # get_api_versions - patch_http_get.return_value = MockResponse(api_versions, 200) - # put status blob - patch_http_put.return_value = MockResponse(None, 500) - - host_plugin.status_error_state.is_triggered = Mock(return_value=True) - - with self.assertRaises(HttpError): - host_plugin.put_vm_status(status_blob=status_blob, sas_url=sas_url) - - self.assertEqual(1, patch_http_get.call_count) - self.assertEqual(hostplugin_versions_url, patch_http_get.call_args[0][0]) - - self.assertEqual(1, patch_http_put.call_count) - self.assertEqual(hostplugin_status_url, patch_http_put.call_args[0][0]) - - self.assertEqual(2, patch_http_post.call_count) - - # signal for /versions - self.assertEqual(health_service_url, patch_http_post.call_args_list[0][0][0]) - jstr = patch_http_post.call_args_list[0][0][1] - obj = json.loads(jstr) - self.assertEqual(1, len(obj['Observations'])) - self.assertTrue(obj['Observations'][0]['IsHealthy']) - self.assertEqual('GuestAgentPluginVersions', obj['Observations'][0]['ObservationName']) - - # signal for /status - self.assertEqual(health_service_url, patch_http_post.call_args_list[1][0][0]) - jstr = patch_http_post.call_args_list[1][0][1] - obj = json.loads(jstr) - self.assertEqual(1, len(obj['Observations'])) - self.assertFalse(obj['Observations'][0]['IsHealthy']) - self.assertEqual('GuestAgentPluginStatus', obj['Observations'][0]['ObservationName']) + + with patch("azurelinuxagent.common.utils.restutil.http_get") as patch_http_get: + with patch("azurelinuxagent.common.utils.restutil.http_post") as patch_http_post: + with patch("azurelinuxagent.common.utils.restutil.http_put") as patch_http_put: + status_blob = self._init_status_blob() + # get_api_versions + patch_http_get.return_value = MockResponse(api_versions, 200) + # put status blob + patch_http_put.return_value = MockResponse(None, 500) + + host_plugin.status_error_state.is_triggered = Mock(return_value=True) + + with self.assertRaises(HttpError): + host_plugin.put_vm_status(status_blob=status_blob, sas_url=sas_url) + + self.assertEqual(1, patch_http_get.call_count) + self.assertEqual(hostplugin_versions_url, patch_http_get.call_args[0][0]) + + self.assertEqual(1, patch_http_put.call_count) + self.assertEqual(hostplugin_status_url, patch_http_put.call_args[0][0]) + + self.assertEqual(2, patch_http_post.call_count) + + # signal for /versions + self.assertEqual(health_service_url, patch_http_post.call_args_list[0][0][0]) + jstr = patch_http_post.call_args_list[0][0][1] + obj = json.loads(jstr) + self.assertEqual(1, len(obj['Observations'])) + self.assertTrue(obj['Observations'][0]['IsHealthy']) + self.assertEqual('GuestAgentPluginVersions', obj['Observations'][0]['ObservationName']) + + # signal for /status + self.assertEqual(health_service_url, patch_http_post.call_args_list[1][0][0]) + jstr = patch_http_post.call_args_list[1][0][1] + obj = json.loads(jstr) + self.assertEqual(1, len(obj['Observations'])) + self.assertFalse(obj['Observations'][0]['IsHealthy']) + self.assertEqual('GuestAgentPluginStatus', obj['Observations'][0]['ObservationName']) @patch("azurelinuxagent.common.protocol.hostplugin.HostPluginProtocol.should_report", return_value=True) @patch("azurelinuxagent.common.protocol.healthservice.HealthService.report_host_plugin_extension_artifact") diff --git a/tests/protocol/test_wire.py b/tests/protocol/test_wire.py index 4ecc039952..7aef0b9675 100644 --- a/tests/protocol/test_wire.py +++ b/tests/protocol/test_wire.py @@ -24,7 +24,7 @@ import contextlib from azurelinuxagent.common.exception import InvalidContainerError, ResourceGoneError, ProtocolError, \ - ExtensionDownloadError + ExtensionDownloadError, HttpError from azurelinuxagent.common.future import httpclient from azurelinuxagent.common.protocol.hostplugin import HostPluginProtocol from azurelinuxagent.common.protocol.goal_state import ExtensionsConfig @@ -209,26 +209,28 @@ def test_get_host_ga_plugin(self, *args): self.assertEqual(goal_state.container_id, host_plugin.container_id) self.assertEqual(goal_state.role_config_name, host_plugin.role_config_name) - @skip_if_predicate_true(lambda: True, "Needs to be re-enabled before release 2.2.47") - def test_upload_status_blob_default(self, *args): - """ - Default status blob method is HostPlugin. - """ - with create_mock_protocol(status_upload_blob=testurl, status_upload_blob_type=testtype) as protocol: - protocol.client.status_blob.vm_status = VMStatus(message="Ready", status="Ready") + def test_upload_status_blob_should_use_the_host_channel_by_default(self, *_): + with mock_wire_protocol.create(mockwiredata.DATA_FILE) as protocol: + protocol.client.get_host_plugin() # force initialization of the host plugin - with patch.object(WireClient, "get_goal_state") as patch_get_goal_state: - with patch.object(HostPluginProtocol, "put_vm_status") as patch_host_ga_plugin_upload: - with patch.object(StatusBlob, "upload") as patch_default_upload: - HostPluginProtocol.set_default_channel(False) - protocol.client.upload_status_blob() + original_http_request = restutil.http_request - # do not call the direct method unless host plugin fails - patch_default_upload.assert_not_called() - # host plugin always fetches a goal state - patch_get_goal_state.assert_called_once_with() - # host plugin uploads the status blob - patch_host_ga_plugin_upload.assert_called_once_with(ANY, testurl, 'BlockBlob') + def http_request(method, url, *args, **kwargs): + if method == 'PUT': + if protocol.get_endpoint() in url and url.endswith('/status'): + http_request.urls.append(url) + return MockResponse(body=b'', status_code=200) + self.fail('The upload status request was sent to the wrong uri: {0}'.format(uri)) + return original_http_request(method, url, *args, **kwargs) + http_request.urls = [] + + with patch("azurelinuxagent.common.utils.restutil.http_request", side_effect=http_request) as mock_request: + HostPluginProtocol.set_default_channel(False) + protocol.client.status_blob.vm_status = VMStatus(message="Ready", status="Ready") + + protocol.client.upload_status_blob() + + self.assertEqual(len(http_request.urls), 1, 'Expected one upload request to the host: [{0}]'.format(http_request.urls)) def test_upload_status_blob_host_ga_plugin(self, *_): with create_mock_protocol(status_upload_blob=testurl, status_upload_blob_type=testtype) as protocol: @@ -486,6 +488,9 @@ def test_report_event_large_event(self, patch_send_event, *args): class TestWireClient(AgentTestCase): + @staticmethod + def _is_extension_artifact_host_request(expected_url, actual_url, **kwargs): + return actual_url.endswith('/extensionArtifact') and kwargs['headers']['x-ms-artifact-location'] == expected_url def test_get_ext_conf_without_uri(self, *args): with mock_wire_protocol.create(mockwiredata.DATA_FILE_NO_EXT) as protocol: @@ -516,227 +521,235 @@ def test_get_ext_conf_with_uri(self, *args): self.assertEqual("BlockBlob", ext_conf.status_upload_blob_type) self.assertEqual(None, ext_conf.artifacts_profile_blob) - @skip_if_predicate_true(lambda: True, "Needs to be re-enabled before release 2.2.47") - @patch("azurelinuxagent.common.protocol.wire.WireClient.get_goal_state") - @patch("azurelinuxagent.common.protocol.hostplugin.HostPluginProtocol.get_artifact_request") - def test_download_ext_handler_pkg_should_not_invoke_host_channel_when_direct_channel_succeeds(self, - mock_get_artifact_request, - *args): - mock_get_artifact_request.return_value = "dummy_url", "dummy_header" - protocol = WireProtocol("foo.bar") - HostPluginProtocol.set_default_channel(False) - - mock_successful_response = MockResponse(body=b"OK", status_code=200) - destination = os.path.join(self.tmp_dir, "tmp_file") - - # Direct channel succeeds - with patch("azurelinuxagent.common.utils.restutil._http_request", return_value=mock_successful_response): - with patch("azurelinuxagent.common.protocol.wire.WireClient.update_goal_state") as mock_update_goal_state: - with patch("azurelinuxagent.common.protocol.wire.WireClient.stream", wraps=protocol.client.stream) \ - as patch_direct: - with patch( - "azurelinuxagent.common.protocol.wire.WireProtocol._download_ext_handler_pkg_through_host", - wraps=protocol._download_ext_handler_pkg_through_host) as patch_host: - ret = protocol.download_ext_handler_pkg("uri", destination) - self.assertEquals(ret, True) - - self.assertEquals(patch_host.call_count, 0) - self.assertEquals(patch_direct.call_count, 1) - self.assertEquals(mock_update_goal_state.call_count, 0) - - self.assertEquals(HostPluginProtocol.is_default_channel(), False) - - @skip_if_predicate_true(lambda: True, "Needs to be re-enabled before release 2.2.47") - @patch("azurelinuxagent.common.protocol.wire.WireClient.get_goal_state") - @patch("azurelinuxagent.common.protocol.hostplugin.HostPluginProtocol.get_artifact_request") - def test_download_ext_handler_pkg_should_use_host_channel_when_direct_channel_fails(self, mock_get_artifact_request, - *args): - mock_get_artifact_request.return_value = "dummy_url", "dummy_header" - protocol = WireProtocol("foo.bar") - HostPluginProtocol.set_default_channel(False) - - mock_failed_response = MockResponse(body=b"", status_code=httpclient.GONE) - mock_successful_response = MockResponse(body=b"OK", status_code=200) - destination = os.path.join(self.tmp_dir, "tmp_file") - - # Direct channel fails, host channel succeeds. Goal state should not have been updated and host channel - # should have been set as default. - with patch("azurelinuxagent.common.utils.restutil._http_request", - side_effect=[mock_failed_response, mock_successful_response]): - with patch("azurelinuxagent.common.protocol.wire.WireClient.update_goal_state") as mock_update_goal_state: - with patch("azurelinuxagent.common.protocol.wire.WireClient.stream", wraps=protocol.client.stream) \ - as patch_direct: - with patch( - "azurelinuxagent.common.protocol.wire.WireProtocol._download_ext_handler_pkg_through_host", - wraps=protocol._download_ext_handler_pkg_through_host) as patch_host: - ret = protocol.download_ext_handler_pkg("uri", destination) - self.assertEquals(ret, True) - - self.assertEquals(patch_host.call_count, 1) - # The host channel calls the direct function under the covers - self.assertEquals(patch_direct.call_count, 1 + patch_host.call_count) - self.assertEquals(mock_update_goal_state.call_count, 0) - - self.assertEquals(HostPluginProtocol.is_default_channel(), True) - - @skip_if_predicate_true(lambda: True, "Needs to be re-enabled before release 2.2.47") - @patch("azurelinuxagent.common.protocol.wire.WireClient.get_goal_state") - @patch("azurelinuxagent.common.protocol.hostplugin.HostPluginProtocol.get_artifact_request") - def test_download_ext_handler_pkg_should_retry_the_host_channel_after_refreshing_host_plugin(self, - mock_get_artifact_request, - *args): - mock_get_artifact_request.return_value = "dummy_url", "dummy_header" - protocol = WireProtocol("foo.bar") - HostPluginProtocol.set_default_channel(False) - - mock_failed_response = MockResponse(body=b"", status_code=httpclient.GONE) - mock_successful_response = MockResponse(body=b"OK", status_code=200) - destination = os.path.join(self.tmp_dir, "tmp_file") - - # Direct channel fails, host channel fails due to stale goal state, host channel succeeds after refresh. - # As a consequence, goal state should have been updated and host channel should have been set as default. - with patch("azurelinuxagent.common.utils.restutil._http_request", - side_effect=[mock_failed_response, mock_failed_response, mock_successful_response]): - with patch( - "azurelinuxagent.common.protocol.wire.WireClient.update_host_plugin_from_goal_state") as mock_update_host_plugin_from_goal_state: - with patch("azurelinuxagent.common.protocol.wire.WireClient.stream", wraps=protocol.client.stream) \ - as patch_direct: - with patch( - "azurelinuxagent.common.protocol.wire.WireProtocol._download_ext_handler_pkg_through_host", - wraps=protocol._download_ext_handler_pkg_through_host) as patch_host: - ret = protocol.download_ext_handler_pkg("uri", destination) - self.assertEquals(ret, True) - - self.assertEquals(patch_host.call_count, 2) - # The host channel calls the direct function under the covers - self.assertEquals(patch_direct.call_count, 1 + patch_host.call_count) - self.assertEquals(mock_update_host_plugin_from_goal_state.call_count, 1) - - self.assertEquals(HostPluginProtocol.is_default_channel(), True) - - @skip_if_predicate_true(lambda: True, "Needs to be re-enabled before release 2.2.47") - @patch("azurelinuxagent.common.protocol.wire.WireClient.get_goal_state") - @patch("azurelinuxagent.common.protocol.hostplugin.HostPluginProtocol.get_artifact_request") - def test_download_ext_handler_pkg_should_not_change_default_channel_if_host_fails(self, mock_get_artifact_request, - *args): - mock_get_artifact_request.return_value = "dummy_url", "dummy_header" - protocol = WireProtocol("foo.bar") - HostPluginProtocol.set_default_channel(False) - - mock_failed_response = MockResponse(body=b"", status_code=httpclient.GONE) - destination = os.path.join(self.tmp_dir, "tmp_file") + def test_download_ext_handler_pkg_should_not_invoke_host_channel_when_direct_channel_succeeds(self): + extension_url = 'https://fake_host/fake_extension.zip' + target_file = os.path.join(self.tmp_dir, 'fake_extension.zip') - # Everything fails. Goal state should have been updated and host channel should not have been set as default. - with patch("azurelinuxagent.common.utils.restutil._http_request", return_value=mock_failed_response): - with patch( - "azurelinuxagent.common.protocol.wire.WireClient.update_host_plugin_from_goal_state") as mock_update_host_plugin_from_goal_state: - with patch("azurelinuxagent.common.protocol.wire.WireClient.stream", wraps=protocol.client.stream) \ - as patch_direct: - with patch( - "azurelinuxagent.common.protocol.wire.WireProtocol._download_ext_handler_pkg_through_host", - wraps=protocol._download_ext_handler_pkg_through_host) as patch_host: - ret = protocol.download_ext_handler_pkg("uri", destination) - self.assertEquals(ret, False) - - self.assertEquals(patch_host.call_count, 2) - # The host channel calls the direct function under the covers - self.assertEquals(patch_direct.call_count, 1 + patch_host.call_count) - self.assertEquals(mock_update_host_plugin_from_goal_state.call_count, 1) - - self.assertEquals(HostPluginProtocol.is_default_channel(), False) - - @skip_if_predicate_true(lambda: True, "Needs to be re-enabled before release 2.2.47") - @patch("azurelinuxagent.common.protocol.wire.WireClient.get_goal_state") - @patch("azurelinuxagent.common.protocol.hostplugin.HostPluginProtocol.get_artifact_request") - def test_fetch_manifest_should_not_invoke_host_channel_when_direct_channel_succeeds(self, mock_get_artifact_request, - *args): - mock_get_artifact_request.return_value = "dummy_url", "dummy_header" - client = WireClient("foo.bar") + with mock_wire_protocol.create(mockwiredata.DATA_FILE) as protocol: + original_http_request = restutil.http_request - HostPluginProtocol.set_default_channel(False) - mock_successful_response = MockResponse(body=b"OK", status_code=200) + def http_request(method, url, *args, **kwargs): + if method == 'GET': + if url == extension_url: + http_request.urls.append(url) + return MockResponse(body=b'', status_code=200) + elif TestWireClient._is_extension_artifact_host_request(extension_url, url, **kwargs): + self.fail('The host channel should not have been used') + return original_http_request(method, url, *args, **kwargs) + http_request.urls = [] - # Direct channel succeeds - with patch("azurelinuxagent.common.utils.restutil._http_request", return_value=mock_successful_response): - with patch("azurelinuxagent.common.protocol.wire.WireClient.update_goal_state") as mock_update_goal_state: - with patch("azurelinuxagent.common.protocol.wire.WireClient.fetch", wraps=client.fetch) as patch_direct: - with patch("azurelinuxagent.common.protocol.wire.WireClient.fetch_manifest_through_host", - wraps=client.fetch_manifest_through_host) as patch_host: - ret = client.fetch_manifest([VMAgentManifestUri(uri="uri1")]) - self.assertEquals(ret, "OK") + with patch("azurelinuxagent.common.utils.restutil.http_request", side_effect=http_request): + HostPluginProtocol.set_default_channel(False) - self.assertEquals(patch_host.call_count, 0) - # The host channel calls the direct function under the covers - self.assertEquals(patch_direct.call_count, 1) - self.assertEquals(mock_update_goal_state.call_count, 0) + success = protocol.download_ext_handler_pkg(extension_url, target_file) - self.assertEquals(HostPluginProtocol.is_default_channel(), False) + self.assertEquals(success, True, 'The download should have succeeded') + self.assertEquals(len(http_request.urls), 1, "Unexpected number of HTTP requests: [{0}]".format(http_request.urls)) + self.assertEquals(http_request.urls[0], extension_url, "The extension should have been downloaded over the direct channel") + self.assertTrue(os.path.exists(target_file), 'The extension package was not downloaded') + self.assertEquals(HostPluginProtocol.is_default_channel(), False, "The host channel should not have been set as the default") - @skip_if_predicate_true(lambda: True, "Needs to be re-enabled before release 2.2.47") - @patch("azurelinuxagent.common.protocol.wire.WireClient.get_goal_state") - @patch("azurelinuxagent.common.protocol.hostplugin.HostPluginProtocol.get_artifact_request") - def test_fetch_manifest_should_use_host_channel_when_direct_channel_fails(self, mock_get_artifact_request, *args): - mock_get_artifact_request.return_value = "dummy_url", "dummy_header" - client = WireClient("foo.bar") + def test_download_ext_handler_pkg_should_use_host_channel_when_direct_channel_fails_and_set_host_as_default(self): + extension_url = 'https://fake_host/fake_extension.zip' + target_file = os.path.join(self.tmp_dir, 'fake_extension.zip') - HostPluginProtocol.set_default_channel(False) - - mock_failed_response = MockResponse(body=b"", status_code=httpclient.GONE) - mock_successful_response = MockResponse(body=b"OK", status_code=200) + with mock_wire_protocol.create(mockwiredata.DATA_FILE) as protocol: + original_http_request = restutil.http_request + + def http_request(method, url, *args, **kwargs): + if method == 'GET': + if url == extension_url: + http_request.urls.append(url) + raise HttpError("Exception to fake an error on the direct channel") + elif TestWireClient._is_extension_artifact_host_request(extension_url, url, **kwargs): + http_request.urls.append(url) + return MockResponse(body=b'', status_code=200) + return original_http_request(method, url, *args, **kwargs) + http_request.urls = [] + + with patch("azurelinuxagent.common.utils.restutil.http_request", side_effect=http_request): + HostPluginProtocol.set_default_channel(False) + + success = protocol.download_ext_handler_pkg(extension_url, target_file) + + self.assertEquals(success, True, 'The download should have succeeded') + self.assertEquals(len(http_request.urls), 2, "Unexpected number of HTTP requests: [{0}]".format(http_request.urls)) + self.assertEquals(http_request.urls[0], extension_url, "The first attempt should have been over the direct channel") + self.assertTrue(http_request.urls[1].endswith('/extensionArtifact'), "The retry attempt should have been over the host channel") + self.assertTrue(os.path.exists(target_file), 'The extension package was not downloaded') + self.assertEquals(HostPluginProtocol.is_default_channel(), True, "The host channel should have been set as the default") + + def test_download_ext_handler_pkg_should_retry_the_host_channel_after_refreshing_host_plugin(self): + extension_url = 'https://fake_host/fake_extension.zip' + target_file = os.path.join(self.tmp_dir, 'fake_extension.zip') - # Direct channel fails, host channel succeeds. Goal state should not have been updated and host channel - # should have been set as default - with patch("azurelinuxagent.common.utils.restutil._http_request", - side_effect=[mock_failed_response, mock_successful_response]): - with patch("azurelinuxagent.common.protocol.wire.WireClient.update_goal_state") as mock_update_goal_state: - with patch("azurelinuxagent.common.protocol.wire.WireClient.fetch", wraps=client.fetch) as patch_direct: - with patch("azurelinuxagent.common.protocol.wire.WireClient.fetch_manifest_through_host", - wraps=client.fetch_manifest_through_host) as patch_host: - ret = client.fetch_manifest([VMAgentManifestUri(uri="uri1")]) - self.assertEquals(ret, "OK") + with mock_wire_protocol.create(mockwiredata.DATA_FILE) as protocol: + protocol.client.get_host_plugin() # force initialization of the host plugin + + original_http_request = restutil.http_request + + def http_request(method, url, *args, **kwargs): + if method == 'GET': + if url == extension_url: + http_request.urls.append(url) + raise HttpError("Exception to fake an error on the direct channel") + elif TestWireClient._is_extension_artifact_host_request(extension_url, url, **kwargs): + http_request.urls.append(url) + # fake a stale goal state then succeed once the goal state has been refreshed + if not any(url.endswith('/machine/?comp=goalstate') for url in http_request.urls): + raise ResourceGoneError("Exception to fake an error on the host channel") + else: + return MockResponse(body=b'', status_code=200) + elif url.endswith('/machine/?comp=goalstate'): + http_request.urls.append(url) + return original_http_request(method, url, *args, **kwargs) + http_request.urls = [] + + with patch("azurelinuxagent.common.utils.restutil.http_request", side_effect=http_request): + HostPluginProtocol.set_default_channel(False) + + success = protocol.download_ext_handler_pkg(extension_url, target_file) + + self.assertEquals(success, True, 'The download should have succeeded') + self.assertEquals(len(http_request.urls), 4, "Unexpected number of HTTP requests: [{0}]".format(http_request.urls)) + self.assertEquals(http_request.urls[0], extension_url, "The first attempt should have been over the direct channel") + self.assertTrue(http_request.urls[1].endswith('/extensionArtifact'), "The second attempt should have been over the host channel") + self.assertTrue(http_request.urls[2].endswith('/machine/?comp=goalstate'), "The host channel should have been refreshed the goal state") + self.assertTrue(http_request.urls[3].endswith('/extensionArtifact'), "The third attempt should have been over the host channel") + self.assertTrue(os.path.exists(target_file), 'The extension package was not downloaded') + self.assertEquals(HostPluginProtocol.is_default_channel(), True, "The host channel should have been set as the default") + + def test_download_ext_handler_pkg_should_not_change_default_channel_when_all_channels_fail(self): + extension_url = 'https://fake_host/fake_extension.zip' - self.assertEquals(patch_host.call_count, 1) - # The host channel calls the direct function under the covers - self.assertEquals(patch_direct.call_count, 1 + patch_host.call_count) - self.assertEquals(mock_update_goal_state.call_count, 0) + with mock_wire_protocol.create(mockwiredata.DATA_FILE) as protocol: + protocol.client.get_host_plugin() # force initialization of the host plugin + + original_http_request = restutil.http_request + + def http_request(method, url, *args, **kwargs): + if method == 'GET': + if url == extension_url: + http_request.urls.append(url) + raise HttpError("Exception to fake error on direct channel") + if TestWireClient._is_extension_artifact_host_request(extension_url, url, **kwargs): + http_request.urls.append(url) + raise ResourceGoneError("Exception to fake error on host channel") + elif url.endswith('/machine/?comp=goalstate'): + http_request.urls.append(url) + return original_http_request(method, url, *args, **kwargs) + http_request.urls = [] + + with patch("azurelinuxagent.common.utils.restutil.http_request", side_effect=http_request): + HostPluginProtocol.set_default_channel(False) + + success = protocol.download_ext_handler_pkg(extension_url, "/an-invalid-directory/an-invalid-file.zip") + + self.assertEquals(success, False, "The download should have failed") + self.assertEquals(len(http_request.urls), 4, "Unexpected number of HTTP requests: [{0}]".format(http_request.urls)) + self.assertEquals(http_request.urls[0], extension_url, "The first attempt should have been over the direct channel") + self.assertTrue(http_request.urls[1].endswith('/extensionArtifact'), "The second attempt should have been over the host channel") + self.assertTrue(http_request.urls[2].endswith('/machine/?comp=goalstate'), "The host channel should have been refreshed the goal state") + self.assertTrue(http_request.urls[3].endswith('/extensionArtifact'), "The third attempt should have been over the host channel") + self.assertEquals(HostPluginProtocol.is_default_channel(), False, "The host channel should not have been set as the default") + + def test_fetch_manifest_should_not_invoke_host_channel_when_direct_channel_succeeds(self): + manifest_url = 'https://fake_host/fake_manifest.xml' + manifest_xml = '' - self.assertEquals(HostPluginProtocol.is_default_channel(), True) + with mock_wire_protocol.create(mockwiredata.DATA_FILE) as protocol: + original_http_request = restutil.http_request - # Reset default channel - HostPluginProtocol.set_default_channel(False) + def http_request(method, url, *args, **kwargs): + if method == 'GET': + if url == manifest_url: + return MockResponse(body=manifest_xml.encode('utf-8'), status_code=200) + elif url.endswith('/extensionArtifact'): + self.fail('The Host GA Plugin should not have been invoked') + return original_http_request(method, url, *args, **kwargs) - @skip_if_predicate_true(lambda: True, "Needs to be re-enabled before release 2.2.47") - @patch("azurelinuxagent.common.protocol.wire.WireClient.get_goal_state") - @patch("azurelinuxagent.common.protocol.hostplugin.HostPluginProtocol.get_artifact_request") - def test_fetch_manifest_should_retry_the_host_channel_after_refreshing_the_host_plugin(self, - mock_get_artifact_request, - *args): - mock_get_artifact_request.return_value = "dummy_url", "dummy_header" - client = WireClient("foo.bar") + with patch("azurelinuxagent.common.utils.restutil.http_request", side_effect=http_request) as mock_request: + HostPluginProtocol.set_default_channel(False) - HostPluginProtocol.set_default_channel(False) + manifest = protocol.client.fetch_manifest([VMAgentManifestUri(uri=manifest_url)]) - mock_failed_response = MockResponse(body=b"", status_code=httpclient.GONE) - mock_successful_response = MockResponse(body=b"OK", status_code=200) + self.assertEquals(manifest, manifest_xml) + self.assertTrue(any(args[0][0] == 'GET' and args[0][1] == manifest_url for args in mock_request.call_args_list)) # direct channel + self.assertEquals(HostPluginProtocol.is_default_channel(), False, "The default channel should not have changed") - # Direct channel fails, host channel fails due to stale goal state, host channel succeeds after refresh. - # As a consequence, goal state should have been updated and host channel should have been set as default. - with patch("azurelinuxagent.common.utils.restutil._http_request", - side_effect=[mock_failed_response, mock_failed_response, mock_successful_response]): - with patch( - "azurelinuxagent.common.protocol.wire.WireClient.update_host_plugin_from_goal_state") as mock_update_host_plugin_from_goal_state: - with patch("azurelinuxagent.common.protocol.wire.WireClient.fetch", wraps=client.fetch) as patch_direct: - with patch("azurelinuxagent.common.protocol.wire.WireClient.fetch_manifest_through_host", - wraps=client.fetch_manifest_through_host) as patch_host: - ret = client.fetch_manifest([VMAgentManifestUri(uri="uri1")]) - self.assertEquals(ret, "OK") + def test_fetch_manifest_should_use_host_channel_when_direct_channel_fails_and_set_it_to_default(self): + manifest_url = 'https://fake_host/fake_manifest.xml' + manifest_xml = '' - self.assertEquals(patch_host.call_count, 2) - # The host channel calls the direct function under the covers - self.assertEquals(patch_direct.call_count, 1 + patch_host.call_count) - self.assertEquals(mock_update_host_plugin_from_goal_state.call_count, 1) + with mock_wire_protocol.create(mockwiredata.DATA_FILE) as protocol: + original_http_request = restutil.http_request + + def http_request(method, url, *args, **kwargs): + if method == 'GET': + if url == manifest_url: + http_request.urls.append(url) + raise ResourceGoneError("Exception to fake an error on the direct channel") + elif TestWireClient._is_extension_artifact_host_request(manifest_url, url, **kwargs): + http_request.urls.append(url) + return MockResponse(body=manifest_xml.encode('utf-8'), status_code=200) + return original_http_request(method, url, *args, **kwargs) + http_request.urls = [] + + with patch("azurelinuxagent.common.utils.restutil.http_request", side_effect=http_request): + HostPluginProtocol.set_default_channel(False) + + try: + manifest = protocol.client.fetch_manifest([VMAgentManifestUri(uri=manifest_url)]) + + self.assertEquals(manifest, manifest_xml) + self.assertEquals(len(http_request.urls), 2, "Unexpected number of HTTP requests: [{0}]".format(http_request.urls)) + self.assertEquals(http_request.urls[0], manifest_url, "The first attempt should have been over the direct channel") + self.assertTrue(http_request.urls[1].endswith('/extensionArtifact'), "The retry should have been over the host channel") + self.assertEquals(HostPluginProtocol.is_default_channel(), True, "The host should have been set as the default channel") + finally: + HostPluginProtocol.set_default_channel(False) # Reset default channel + + def test_fetch_manifest_should_retry_the_host_channel_after_refreshing_the_host_plugin_and_set_the_host_as_default(self): + manifest_url = 'https://fake_host/fake_manifest.xml' + manifest_xml = '' - self.assertEquals(HostPluginProtocol.is_default_channel(), True) + with mock_wire_protocol.create(mockwiredata.DATA_FILE) as protocol: + protocol.client.get_host_plugin() # force initialization of the host plugin + + original_http_request = restutil.http_request + + def http_request(method, url, *args, **kwargs): + if method == 'GET': + if url == manifest_url: + http_request.urls.append(url) + raise HttpError("Exception to fake an error on the direct channel") + elif TestWireClient._is_extension_artifact_host_request(manifest_url, url, **kwargs): + http_request.urls.append(url) + # fake a stale goal state then succeed once the goal state has been refreshed + if not any(url.endswith('/machine/?comp=goalstate') for url in http_request.urls): + raise ResourceGoneError("Exception to fake an error on the host channel") + else: + return MockResponse(body=manifest_xml.encode('utf-8'), status_code=200) + elif url.endswith('/machine/?comp=goalstate'): + http_request.urls.append(url) + return original_http_request(method, url, *args, **kwargs) + http_request.urls = [] + + with patch("azurelinuxagent.common.utils.restutil.http_request", side_effect=http_request): + HostPluginProtocol.set_default_channel(False) + + try: + manifest = protocol.client.fetch_manifest([VMAgentManifestUri(uri=manifest_url)]) + + self.assertEquals(manifest, manifest_xml) + self.assertEquals(len(http_request.urls), 4, "Unexpected number of HTTP requests: [{0}]".format(http_request.urls)) + self.assertEquals(http_request.urls[0], manifest_url, "The first attempt should have been over the direct channel") + self.assertTrue(http_request.urls[1].endswith('/extensionArtifact'), "The second attempt should have been over the host channel") + self.assertTrue(http_request.urls[2].endswith('/machine/?comp=goalstate'), "The host channel should have been refreshed the goal state") + self.assertTrue(http_request.urls[3].endswith('/extensionArtifact'), "The third attempt should have been over the host channel") + self.assertEquals(HostPluginProtocol.is_default_channel(), True, "The host should have been set as the default channel") + finally: + HostPluginProtocol.set_default_channel(False) # Reset default channel @patch("azurelinuxagent.common.protocol.wire.WireClient.get_goal_state") @patch("azurelinuxagent.common.protocol.hostplugin.HostPluginProtocol.get_artifact_request") From 9b8370159ec6f53cb57e0d64bd0bf16e7e5df500 Mon Sep 17 00:00:00 2001 From: nam Date: Sat, 7 Mar 2020 15:57:15 -0800 Subject: [PATCH 2/5] Review feedback --- tests/protocol/test_wire.py | 83 ++++++++++++++++++++----------------- 1 file changed, 45 insertions(+), 38 deletions(-) diff --git a/tests/protocol/test_wire.py b/tests/protocol/test_wire.py index 7aef0b9675..5da5084c8f 100644 --- a/tests/protocol/test_wire.py +++ b/tests/protocol/test_wire.py @@ -211,8 +211,6 @@ def test_get_host_ga_plugin(self, *args): def test_upload_status_blob_should_use_the_host_channel_by_default(self, *_): with mock_wire_protocol.create(mockwiredata.DATA_FILE) as protocol: - protocol.client.get_host_plugin() # force initialization of the host plugin - original_http_request = restutil.http_request def http_request(method, url, *args, **kwargs): @@ -538,16 +536,16 @@ def http_request(method, url, *args, **kwargs): return original_http_request(method, url, *args, **kwargs) http_request.urls = [] - with patch("azurelinuxagent.common.utils.restutil.http_request", side_effect=http_request): - HostPluginProtocol.set_default_channel(False) + with patch("azurelinuxagent.common.utils.restutil.http_request", side_effect=http_request): + HostPluginProtocol.set_default_channel(False) - success = protocol.download_ext_handler_pkg(extension_url, target_file) + success = protocol.download_ext_handler_pkg(extension_url, target_file) - self.assertEquals(success, True, 'The download should have succeeded') - self.assertEquals(len(http_request.urls), 1, "Unexpected number of HTTP requests: [{0}]".format(http_request.urls)) - self.assertEquals(http_request.urls[0], extension_url, "The extension should have been downloaded over the direct channel") - self.assertTrue(os.path.exists(target_file), 'The extension package was not downloaded') - self.assertEquals(HostPluginProtocol.is_default_channel(), False, "The host channel should not have been set as the default") + self.assertEquals(success, True, 'The download should have succeeded') + self.assertEquals(len(http_request.urls), 1, "Unexpected number of HTTP requests: [{0}]".format(http_request.urls)) + self.assertEquals(http_request.urls[0], extension_url, "The extension should have been downloaded over the direct channel") + self.assertTrue(os.path.exists(target_file), 'The extension package was not downloaded') + self.assertEquals(HostPluginProtocol.is_default_channel(), False, "The host channel should not have been set as the default") def test_download_ext_handler_pkg_should_use_host_channel_when_direct_channel_fails_and_set_host_as_default(self): extension_url = 'https://fake_host/fake_extension.zip' @@ -567,24 +565,26 @@ def http_request(method, url, *args, **kwargs): return original_http_request(method, url, *args, **kwargs) http_request.urls = [] - with patch("azurelinuxagent.common.utils.restutil.http_request", side_effect=http_request): - HostPluginProtocol.set_default_channel(False) + with patch("azurelinuxagent.common.utils.restutil.http_request", side_effect=http_request): + HostPluginProtocol.set_default_channel(False) - success = protocol.download_ext_handler_pkg(extension_url, target_file) + success = protocol.download_ext_handler_pkg(extension_url, target_file) - self.assertEquals(success, True, 'The download should have succeeded') - self.assertEquals(len(http_request.urls), 2, "Unexpected number of HTTP requests: [{0}]".format(http_request.urls)) - self.assertEquals(http_request.urls[0], extension_url, "The first attempt should have been over the direct channel") - self.assertTrue(http_request.urls[1].endswith('/extensionArtifact'), "The retry attempt should have been over the host channel") - self.assertTrue(os.path.exists(target_file), 'The extension package was not downloaded') - self.assertEquals(HostPluginProtocol.is_default_channel(), True, "The host channel should have been set as the default") + self.assertEquals(success, True, 'The download should have succeeded') + self.assertEquals(len(http_request.urls), 2, "Unexpected number of HTTP requests: [{0}]".format(http_request.urls)) + self.assertEquals(http_request.urls[0], extension_url, "The first attempt should have been over the direct channel") + self.assertTrue(http_request.urls[1].endswith('/extensionArtifact'), "The retry attempt should have been over the host channel") + self.assertTrue(os.path.exists(target_file), 'The extension package was not downloaded') + self.assertEquals(HostPluginProtocol.is_default_channel(), True, "The host channel should have been set as the default") def test_download_ext_handler_pkg_should_retry_the_host_channel_after_refreshing_host_plugin(self): extension_url = 'https://fake_host/fake_extension.zip' target_file = os.path.join(self.tmp_dir, 'fake_extension.zip') with mock_wire_protocol.create(mockwiredata.DATA_FILE) as protocol: - protocol.client.get_host_plugin() # force initialization of the host plugin + # initialization of the host plugin triggers a request for the goal state; do it here so that this request does not + # confuse the mock below. + protocol.client.get_host_plugin() original_http_request = restutil.http_request @@ -597,7 +597,7 @@ def http_request(method, url, *args, **kwargs): http_request.urls.append(url) # fake a stale goal state then succeed once the goal state has been refreshed if not any(url.endswith('/machine/?comp=goalstate') for url in http_request.urls): - raise ResourceGoneError("Exception to fake an error on the host channel") + raise ResourceGoneError("Exception to fake a stale goal") else: return MockResponse(body=b'', status_code=200) elif url.endswith('/machine/?comp=goalstate'): @@ -605,25 +605,27 @@ def http_request(method, url, *args, **kwargs): return original_http_request(method, url, *args, **kwargs) http_request.urls = [] - with patch("azurelinuxagent.common.utils.restutil.http_request", side_effect=http_request): - HostPluginProtocol.set_default_channel(False) + with patch("azurelinuxagent.common.utils.restutil.http_request", side_effect=http_request): + HostPluginProtocol.set_default_channel(False) - success = protocol.download_ext_handler_pkg(extension_url, target_file) + success = protocol.download_ext_handler_pkg(extension_url, target_file) - self.assertEquals(success, True, 'The download should have succeeded') - self.assertEquals(len(http_request.urls), 4, "Unexpected number of HTTP requests: [{0}]".format(http_request.urls)) - self.assertEquals(http_request.urls[0], extension_url, "The first attempt should have been over the direct channel") - self.assertTrue(http_request.urls[1].endswith('/extensionArtifact'), "The second attempt should have been over the host channel") - self.assertTrue(http_request.urls[2].endswith('/machine/?comp=goalstate'), "The host channel should have been refreshed the goal state") - self.assertTrue(http_request.urls[3].endswith('/extensionArtifact'), "The third attempt should have been over the host channel") - self.assertTrue(os.path.exists(target_file), 'The extension package was not downloaded') - self.assertEquals(HostPluginProtocol.is_default_channel(), True, "The host channel should have been set as the default") + self.assertEquals(success, True, 'The download should have succeeded') + self.assertEquals(len(http_request.urls), 4, "Unexpected number of HTTP requests: [{0}]".format(http_request.urls)) + self.assertEquals(http_request.urls[0], extension_url, "The first attempt should have been over the direct channel") + self.assertTrue(http_request.urls[1].endswith('/extensionArtifact'), "The second attempt should have been over the host channel") + self.assertTrue(http_request.urls[2].endswith('/machine/?comp=goalstate'), "The host channel should have been refreshed the goal state") + self.assertTrue(http_request.urls[3].endswith('/extensionArtifact'), "The third attempt should have been over the host channel") + self.assertTrue(os.path.exists(target_file), 'The extension package was not downloaded') + self.assertEquals(HostPluginProtocol.is_default_channel(), True, "The host channel should have been set as the default") def test_download_ext_handler_pkg_should_not_change_default_channel_when_all_channels_fail(self): extension_url = 'https://fake_host/fake_extension.zip' with mock_wire_protocol.create(mockwiredata.DATA_FILE) as protocol: - protocol.client.get_host_plugin() # force initialization of the host plugin + # initialization of the host plugin triggers a request for the goal state; do it here so that this request does not + # confuse the mock below. + protocol.client.get_host_plugin() original_http_request = restutil.http_request @@ -663,18 +665,21 @@ def test_fetch_manifest_should_not_invoke_host_channel_when_direct_channel_succe def http_request(method, url, *args, **kwargs): if method == 'GET': if url == manifest_url: + http_request.urls.append(url) return MockResponse(body=manifest_xml.encode('utf-8'), status_code=200) elif url.endswith('/extensionArtifact'): self.fail('The Host GA Plugin should not have been invoked') return original_http_request(method, url, *args, **kwargs) + http_request.urls = [] with patch("azurelinuxagent.common.utils.restutil.http_request", side_effect=http_request) as mock_request: HostPluginProtocol.set_default_channel(False) manifest = protocol.client.fetch_manifest([VMAgentManifestUri(uri=manifest_url)]) - self.assertEquals(manifest, manifest_xml) - self.assertTrue(any(args[0][0] == 'GET' and args[0][1] == manifest_url for args in mock_request.call_args_list)) # direct channel + self.assertEquals(manifest, manifest_xml, 'The expected manifest was not downloaded') + self.assertEquals(len(http_request.urls), 1, "Unexpected number of HTTP requests: [{0}]".format(http_request.urls)) + self.assertEquals(http_request.urls[0], manifest_url, "The manifest should have been downloaded over the direct channel") self.assertEquals(HostPluginProtocol.is_default_channel(), False, "The default channel should not have changed") def test_fetch_manifest_should_use_host_channel_when_direct_channel_fails_and_set_it_to_default(self): @@ -701,7 +706,7 @@ def http_request(method, url, *args, **kwargs): try: manifest = protocol.client.fetch_manifest([VMAgentManifestUri(uri=manifest_url)]) - self.assertEquals(manifest, manifest_xml) + self.assertEquals(manifest, manifest_xml, 'The expected manifest was not downloaded') self.assertEquals(len(http_request.urls), 2, "Unexpected number of HTTP requests: [{0}]".format(http_request.urls)) self.assertEquals(http_request.urls[0], manifest_url, "The first attempt should have been over the direct channel") self.assertTrue(http_request.urls[1].endswith('/extensionArtifact'), "The retry should have been over the host channel") @@ -714,7 +719,9 @@ def test_fetch_manifest_should_retry_the_host_channel_after_refreshing_the_host_ manifest_xml = '' with mock_wire_protocol.create(mockwiredata.DATA_FILE) as protocol: - protocol.client.get_host_plugin() # force initialization of the host plugin + # initialization of the host plugin triggers a request for the goal state; do it here so that this request does not + # confuse the mock below. + protocol.client.get_host_plugin() original_http_request = restutil.http_request @@ -727,7 +734,7 @@ def http_request(method, url, *args, **kwargs): http_request.urls.append(url) # fake a stale goal state then succeed once the goal state has been refreshed if not any(url.endswith('/machine/?comp=goalstate') for url in http_request.urls): - raise ResourceGoneError("Exception to fake an error on the host channel") + raise ResourceGoneError("Exception to fake a stale goal state") else: return MockResponse(body=manifest_xml.encode('utf-8'), status_code=200) elif url.endswith('/machine/?comp=goalstate'): From fbd87d9fa6cec4c3eba78388ceecfe21e46eb4f7 Mon Sep 17 00:00:00 2001 From: nam Date: Mon, 9 Mar 2020 07:42:42 -0700 Subject: [PATCH 3/5] Created mock_http_request --- tests/common/test_event.py | 5 +- tests/ga/test_monitor.py | 28 +- tests/ga/test_remoteaccess.py | 7 +- .../{mock_wire_protocol.py => mocks.py} | 77 ++++- tests/protocol/test_hostplugin.py | 16 +- tests/protocol/test_wire.py | 303 ++++++++---------- tests/utils/event_logger_tools.py | 5 +- 7 files changed, 224 insertions(+), 217 deletions(-) rename tests/protocol/{mock_wire_protocol.py => mocks.py} (52%) diff --git a/tests/common/test_event.py b/tests/common/test_event.py index d092b31d81..f438d65c2a 100644 --- a/tests/common/test_event.py +++ b/tests/common/test_event.py @@ -28,7 +28,8 @@ WALAEventOperation, parse_xml_event, parse_json_event, AGENT_EVENT_FILE_EXTENSION, EVENTS_DIRECTORY from azurelinuxagent.common.future import ustr from azurelinuxagent.common.protocol.goal_state import GoalState -from tests.protocol import mockwiredata, mock_wire_protocol +from tests.protocol import mockwiredata +from tests.protocol.mocks import mock_wire_protocol from azurelinuxagent.common.version import CURRENT_AGENT, CURRENT_VERSION, AGENT_EXECUTION_MODE from azurelinuxagent.common.osutil import get_osutil from tests.tools import AgentTestCase, data_dir, load_data, Mock, patch, skip_if_predicate_true @@ -98,7 +99,7 @@ def create_event_and_return_container_id(): self.fail("Could not find Contained ID on event") - with mock_wire_protocol.create(mockwiredata.DATA_FILE) as protocol: + with mock_wire_protocol(mockwiredata.DATA_FILE) as protocol: contained_id = create_event_and_return_container_id() # The expect value comes from DATA_FILE self.assertEquals(contained_id, 'c6d5526c-5ac2-4200-b6e2-56f2b70c5ab2', "Incorrect container ID") diff --git a/tests/ga/test_monitor.py b/tests/ga/test_monitor.py index 53eae306f6..51b8c19421 100644 --- a/tests/ga/test_monitor.py +++ b/tests/ga/test_monitor.py @@ -29,28 +29,20 @@ from azurelinuxagent.common import event, logger from azurelinuxagent.common.cgroup import CGroup, CpuCgroup, MemoryCgroup -from azurelinuxagent.common.cgroupconfigurator import CGroupConfigurator from azurelinuxagent.common.cgroupstelemetry import CGroupsTelemetry, MetricValue from azurelinuxagent.common.datacontract import get_properties -from azurelinuxagent.common.event import EventLogger, WALAEventOperation, EVENTS_DIRECTORY +from azurelinuxagent.common.event import WALAEventOperation, EVENTS_DIRECTORY from azurelinuxagent.common.exception import HttpError -from azurelinuxagent.common.future import ustr from azurelinuxagent.common.logger import Logger from azurelinuxagent.common.osutil import get_osutil -from azurelinuxagent.common.osutil.default import BASE_CGROUPS, DefaultOSUtil -from azurelinuxagent.common.protocol.wire import ExtHandler, ExtHandlerProperties from azurelinuxagent.common.protocol.wire import WireProtocol from azurelinuxagent.common.telemetryevent import TelemetryEvent, TelemetryEventParam from azurelinuxagent.common.utils import fileutil, restutil from azurelinuxagent.common.version import AGENT_VERSION, CURRENT_VERSION, CURRENT_AGENT, DISTRO_NAME, DISTRO_VERSION, DISTRO_CODE_NAME -from azurelinuxagent.ga.exthandlers import ExtHandlerInstance -from azurelinuxagent.ga.monitor import generate_extension_metrics_telemetry_dictionary, get_monitor_handler, \ - MonitorHandler -from tests.common.test_cgroupstelemetry import make_new_cgroup +from azurelinuxagent.ga.monitor import generate_extension_metrics_telemetry_dictionary, get_monitor_handler, MonitorHandler from tests.protocol.mockwiredata import DATA_FILE -from tests.protocol import mock_wire_protocol -from tests.tools import Mock, MagicMock, patch, AgentTestCase, data_dir, are_cgroups_enabled, i_am_root, \ - skip_if_predicate_false, is_trusty_in_travis, skip_if_predicate_true, clear_singleton_instances, PropertyMock +from tests.protocol.mocks import mock_wire_protocol +from tests.tools import Mock, MagicMock, patch, AgentTestCase, clear_singleton_instances, PropertyMock from tests.utils.event_logger_tools import EventLoggerTools @@ -339,7 +331,7 @@ def _get_event_data(duration, is_success, message, name, op, version, eventId=1) def test_collect_and_send_events(self, mock_lib_dir, patch_send_event, *_): mock_lib_dir.return_value = self.lib_dir - with mock_wire_protocol.create(DATA_FILE) as protocol: + with mock_wire_protocol(DATA_FILE) as protocol: monitor_handler = TestEventMonitoring._create_monitor_handler(protocol) self._create_extension_event(message="Message-Test") @@ -406,7 +398,7 @@ def test_collect_and_send_events(self, mock_lib_dir, patch_send_event, *_): def test_collect_and_send_events_with_small_events(self, mock_lib_dir, patch_send_event, *_): mock_lib_dir.return_value = self.lib_dir - with mock_wire_protocol.create(DATA_FILE) as protocol: + with mock_wire_protocol(DATA_FILE) as protocol: monitor_handler = TestEventMonitoring._create_monitor_handler(protocol) sizes = [15, 15, 15, 15] # get the powers of 2 - 2**16 is the limit @@ -425,7 +417,7 @@ def test_collect_and_send_events_with_small_events(self, mock_lib_dir, patch_sen def test_collect_and_send_events_with_large_events(self, mock_lib_dir, patch_send_event, *_): mock_lib_dir.return_value = self.lib_dir - with mock_wire_protocol.create(DATA_FILE) as protocol: + with mock_wire_protocol(DATA_FILE) as protocol: monitor_handler = TestEventMonitoring._create_monitor_handler(protocol) sizes = [17, 17, 17] # get the powers of 2 @@ -446,7 +438,7 @@ def test_collect_and_send_with_http_post_returning_503(self, mock_lib_dir, *_): mock_lib_dir.return_value = self.lib_dir fileutil.mkdir(self.event_dir) - with mock_wire_protocol.create(DATA_FILE) as protocol: + with mock_wire_protocol(DATA_FILE) as protocol: monitor_handler = TestEventMonitoring._create_monitor_handler(protocol) sizes = [1, 2, 3] # get the powers of 2, and multiple by 1024. @@ -472,7 +464,7 @@ def test_collect_and_send_with_send_event_generating_exception(self, mock_lib_di mock_lib_dir.return_value = self.lib_dir fileutil.mkdir(self.event_dir) - with mock_wire_protocol.create(DATA_FILE) as protocol: + with mock_wire_protocol(DATA_FILE) as protocol: monitor_handler = TestEventMonitoring._create_monitor_handler(protocol) sizes = [1, 2, 3] # get the powers of 2, and multiple by 1024. @@ -496,7 +488,7 @@ def test_collect_and_send_with_call_wireserver_returns_http_error(self, mock_lib mock_lib_dir.return_value = self.lib_dir fileutil.mkdir(self.event_dir) - with mock_wire_protocol.create(DATA_FILE) as protocol: + with mock_wire_protocol(DATA_FILE) as protocol: monitor_handler = TestEventMonitoring._create_monitor_handler(protocol) sizes = [1, 2, 3] # get the powers of 2, and multiple by 1024. diff --git a/tests/ga/test_remoteaccess.py b/tests/ga/test_remoteaccess.py index 7d8544efa5..0cc394fa5e 100644 --- a/tests/ga/test_remoteaccess.py +++ b/tests/ga/test_remoteaccess.py @@ -18,7 +18,8 @@ from azurelinuxagent.common.protocol.goal_state import GoalState, RemoteAccess from tests.tools import AgentTestCase, load_data, patch, Mock -from tests.protocol import mockwiredata, mock_wire_protocol +from tests.protocol import mockwiredata +from tests.protocol.mocks import mock_wire_protocol class TestRemoteAccess(AgentTestCase): @@ -33,7 +34,7 @@ def test_parse_remote_access(self): self.assertEquals("2019-01-01", remote_access.user_list.users[0].expiration, "Expiration does not match.") def test_goal_state_with_no_remote_access(self): - with mock_wire_protocol.create(mockwiredata.DATA_FILE) as protocol: + with mock_wire_protocol(mockwiredata.DATA_FILE) as protocol: self.assertIsNone(protocol.client.get_remote_access()) def test_parse_two_remote_access_accounts(self): @@ -74,7 +75,7 @@ def test_parse_zero_remote_access_accounts(self): self.assertEquals(0, len(remote_access.user_list.users), "User count does not match.") def test_update_remote_access_conf_remote_access(self): - with mock_wire_protocol.create(mockwiredata.DATA_FILE_REMOTE_ACCESS) as protocol: + with mock_wire_protocol(mockwiredata.DATA_FILE_REMOTE_ACCESS) as protocol: self.assertIsNotNone(protocol.client.get_remote_access()) self.assertEquals(1, len(protocol.client.get_remote_access().user_list.users)) self.assertEquals('testAccount', protocol.client.get_remote_access().user_list.users[0].name) diff --git a/tests/protocol/mock_wire_protocol.py b/tests/protocol/mocks.py similarity index 52% rename from tests/protocol/mock_wire_protocol.py rename to tests/protocol/mocks.py index d2657493af..51e8feb701 100644 --- a/tests/protocol/mock_wire_protocol.py +++ b/tests/protocol/mocks.py @@ -14,16 +14,16 @@ # # Requires Python 2.6+ and Openssl 1.0+ # - import contextlib import re from azurelinuxagent.common.protocol.wire import WireProtocol -from azurelinuxagent.common.utils.restutil import KNOWN_WIRESERVER_IP, http_request, DEFAULT_RETRIES, RETRY_CODES, DELAY_IN_SECONDS +from azurelinuxagent.common.utils import restutil from tests.tools import patch from tests.protocol.mockwiredata import WireProtocolData + @contextlib.contextmanager -def create(mock_wire_data_file): +def mock_wire_protocol(mock_wire_data_file): """ Creates a mock WireProtocol object that will return the data specified by 'mock_wire_data_file' (which must follow the structure of the data files defined in tests/protocol/mockwiredata.py). @@ -55,7 +55,7 @@ def stop(): stop_mock_crypt_util() stop_mock_http_request() - protocol = WireProtocol(KNOWN_WIRESERVER_IP) + protocol = WireProtocol(restutil.KNOWN_WIRESERVER_IP) protocol.mock_wire_data = WireProtocolData(mock_wire_data_file) protocol.stop_mock_http_request = stop_mock_http_request protocol.stop_mock_crypt_util = stop_mock_crypt_util @@ -64,24 +64,24 @@ def stop(): try: # To minimize the impact of mocking restutil.http_request we only use the mock data for requests # to the wireserver or requests starting with "mock-goal-state" - mock_data_re = re.compile(r'https?://(mock-goal-state|{0}).*'.format(KNOWN_WIRESERVER_IP.replace(r'.', r'\.')), re.IGNORECASE) + mock_data_re = re.compile(r'https?://(mock-goal-state|{0}).*'.format(restutil.KNOWN_WIRESERVER_IP.replace(r'.', r'\.')), re.IGNORECASE) - original_http_request = http_request + original_http_request = restutil.http_request - def mock_http_request(method, url, data, headers=None, use_proxy=False, max_retry=DEFAULT_RETRIES, retry_codes=RETRY_CODES, retry_delay=DELAY_IN_SECONDS): + def http_request(method, url, data, headers=None, use_proxy=False, max_retry=restutil.DEFAULT_RETRIES, retry_codes=restutil.RETRY_CODES, retry_delay=restutil.DELAY_IN_SECONDS): if method == 'GET' and mock_data_re.match(url) is not None: return protocol.mock_wire_data.mock_http_get(url, headers, use_proxy, max_retry, retry_codes, retry_delay) elif method == 'POST': return protocol.mock_wire_data.mock_http_post(url, data, headers, use_proxy, max_retry, retry_codes, retry_delay) return original_http_request(method, url, data, headers, use_proxy, max_retry, retry_codes, retry_delay) - p = patch("azurelinuxagent.common.utils.restutil.http_request", side_effect=mock_http_request) - p.start() - stop_mock_http_request.mock = p + patched = patch("azurelinuxagent.common.utils.restutil.http_request", side_effect=http_request) + patched.start() + stop_mock_http_request.mock = patched - p = patch("azurelinuxagent.common.protocol.wire.CryptUtil", side_effect=protocol.mock_wire_data.mock_crypt_util) - p.start() - stop_mock_crypt_util.mock = p + patched = patch("azurelinuxagent.common.protocol.wire.CryptUtil", side_effect=protocol.mock_wire_data.mock_crypt_util) + patched.start() + stop_mock_crypt_util.mock = patched protocol.detect() @@ -89,3 +89,54 @@ def mock_http_request(method, url, data, headers=None, use_proxy=False, max_retr finally: protocol.stop() + + +@contextlib.contextmanager +def mock_http_request(http_get_handler=None, http_post_handler=None, http_put_handler=None): + """ + Creates a Mock of restutil.http_request that executes the handler given for the corresponding HTTP method. + + The return value of the handler function is interpreted similarly to the "return_value" argument of patch(): if it + is an exception the exception is raised or, if it is any object other than None, the value is returned by the mock. + + If the handler function returns None the call is passed to the original restutil.http_request. + + The patch maintains a list of "tracked" urls. When the handler function returns a value than is not None the url + for the request is automatically added to the tracked list. The handler function can add other items to this list + using the track_url() method on the mock. + + The returned Mock is augmented with these 2 methods: + + * track_url(url) - adds the given item to the list of tracked urls. + * get_tracked_urls() - returns the list of tracked urls. + """ + tracked_urls = [] + original_http_request = restutil.http_request + + def http_request(method, url, *args, **kwargs): + handler = None + if method == 'GET': + handler = http_get_handler + elif method == 'POST': + handler = http_post_handler + elif method == 'PUT': + handler = http_put_handler + + if handler is not None: + return_value = handler(url, *args, **kwargs) + if return_value is not None: + tracked_urls.append(url) + if isinstance(return_value, Exception): + raise return_value + return return_value + + return original_http_request(method, url, *args, **kwargs) + + patched = patch("azurelinuxagent.common.utils.restutil.http_request", side_effect=http_request) + patched.track_url = lambda url: tracked_urls.append(url) + patched.get_tracked_urls = lambda: tracked_urls + patched.start() + try: + yield patched + finally: + patched.stop() diff --git a/tests/protocol/test_hostplugin.py b/tests/protocol/test_hostplugin.py index 496ea91aa1..5dffee01a9 100644 --- a/tests/protocol/test_hostplugin.py +++ b/tests/protocol/test_hostplugin.py @@ -31,8 +31,8 @@ from azurelinuxagent.common.future import ustr from azurelinuxagent.common.protocol.hostplugin import API_VERSION from azurelinuxagent.common.utils import restutil -from tests.protocol import mock_wire_protocol -from tests.protocol.mockwiredata import WireProtocolData, DATA_FILE, DATA_FILE_NO_EXT +from tests.protocol.mocks import mock_wire_protocol +from tests.protocol.mockwiredata import DATA_FILE, DATA_FILE_NO_EXT from tests.protocol.test_wire import MockResponse as TestWireMockResponse from tests.tools import AgentTestCase, PY_VERSION_MAJOR, Mock, patch @@ -65,7 +65,7 @@ class TestHostPlugin(AgentTestCase): def _init_host(self): - with mock_wire_protocol.create(DATA_FILE) as protocol: + with mock_wire_protocol(DATA_FILE) as protocol: test_goal_state = protocol.client.get_goal_state() host_plugin = wire.HostPluginProtocol(wireserver_url, test_goal_state.container_id, @@ -155,7 +155,7 @@ def _validate_hostplugin_args(self, args, goal_state, exp_method, exp_url, exp_d @staticmethod @contextlib.contextmanager def create_mock_protocol(): - with mock_wire_protocol.create(DATA_FILE_NO_EXT) as protocol: + with mock_wire_protocol(DATA_FILE_NO_EXT) as protocol: # These tests use mock wire data that dont have any extensions (extension config will be empty). # Populate the upload blob and set an initial empty status before returning the protocol. ext_conf = protocol.client._goal_state.ext_conf @@ -401,7 +401,7 @@ def test_put_status_error_reporting(self, patch_add_event): def test_validate_http_request(self): """Validate correct set of data is sent to HostGAPlugin when reporting VM status""" - with mock_wire_protocol.create(DATA_FILE) as protocol: + with mock_wire_protocol(DATA_FILE) as protocol: test_goal_state = protocol.client._goal_state plugin = protocol.client.get_host_plugin() @@ -435,7 +435,7 @@ def test_validate_http_request(self): self.assertEqual(health_service_url, patch_http.call_args_list[1][0][1]) def test_validate_block_blob(self): - with mock_wire_protocol.create(DATA_FILE) as protocol: + with mock_wire_protocol(DATA_FILE) as protocol: test_goal_state = protocol.client._goal_state host_client = wire.HostPluginProtocol(wireserver_url, @@ -478,7 +478,7 @@ def test_validate_block_blob(self): def test_validate_page_blobs(self): """Validate correct set of data is sent for page blobs""" - with mock_wire_protocol.create(DATA_FILE) as protocol: + with mock_wire_protocol(DATA_FILE) as protocol: test_goal_state = protocol.client._goal_state host_client = wire.HostPluginProtocol(wireserver_url, @@ -537,7 +537,7 @@ def test_validate_page_blobs(self): exp_method, exp_url, exp_data) def test_validate_get_extension_artifacts(self): - with mock_wire_protocol.create(DATA_FILE) as protocol: + with mock_wire_protocol(DATA_FILE) as protocol: test_goal_state = protocol.client._goal_state expected_url = hostplugin.URI_FORMAT_GET_EXTENSION_ARTIFACT.format(wireserver_url, hostplugin.HOST_PLUGIN_PORT) diff --git a/tests/protocol/test_wire.py b/tests/protocol/test_wire.py index 5da5084c8f..8fbad1a970 100644 --- a/tests/protocol/test_wire.py +++ b/tests/protocol/test_wire.py @@ -34,7 +34,8 @@ from azurelinuxagent.common.utils import restutil from azurelinuxagent.common.version import CURRENT_VERSION, DISTRO_NAME, DISTRO_VERSION from tests.ga.test_monitor import random_generator -from tests.protocol import mockwiredata, mock_wire_protocol +from tests.protocol import mockwiredata +from tests.protocol.mocks import mock_wire_protocol, mock_http_request from tests.protocol.mockwiredata import DATA_FILE_NO_EXT from tests.protocol.mockwiredata import WireProtocolData from tests.tools import ANY, MagicMock, Mock, patch, AgentTestCase, skip_if_predicate_true @@ -61,7 +62,7 @@ def get_event(message, duration=30000, evt_type="", is_internal=False, is_succes @contextlib.contextmanager def create_mock_protocol(artifacts_profile_blob=None, status_upload_blob=None, status_upload_blob_type=None): - with mock_wire_protocol.create(DATA_FILE_NO_EXT) as protocol: + with mock_wire_protocol(DATA_FILE_NO_EXT) as protocol: # These tests use mock wire data that dont have any extensions (extension config will be empty). # Populate the upload blob and artifacts profile blob. ext_conf = ExtensionsConfig(None) @@ -195,7 +196,7 @@ def test_call_storage_kwargs(self, *args): self.assertTrue(c == (True if i != 3 else False)) def test_status_blob_parsing(self, *args): - with mock_wire_protocol.create(mockwiredata.DATA_FILE) as protocol: + with mock_wire_protocol(mockwiredata.DATA_FILE) as protocol: self.assertEqual(protocol.client.get_ext_conf().status_upload_blob, 'https://test.blob.core.windows.net/vhds/test-cs12.test-cs12.test-cs12.status?' 'sr=b&sp=rw&se=9999-01-01&sk=key1&sv=2014-02-14&' @@ -203,32 +204,27 @@ def test_status_blob_parsing(self, *args): self.assertEqual(protocol.client.get_ext_conf().status_upload_blob_type, u'BlockBlob') def test_get_host_ga_plugin(self, *args): - with mock_wire_protocol.create(mockwiredata.DATA_FILE) as protocol: + with mock_wire_protocol(mockwiredata.DATA_FILE) as protocol: host_plugin = protocol.client.get_host_plugin() goal_state = protocol.client.get_goal_state() self.assertEqual(goal_state.container_id, host_plugin.container_id) self.assertEqual(goal_state.role_config_name, host_plugin.role_config_name) def test_upload_status_blob_should_use_the_host_channel_by_default(self, *_): - with mock_wire_protocol.create(mockwiredata.DATA_FILE) as protocol: - original_http_request = restutil.http_request - - def http_request(method, url, *args, **kwargs): - if method == 'PUT': - if protocol.get_endpoint() in url and url.endswith('/status'): - http_request.urls.append(url) - return MockResponse(body=b'', status_code=200) - self.fail('The upload status request was sent to the wrong uri: {0}'.format(uri)) - return original_http_request(method, url, *args, **kwargs) - http_request.urls = [] - - with patch("azurelinuxagent.common.utils.restutil.http_request", side_effect=http_request) as mock_request: + with mock_wire_protocol(mockwiredata.DATA_FILE) as protocol: + def handler(url, *_, **__): + if protocol.get_endpoint() in url and url.endswith('/status'): + return MockResponse(body=b'', status_code=200) + self.fail('The upload status request was sent to the wrong url: {0}'.format(url)) + + with mock_http_request(http_put_handler=handler) as http_request: HostPluginProtocol.set_default_channel(False) protocol.client.status_blob.vm_status = VMStatus(message="Ready", status="Ready") protocol.client.upload_status_blob() - self.assertEqual(len(http_request.urls), 1, 'Expected one upload request to the host: [{0}]'.format(http_request.urls)) + urls = http_request.get_tracked_urls() + self.assertEqual(len(urls), 1, 'Expected one post request to the host: [{0}]'.format(urls)) def test_upload_status_blob_host_ga_plugin(self, *_): with create_mock_protocol(status_upload_blob=testurl, status_upload_blob_type=testtype) as protocol: @@ -266,7 +262,7 @@ def test_upload_status_blob_reports_prepare_error(self, *_): def test_get_in_vm_artifacts_profile_blob_not_available(self, *_): # Test when artifacts_profile_blob is null/None - with mock_wire_protocol.create(DATA_FILE_NO_EXT) as protocol: + with mock_wire_protocol(DATA_FILE_NO_EXT) as protocol: protocol.client._goal_state.ext_conf = ExtensionsConfig(None) self.assertEqual(None, protocol.client.get_artifacts_profile()) @@ -491,7 +487,7 @@ def _is_extension_artifact_host_request(expected_url, actual_url, **kwargs): return actual_url.endswith('/extensionArtifact') and kwargs['headers']['x-ms-artifact-location'] == expected_url def test_get_ext_conf_without_uri(self, *args): - with mock_wire_protocol.create(mockwiredata.DATA_FILE_NO_EXT) as protocol: + with mock_wire_protocol(mockwiredata.DATA_FILE_NO_EXT) as protocol: ext_conf = protocol.client.get_ext_conf() self.assertEqual(0, len(ext_conf.ext_handlers.extHandlers)) @@ -523,27 +519,23 @@ def test_download_ext_handler_pkg_should_not_invoke_host_channel_when_direct_cha extension_url = 'https://fake_host/fake_extension.zip' target_file = os.path.join(self.tmp_dir, 'fake_extension.zip') - with mock_wire_protocol.create(mockwiredata.DATA_FILE) as protocol: - original_http_request = restutil.http_request - - def http_request(method, url, *args, **kwargs): - if method == 'GET': - if url == extension_url: - http_request.urls.append(url) - return MockResponse(body=b'', status_code=200) - elif TestWireClient._is_extension_artifact_host_request(extension_url, url, **kwargs): - self.fail('The host channel should not have been used') - return original_http_request(method, url, *args, **kwargs) - http_request.urls = [] + with mock_wire_protocol(mockwiredata.DATA_FILE) as protocol: + def handler(url, *_, **kwargs): + if url == extension_url: + return MockResponse(body=b'', status_code=200) + if TestWireClient._is_extension_artifact_host_request(extension_url, url, **kwargs): + self.fail('The host channel should not have been used') + return None - with patch("azurelinuxagent.common.utils.restutil.http_request", side_effect=http_request): + with mock_http_request(http_get_handler=handler) as http_request: HostPluginProtocol.set_default_channel(False) success = protocol.download_ext_handler_pkg(extension_url, target_file) + urls = http_request.get_tracked_urls() self.assertEquals(success, True, 'The download should have succeeded') - self.assertEquals(len(http_request.urls), 1, "Unexpected number of HTTP requests: [{0}]".format(http_request.urls)) - self.assertEquals(http_request.urls[0], extension_url, "The extension should have been downloaded over the direct channel") + self.assertEquals(len(urls), 1, "Unexpected number of HTTP requests: [{0}]".format(urls)) + self.assertEquals(urls[0], extension_url, "The extension should have been downloaded over the direct channel") self.assertTrue(os.path.exists(target_file), 'The extension package was not downloaded') self.assertEquals(HostPluginProtocol.is_default_channel(), False, "The host channel should not have been set as the default") @@ -551,29 +543,24 @@ def test_download_ext_handler_pkg_should_use_host_channel_when_direct_channel_fa extension_url = 'https://fake_host/fake_extension.zip' target_file = os.path.join(self.tmp_dir, 'fake_extension.zip') - with mock_wire_protocol.create(mockwiredata.DATA_FILE) as protocol: - original_http_request = restutil.http_request - - def http_request(method, url, *args, **kwargs): - if method == 'GET': - if url == extension_url: - http_request.urls.append(url) - raise HttpError("Exception to fake an error on the direct channel") - elif TestWireClient._is_extension_artifact_host_request(extension_url, url, **kwargs): - http_request.urls.append(url) - return MockResponse(body=b'', status_code=200) - return original_http_request(method, url, *args, **kwargs) - http_request.urls = [] - - with patch("azurelinuxagent.common.utils.restutil.http_request", side_effect=http_request): + with mock_wire_protocol(mockwiredata.DATA_FILE) as protocol: + def handler(url, *_, **kwargs): + if url == extension_url: + return HttpError("Exception to fake an error on the direct channel") + if TestWireClient._is_extension_artifact_host_request(extension_url, url, **kwargs): + return MockResponse(body=b'', status_code=200) + return None + + with mock_http_request(http_get_handler=handler) as http_request: HostPluginProtocol.set_default_channel(False) success = protocol.download_ext_handler_pkg(extension_url, target_file) + urls = http_request.get_tracked_urls() self.assertEquals(success, True, 'The download should have succeeded') - self.assertEquals(len(http_request.urls), 2, "Unexpected number of HTTP requests: [{0}]".format(http_request.urls)) - self.assertEquals(http_request.urls[0], extension_url, "The first attempt should have been over the direct channel") - self.assertTrue(http_request.urls[1].endswith('/extensionArtifact'), "The retry attempt should have been over the host channel") + self.assertEquals(len(urls), 2, "Unexpected number of HTTP requests: [{0}]".format(urls)) + self.assertEquals(urls[0], extension_url, "The first attempt should have been over the direct channel") + self.assertTrue(urls[1].endswith('/extensionArtifact'), "The retry attempt should have been over the host channel") self.assertTrue(os.path.exists(target_file), 'The extension package was not downloaded') self.assertEquals(HostPluginProtocol.is_default_channel(), True, "The host channel should have been set as the default") @@ -581,135 +568,115 @@ def test_download_ext_handler_pkg_should_retry_the_host_channel_after_refreshing extension_url = 'https://fake_host/fake_extension.zip' target_file = os.path.join(self.tmp_dir, 'fake_extension.zip') - with mock_wire_protocol.create(mockwiredata.DATA_FILE) as protocol: + with mock_wire_protocol(mockwiredata.DATA_FILE) as protocol: # initialization of the host plugin triggers a request for the goal state; do it here so that this request does not - # confuse the mock below. + # confuse the mock below, which needs to track those requests. protocol.client.get_host_plugin() - original_http_request = restutil.http_request - - def http_request(method, url, *args, **kwargs): - if method == 'GET': - if url == extension_url: - http_request.urls.append(url) - raise HttpError("Exception to fake an error on the direct channel") - elif TestWireClient._is_extension_artifact_host_request(extension_url, url, **kwargs): - http_request.urls.append(url) - # fake a stale goal state then succeed once the goal state has been refreshed - if not any(url.endswith('/machine/?comp=goalstate') for url in http_request.urls): - raise ResourceGoneError("Exception to fake a stale goal") - else: - return MockResponse(body=b'', status_code=200) - elif url.endswith('/machine/?comp=goalstate'): - http_request.urls.append(url) - return original_http_request(method, url, *args, **kwargs) - http_request.urls = [] - - with patch("azurelinuxagent.common.utils.restutil.http_request", side_effect=http_request): + def handler(url, *args, **kwargs): + if url == extension_url: + return HttpError("Exception to fake an error on the direct channel") + if TestWireClient._is_extension_artifact_host_request(extension_url, url, **kwargs): + # fake a stale goal state then succeed once the goal state has been refreshed + if not any(url.endswith('/machine/?comp=goalstate') for url in http_request.get_tracked_urls()): + return ResourceGoneError("Exception to fake a stale goal") + return MockResponse(body=b'', status_code=200) + if url.endswith('/machine/?comp=goalstate'): + http_request.track_url(url) + return None + + with mock_http_request(http_get_handler=handler) as http_request: HostPluginProtocol.set_default_channel(False) success = protocol.download_ext_handler_pkg(extension_url, target_file) + urls = http_request.get_tracked_urls() self.assertEquals(success, True, 'The download should have succeeded') - self.assertEquals(len(http_request.urls), 4, "Unexpected number of HTTP requests: [{0}]".format(http_request.urls)) - self.assertEquals(http_request.urls[0], extension_url, "The first attempt should have been over the direct channel") - self.assertTrue(http_request.urls[1].endswith('/extensionArtifact'), "The second attempt should have been over the host channel") - self.assertTrue(http_request.urls[2].endswith('/machine/?comp=goalstate'), "The host channel should have been refreshed the goal state") - self.assertTrue(http_request.urls[3].endswith('/extensionArtifact'), "The third attempt should have been over the host channel") + self.assertEquals(len(urls), 4, "Unexpected number of HTTP requests: [{0}]".format(urls)) + self.assertEquals(urls[0], extension_url, "The first attempt should have been over the direct channel") + self.assertTrue(urls[1].endswith('/extensionArtifact'), "The second attempt should have been over the host channel") + self.assertTrue(urls[2].endswith('/machine/?comp=goalstate'), "The host channel should have been refreshed the goal state") + self.assertTrue(urls[3].endswith('/extensionArtifact'), "The third attempt should have been over the host channel") self.assertTrue(os.path.exists(target_file), 'The extension package was not downloaded') self.assertEquals(HostPluginProtocol.is_default_channel(), True, "The host channel should have been set as the default") def test_download_ext_handler_pkg_should_not_change_default_channel_when_all_channels_fail(self): extension_url = 'https://fake_host/fake_extension.zip' - with mock_wire_protocol.create(mockwiredata.DATA_FILE) as protocol: + with mock_wire_protocol(mockwiredata.DATA_FILE) as protocol: # initialization of the host plugin triggers a request for the goal state; do it here so that this request does not - # confuse the mock below. + # confuse the mock below, which needs to track those requests. protocol.client.get_host_plugin() - original_http_request = restutil.http_request - - def http_request(method, url, *args, **kwargs): - if method == 'GET': - if url == extension_url: - http_request.urls.append(url) - raise HttpError("Exception to fake error on direct channel") - if TestWireClient._is_extension_artifact_host_request(extension_url, url, **kwargs): - http_request.urls.append(url) - raise ResourceGoneError("Exception to fake error on host channel") - elif url.endswith('/machine/?comp=goalstate'): - http_request.urls.append(url) - return original_http_request(method, url, *args, **kwargs) - http_request.urls = [] - - with patch("azurelinuxagent.common.utils.restutil.http_request", side_effect=http_request): + def handler(url, *args, **kwargs): + if url == extension_url: + return HttpError("Exception to fake error on direct channel") + if TestWireClient._is_extension_artifact_host_request(extension_url, url, **kwargs): + return ResourceGoneError("Exception to fake error on host channel") + if url.endswith('/machine/?comp=goalstate'): + http_request.track_url(url) # keep track of goal state requests + return None + + with mock_http_request(http_get_handler=handler) as http_request: HostPluginProtocol.set_default_channel(False) success = protocol.download_ext_handler_pkg(extension_url, "/an-invalid-directory/an-invalid-file.zip") + urls = http_request.get_tracked_urls() self.assertEquals(success, False, "The download should have failed") - self.assertEquals(len(http_request.urls), 4, "Unexpected number of HTTP requests: [{0}]".format(http_request.urls)) - self.assertEquals(http_request.urls[0], extension_url, "The first attempt should have been over the direct channel") - self.assertTrue(http_request.urls[1].endswith('/extensionArtifact'), "The second attempt should have been over the host channel") - self.assertTrue(http_request.urls[2].endswith('/machine/?comp=goalstate'), "The host channel should have been refreshed the goal state") - self.assertTrue(http_request.urls[3].endswith('/extensionArtifact'), "The third attempt should have been over the host channel") + self.assertEquals(len(urls), 4, "Unexpected number of HTTP requests: [{0}]".format(urls)) + self.assertEquals(urls[0], extension_url, "The first attempt should have been over the direct channel") + self.assertTrue(urls[1].endswith('/extensionArtifact'), "The second attempt should have been over the host channel") + self.assertTrue(urls[2].endswith('/machine/?comp=goalstate'), "The host channel should have been refreshed the goal state") + self.assertTrue(urls[3].endswith('/extensionArtifact'), "The third attempt should have been over the host channel") self.assertEquals(HostPluginProtocol.is_default_channel(), False, "The host channel should not have been set as the default") def test_fetch_manifest_should_not_invoke_host_channel_when_direct_channel_succeeds(self): manifest_url = 'https://fake_host/fake_manifest.xml' manifest_xml = '' - with mock_wire_protocol.create(mockwiredata.DATA_FILE) as protocol: - original_http_request = restutil.http_request - - def http_request(method, url, *args, **kwargs): - if method == 'GET': - if url == manifest_url: - http_request.urls.append(url) - return MockResponse(body=manifest_xml.encode('utf-8'), status_code=200) - elif url.endswith('/extensionArtifact'): - self.fail('The Host GA Plugin should not have been invoked') - return original_http_request(method, url, *args, **kwargs) - http_request.urls = [] + with mock_wire_protocol(mockwiredata.DATA_FILE) as protocol: + def handler(url, *_, **__): + if url == manifest_url: + return MockResponse(body=manifest_xml.encode('utf-8'), status_code=200) + if url.endswith('/extensionArtifact'): + self.fail('The Host GA Plugin should not have been invoked') + return None - with patch("azurelinuxagent.common.utils.restutil.http_request", side_effect=http_request) as mock_request: + with mock_http_request(http_get_handler=handler) as http_request: HostPluginProtocol.set_default_channel(False) manifest = protocol.client.fetch_manifest([VMAgentManifestUri(uri=manifest_url)]) + urls = http_request.get_tracked_urls() self.assertEquals(manifest, manifest_xml, 'The expected manifest was not downloaded') - self.assertEquals(len(http_request.urls), 1, "Unexpected number of HTTP requests: [{0}]".format(http_request.urls)) - self.assertEquals(http_request.urls[0], manifest_url, "The manifest should have been downloaded over the direct channel") + self.assertEquals(len(urls), 1, "Unexpected number of HTTP requests: [{0}]".format(urls)) + self.assertEquals(urls[0], manifest_url, "The manifest should have been downloaded over the direct channel") self.assertEquals(HostPluginProtocol.is_default_channel(), False, "The default channel should not have changed") def test_fetch_manifest_should_use_host_channel_when_direct_channel_fails_and_set_it_to_default(self): manifest_url = 'https://fake_host/fake_manifest.xml' manifest_xml = '' - with mock_wire_protocol.create(mockwiredata.DATA_FILE) as protocol: - original_http_request = restutil.http_request - - def http_request(method, url, *args, **kwargs): - if method == 'GET': - if url == manifest_url: - http_request.urls.append(url) - raise ResourceGoneError("Exception to fake an error on the direct channel") - elif TestWireClient._is_extension_artifact_host_request(manifest_url, url, **kwargs): - http_request.urls.append(url) - return MockResponse(body=manifest_xml.encode('utf-8'), status_code=200) - return original_http_request(method, url, *args, **kwargs) - http_request.urls = [] - - with patch("azurelinuxagent.common.utils.restutil.http_request", side_effect=http_request): + with mock_wire_protocol(mockwiredata.DATA_FILE) as protocol: + def handler(url, *args, **kwargs): + if url == manifest_url: + return ResourceGoneError("Exception to fake an error on the direct channel") + if TestWireClient._is_extension_artifact_host_request(manifest_url, url, **kwargs): + return MockResponse(body=manifest_xml.encode('utf-8'), status_code=200) + return None + + with mock_http_request(http_get_handler=handler) as http_request: HostPluginProtocol.set_default_channel(False) try: manifest = protocol.client.fetch_manifest([VMAgentManifestUri(uri=manifest_url)]) + urls = http_request.get_tracked_urls() self.assertEquals(manifest, manifest_xml, 'The expected manifest was not downloaded') - self.assertEquals(len(http_request.urls), 2, "Unexpected number of HTTP requests: [{0}]".format(http_request.urls)) - self.assertEquals(http_request.urls[0], manifest_url, "The first attempt should have been over the direct channel") - self.assertTrue(http_request.urls[1].endswith('/extensionArtifact'), "The retry should have been over the host channel") + self.assertEquals(len(urls), 2, "Unexpected number of HTTP requests: [{0}]".format(urls)) + self.assertEquals(urls[0], manifest_url, "The first attempt should have been over the direct channel") + self.assertTrue(urls[1].endswith('/extensionArtifact'), "The retry should have been over the host channel") self.assertEquals(HostPluginProtocol.is_default_channel(), True, "The host should have been set as the default channel") finally: HostPluginProtocol.set_default_channel(False) # Reset default channel @@ -718,42 +685,36 @@ def test_fetch_manifest_should_retry_the_host_channel_after_refreshing_the_host_ manifest_url = 'https://fake_host/fake_manifest.xml' manifest_xml = '' - with mock_wire_protocol.create(mockwiredata.DATA_FILE) as protocol: + with mock_wire_protocol(mockwiredata.DATA_FILE) as protocol: # initialization of the host plugin triggers a request for the goal state; do it here so that this request does not - # confuse the mock below. + # confuse the mock below, which needs to track those requests. protocol.client.get_host_plugin() - original_http_request = restutil.http_request - - def http_request(method, url, *args, **kwargs): - if method == 'GET': - if url == manifest_url: - http_request.urls.append(url) - raise HttpError("Exception to fake an error on the direct channel") - elif TestWireClient._is_extension_artifact_host_request(manifest_url, url, **kwargs): - http_request.urls.append(url) - # fake a stale goal state then succeed once the goal state has been refreshed - if not any(url.endswith('/machine/?comp=goalstate') for url in http_request.urls): - raise ResourceGoneError("Exception to fake a stale goal state") - else: - return MockResponse(body=manifest_xml.encode('utf-8'), status_code=200) - elif url.endswith('/machine/?comp=goalstate'): - http_request.urls.append(url) - return original_http_request(method, url, *args, **kwargs) - http_request.urls = [] - - with patch("azurelinuxagent.common.utils.restutil.http_request", side_effect=http_request): + def handler(url, *_, **kwargs): + if url == manifest_url: + return HttpError("Exception to fake an error on the direct channel") + if TestWireClient._is_extension_artifact_host_request(manifest_url, url, **kwargs): + # fake a stale goal state then succeed once the goal state has been refreshed + if not any(url.endswith('/machine/?comp=goalstate') for url in http_request.get_tracked_urls()): + return ResourceGoneError("Exception to fake a stale goal state") + return MockResponse(body=manifest_xml.encode('utf-8'), status_code=200) + elif url.endswith('/machine/?comp=goalstate'): + http_request.track_url(url) # keep track of goal state requests + return None + + with mock_http_request(http_get_handler=handler) as http_request: HostPluginProtocol.set_default_channel(False) try: manifest = protocol.client.fetch_manifest([VMAgentManifestUri(uri=manifest_url)]) + urls = http_request.get_tracked_urls() self.assertEquals(manifest, manifest_xml) - self.assertEquals(len(http_request.urls), 4, "Unexpected number of HTTP requests: [{0}]".format(http_request.urls)) - self.assertEquals(http_request.urls[0], manifest_url, "The first attempt should have been over the direct channel") - self.assertTrue(http_request.urls[1].endswith('/extensionArtifact'), "The second attempt should have been over the host channel") - self.assertTrue(http_request.urls[2].endswith('/machine/?comp=goalstate'), "The host channel should have been refreshed the goal state") - self.assertTrue(http_request.urls[3].endswith('/extensionArtifact'), "The third attempt should have been over the host channel") + self.assertEquals(len(urls), 4, "Unexpected number of HTTP requests: [{0}]".format(urls)) + self.assertEquals(urls[0], manifest_url, "The first attempt should have been over the direct channel") + self.assertTrue(urls[1].endswith('/extensionArtifact'), "The second attempt should have been over the host channel") + self.assertTrue(urls[2].endswith('/machine/?comp=goalstate'), "The host channel should have been refreshed the goal state") + self.assertTrue(urls[3].endswith('/extensionArtifact'), "The third attempt should have been over the host channel") self.assertEquals(HostPluginProtocol.is_default_channel(), True, "The host should have been set as the default channel") finally: HostPluginProtocol.set_default_channel(False) # Reset default channel @@ -891,7 +852,7 @@ def test_get_artifacts_profile_should_refresh_the_host_plugin_and_not_change_def self.assertEquals(HostPluginProtocol.is_default_channel(), False) def test_send_request_using_appropriate_channel_should_not_invoke_host_channel_when_direct_channel_succeeds(self): - with mock_wire_protocol.create(mockwiredata.DATA_FILE) as protocol: + with mock_wire_protocol(mockwiredata.DATA_FILE) as protocol: protocol.client.get_host_plugin().set_default_channel(False) def direct_func(*args): @@ -912,7 +873,7 @@ def host_func(*args): self.assertEquals(0, host_func.counter) def test_send_request_using_appropriate_channel_should_not_use_direct_channel_when_host_channel_is_default(self): - with mock_wire_protocol.create(mockwiredata.DATA_FILE) as protocol: + with mock_wire_protocol(mockwiredata.DATA_FILE) as protocol: protocol.client.get_host_plugin().set_default_channel(True) def direct_func(*args): @@ -933,7 +894,7 @@ def host_func(*args): self.assertEquals(1, host_func.counter) def test_send_request_using_appropriate_channel_should_use_host_channel_when_direct_channel_fails(self): - with mock_wire_protocol.create(mockwiredata.DATA_FILE) as protocol: + with mock_wire_protocol(mockwiredata.DATA_FILE) as protocol: host = protocol.client.get_host_plugin() host.set_default_channel(False) @@ -957,7 +918,7 @@ def host_func(*args): self.assertEquals(True, host.is_default_channel()) def test_send_request_using_appropriate_channel_should_retry_the_host_channel_after_reloading_goal_state(self): - with mock_wire_protocol.create(mockwiredata.DATA_FILE) as protocol: + with mock_wire_protocol(mockwiredata.DATA_FILE) as protocol: protocol.client.get_host_plugin().set_default_channel(False) def direct_func(*args): @@ -991,7 +952,7 @@ class UpdateGoalStateTestCase(AgentTestCase): """ def test_it_should_update_the_goal_state_and_the_host_plugin_when_the_incarnation_changes(self): - with mock_wire_protocol.create(mockwiredata.DATA_FILE) as protocol: + with mock_wire_protocol(mockwiredata.DATA_FILE) as protocol: protocol.client.get_host_plugin() # if the incarnation changes the behavior is the same for forced and non-forced updates @@ -1048,7 +1009,7 @@ def test_it_should_update_the_goal_state_and_the_host_plugin_when_the_incarnatio self.assertEqual(protocol.client.get_host_plugin().role_config_name, new_role_config_name) def test_non_forced_update_should_not_update_the_goal_state_nor_the_host_plugin_when_the_incarnation_does_not_change(self): - with mock_wire_protocol.create(mockwiredata.DATA_FILE) as protocol: + with mock_wire_protocol(mockwiredata.DATA_FILE) as protocol: protocol.client.get_host_plugin() # The container id, role config name and shared config can change without the incarnation changing; capture the initial @@ -1072,7 +1033,7 @@ def test_non_forced_update_should_not_update_the_goal_state_nor_the_host_plugin_ self.assertEqual(protocol.client.get_host_plugin().role_config_name, role_config_name) def test_forced_update_should_update_the_goal_state_and_the_host_plugin_when_the_incarnation_does_not_change(self): - with mock_wire_protocol.create(mockwiredata.DATA_FILE) as protocol: + with mock_wire_protocol(mockwiredata.DATA_FILE) as protocol: protocol.client.get_host_plugin() # The container id, role config name and shared config can change without the incarnation changing @@ -1101,7 +1062,7 @@ class UpdateHostPluginFromGoalStateTestCase(AgentTestCase): """ def test_it_should_update_the_host_plugin_with_or_without_incarnation_changes(self): - with mock_wire_protocol.create(mockwiredata.DATA_FILE) as protocol: + with mock_wire_protocol(mockwiredata.DATA_FILE) as protocol: protocol.client.get_host_plugin() # the behavior should be the same whether the incarnation changes or not diff --git a/tests/utils/event_logger_tools.py b/tests/utils/event_logger_tools.py index 49aecdf348..626d71d9ef 100644 --- a/tests/utils/event_logger_tools.py +++ b/tests/utils/event_logger_tools.py @@ -20,7 +20,8 @@ import azurelinuxagent.common.event as event from azurelinuxagent.common.version import DISTRO_NAME, DISTRO_VERSION, DISTRO_CODE_NAME import tests.tools as tools -from tests.protocol import mockwiredata, mock_wire_protocol +from tests.protocol import mockwiredata +from tests.protocol.mocks import mock_wire_protocol class EventLoggerTools(object): @@ -52,7 +53,7 @@ def initialize_event_logger(event_dir): mock_imds_client = tools.Mock() mock_imds_client.get_compute = tools.Mock(return_value=mock_imds_info) - with mock_wire_protocol.create(mockwiredata.DATA_FILE) as mock_protocol: + with mock_wire_protocol(mockwiredata.DATA_FILE) as mock_protocol: with tools.patch("azurelinuxagent.common.event.get_imds_client", return_value=mock_imds_client): event.initialize_event_logger_vminfo_common_parameters(mock_protocol) From cb1a1e284fceb620b05fb01c4f53189322f0847e Mon Sep 17 00:00:00 2001 From: nam Date: Mon, 9 Mar 2020 07:50:15 -0700 Subject: [PATCH 4/5] Remove import --- tests/daemon/test_daemon.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/daemon/test_daemon.py b/tests/daemon/test_daemon.py index 9b4255910d..e17dd7f459 100644 --- a/tests/daemon/test_daemon.py +++ b/tests/daemon/test_daemon.py @@ -24,7 +24,6 @@ from azurelinuxagent.daemon import * from azurelinuxagent.daemon.main import OPENSSL_FIPS_ENVIRONMENT from azurelinuxagent.pa.provision.default import ProvisionHandler -from tests.protocol import mock_wire_protocol, mockwiredata from tests.tools import AgentTestCase, Mock, patch From f85571573fe767de4a4344b2186991fd9d1f02b8 Mon Sep 17 00:00:00 2001 From: nam Date: Mon, 9 Mar 2020 08:36:30 -0700 Subject: [PATCH 5/5] Use kwargs --- tests/protocol/mocks.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/protocol/mocks.py b/tests/protocol/mocks.py index 51e8feb701..673683fb51 100644 --- a/tests/protocol/mocks.py +++ b/tests/protocol/mocks.py @@ -68,12 +68,12 @@ def stop(): original_http_request = restutil.http_request - def http_request(method, url, data, headers=None, use_proxy=False, max_retry=restutil.DEFAULT_RETRIES, retry_codes=restutil.RETRY_CODES, retry_delay=restutil.DELAY_IN_SECONDS): + def http_request(method, url, data, **kwargs): if method == 'GET' and mock_data_re.match(url) is not None: - return protocol.mock_wire_data.mock_http_get(url, headers, use_proxy, max_retry, retry_codes, retry_delay) + return protocol.mock_wire_data.mock_http_get(url, **kwargs) elif method == 'POST': - return protocol.mock_wire_data.mock_http_post(url, data, headers, use_proxy, max_retry, retry_codes, retry_delay) - return original_http_request(method, url, data, headers, use_proxy, max_retry, retry_codes, retry_delay) + return protocol.mock_wire_data.mock_http_post(url, data, **kwargs) + return original_http_request(method, url, data, **kwargs) patched = patch("azurelinuxagent.common.utils.restutil.http_request", side_effect=http_request) patched.start()