diff --git a/azurelinuxagent/common/protocol/imds.py b/azurelinuxagent/common/protocol/imds.py index d10bd2b746..3e03a077f0 100644 --- a/azurelinuxagent/common/protocol/imds.py +++ b/azurelinuxagent/common/protocol/imds.py @@ -8,11 +8,12 @@ from azurelinuxagent.common.future import ustr import azurelinuxagent.common.logger as logger from azurelinuxagent.common.protocol.restapi import DataContract, set_properties +from azurelinuxagent.common.protocol.util import get_protocol_util from azurelinuxagent.common.utils.flexible_version import FlexibleVersion IMDS_ENDPOINT = '169.254.169.254' APIVERSION = '2018-02-01' -BASE_URI = "http://{0}/metadata/instance/{1}?api-version={2}" +BASE_METADATA_URI = "http://{0}/metadata/{1}?api-version={2}" IMDS_IMAGE_ORIGIN_UNKNOWN = 0 IMDS_IMAGE_ORIGIN_CUSTOM = 1 @@ -227,7 +228,8 @@ def image_origin(self): return IMDS_IMAGE_ORIGIN_PLATFORM except Exception as e: - logger.warn("Could not determine the image origin from IMDS: {0}", str(e)) + logger.periodic_warn(logger.EVERY_FIFTEEN_MINUTES, + "[PERIODIC] Could not determine the image origin from IMDS: {0}".format(str(e))) return IMDS_IMAGE_ORIGIN_UNKNOWN @@ -242,15 +244,48 @@ def __init__(self, version=APIVERSION): 'User-Agent': restutil.HTTP_USER_AGENT_HEALTH, 'Metadata': True, } - pass + self._regex_imds_ioerror = re.compile(r".*HTTP Failed. GET http://[^ ]+ -- IOError timed out -- [0-9]+ attempts made") + self._protocol_util = get_protocol_util() - @property - def compute_url(self): - return BASE_URI.format(IMDS_ENDPOINT, 'compute', self._api_version) + def _get_metadata_url(self, endpoint, resource_path): + return BASE_METADATA_URI.format(endpoint, resource_path, self._api_version) - @property - def instance_url(self): - return BASE_URI.format(IMDS_ENDPOINT, '', self._api_version) + def _http_get(self, endpoint, resource_path, headers): + url = self._get_metadata_url(endpoint, resource_path) + return restutil.http_get(url, headers=headers, use_proxy=False) + + def get_metadata(self, resource_path, is_health): + """ + Get metadata from IMDS, falling back to Wireserver endpoint if necessary. + + :param str resource_path: path of IMDS resource + :param bool is_health: True if for health/heartbeat, False otherwise + :return: Tuple + is_request_success: True when connection succeeds, False otherwise + response: response from IMDS on request success, failure message otherwise + """ + headers = self._health_headers if is_health else self._headers + endpoint = IMDS_ENDPOINT + try: + resp = self._http_get(endpoint=endpoint, resource_path=resource_path, headers=headers) + except HttpError as e: + logger.periodic_warn(logger.EVERY_FIFTEEN_MINUTES, + "[PERIODIC] Unable to connect to primary IMDS endpoint {0}".format(endpoint)) + if not self._regex_imds_ioerror.match(str(e)): + raise + endpoint = self._protocol_util.get_wireserver_endpoint() + try: + resp = self._http_get(endpoint=endpoint, resource_path=resource_path, headers=headers) + except HttpError as e: + logger.periodic_warn(logger.EVERY_FIFTEEN_MINUTES, + "[PERIODIC] Unable to connect to backup IMDS endpoint {0}".format(endpoint)) + if not self._regex_imds_ioerror.match(str(e)): + raise + return False, "IMDS error in /metadata/{0}: Unable to connect to endpoint".format(resource_path) + if restutil.request_failed(resp): + return False, "IMDS error in /metadata/{0}: {1}".format( + resource_path, restutil.read_response_error(resp)) + return True, resp.read() def get_compute(self): """ @@ -260,13 +295,12 @@ def get_compute(self): :rtype: ComputeInfo """ - resp = restutil.http_get(self.compute_url, headers=self._headers) - - if restutil.request_failed(resp): - raise HttpError("{0} - GET: {1}".format(resp.status, self.compute_url)) + # ensure we get a 200 + success, resp = self.get_metadata('instance/compute', is_health=False) + if not success: + raise HttpError(resp) - data = resp.read() - data = json.loads(ustr(data, encoding="utf-8")) + data = json.loads(ustr(resp, encoding="utf-8")) compute_info = ComputeInfo() set_properties('compute', compute_info, data) @@ -284,14 +318,13 @@ def validate(self): """ # ensure we get a 200 - resp = restutil.http_get(self.instance_url, headers=self._health_headers) - if restutil.request_failed(resp): - return False, "{0}".format(restutil.read_response_error(resp)) + success, resp = self.get_metadata('instance', is_health=True) + if not success: + return False, resp # ensure the response is valid json - data = resp.read() try: - json_data = json.loads(ustr(data, encoding="utf-8")) + json_data = json.loads(ustr(resp, encoding="utf-8")) except Exception as e: return False, "JSON parsing failed: {0}".format(ustr(e)) diff --git a/azurelinuxagent/common/protocol/util.py b/azurelinuxagent/common/protocol/util.py index 186ff76ba5..692b54890c 100644 --- a/azurelinuxagent/common/protocol/util.py +++ b/azurelinuxagent/common/protocol/util.py @@ -152,7 +152,7 @@ def _get_tag_file_path(self): conf.get_lib_dir(), TAG_FILE_NAME) - def _get_wireserver_endpoint(self): + def get_wireserver_endpoint(self): try: file_path = os.path.join(conf.get_lib_dir(), ENDPOINT_FILE_NAME) return fileutil.read_file(file_path) @@ -182,7 +182,7 @@ def _detect_wire_protocol(self): endpoint = self.dhcp_handler.endpoint else: logger.info("_detect_wire_protocol: DHCP not available") - endpoint = self._get_wireserver_endpoint() + endpoint = self.get_wireserver_endpoint() if endpoint == None: endpoint = conf_endpoint logger.info("Using hardcoded WireServer endpoint {0}", endpoint) @@ -239,7 +239,7 @@ def _get_protocol(self): protocol_name = fileutil.read_file(protocol_file_path) if protocol_name == prots.WireProtocol: - endpoint = self._get_wireserver_endpoint() + endpoint = self.get_wireserver_endpoint() return WireProtocol(endpoint) elif protocol_name == prots.MetadataProtocol: return MetadataProtocol() diff --git a/tests/protocol/test_imds.py b/tests/protocol/test_imds.py index 327aa020cf..adabcf40fa 100644 --- a/tests/protocol/test_imds.py +++ b/tests/protocol/test_imds.py @@ -5,7 +5,7 @@ import azurelinuxagent.common.protocol.imds as imds from azurelinuxagent.common.exception import HttpError -from azurelinuxagent.common.future import ustr +from azurelinuxagent.common.future import ustr, httpclient from azurelinuxagent.common.protocol.restapi import set_properties from azurelinuxagent.common.utils import restutil from tests.ga.test_update import ResponseMock @@ -337,13 +337,138 @@ def _assert_validation(self, http_status_code, http_response, expected_valid, ex self.assertEqual(restutil.HTTP_USER_AGENT_HEALTH, kw_args['headers']['User-Agent']) self.assertTrue('Metadata' in kw_args['headers']) self.assertEqual(True, kw_args['headers']['Metadata']) - self.assertEqual('http://169.254.169.254/metadata/instance/?api-version=2018-02-01', + self.assertEqual('http://169.254.169.254/metadata/instance?api-version=2018-02-01', positional_args[0]) self.assertEqual(expected_valid, validate_response[0]) self.assertTrue(expected_response in validate_response[1], "Expected: '{0}', Actual: '{1}'" .format(expected_response, validate_response[1])) + @patch("azurelinuxagent.common.protocol.util.ProtocolUtil") + def test_endpoint_fallback(self, ProtocolUtil): + # http error status codes are tested in test_response_validation, none of which + # should trigger a fallback. This is confirmed as _assert_validation will count + # http GET calls and enforces a single GET call (fallback would cause 2) and + # checks the url called. + + test_subject = imds.ImdsClient() + ProtocolUtil().get_wireserver_endpoint.return_value = "foo.bar" + + # ensure user-agent gets set correctly + for is_health, expected_useragent in [(False, restutil.HTTP_USER_AGENT), (True, restutil.HTTP_USER_AGENT_HEALTH)]: + # set a different resource path for health query to make debugging unit test easier + resource_path = 'something/health' if is_health else 'something' + + # both endpoints unreachable + test_subject._http_get = Mock(side_effect=self._mock_http_get_unreachable_both) + conn_success, response = test_subject.get_metadata(resource_path=resource_path, is_health=is_health) + self.assertFalse(conn_success) + self.assertEqual('IMDS error in /metadata/{0}: Unable to connect to endpoint'.format(resource_path), response) + self.assertEqual(2, test_subject._http_get.call_count) + for _, kwargs in test_subject._http_get.call_args_list: + self.assertTrue('User-Agent' in kwargs['headers']) + self.assertEqual(expected_useragent, kwargs['headers']['User-Agent']) + + # primary IMDS endpoint unreachable and success in secondary IMDS endpoint + test_subject._http_get = Mock(side_effect=self._mock_http_get_unreachable_primary_with_ok) + conn_success, response = test_subject.get_metadata(resource_path=resource_path, is_health=is_health) + self.assertTrue(conn_success) + self.assertEqual('Mock success response', response) + self.assertEqual(2, test_subject._http_get.call_count) + for _, kwargs in test_subject._http_get.call_args_list: + self.assertTrue('User-Agent' in kwargs['headers']) + self.assertEqual(expected_useragent, kwargs['headers']['User-Agent']) + + # primary IMDS endpoint unreachable and http error in secondary IMDS endpoint + test_subject._http_get = Mock(side_effect=self._mock_http_get_unreachable_primary_with_fail) + conn_success, response = test_subject.get_metadata(resource_path=resource_path, is_health=is_health) + self.assertFalse(conn_success) + self.assertEqual('IMDS error in /metadata/{0}: [HTTP Failed] [404: reason] Mock not found'.format(resource_path), response) + self.assertEqual(2, test_subject._http_get.call_count) + for _, kwargs in test_subject._http_get.call_args_list: + self.assertTrue('User-Agent' in kwargs['headers']) + self.assertEqual(expected_useragent, kwargs['headers']['User-Agent']) + + # primary IMDS endpoint unreachable and http error in secondary IMDS endpoint + test_subject._http_get = Mock(side_effect=self._mock_http_get_unreachable_primary_with_fail) + conn_success, response = test_subject.get_metadata(resource_path=resource_path, is_health=is_health) + self.assertFalse(conn_success) + self.assertEqual('IMDS error in /metadata/{0}: [HTTP Failed] [404: reason] Mock not found'.format(resource_path), response) + self.assertEqual(2, test_subject._http_get.call_count) + for _, kwargs in test_subject._http_get.call_args_list: + self.assertTrue('User-Agent' in kwargs['headers']) + self.assertEqual(expected_useragent, kwargs['headers']['User-Agent']) + + # primary IMDS endpoint with non-fallback HTTPError + test_subject._http_get = Mock(side_effect=self._mock_http_get_nonfallback_primary) + try: + test_subject.get_metadata(resource_path=resource_path, is_health=is_health) + self.assertTrue(False, 'Expected HttpError but no except raised') + except HttpError as e: + self.assertEqual('[HttpError] HTTP Failed. GET http://169.254.169.254/metadata/{0} -- IOError incomplete read -- 6 attempts made'.format(resource_path), str(e)) + self.assertEqual(1, test_subject._http_get.call_count) + for _, kwargs in test_subject._http_get.call_args_list: + self.assertTrue('User-Agent' in kwargs['headers']) + self.assertEqual(expected_useragent, kwargs['headers']['User-Agent']) + except Exception as e: + self.assertTrue(False, 'Expected HttpError but got {0}'.format(str(e))) + + # primary IMDS endpoint unreachable and non-timeout HTTPError in secondary IMDS endpoint + test_subject._http_get = Mock(side_effect=self._mock_http_get_unreachable_primary_with_except) + try: + test_subject.get_metadata(resource_path=resource_path, is_health=is_health) + self.assertTrue(False, 'Expected HttpError but no except raised') + except HttpError as e: + self.assertEqual('[HttpError] HTTP Failed. GET http://foo.bar/metadata/{0} -- IOError incomplete read -- 6 attempts made'.format(resource_path), str(e)) + self.assertEqual(2, test_subject._http_get.call_count) + for _, kwargs in test_subject._http_get.call_args_list: + self.assertTrue('User-Agent' in kwargs['headers']) + self.assertEqual(expected_useragent, kwargs['headers']['User-Agent']) + except Exception as e: + self.assertTrue(False, 'Expected HttpError but got {0}'.format(str(e))) + + def _mock_http_get_unreachable_both(self, *_, **kwargs): + raise HttpError("HTTP Failed. GET http://{0}/metadata/{1} -- IOError timed out -- 6 attempts made" + .format(kwargs['endpoint'], kwargs['resource_path'])) + + def _mock_http_get_unreachable_primary_with_ok(self, *_, **kwargs): + if "169.254.169.254" == kwargs['endpoint']: + raise HttpError("HTTP Failed. GET http://{0}/metadata/{1} -- IOError timed out -- 6 attempts made" + .format(kwargs['endpoint'], kwargs['resource_path'])) + elif "foo.bar" == kwargs['endpoint']: + resp = MagicMock() + resp.status = httpclient.OK + resp.read.return_value = 'Mock success response' + return resp + raise Exception("Unexpected endpoint called") + + def _mock_http_get_unreachable_primary_with_fail(self, *_, **kwargs): + if "169.254.169.254" == kwargs['endpoint']: + raise HttpError("HTTP Failed. GET http://{0}/metadata/{1} -- IOError timed out -- 6 attempts made" + .format(kwargs['endpoint'], kwargs['resource_path'])) + elif "foo.bar" == kwargs['endpoint']: + resp = MagicMock() + resp.status = httpclient.NOT_FOUND + resp.reason = 'reason' + resp.read.return_value = 'Mock not found' + return resp + raise Exception("Unexpected endpoint called") + + def _mock_http_get_nonfallback_primary(self, *_, **kwargs): + if "169.254.169.254" == kwargs['endpoint']: + raise HttpError("HTTP Failed. GET http://{0}/metadata/{1} -- IOError incomplete read -- 6 attempts made" + .format(kwargs['endpoint'], kwargs['resource_path'])) + raise Exception("Unexpected endpoint called") + + def _mock_http_get_unreachable_primary_with_except(self, *_, **kwargs): + if "169.254.169.254" == kwargs['endpoint']: + raise HttpError("HTTP Failed. GET http://{0}/metadata/{1} -- IOError timed out -- 6 attempts made" + .format(kwargs['endpoint'], kwargs['resource_path'])) + elif "foo.bar" == kwargs['endpoint']: + raise HttpError("HTTP Failed. GET http://{0}/metadata/{1} -- IOError incomplete read -- 6 attempts made" + .format(kwargs['endpoint'], kwargs['resource_path'])) + raise Exception("Unexpected endpoint called") + if __name__ == '__main__': unittest.main()