diff --git a/azurelinuxagent/agent.py b/azurelinuxagent/agent.py index 4229194cd7..e1dfc4ed59 100644 --- a/azurelinuxagent/agent.py +++ b/azurelinuxagent/agent.py @@ -129,7 +129,7 @@ def run_exthandlers(self): """ Run the update and extension handler """ - logger.set_prefix("Upd/Ext-Handler") + logger.set_prefix("ExtHandler") from azurelinuxagent.ga.update import get_update_handler update_handler = get_update_handler() update_handler.run() diff --git a/azurelinuxagent/common/errorstate.py b/azurelinuxagent/common/errorstate.py index 52e2ddd5ec..38aaa1f916 100644 --- a/azurelinuxagent/common/errorstate.py +++ b/azurelinuxagent/common/errorstate.py @@ -2,6 +2,8 @@ ERROR_STATE_DELTA_DEFAULT = timedelta(minutes=15) ERROR_STATE_DELTA_INSTALL = timedelta(minutes=5) +ERROR_STATE_HOST_PLUGIN_FAILURE = timedelta(minutes=5) + class ErrorState(object): def __init__(self, min_timedelta=ERROR_STATE_DELTA_DEFAULT): diff --git a/azurelinuxagent/common/event.py b/azurelinuxagent/common/event.py index 1c26a13990..1f55c23bbe 100644 --- a/azurelinuxagent/common/event.py +++ b/azurelinuxagent/common/event.py @@ -56,6 +56,7 @@ class WALAEventOperation: HealthCheck = "HealthCheck" HeartBeat = "HeartBeat" HostPlugin = "HostPlugin" + HostPluginHeartbeat = "HostPluginHeartbeat" HttpErrors = "HttpErrors" Install = "Install" InitializeHostPlugin = "InitializeHostPlugin" diff --git a/azurelinuxagent/common/protocol/healthservice.py b/azurelinuxagent/common/protocol/healthservice.py new file mode 100644 index 0000000000..29a44f2a2a --- /dev/null +++ b/azurelinuxagent/common/protocol/healthservice.py @@ -0,0 +1,150 @@ +# Microsoft Azure Linux Agent +# +# Copyright 2018 Microsoft Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Requires Python 2.6+ and Openssl 1.0+ +# + +import json + +from azurelinuxagent.common import logger +from azurelinuxagent.common.exception import HttpError +from azurelinuxagent.common.future import ustr +from azurelinuxagent.common.utils import restutil + + +class Observation(object): + def __init__(self, name, is_healthy, description='', value=''): + if name is None: + raise ValueError("Observation name must be provided") + + if is_healthy is None: + raise ValueError("Observation health must be provided") + + if value is None: + value = '' + + if description is None: + description = '' + + self.name = name + self.is_healthy = is_healthy + self.description = description + self.value = value + + @property + def as_obj(self): + return { + "ObservationName": self.name[:64], + "IsHealthy": self.is_healthy, + "Description": self.description[:128], + "Value": self.value[:128] + } + + +class HealthService(object): + + ENDPOINT = 'http://{0}:80/HealthService' + API = 'reporttargethealth' + VERSION = "1.0" + OBSERVER_NAME = 'WALinuxAgent' + HOST_PLUGIN_HEARTBEAT_OBSERVATION_NAME = 'GuestAgentPluginHeartbeat' + HOST_PLUGIN_STATUS_OBSERVATION_NAME = 'GuestAgentPluginStatus' + HOST_PLUGIN_VERSIONS_OBSERVATION_NAME = 'GuestAgentPluginVersions' + HOST_PLUGIN_ARTIFACT_OBSERVATION_NAME = 'GuestAgentPluginArtifact' + MAX_OBSERVATIONS = 10 + + def __init__(self, endpoint): + self.endpoint = HealthService.ENDPOINT.format(endpoint) + self.api = HealthService.API + self.version = HealthService.VERSION + self.source = HealthService.OBSERVER_NAME + self.observations = list() + + @property + def as_json(self): + data = { + "Api": self.api, + "Version": self.version, + "Source": self.source, + "Observations": [o.as_obj for o in self.observations] + } + return json.dumps(data) + + def report_host_plugin_heartbeat(self, is_healthy): + """ + Reports a signal for /health + :param is_healthy: whether the call succeeded + """ + self._observe(name=HealthService.HOST_PLUGIN_HEARTBEAT_OBSERVATION_NAME, + is_healthy=is_healthy) + self._report() + + def report_host_plugin_versions(self, is_healthy, response): + """ + Reports a signal for /versions + :param is_healthy: whether the api call succeeded + :param response: debugging information for failures + """ + self._observe(name=HealthService.HOST_PLUGIN_VERSIONS_OBSERVATION_NAME, + is_healthy=is_healthy, + value=response) + self._report() + + def report_host_plugin_extension_artifact(self, is_healthy, source, response): + """ + Reports a signal for /extensionArtifact + :param is_healthy: whether the api call succeeded + :param source: specifies the api caller for debugging failures + :param response: debugging information for failures + :return: + """ + self._observe(name=HealthService.HOST_PLUGIN_ARTIFACT_OBSERVATION_NAME, + is_healthy=is_healthy, + description=source, + value=response) + self._report() + + def report_host_plugin_status(self, is_healthy, response): + """ + Reports a signal for /status + :param is_healthy: whether the api call succeeded + :param response: debugging information for failures + :return: + """ + self._observe(name=HealthService.HOST_PLUGIN_STATUS_OBSERVATION_NAME, + is_healthy=is_healthy, + value=response) + self._report() + + def _observe(self, name, is_healthy, value='', description=''): + # ensure we keep the list size within bounds + if len(self.observations) >= HealthService.MAX_OBSERVATIONS: + del self.observations[:HealthService.MAX_OBSERVATIONS-1] + self.observations.append(Observation(name=name, + is_healthy=is_healthy, + value=value, + description=description)) + + def _report(self): + logger.verbose('HealthService: report observations') + try: + restutil.http_post(self.endpoint, self.as_json, headers={'Content-Type': 'application/json'}) + logger.verbose('HealthService: Reported observations to {0}: {1}', self.endpoint, self.as_json) + except HttpError as e: + logger.warn("HealthService: could not report observations: {0}", ustr(e)) + finally: + # these signals are not timestamped, so there is no value in persisting data + del self.observations[:] diff --git a/azurelinuxagent/common/protocol/hostplugin.py b/azurelinuxagent/common/protocol/hostplugin.py index f80836d2ec..c958c02013 100644 --- a/azurelinuxagent/common/protocol/hostplugin.py +++ b/azurelinuxagent/common/protocol/hostplugin.py @@ -18,13 +18,15 @@ # import base64 +import datetime import json -import traceback from azurelinuxagent.common import logger +from azurelinuxagent.common.errorstate import ErrorState, ERROR_STATE_HOST_PLUGIN_FAILURE from azurelinuxagent.common.exception import HttpError, ProtocolError, \ ResourceGoneError from azurelinuxagent.common.future import ustr, httpclient +from azurelinuxagent.common.protocol.healthservice import HealthService from azurelinuxagent.common.utils import restutil from azurelinuxagent.common.utils import textutil from azurelinuxagent.common.utils.textutil import remove_bom @@ -35,6 +37,7 @@ URI_FORMAT_GET_EXTENSION_ARTIFACT = "http://{0}:{1}/extensionArtifact" URI_FORMAT_PUT_VM_STATUS = "http://{0}:{1}/status" URI_FORMAT_PUT_LOG = "http://{0}:{1}/vmAgentLog" +URI_FORMAT_HEALTH = "http://{0}:{1}/health" API_VERSION = "2015-09-01" HEADER_CONTAINER_ID = "x-ms-containerid" HEADER_VERSION = "x-ms-version" @@ -47,6 +50,9 @@ class HostPluginProtocol(object): _is_default_channel = False + FETCH_REPORTING_PERIOD = datetime.timedelta(minutes=1) + STATUS_REPORTING_PERIOD = datetime.timedelta(minutes=1) + def __init__(self, endpoint, container_id, role_config_name): if endpoint is None: raise ProtocolError("HostGAPlugin: Endpoint not provided") @@ -58,6 +64,11 @@ def __init__(self, endpoint, container_id, role_config_name): self.deployment_id = None self.role_config_name = role_config_name self.manifest_uri = None + self.health_service = HealthService(endpoint) + self.fetch_error_state = ErrorState(min_timedelta=ERROR_STATE_HOST_PLUGIN_FAILURE) + self.status_error_state = ErrorState(min_timedelta=ERROR_STATE_HOST_PLUGIN_FAILURE) + self.fetch_last_timestamp = None + self.status_last_timestamp = None @staticmethod def is_default_channel(): @@ -77,25 +88,44 @@ def ensure_initialized(self): is_success=self.is_available) return self.is_available + def get_health(self): + """ + Call the /health endpoint + :return: True if 200 received, False otherwise + """ + url = URI_FORMAT_HEALTH.format(self.endpoint, + HOST_PLUGIN_PORT) + logger.verbose("HostGAPlugin: Getting health from [{0}]", url) + status_ok = False + try: + response = restutil.http_get(url, max_retry=1) + status_ok = restutil.request_succeeded(response) + except HttpError as e: + logger.verbose("HostGAPlugin: Exception getting health", ustr(e)) + return status_ok + def get_api_versions(self): url = URI_FORMAT_GET_API_VERSIONS.format(self.endpoint, HOST_PLUGIN_PORT) - logger.verbose("HostGAPlugin: Getting API versions at [{0}]".format( - url)) + logger.verbose("HostGAPlugin: Getting API versions at [{0}]" + .format(url)) return_val = [] + error_response = '' + is_healthy = False try: headers = {HEADER_CONTAINER_ID: self.container_id} response = restutil.http_get(url, headers) if restutil.request_failed(response): - logger.error( - "HostGAPlugin: Failed Get API versions: {0}".format( - restutil.read_response_error(response))) + error_response = restutil.read_response_error(response) + logger.error("HostGAPlugin: Failed Get API versions: {0}".format(error_response)) else: return_val = ustr(remove_bom(response.read()), encoding='utf-8') - + is_healthy = True except HttpError as e: logger.error("HostGAPlugin: Exception Get API versions: {0}".format(e)) + self.health_service.report_host_plugin_versions(is_healthy=is_healthy, response=error_response) + return return_val def get_artifact_request(self, artifact_url, artifact_manifest_url=None): @@ -117,6 +147,55 @@ def get_artifact_request(self, artifact_url, artifact_manifest_url=None): return url, headers + def report_fetch_health(self, uri, is_healthy=True, source='', response=''): + + if uri != URI_FORMAT_GET_EXTENSION_ARTIFACT.format(self.endpoint, HOST_PLUGIN_PORT): + return + + if self.should_report(is_healthy, + self.fetch_error_state, + self.fetch_last_timestamp, + HostPluginProtocol.FETCH_REPORTING_PERIOD): + self.fetch_last_timestamp = datetime.datetime.utcnow() + health_signal = self.fetch_error_state.is_triggered() is False + self.health_service.report_host_plugin_extension_artifact(is_healthy=health_signal, + source=source, + response=response) + + def report_status_health(self, is_healthy, response=''): + if self.should_report(is_healthy, + self.status_error_state, + self.status_last_timestamp, + HostPluginProtocol.STATUS_REPORTING_PERIOD): + self.status_last_timestamp = datetime.datetime.utcnow() + health_signal = self.status_error_state.is_triggered() is False + self.health_service.report_host_plugin_status(is_healthy=health_signal, + response=response) + + @staticmethod + def should_report(is_healthy, error_state, last_timestamp, period): + """ + Determine whether a health signal should be reported + :param is_healthy: whether the current measurement is healthy + :param error_state: the error state which is tracking time since failure + :param last_timestamp: the last measurement time stamp + :param period: the reporting period + :return: True if the signal should be reported, False otherwise + """ + + if is_healthy: + # we only reset the error state upon success, since we want to keep + # reporting the failure; this is different to other uses of error states + # which do not have a separate periodicity + error_state.reset() + else: + error_state.incr() + + if last_timestamp is None: + last_timestamp = datetime.datetime.utcnow() - period + + return datetime.datetime.utcnow() >= (last_timestamp + period) + def put_vm_log(self, content): raise NotImplementedError("Unimplemented") @@ -156,16 +235,20 @@ def _put_block_blob_status(self, sas_url, status_blob): url = URI_FORMAT_PUT_VM_STATUS.format(self.endpoint, HOST_PLUGIN_PORT) response = restutil.http_put(url, - data=self._build_status_data( - sas_url, - status_blob.get_block_blob_headers(len(status_blob.data)), - bytearray(status_blob.data, encoding='utf-8')), - headers=self._build_status_headers()) + data=self._build_status_data( + sas_url, + status_blob.get_block_blob_headers(len(status_blob.data)), + bytearray(status_blob.data, encoding='utf-8')), + headers=self._build_status_headers()) if restutil.request_failed(response): - raise HttpError("HostGAPlugin: Put BlockBlob failed: {0}".format( - restutil.read_response_error(response))) + error_response = restutil.read_response_error(response) + is_healthy = not restutil.request_failed_at_hostplugin(response) + self.report_status_health(is_healthy=is_healthy, response=error_response) + raise HttpError("HostGAPlugin: Put BlockBlob failed: {0}" + .format(error_response)) else: + self.report_status_health(is_healthy=True) logger.verbose("HostGAPlugin: Put BlockBlob status succeeded") def _put_page_blob_status(self, sas_url, status_blob): @@ -178,16 +261,19 @@ def _put_page_blob_status(self, sas_url, status_blob): # First, initialize an empty blob response = restutil.http_put(url, - data=self._build_status_data( - sas_url, - status_blob.get_page_blob_create_headers(status_size)), - headers=self._build_status_headers()) + data=self._build_status_data( + sas_url, + status_blob.get_page_blob_create_headers(status_size)), + headers=self._build_status_headers()) if restutil.request_failed(response): - raise HttpError( - "HostGAPlugin: Failed PageBlob clean-up: {0}".format( - restutil.read_response_error(response))) + error_response = restutil.read_response_error(response) + is_healthy = not restutil.request_failed_at_hostplugin(response) + self.report_status_health(is_healthy=is_healthy, response=error_response) + raise HttpError("HostGAPlugin: Failed PageBlob clean-up: {0}" + .format(error_response)) else: + self.report_status_health(is_healthy=True) logger.verbose("HostGAPlugin: PageBlob clean-up succeeded") # Then, upload the blob in pages @@ -207,17 +293,19 @@ def _put_page_blob_status(self, sas_url, status_blob): # Send the page response = restutil.http_put(url, - data=self._build_status_data( - sas_url, - status_blob.get_page_blob_page_headers(start, end), - buf), - headers=self._build_status_headers()) + data=self._build_status_data( + sas_url, + status_blob.get_page_blob_page_headers(start, end), + buf), + headers=self._build_status_headers()) if restutil.request_failed(response): + error_response = restutil.read_response_error(response) + is_healthy = not restutil.request_failed_at_hostplugin(response) + self.report_status_health(is_healthy=is_healthy, response=error_response) raise HttpError( - "HostGAPlugin Error: Put PageBlob bytes [{0},{1}]: " \ - "{2}".format( - start, end, restutil.read_response_error(response))) + "HostGAPlugin Error: Put PageBlob bytes " + "[{0},{1}]: {2}".format(start, end, error_response)) # Advance to the next page (if any) start = end diff --git a/azurelinuxagent/common/protocol/restapi.py b/azurelinuxagent/common/protocol/restapi.py index fe3b62ffed..7bcdd70cc0 100644 --- a/azurelinuxagent/common/protocol/restapi.py +++ b/azurelinuxagent/common/protocol/restapi.py @@ -333,12 +333,14 @@ def get_artifacts_profile(self): raise NotImplementedError() def download_ext_handler_pkg(self, uri, headers=None, use_proxy=True): + pkg = None try: resp = restutil.http_get(uri, headers=headers, use_proxy=use_proxy) if restutil.request_succeeded(resp): - return resp.read() + pkg = resp.read() except Exception as e: logger.warn("Failed to download from: {0}".format(uri), e) + return pkg def report_provision_status(self, provision_status): raise NotImplementedError() diff --git a/azurelinuxagent/common/protocol/util.py b/azurelinuxagent/common/protocol/util.py index f6383cc6f6..a3e3176fad 100644 --- a/azurelinuxagent/common/protocol/util.py +++ b/azurelinuxagent/common/protocol/util.py @@ -54,6 +54,7 @@ def get_protocol_util(): class ProtocolUtil(object): + """ ProtocolUtil handles initialization for protocol instance. 2 protocol types are invoked, wire protocol and metadata protocols. diff --git a/azurelinuxagent/common/protocol/wire.py b/azurelinuxagent/common/protocol/wire.py index 5e5604a3ad..267be74e1f 100644 --- a/azurelinuxagent/common/protocol/wire.py +++ b/azurelinuxagent/common/protocol/wire.py @@ -32,7 +32,8 @@ from azurelinuxagent.common.exception import ProtocolNotFoundError, \ ResourceGoneError from azurelinuxagent.common.future import httpclient, bytebuffer -from azurelinuxagent.common.protocol.hostplugin import HostPluginProtocol +from azurelinuxagent.common.protocol.hostplugin import HostPluginProtocol, URI_FORMAT_GET_EXTENSION_ARTIFACT, \ + HOST_PLUGIN_PORT from azurelinuxagent.common.protocol.restapi import * from azurelinuxagent.common.utils.archive import StateFlusher from azurelinuxagent.common.utils.cryptutil import CryptUtil @@ -159,16 +160,15 @@ def get_artifacts_profile(self): logger.verbose("Get In-VM Artifacts Profile") return self.client.get_artifacts_profile() - def download_ext_handler_pkg(self, uri, headers=None): - package = super(WireProtocol, self).download_ext_handler_pkg(uri) + def download_ext_handler_pkg(self, uri, headers=None, use_proxy=True): + package = self.client.fetch(uri, headers=headers, use_proxy=use_proxy, decode=False) - if package is not None: - return package - else: + if package is None: logger.verbose("Download did not succeed, falling back to host plugin") host = self.client.get_host_plugin() uri, headers = host.get_artifact_request(uri, host.manifest_uri) - package = super(WireProtocol, self).download_ext_handler_pkg(uri, headers=headers, use_proxy=False) + package = self.client.fetch(uri, headers=headers, use_proxy=False, decode=False) + return package def report_provision_status(self, provision_status): @@ -252,12 +252,10 @@ def _build_health_report(incarnation, container_id, role_instance_id, return xml -""" -Convert VMStatus object to status blob format -""" - - def ga_status_to_guest_info(ga_status): + """ + Convert VMStatus object to status blob format + """ v1_ga_guest_info = { "computerName" : ga_status.hostname, "osName" : ga_status.osname, @@ -279,6 +277,7 @@ def ga_status_to_v1(ga_status): } return v1_ga_status + def ext_substatus_to_v1(sub_status_list): status_list = [] for substatus in sub_status_list: @@ -616,7 +615,7 @@ def fetch_manifest(self, version_uris): logger.verbose("Using host plugin as default channel") else: logger.verbose("Failed to download manifest, " - "switching to host plugin") + "switching to host plugin") try: host = self.get_host_plugin() @@ -640,7 +639,8 @@ def fetch_manifest(self, version_uris): raise ProtocolError("Failed to fetch manifest from all sources") - def fetch(self, uri, headers=None, use_proxy=None): + def fetch(self, uri, headers=None, use_proxy=None, decode=True): + content = None logger.verbose("Fetch [{0}] with headers [{1}]", uri, headers) try: resp = self.call_storage_service( @@ -650,21 +650,27 @@ def fetch(self, uri, headers=None, use_proxy=None): use_proxy=use_proxy) if restutil.request_failed(resp): - msg = "[Storage Failed] URI {0} ".format(uri) - if resp is not None: - msg += restutil.read_response_error(resp) + error_response = restutil.read_response_error(resp) + msg = "Fetch failed from [{0}]: {1}".format(uri, error_response) logger.warn(msg) + if self.host_plugin is not None: + self.host_plugin.report_fetch_health(uri, + is_healthy=not restutil.request_failed_at_hostplugin(resp), + source='WireClient', + response=error_response) raise ProtocolError(msg) + else: + response_content = resp.read() + content = self.decode_config(response_content) if decode else response_content + if self.host_plugin is not None: + self.host_plugin.report_fetch_health(uri, source='WireClient') - return self.decode_config(resp.read()) - - except (HttpError, ProtocolError) as e: + except (HttpError, ProtocolError, IOError) as e: logger.verbose("Fetch failed from [{0}]: {1}", uri, e) - if isinstance(e, ResourceGoneError): raise - return None + return content def update_hosting_env(self, goal_state): if goal_state.hosting_env_uri is None: @@ -1153,7 +1159,7 @@ def get_artifacts_profile(self): logger.verbose("Using host plugin as default channel") else: logger.verbose("Failed to download artifacts profile, " - "switching to host plugin") + "switching to host plugin") host = self.get_host_plugin() uri, headers = host.get_artifact_request(blob) @@ -1176,6 +1182,7 @@ def get_artifacts_profile(self): return None + class VersionInfo(object): def __init__(self, xml_text): """ diff --git a/azurelinuxagent/common/utils/restutil.py b/azurelinuxagent/common/utils/restutil.py index 5ceb4c949d..b64c546f61 100644 --- a/azurelinuxagent/common/utils/restutil.py +++ b/azurelinuxagent/common/utils/restutil.py @@ -52,7 +52,6 @@ ] RESOURCE_GONE_CODES = [ - httpclient.BAD_REQUEST, httpclient.GONE ] @@ -62,6 +61,10 @@ httpclient.ACCEPTED ] +HOSTPLUGIN_UPSTREAM_FAILURE_CODES = [ + 502 +] + THROTTLE_CODES = [ httpclient.FORBIDDEN, httpclient.SERVICE_UNAVAILABLE, @@ -385,12 +388,22 @@ def http_delete(url, headers=None, use_proxy=False, retry_codes=retry_codes, retry_delay=retry_delay) + def request_failed(resp, ok_codes=OK_CODES): return not request_succeeded(resp, ok_codes=ok_codes) + def request_succeeded(resp, ok_codes=OK_CODES): return resp is not None and resp.status in ok_codes + +def request_failed_at_hostplugin(resp, upstream_failure_codes=HOSTPLUGIN_UPSTREAM_FAILURE_CODES): + """ + Host plugin will return 502 for any upstream issue, so a failure is any 5xx except 502 + """ + return resp is not None and resp.status >= 500 and resp.status not in upstream_failure_codes + + def read_response_error(resp): result = '' if resp is not None: diff --git a/azurelinuxagent/ga/exthandlers.py b/azurelinuxagent/ga/exthandlers.py index a802e1927d..1d49187d44 100644 --- a/azurelinuxagent/ga/exthandlers.py +++ b/azurelinuxagent/ga/exthandlers.py @@ -236,10 +236,6 @@ def run(self): message=msg) return - def run_status(self): - self.report_ext_handlers_status() - return - def get_upgrade_guid(self, name): return self.last_upgrade_guids.get(name, (None, False))[0] @@ -721,10 +717,10 @@ def copy_status_files(self, old_ext_handler_i): def set_operation(self, op): self.operation = op - def report_event(self, message="", is_success=True, duration=0): + def report_event(self, message="", is_success=True, duration=0, log_event=True): ext_handler_version = self.ext_handler.properties.version add_event(name=self.ext_handler.name, version=ext_handler_version, message=message, - op=self.operation, is_success=is_success, duration=duration) + op=self.operation, is_success=is_success, duration=duration, log_event=log_event) def download(self): begin_utc = datetime.datetime.utcnow() @@ -1008,7 +1004,7 @@ def launch_command(self, cmd, timeout=300): raise ExtensionError("Non-zero exit code: {0}, {1}\n{2}".format(ret, cmd, msg)) duration = elapsed_milliseconds(begin_utc) - self.report_event(message="{0}\n{1}".format(cmd, msg), duration=duration) + self.report_event(message="{0}\n{1}".format(cmd, msg), duration=duration, log_event=False) def load_manifest(self): man_file = self.get_manifest_file() @@ -1115,9 +1111,8 @@ def set_handler_status(self, status="NotReady", message="", code=0): self.ext_handler.name, self.ext_handler.properties.version)) except (IOError, ValueError, ProtocolError) as e: - fileutil.clean_ioerror(e, - paths=[status_file]) - self.logger.error("Failed to save handler status: {0}", traceback.format_exc()) + fileutil.clean_ioerror(e, paths=[status_file]) + self.logger.error("Failed to save handler status: {0}, {1}", ustr(e), traceback.format_exc()) def get_handler_status(self): state_dir = self.get_conf_dir() diff --git a/azurelinuxagent/ga/monitor.py b/azurelinuxagent/ga/monitor.py index e1bfd2a6f3..2e435e0d32 100644 --- a/azurelinuxagent/ga/monitor.py +++ b/azurelinuxagent/ga/monitor.py @@ -26,12 +26,14 @@ import azurelinuxagent.common.conf as conf import azurelinuxagent.common.utils.fileutil as fileutil import azurelinuxagent.common.logger as logger +from azurelinuxagent.common.errorstate import ErrorState from azurelinuxagent.common.event import add_event, WALAEventOperation from azurelinuxagent.common.exception import EventError, ProtocolError, OSUtilError, HttpError from azurelinuxagent.common.future import ustr from azurelinuxagent.common.osutil import get_osutil from azurelinuxagent.common.protocol import get_protocol_util +from azurelinuxagent.common.protocol.healthservice import HealthService from azurelinuxagent.common.protocol.imds import get_imds_client from azurelinuxagent.common.protocol.restapi import TelemetryEventParam, \ TelemetryEventList, \ @@ -91,19 +93,46 @@ def get_monitor_handler(): class MonitorHandler(object): + + EVENT_COLLECTION_PERIOD = datetime.timedelta(minutes=1) + TELEMETRY_HEARTBEAT_PERIOD = datetime.timedelta(minutes=30) + HOST_PLUGIN_HEARTBEAT_PERIOD = datetime.timedelta(minutes=1) + HOST_PLUGIN_HEALTH_PERIOD = datetime.timedelta(minutes=5) + def __init__(self): self.osutil = get_osutil() self.protocol_util = get_protocol_util() self.imds_client = get_imds_client() - self.sysinfo = [] + self.event_thread = None + self.last_event_collection = None + self.last_telemetry_heartbeat = None + self.last_host_plugin_heartbeat = None + self.protocol = None + self.health_service = None + + self.counter = 0 + self.sysinfo = [] + self.should_run = True + self.heartbeat_id = str(uuid.uuid4()).upper() + self.host_plugin_errorstate = ErrorState(min_timedelta=MonitorHandler.HOST_PLUGIN_HEALTH_PERIOD) def run(self): + self.init_protocols() self.init_sysinfo() self.start() + def stop(self): + self.should_run = False + if self.is_alive(): + self.event_thread.join() + + def init_protocols(self): + self.protocol = self.protocol_util.get_protocol() + self.health_service = HealthService(self.protocol.endpoint) + def is_alive(self): - return self.event_thread.is_alive() + return self.event_thread is not None and self.event_thread.is_alive() def start(self): self.event_thread = threading.Thread(target=self.daemon) @@ -129,8 +158,7 @@ def init_sysinfo(self): logger.warn("Failed to get system info: {0}", e) try: - protocol = self.protocol_util.get_protocol() - vminfo = protocol.get_vminfo() + vminfo = self.protocol.get_vminfo() self.sysinfo.append(TelemetryEventParam("VMName", vminfo.vmName)) self.sysinfo.append(TelemetryEventParam("TenantName", @@ -173,53 +201,110 @@ def collect_event(self, evt_file_name): raise EventError(msg) def collect_and_send_events(self): - event_list = TelemetryEventList() - event_dir = os.path.join(conf.get_lib_dir(), "events") - event_files = os.listdir(event_dir) - for event_file in event_files: - if not event_file.endswith(".tld"): - continue - event_file_path = os.path.join(event_dir, event_file) + if self.last_event_collection is None: + self.last_event_collection = datetime.datetime.utcnow() - MonitorHandler.EVENT_COLLECTION_PERIOD + + if datetime.datetime.utcnow() >= (self.last_event_collection + MonitorHandler.EVENT_COLLECTION_PERIOD): try: - data_str = self.collect_event(event_file_path) - except EventError as e: - logger.error("{0}", e) - continue + event_list = TelemetryEventList() + event_dir = os.path.join(conf.get_lib_dir(), "events") + event_files = os.listdir(event_dir) + for event_file in event_files: + if not event_file.endswith(".tld"): + continue + event_file_path = os.path.join(event_dir, event_file) + try: + data_str = self.collect_event(event_file_path) + except EventError as e: + logger.error("{0}", e) + continue + + try: + event = parse_event(data_str) + self.add_sysinfo(event) + event_list.events.append(event) + except (ValueError, ProtocolError) as e: + logger.warn("Failed to decode event file: {0}", e) + continue + + if len(event_list.events) == 0: + return + + try: + self.protocol.report_event(event_list) + except ProtocolError as e: + logger.error("{0}", e) + except Exception as e: + logger.warn("Failed to send events: {0}", e) + + self.last_event_collection = datetime.datetime.utcnow() + + def daemon(self): + min_delta = min(MonitorHandler.TELEMETRY_HEARTBEAT_PERIOD, + MonitorHandler.EVENT_COLLECTION_PERIOD, + MonitorHandler.HOST_PLUGIN_HEARTBEAT_PERIOD).seconds + while self.should_run: + self.send_telemetry_heartbeat() + self.collect_and_send_events() + self.send_host_plugin_heartbeat() + time.sleep(min_delta) + def add_sysinfo(self, event): + sysinfo_names = [v.name for v in self.sysinfo] + for param in event.parameters: + if param.name in sysinfo_names: + logger.verbose("Remove existing event parameter: [{0}:{1}]", + param.name, + param.value) + event.parameters.remove(param) + event.parameters.extend(self.sysinfo) + + def send_host_plugin_heartbeat(self): + """ + Send a health signal every HOST_PLUGIN_HEARTBEAT_PERIOD. The signal is 'Healthy' when we have been able to + communicate with HostGAPlugin at least once in the last HOST_PLUGIN_HEALTH_PERIOD. + """ + if self.last_host_plugin_heartbeat is None: + self.last_host_plugin_heartbeat = datetime.datetime.utcnow() - MonitorHandler.HOST_PLUGIN_HEARTBEAT_PERIOD + + if datetime.datetime.utcnow() >= (self.last_host_plugin_heartbeat + MonitorHandler.HOST_PLUGIN_HEARTBEAT_PERIOD): try: - event = parse_event(data_str) - self.add_sysinfo(event) - event_list.events.append(event) - except (ValueError, ProtocolError) as e: - logger.warn("Failed to decode event file: {0}", e) - continue + host_plugin = self.protocol.client.get_host_plugin() + host_plugin.ensure_initialized() + is_currently_healthy = host_plugin.get_health() - if len(event_list.events) == 0: - return + if is_currently_healthy: + self.host_plugin_errorstate.reset() + else: + self.host_plugin_errorstate.incr() - try: - protocol = self.protocol_util.get_protocol() - protocol.report_event(event_list) - except ProtocolError as e: - logger.error("{0}", e) + is_healthy = self.host_plugin_errorstate.is_triggered() is False + logger.verbose("HostGAPlugin health: {0}", is_healthy) - def daemon(self): - period = datetime.timedelta(minutes=30) - protocol = self.protocol_util.get_protocol() - last_heartbeat = datetime.datetime.utcnow() - period - - # Create a new identifier on each restart and reset the counter - heartbeat_id = str(uuid.uuid4()).upper() - counter = 0 - while True: - if datetime.datetime.utcnow() >= (last_heartbeat + period): - last_heartbeat = datetime.datetime.utcnow() - incarnation = protocol.get_incarnation() - dropped_packets = self.osutil.get_firewall_dropped_packets( - protocol.endpoint) - - msg = "{0};{1};{2};{3}".format( - incarnation, counter, heartbeat_id, dropped_packets) + self.health_service.report_host_plugin_heartbeat(is_healthy) + + except Exception as e: + msg = "Exception sending host plugin heartbeat: {0}".format(ustr(e)) + add_event( + name=AGENT_NAME, + version=CURRENT_VERSION, + op=WALAEventOperation.HostPluginHeartbeat, + is_success=False, + message=msg, + log_event=False) + + self.last_host_plugin_heartbeat = datetime.datetime.utcnow() + + def send_telemetry_heartbeat(self): + + if self.last_telemetry_heartbeat is None: + self.last_telemetry_heartbeat = datetime.datetime.utcnow() - MonitorHandler.TELEMETRY_HEARTBEAT_PERIOD + + if datetime.datetime.utcnow() >= (self.last_telemetry_heartbeat + MonitorHandler.TELEMETRY_HEARTBEAT_PERIOD): + try: + incarnation = self.protocol.get_incarnation() + dropped_packets = self.osutil.get_firewall_dropped_packets(self.protocol.endpoint) + msg = "{0};{1};{2};{3}".format(incarnation, self.counter, self.heartbeat_id, dropped_packets) add_event( name=AGENT_NAME, @@ -229,21 +314,17 @@ def daemon(self): message=msg, log_event=False) - counter += 1 + self.counter += 1 io_errors = IOErrorCounter.get_and_reset() hostplugin_errors = io_errors.get("hostplugin") protocol_errors = io_errors.get("protocol") other_errors = io_errors.get("other") - if hostplugin_errors > 0 \ - or protocol_errors > 0 \ - or other_errors > 0: - - msg = "hostplugin:{0};protocol:{1};other:{2}"\ - .format(hostplugin_errors, - protocol_errors, - other_errors) + if hostplugin_errors > 0 or protocol_errors > 0 or other_errors > 0: + msg = "hostplugin:{0};protocol:{1};other:{2}".format(hostplugin_errors, + protocol_errors, + other_errors) add_event( name=AGENT_NAME, version=CURRENT_VERSION, @@ -251,19 +332,7 @@ def daemon(self): is_success=True, message=msg, log_event=False) - - try: - self.collect_and_send_events() except Exception as e: - logger.warn("Failed to send events: {0}", e) - time.sleep(60) + logger.warn("Failed to send heartbeat: {0}", e) - def add_sysinfo(self, event): - sysinfo_names = [v.name for v in self.sysinfo] - for param in event.parameters: - if param.name in sysinfo_names: - logger.verbose("Remove existing event parameter: [{0}:{1}]", - param.name, - param.value) - event.parameters.remove(param) - event.parameters.extend(self.sysinfo) + self.last_telemetry_heartbeat = datetime.datetime.utcnow() diff --git a/azurelinuxagent/ga/update.py b/azurelinuxagent/ga/update.py index ec9329f440..7d0cee06ee 100644 --- a/azurelinuxagent/ga/update.py +++ b/azurelinuxagent/ga/update.py @@ -85,6 +85,7 @@ "ovf-env.xml" ] + def get_update_handler(): return UpdateHandler() @@ -190,7 +191,8 @@ def run_latest(self, child_args=None): version=agent_version, op=WALAEventOperation.Enable, is_success=True, - message=msg) + message=msg, + log_event=False) if ret is None: ret = self.child_process.wait() @@ -540,7 +542,7 @@ def _purge_agents(self): known_versions = [agent.version for agent in self.agents] if CURRENT_VERSION not in known_versions: - logger.info( + logger.verbose( u"Running Agent {0} was not found in the agent manifest - adding to list", CURRENT_VERSION) known_versions.append(CURRENT_VERSION) @@ -874,6 +876,8 @@ def _download(self): def _fetch(self, uri, headers=None, use_proxy=True): package = None try: + is_healthy = True + error_response = '' resp = restutil.http_get(uri, use_proxy=use_proxy, headers=headers) if restutil.request_succeeded(resp): package = resp.read() @@ -882,8 +886,13 @@ def _fetch(self, uri, headers=None, use_proxy=True): asbin=True) logger.verbose(u"Agent {0} downloaded from {1}", self.name, uri) else: - logger.verbose("Fetch was unsuccessful [{0}]", - restutil.read_response_error(resp)) + error_response = restutil.read_response_error(resp) + logger.verbose("Fetch was unsuccessful [{0}]", error_response) + is_healthy = not restutil.request_failed_at_hostplugin(resp) + + if self.host is not None: + self.host.report_fetch_health(uri, is_healthy, source='GuestAgent', response=error_response) + except restutil.HttpError as http_error: if isinstance(http_error, ResourceGoneError): raise diff --git a/tests/ga/test_monitor.py b/tests/ga/test_monitor.py index 59d066db03..5608396211 100644 --- a/tests/ga/test_monitor.py +++ b/tests/ga/test_monitor.py @@ -14,22 +14,26 @@ # # Requires Python 2.6+ and Openssl 1.0+ # +from datetime import timedelta from tests.tools import * from azurelinuxagent.ga.monitor import * +@patch('azurelinuxagent.common.event.EventLogger.add_event') +@patch('azurelinuxagent.common.osutil.get_osutil') +@patch('azurelinuxagent.common.protocol.get_protocol_util') +@patch('azurelinuxagent.common.protocol.util.ProtocolUtil.get_protocol') +@patch("azurelinuxagent.common.protocol.healthservice.HealthService._report") class TestMonitor(AgentTestCase): - def test_parse_xml_event(self): + def test_parse_xml_event(self, *args): data_str = load_data('ext/event.xml') event = parse_xml_event(data_str) self.assertNotEquals(None, event) self.assertNotEquals(0, event.parameters) self.assertNotEquals(None, event.parameters[0]) - @patch('azurelinuxagent.common.osutil.get_osutil') - @patch('azurelinuxagent.common.protocol.get_protocol_util') - def test_add_sysinfo(self, _, __): + def test_add_sysinfo(self, *args): data_str = load_data('ext/event.xml') event = parse_xml_event(data_str) monitor_handler = get_monitor_handler() @@ -76,3 +80,94 @@ def test_add_sysinfo(self, _, __): counter += 1 self.assertEquals(5, counter) + + @patch("azurelinuxagent.ga.monitor.MonitorHandler.send_telemetry_heartbeat") + @patch("azurelinuxagent.ga.monitor.MonitorHandler.collect_and_send_events") + @patch("azurelinuxagent.ga.monitor.MonitorHandler.send_host_plugin_heartbeat") + def test_heartbeats(self, patch_hostplugin_heartbeat, patch_send_events, patch_telemetry_heartbeat, *args): + monitor_handler = get_monitor_handler() + + self.assertEqual(0, patch_hostplugin_heartbeat.call_count) + self.assertEqual(0, patch_send_events.call_count) + self.assertEqual(0, patch_telemetry_heartbeat.call_count) + + monitor_handler.start() + time.sleep(1) + self.assertTrue(monitor_handler.is_alive()) + + self.assertNotEqual(0, patch_hostplugin_heartbeat.call_count) + self.assertNotEqual(0, patch_send_events.call_count) + self.assertNotEqual(0, patch_telemetry_heartbeat.call_count) + + monitor_handler.stop() + + def test_heartbeat_timings_updates_after_window(self, *args): + monitor_handler = get_monitor_handler() + + MonitorHandler.TELEMETRY_HEARTBEAT_PERIOD = timedelta(milliseconds=100) + MonitorHandler.EVENT_COLLECTION_PERIOD = timedelta(milliseconds=100) + MonitorHandler.HOST_PLUGIN_HEARTBEAT_PERIOD = timedelta(milliseconds=100) + + self.assertEqual(None, monitor_handler.last_host_plugin_heartbeat) + self.assertEqual(None, monitor_handler.last_event_collection) + self.assertEqual(None, monitor_handler.last_telemetry_heartbeat) + + monitor_handler.start() + time.sleep(0.2) + self.assertTrue(monitor_handler.is_alive()) + + self.assertNotEqual(None, monitor_handler.last_host_plugin_heartbeat) + self.assertNotEqual(None, monitor_handler.last_event_collection) + self.assertNotEqual(None, monitor_handler.last_telemetry_heartbeat) + + heartbeat_hostplugin = monitor_handler.last_host_plugin_heartbeat + heartbeat_telemetry = monitor_handler.last_telemetry_heartbeat + events_collection = monitor_handler.last_event_collection + + time.sleep(0.5) + + self.assertNotEqual(heartbeat_hostplugin, monitor_handler.last_host_plugin_heartbeat) + self.assertNotEqual(events_collection, monitor_handler.last_event_collection) + self.assertNotEqual(heartbeat_telemetry, monitor_handler.last_telemetry_heartbeat) + + monitor_handler.stop() + + def test_heartbeat_timings_no_updates_within_window(self, *args): + monitor_handler = get_monitor_handler() + + MonitorHandler.TELEMETRY_HEARTBEAT_PERIOD = timedelta(seconds=1) + MonitorHandler.EVENT_COLLECTION_PERIOD = timedelta(seconds=1) + MonitorHandler.HOST_PLUGIN_HEARTBEAT_PERIOD = timedelta(seconds=1) + + self.assertEqual(None, monitor_handler.last_host_plugin_heartbeat) + self.assertEqual(None, monitor_handler.last_event_collection) + self.assertEqual(None, monitor_handler.last_telemetry_heartbeat) + + monitor_handler.start() + time.sleep(0.2) + self.assertTrue(monitor_handler.is_alive()) + + self.assertNotEqual(None, monitor_handler.last_host_plugin_heartbeat) + self.assertNotEqual(None, monitor_handler.last_event_collection) + self.assertNotEqual(None, monitor_handler.last_telemetry_heartbeat) + + heartbeat_hostplugin = monitor_handler.last_host_plugin_heartbeat + heartbeat_telemetry = monitor_handler.last_telemetry_heartbeat + events_collection = monitor_handler.last_event_collection + + time.sleep(0.5) + + self.assertEqual(heartbeat_hostplugin, monitor_handler.last_host_plugin_heartbeat) + self.assertEqual(events_collection, monitor_handler.last_event_collection) + self.assertEqual(heartbeat_telemetry, monitor_handler.last_telemetry_heartbeat) + + monitor_handler.stop() + + @patch("azurelinuxagent.common.protocol.healthservice.HealthService.report_host_plugin_heartbeat") + def test_heartbeat_creates_signal(self, patch_report_heartbeat, *args): + monitor_handler = get_monitor_handler() + monitor_handler.init_protocols() + monitor_handler.last_host_plugin_heartbeat = datetime.datetime.utcnow() - timedelta(hours=1) + monitor_handler.send_host_plugin_heartbeat() + self.assertEqual(1, patch_report_heartbeat.call_count) + monitor_handler.stop() diff --git a/tests/ga/test_update.py b/tests/ga/test_update.py index 0a267965f4..d53bd88060 100644 --- a/tests/ga/test_update.py +++ b/tests/ga/test_update.py @@ -514,7 +514,8 @@ def test_download_fail(self, mock_http_get, mock_loaded, mock_downloaded): @patch("azurelinuxagent.ga.update.GuestAgent._ensure_downloaded") @patch("azurelinuxagent.ga.update.GuestAgent._ensure_loaded") @patch("azurelinuxagent.ga.update.restutil.http_get") - def test_download_fallback(self, mock_http_get, mock_loaded, mock_downloaded): + @patch("azurelinuxagent.ga.update.restutil.http_post") + def test_download_fallback(self, mock_http_post, mock_http_get, mock_loaded, mock_downloaded): self.remove_agents() self.assertFalse(os.path.isdir(self.agent_path)) diff --git a/tests/protocol/test_healthservice.py b/tests/protocol/test_healthservice.py new file mode 100644 index 0000000000..e9646c01a0 --- /dev/null +++ b/tests/protocol/test_healthservice.py @@ -0,0 +1,213 @@ +# Copyright 2018 Microsoft Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Requires Python 2.6+ and Openssl 1.0+ +import json + +from azurelinuxagent.common.exception import HttpError +from azurelinuxagent.common.protocol.healthservice import Observation, HealthService +from azurelinuxagent.common.utils import restutil +from tests.protocol.test_hostplugin import MockResponse +from tests.tools import * + + +class TestHealthService(AgentTestCase): + + def assert_status_code(self, status_code, expected_healthy): + response = MockResponse('response', status_code) + is_healthy = not restutil.request_failed_at_hostplugin(response) + self.assertEqual(expected_healthy, is_healthy) + + def assert_observation(self, call_args, name, is_healthy, value, description): + endpoint = call_args[0][0] + content = call_args[0][1] + + jo = json.loads(content) + api = jo['Api'] + source = jo['Source'] + version = jo['Version'] + obs = jo['Observations'] + fo = obs[0] + obs_name = fo['ObservationName'] + obs_healthy = fo['IsHealthy'] + obs_value = fo['Value'] + obs_description = fo['Description'] + + self.assertEqual('application/json', call_args[1]['headers']['Content-Type']) + self.assertEqual('http://endpoint:80/HealthService', endpoint) + self.assertEqual('reporttargethealth', api) + self.assertEqual('WALinuxAgent', source) + self.assertEqual('1.0', version) + + self.assertEqual(name, obs_name) + self.assertEqual(value, obs_value) + self.assertEqual(is_healthy, obs_healthy) + self.assertEqual(description, obs_description) + + def test_observation_validity(self): + try: + Observation(name=None, is_healthy=True) + self.fail('Empty observation name should raise ValueError') + except ValueError: + pass + + try: + Observation(name='Name', is_healthy=None) + self.fail('Empty measurement should raise ValueError') + except ValueError: + pass + + o = Observation(name='Name', is_healthy=True, value=None, description=None) + self.assertEqual('', o.value) + self.assertEqual('', o.description) + + long_str = 's' * 200 + o = Observation(name=long_str, is_healthy=True, value=long_str, description=long_str) + self.assertEqual(200, len(o.name)) + self.assertEqual(200, len(o.value)) + self.assertEqual(200, len(o.description)) + + self.assertEqual(64, len(o.as_obj['ObservationName'])) + self.assertEqual(128, len(o.as_obj['Value'])) + self.assertEqual(128, len(o.as_obj['Description'])) + + def test_observation_json(self): + health_service = HealthService('endpoint') + health_service.observations.append(Observation(name='name', + is_healthy=True, + value='value', + description='description')) + expected_json = '{"Source": "WALinuxAgent", ' \ + '"Api": "reporttargethealth", ' \ + '"Version": "1.0", ' \ + '"Observations": [{' \ + '"Value": "value", ' \ + '"ObservationName": "name", ' \ + '"Description": "description", ' \ + '"IsHealthy": true' \ + '}]}' + expected = sorted(json.loads(expected_json).items()) + actual = sorted(json.loads(health_service.as_json).items()) + self.assertEqual(expected, actual) + + @patch("azurelinuxagent.common.utils.restutil.http_post") + def test_reporting(self, patch_post): + health_service = HealthService('endpoint') + health_service.report_host_plugin_status(is_healthy=True, response='response') + self.assertEqual(1, patch_post.call_count) + self.assert_observation(call_args=patch_post.call_args, + name=HealthService.HOST_PLUGIN_STATUS_OBSERVATION_NAME, + is_healthy=True, + value='response', + description='') + self.assertEqual(0, len(health_service.observations)) + + health_service.report_host_plugin_status(is_healthy=False, response='error') + self.assertEqual(2, patch_post.call_count) + self.assert_observation(call_args=patch_post.call_args, + name=HealthService.HOST_PLUGIN_STATUS_OBSERVATION_NAME, + is_healthy=False, + value='error', + description='') + self.assertEqual(0, len(health_service.observations)) + + health_service.report_host_plugin_extension_artifact(is_healthy=True, source='source', response='response') + self.assertEqual(3, patch_post.call_count) + self.assert_observation(call_args=patch_post.call_args, + name=HealthService.HOST_PLUGIN_ARTIFACT_OBSERVATION_NAME, + is_healthy=True, + value='response', + description='source') + self.assertEqual(0, len(health_service.observations)) + + health_service.report_host_plugin_extension_artifact(is_healthy=False, source='source', response='response') + self.assertEqual(4, patch_post.call_count) + self.assert_observation(call_args=patch_post.call_args, + name=HealthService.HOST_PLUGIN_ARTIFACT_OBSERVATION_NAME, + is_healthy=False, + value='response', + description='source') + self.assertEqual(0, len(health_service.observations)) + + health_service.report_host_plugin_heartbeat(is_healthy=True) + self.assertEqual(5, patch_post.call_count) + self.assert_observation(call_args=patch_post.call_args, + name=HealthService.HOST_PLUGIN_HEARTBEAT_OBSERVATION_NAME, + is_healthy=True, + value='', + description='') + self.assertEqual(0, len(health_service.observations)) + + health_service.report_host_plugin_heartbeat(is_healthy=False) + self.assertEqual(6, patch_post.call_count) + self.assert_observation(call_args=patch_post.call_args, + name=HealthService.HOST_PLUGIN_HEARTBEAT_OBSERVATION_NAME, + is_healthy=False, + value='', + description='') + self.assertEqual(0, len(health_service.observations)) + + health_service.report_host_plugin_versions(is_healthy=True, response='response') + self.assertEqual(7, patch_post.call_count) + self.assert_observation(call_args=patch_post.call_args, + name=HealthService.HOST_PLUGIN_VERSIONS_OBSERVATION_NAME, + is_healthy=True, + value='response', + description='') + self.assertEqual(0, len(health_service.observations)) + + health_service.report_host_plugin_versions(is_healthy=False, response='response') + self.assertEqual(8, patch_post.call_count) + self.assert_observation(call_args=patch_post.call_args, + name=HealthService.HOST_PLUGIN_VERSIONS_OBSERVATION_NAME, + is_healthy=False, + value='response', + description='') + self.assertEqual(0, len(health_service.observations)) + + patch_post.side_effect = HttpError() + health_service.report_host_plugin_versions(is_healthy=True, response='') + + self.assertEqual(9, patch_post.call_count) + self.assertEqual(0, len(health_service.observations)) + + def test_observation_length(self): + health_service = HealthService('endpoint') + + # make 100 observations + for i in range(0, 100): + health_service._observe(is_healthy=True, name='{0}'.format(i)) + + # ensure we keep only 10 + self.assertEqual(10, len(health_service.observations)) + + # ensure we keep the most recent 10 + self.assertEqual('90', health_service.observations[0].name) + self.assertEqual('99', health_service.observations[9].name) + + def test_status_codes(self): + # healthy + self.assert_status_code(status_code=200, expected_healthy=True) + self.assert_status_code(status_code=201, expected_healthy=True) + self.assert_status_code(status_code=302, expected_healthy=True) + self.assert_status_code(status_code=416, expected_healthy=True) + self.assert_status_code(status_code=419, expected_healthy=True) + self.assert_status_code(status_code=429, expected_healthy=True) + self.assert_status_code(status_code=502, expected_healthy=True) + + # unhealthy + self.assert_status_code(status_code=500, expected_healthy=False) + self.assert_status_code(status_code=501, expected_healthy=False) + self.assert_status_code(status_code=503, expected_healthy=False) + self.assert_status_code(status_code=504, expected_healthy=False) diff --git a/tests/protocol/test_hostplugin.py b/tests/protocol/test_hostplugin.py index 32729d542c..ad4a308b4d 100644 --- a/tests/protocol/test_hostplugin.py +++ b/tests/protocol/test_hostplugin.py @@ -18,29 +18,31 @@ import base64 import json import sys - -from azurelinuxagent.common.future import ustr - -if sys.version_info[0] == 3: - import http.client as httpclient - bytebuffer = memoryview -elif sys.version_info[0] == 2: - import httplib as httpclient - bytebuffer = buffer +import datetime import azurelinuxagent.common.protocol.restapi as restapi import azurelinuxagent.common.protocol.wire as wire import azurelinuxagent.common.protocol.hostplugin as hostplugin +from azurelinuxagent.common.errorstate import ErrorState -from azurelinuxagent.common import event -from azurelinuxagent.common.exception import ProtocolError, HttpError +from azurelinuxagent.common.exception import HttpError from azurelinuxagent.common.protocol.hostplugin import API_VERSION from azurelinuxagent.common.utils import restutil - from tests.protocol.mockwiredata import WireProtocolData, DATA_FILE +from tests.protocol.test_wire import MockResponse from tests.tools import * +if sys.version_info[0] == 3: + import http.client as httpclient + bytebuffer = memoryview +elif sys.version_info[0] == 2: + import httplib as httpclient + bytebuffer = buffer + + hostplugin_status_url = "http://168.63.129.16:32526/status" +hostplugin_versions_url = "http://168.63.129.16:32526/versions" +health_service_url = 'http://168.63.129.16:80/HealthService' sas_url = "http://sas_url" wireserver_url = "168.63.129.16" @@ -55,15 +57,30 @@ if PY_VERSION_MAJOR > 2: faux_status_b64 = faux_status_b64.decode('utf-8') + class TestHostPlugin(AgentTestCase): + def _init_host(self): + test_goal_state = wire.GoalState(WireProtocolData(DATA_FILE).goal_state) + host_plugin = wire.HostPluginProtocol(wireserver_url, + test_goal_state.container_id, + test_goal_state.role_config_name) + self.assertTrue(host_plugin.health_service is not None) + return host_plugin + + def _init_status_blob(self): + wire_protocol_client = wire.WireProtocol(wireserver_url).client + status_blob = wire_protocol_client.status_blob + status_blob.data = faux_status + status_blob.vm_status = restapi.VMStatus(message="Ready", status="Ready") + return status_blob + def _compare_data(self, actual, expected): for k in iter(expected.keys()): if k == 'content' or k == 'requestUri': if actual[k] != expected[k]: - print("Mismatch: Actual '{0}'='{1}', " \ - "Expected '{0}'='{3}'".format( - k, actual[k], expected[k])) + print("Mismatch: Actual '{0}'='{1}', " + "Expected '{0}'='{2}'".format(k, actual[k], expected[k])) return False elif k == 'headers': for h in expected['headers']: @@ -93,7 +110,7 @@ def _hostplugin_data(self, blob_headers, content=None): s = s.decode('utf-8') data['content'] = s return data - + def _hostplugin_headers(self, goal_state): return { 'x-ms-version': '2015-09-01', @@ -101,7 +118,7 @@ def _hostplugin_headers(self, goal_state): 'x-ms-containerid': goal_state.container_id, 'x-ms-host-config-name': goal_state.role_config_name } - + def _validate_hostplugin_args(self, args, goal_state, exp_method, exp_url, exp_data): args, kwargs = args self.assertEqual(exp_method, args[0]) @@ -169,7 +186,8 @@ def test_fallback_failure(self): self.assertEqual(1, patch_upload.call_count) self.assertFalse(wire.HostPluginProtocol.is_default_channel()) - def test_put_status_error_reporting(self): + @patch("azurelinuxagent.common.event.add_event") + def test_put_status_error_reporting(self, patch_add_event): """ Validate the telemetry when uploading status fails """ @@ -186,24 +204,22 @@ def test_put_status_error_reporting(self): wire_protocol_client.ext_conf.status_upload_blob = sas_url wire_protocol_client.status_blob.set_vm_status(status) put_error = wire.HttpError("put status http error") - with patch.object(event, - "add_event") as patch_add_event: - with patch.object(restutil, - "http_put", - side_effect=put_error) as patch_http_put: - with patch.object(wire.HostPluginProtocol, - "ensure_initialized", return_value=True): - wire_protocol_client.upload_status_blob() - - # The agent tries to upload via HostPlugin and that fails due to - # http_put having a side effect of "put_error" - # - # The agent tries to upload using a direct connection, and that succeeds. - self.assertEqual(1, wire_protocol_client.status_blob.upload.call_count) - # The agent never touches the default protocol is this code path, so no change. - self.assertFalse(wire.HostPluginProtocol.is_default_channel()) - # The agent never logs a telemetry event for a bad HTTP call - self.assertEqual(patch_add_event.call_count, 0) + with patch.object(restutil, + "http_put", + side_effect=put_error) as patch_http_put: + with patch.object(wire.HostPluginProtocol, + "ensure_initialized", return_value=True): + wire_protocol_client.upload_status_blob() + + # The agent tries to upload via HostPlugin and that fails due to + # http_put having a side effect of "put_error" + # + # The agent tries to upload using a direct connection, and that succeeds. + self.assertEqual(1, wire_protocol_client.status_blob.upload.call_count) + # The agent never touches the default protocol is this code path, so no change. + self.assertFalse(wire.HostPluginProtocol.is_default_channel()) + # The agent never logs a telemetry event for a bad HTTP call + self.assertEqual(patch_add_event.call_count, 0) def test_validate_http_request(self): """Validate correct set of data is sent to HostGAPlugin when reporting VM status""" @@ -218,8 +234,8 @@ def test_validate_http_request(self): exp_method = 'PUT' exp_url = hostplugin_status_url exp_data = self._hostplugin_data( - status_blob.get_block_blob_headers(len(faux_status)), - bytearray(faux_status, encoding='utf-8')) + status_blob.get_block_blob_headers(len(faux_status)), + bytearray(faux_status, encoding='utf-8')) with patch.object(restutil, "http_request") as patch_http: patch_http.return_value = Mock(status=httpclient.OK) @@ -231,12 +247,18 @@ def test_validate_http_request(self): patch_api.return_value = API_VERSION plugin.put_vm_status(status_blob, sas_url, block_blob_type) - self.assertTrue(patch_http.call_count == 1) + self.assertTrue(patch_http.call_count == 2) + + # first call is to host plugin self._validate_hostplugin_args( patch_http.call_args_list[0], test_goal_state, exp_method, exp_url, exp_data) + # second call is to health service + self.assertEqual('POST', patch_http.call_args_list[1][0][0]) + self.assertEqual(health_service_url, patch_http.call_args_list[1][0][1]) + def test_no_fallback(self): """ Validate fallback to upload status using HostGAPlugin is not happening @@ -263,6 +285,7 @@ def test_validate_block_blob(self): test_goal_state.role_config_name) self.assertFalse(host_client.is_initialized) self.assertTrue(host_client.api_versions is None) + self.assertTrue(host_client.health_service is not None) status_blob = wire_protocol_client.status_blob status_blob.data = faux_status @@ -272,23 +295,29 @@ def test_validate_block_blob(self): exp_method = 'PUT' exp_url = hostplugin_status_url exp_data = self._hostplugin_data( - status_blob.get_block_blob_headers(len(faux_status)), - bytearray(faux_status, encoding='utf-8')) + status_blob.get_block_blob_headers(len(faux_status)), + bytearray(faux_status, encoding='utf-8')) with patch.object(restutil, "http_request") as patch_http: patch_http.return_value = Mock(status=httpclient.OK) with patch.object(wire.HostPluginProtocol, - "get_api_versions") as patch_get: + "get_api_versions") as patch_get: patch_get.return_value = api_versions host_client.put_vm_status(status_blob, sas_url) - self.assertTrue(patch_http.call_count == 1) + self.assertTrue(patch_http.call_count == 2) + + # first call is to host plugin self._validate_hostplugin_args( patch_http.call_args_list[0], test_goal_state, exp_method, exp_url, exp_data) - + + # second call is to health service + self.assertEqual('POST', patch_http.call_args_list[1][0][0]) + self.assertEqual(health_service_url, patch_http.call_args_list[1][0][1]) + def test_validate_page_blobs(self): """Validate correct set of data is sent for page blobs""" wire_protocol_client = wire.WireProtocol(wireserver_url).client @@ -317,29 +346,35 @@ def test_validate_page_blobs(self): mock_response = MockResponse('', httpclient.OK) with patch.object(restutil, "http_request", - return_value=mock_response) as patch_http: + return_value=mock_response) as patch_http: with patch.object(wire.HostPluginProtocol, - "get_api_versions") as patch_get: + "get_api_versions") as patch_get: patch_get.return_value = api_versions host_client.put_vm_status(status_blob, sas_url) - self.assertTrue(patch_http.call_count == 2) + self.assertTrue(patch_http.call_count == 3) + # first call is to host plugin exp_data = self._hostplugin_data( - status_blob.get_page_blob_create_headers( - page_size)) + status_blob.get_page_blob_create_headers( + page_size)) self._validate_hostplugin_args( patch_http.call_args_list[0], test_goal_state, exp_method, exp_url, exp_data) + # second call is to health service + self.assertEqual('POST', patch_http.call_args_list[1][0][0]) + self.assertEqual(health_service_url, patch_http.call_args_list[1][0][1]) + + # last call is to host plugin exp_data = self._hostplugin_data( - status_blob.get_page_blob_page_headers( - 0, page_size), - page) - exp_data['requestUri'] += "?comp=page" + status_blob.get_page_blob_page_headers( + 0, page_size), + page) + exp_data['requestUri'] += "?comp=page" self._validate_hostplugin_args( - patch_http.call_args_list[1], + patch_http.call_args_list[2], test_goal_state, exp_method, exp_url, exp_data) @@ -356,6 +391,7 @@ def test_validate_get_extension_artifacts(self): test_goal_state.role_config_name) self.assertFalse(host_client.is_initialized) self.assertTrue(host_client.api_versions is None) + self.assertTrue(host_client.health_service is not None) with patch.object(wire.HostPluginProtocol, "get_api_versions", return_value=api_versions) as patch_get: actual_url, actual_headers = host_client.get_artifact_request(sas_url) @@ -365,7 +401,230 @@ def test_validate_get_extension_artifacts(self): for k in expected_headers: self.assertTrue(k in actual_headers) self.assertEqual(expected_headers[k], actual_headers[k]) - + + @patch("azurelinuxagent.common.utils.restutil.http_get") + def test_health(self, patch_http_get): + host_plugin = self._init_host() + + patch_http_get.return_value = MockResponse('', 200) + result = host_plugin.get_health() + self.assertEqual(1, patch_http_get.call_count) + self.assertTrue(result) + + patch_http_get.return_value = MockResponse('', 500) + result = host_plugin.get_health() + self.assertFalse(result) + + @patch("azurelinuxagent.common.utils.restutil.http_get") + @patch("azurelinuxagent.common.protocol.healthservice.HealthService.report_host_plugin_versions") + def test_ensure_health_service_called(self, patch_http_get, patch_report_versions): + host_plugin = self._init_host() + + host_plugin.get_api_versions() + self.assertEqual(1, patch_http_get.call_count) + self.assertEqual(1, patch_report_versions.call_count) + + @patch("azurelinuxagent.common.utils.restutil.http_get") + @patch("azurelinuxagent.common.utils.restutil.http_post") + @patch("azurelinuxagent.common.utils.restutil.http_put") + def test_put_status_healthy_signal(self, patch_http_put, patch_http_post, patch_http_get): + host_plugin = self._init_host() + status_blob = self._init_status_blob() + # get_api_versions + patch_http_get.return_value = MockResponse(api_versions, 200) + # put status blob + patch_http_put.return_value = MockResponse(None, 201) + + host_plugin.put_vm_status(status_blob=status_blob, sas_url=sas_url) + self.assertEqual(1, patch_http_get.call_count) + self.assertEqual(hostplugin_versions_url, patch_http_get.call_args[0][0]) + + self.assertEqual(2, patch_http_put.call_count) + self.assertEqual(hostplugin_status_url, patch_http_put.call_args_list[0][0][0]) + self.assertEqual(hostplugin_status_url, patch_http_put.call_args_list[1][0][0]) + + self.assertEqual(2, patch_http_post.call_count) + + # signal for /versions + self.assertEqual(health_service_url, patch_http_post.call_args_list[0][0][0]) + jstr = patch_http_post.call_args_list[0][0][1] + obj = json.loads(jstr) + self.assertEqual(1, len(obj['Observations'])) + self.assertTrue(obj['Observations'][0]['IsHealthy']) + self.assertEqual('GuestAgentPluginVersions', obj['Observations'][0]['ObservationName']) + + # signal for /status + self.assertEqual(health_service_url, patch_http_post.call_args_list[1][0][0]) + jstr = patch_http_post.call_args_list[1][0][1] + obj = json.loads(jstr) + self.assertEqual(1, len(obj['Observations'])) + self.assertTrue(obj['Observations'][0]['IsHealthy']) + self.assertEqual('GuestAgentPluginStatus', obj['Observations'][0]['ObservationName']) + + @patch("azurelinuxagent.common.utils.restutil.http_get") + @patch("azurelinuxagent.common.utils.restutil.http_post") + @patch("azurelinuxagent.common.utils.restutil.http_put") + def test_put_status_unhealthy_signal_transient(self, patch_http_put, patch_http_post, patch_http_get): + host_plugin = self._init_host() + status_blob = self._init_status_blob() + # get_api_versions + patch_http_get.return_value = MockResponse(api_versions, 200) + # put status blob + patch_http_put.return_value = MockResponse(None, 500) + + if sys.version_info < (2, 7): + self.assertRaises(HttpError, host_plugin.put_vm_status, status_blob, sas_url) + else: + with self.assertRaises(HttpError): + host_plugin.put_vm_status(status_blob=status_blob, sas_url=sas_url) + + self.assertEqual(1, patch_http_get.call_count) + self.assertEqual(hostplugin_versions_url, patch_http_get.call_args[0][0]) + + self.assertEqual(1, patch_http_put.call_count) + self.assertEqual(hostplugin_status_url, patch_http_put.call_args[0][0]) + + self.assertEqual(2, patch_http_post.call_count) + + # signal for /versions + self.assertEqual(health_service_url, patch_http_post.call_args_list[0][0][0]) + jstr = patch_http_post.call_args_list[0][0][1] + obj = json.loads(jstr) + self.assertEqual(1, len(obj['Observations'])) + self.assertTrue(obj['Observations'][0]['IsHealthy']) + self.assertEqual('GuestAgentPluginVersions', obj['Observations'][0]['ObservationName']) + + # signal for /status + self.assertEqual(health_service_url, patch_http_post.call_args_list[1][0][0]) + jstr = patch_http_post.call_args_list[1][0][1] + obj = json.loads(jstr) + self.assertEqual(1, len(obj['Observations'])) + self.assertTrue(obj['Observations'][0]['IsHealthy']) + self.assertEqual('GuestAgentPluginStatus', obj['Observations'][0]['ObservationName']) + + @patch("azurelinuxagent.common.utils.restutil.http_get") + @patch("azurelinuxagent.common.utils.restutil.http_post") + @patch("azurelinuxagent.common.utils.restutil.http_put") + def test_put_status_unhealthy_signal_permanent(self, patch_http_put, patch_http_post, patch_http_get): + host_plugin = self._init_host() + status_blob = self._init_status_blob() + # get_api_versions + patch_http_get.return_value = MockResponse(api_versions, 200) + # put status blob + patch_http_put.return_value = MockResponse(None, 500) + + host_plugin.status_error_state.is_triggered = Mock(return_value=True) + + if sys.version_info < (2, 7): + self.assertRaises(HttpError, host_plugin.put_vm_status, status_blob, sas_url) + else: + with self.assertRaises(HttpError): + host_plugin.put_vm_status(status_blob=status_blob, sas_url=sas_url) + + self.assertEqual(1, patch_http_get.call_count) + self.assertEqual(hostplugin_versions_url, patch_http_get.call_args[0][0]) + + self.assertEqual(1, patch_http_put.call_count) + self.assertEqual(hostplugin_status_url, patch_http_put.call_args[0][0]) + + self.assertEqual(2, patch_http_post.call_count) + + # signal for /versions + self.assertEqual(health_service_url, patch_http_post.call_args_list[0][0][0]) + jstr = patch_http_post.call_args_list[0][0][1] + obj = json.loads(jstr) + self.assertEqual(1, len(obj['Observations'])) + self.assertTrue(obj['Observations'][0]['IsHealthy']) + self.assertEqual('GuestAgentPluginVersions', obj['Observations'][0]['ObservationName']) + + # signal for /status + self.assertEqual(health_service_url, patch_http_post.call_args_list[1][0][0]) + jstr = patch_http_post.call_args_list[1][0][1] + obj = json.loads(jstr) + self.assertEqual(1, len(obj['Observations'])) + self.assertFalse(obj['Observations'][0]['IsHealthy']) + self.assertEqual('GuestAgentPluginStatus', obj['Observations'][0]['ObservationName']) + + @patch("azurelinuxagent.common.protocol.hostplugin.HostPluginProtocol.should_report", return_value=True) + @patch("azurelinuxagent.common.protocol.healthservice.HealthService.report_host_plugin_extension_artifact") + def test_report_fetch_health(self, patch_report_artifact, patch_should_report): + host_plugin = self._init_host() + host_plugin.report_fetch_health(uri='', is_healthy=True) + self.assertEqual(0, patch_should_report.call_count) + + host_plugin.report_fetch_health(uri='http://169.254.169.254/extensionArtifact', is_healthy=True) + self.assertEqual(0, patch_should_report.call_count) + + host_plugin.report_fetch_health(uri='http://168.63.129.16:32526/status', is_healthy=True) + self.assertEqual(0, patch_should_report.call_count) + + self.assertEqual(None, host_plugin.fetch_last_timestamp) + host_plugin.report_fetch_health(uri='http://168.63.129.16:32526/extensionArtifact', is_healthy=True) + self.assertNotEqual(None, host_plugin.fetch_last_timestamp) + self.assertEqual(1, patch_should_report.call_count) + self.assertEqual(1, patch_report_artifact.call_count) + + @patch("azurelinuxagent.common.protocol.hostplugin.HostPluginProtocol.should_report", return_value=True) + @patch("azurelinuxagent.common.protocol.healthservice.HealthService.report_host_plugin_status") + def test_report_status_health(self, patch_report_status, patch_should_report): + host_plugin = self._init_host() + self.assertEqual(None, host_plugin.status_last_timestamp) + host_plugin.report_status_health(is_healthy=True) + self.assertNotEqual(None, host_plugin.status_last_timestamp) + self.assertEqual(1, patch_should_report.call_count) + self.assertEqual(1, patch_report_status.call_count) + + def test_should_report(self): + host_plugin = self._init_host() + error_state = ErrorState(min_timedelta=datetime.timedelta(minutes=5)) + period = datetime.timedelta(minutes=1) + last_timestamp = None + + # first measurement at 0s, should report + is_healthy = True + actual = host_plugin.should_report(is_healthy, + error_state, + last_timestamp, + period) + self.assertEqual(True, actual) + + # second measurement at 30s, should not report + last_timestamp = datetime.datetime.utcnow() - datetime.timedelta(seconds=30) + actual = host_plugin.should_report(is_healthy, + error_state, + last_timestamp, + period) + self.assertEqual(False, actual) + + # third measurement at 60s, should report + last_timestamp = datetime.datetime.utcnow() - datetime.timedelta(seconds=60) + actual = host_plugin.should_report(is_healthy, + error_state, + last_timestamp, + period) + self.assertEqual(True, actual) + + # fourth measurement unhealthy, should report and increment counter + is_healthy = False + self.assertEqual(0, error_state.count) + actual = host_plugin.should_report(is_healthy, + error_state, + last_timestamp, + period) + self.assertEqual(1, error_state.count) + self.assertEqual(True, actual) + + # fifth measurement, should not report and reset counter + is_healthy = True + last_timestamp = datetime.datetime.utcnow() - datetime.timedelta(seconds=30) + self.assertEqual(1, error_state.count) + actual = host_plugin.should_report(is_healthy, + error_state, + last_timestamp, + period) + self.assertEqual(0, error_state.count) + self.assertEqual(False, actual) + class MockResponse: def __init__(self, body, status_code): @@ -373,7 +632,8 @@ def __init__(self, body, status_code): self.status = status_code def read(self): - return self.body + return self.body if sys.version_info[0] == 2 else bytes(self.body, encoding='utf-8') + if __name__ == '__main__': unittest.main() diff --git a/tests/protocol/test_wire.py b/tests/protocol/test_wire.py index 4a69d105c2..91de36206d 100644 --- a/tests/protocol/test_wire.py +++ b/tests/protocol/test_wire.py @@ -28,13 +28,14 @@ @patch("time.sleep") @patch("azurelinuxagent.common.protocol.wire.CryptUtil") +@patch("azurelinuxagent.common.protocol.healthservice.HealthService._report") class TestWireProtocol(AgentTestCase): def setUp(self): super(TestWireProtocol, self).setUp() HostPluginProtocol.set_default_channel(False) - def _test_getters(self, test_data, MockCryptUtil, _): + def _test_getters(self, test_data, __, MockCryptUtil, _): MockCryptUtil.side_effect = test_data.mock_crypt_util with patch.object(restutil, 'http_get', test_data.mock_http_get): @@ -79,7 +80,8 @@ def test_getters_ext_no_public(self, *args): test_data = WireProtocolData(DATA_FILE_EXT_NO_PUBLIC) self._test_getters(test_data, *args) - def test_getters_with_stale_goal_state(self, *args): + @patch("azurelinuxagent.common.protocol.healthservice.HealthService.report_host_plugin_extension_artifact") + def test_getters_with_stale_goal_state(self, patch_report, *args): test_data = WireProtocolData(DATA_FILE) test_data.emulate_stale_goal_state = True @@ -92,10 +94,9 @@ def test_getters_with_stale_goal_state(self, *args): # fetched often; however, the dependent documents, such as the # HostingEnvironmentConfig, will be retrieved the expected number self.assertEqual(2, test_data.call_counts["hostingenvuri"]) + self.assertEqual(1, patch_report.call_count) - def test_call_storage_kwargs(self, - mock_cryptutil, - mock_sleep): + def test_call_storage_kwargs(self, *args): from azurelinuxagent.common.utils import restutil with patch.object(restutil, 'http_get') as http_patch: http_req = restutil.http_get @@ -155,29 +156,21 @@ def test_get_host_ga_plugin(self, *args): self.assertEqual(goal_state.role_config_name, host_plugin.role_config_name) self.assertEqual(1, patch_get_goal_state.call_count) - def test_download_ext_handler_pkg_fallback(self, *args): + @patch("azurelinuxagent.common.utils.restutil.http_request", side_effect=IOError) + @patch("azurelinuxagent.common.protocol.wire.WireClient.get_host_plugin") + @patch("azurelinuxagent.common.protocol.hostplugin.HostPluginProtocol.get_artifact_request") + def test_download_ext_handler_pkg_fallback(self, patch_request, patch_get_host, patch_http, *args): ext_uri = 'extension_uri' host_uri = 'host_uri' - mock_host = HostPluginProtocol(host_uri, 'container_id', 'role_config') - with patch.object(restutil, - "http_request", - side_effect=IOError) as patch_http: - with patch.object(WireClient, - "get_host_plugin", - return_value=mock_host): - with patch.object(HostPluginProtocol, - "get_artifact_request", - return_value=[host_uri, {}]) as patch_request: - - WireProtocol(wireserver_url).download_ext_handler_pkg(ext_uri) + patch_get_host.return_value = HostPluginProtocol(host_uri, 'container_id', 'role_config') + patch_request.return_value = [host_uri, {}] - self.assertEqual(patch_http.call_count, 2) - self.assertEqual(patch_request.call_count, 1) + WireProtocol(wireserver_url).download_ext_handler_pkg(ext_uri) - self.assertEqual(patch_http.call_args_list[0][0][1], - ext_uri) - self.assertEqual(patch_http.call_args_list[1][0][1], - host_uri) + self.assertEqual(patch_http.call_count, 2) + self.assertEqual(patch_request.call_count, 1) + self.assertEqual(patch_http.call_args_list[0][0][1], ext_uri) + self.assertEqual(patch_http.call_args_list[1][0][1], host_uri) def test_upload_status_blob_default(self, *args): """ @@ -228,7 +221,8 @@ def test_upload_status_blob_host_ga_plugin(self, *args): patch_http.assert_called_once_with(testurl, wire_protocol_client.status_blob) self.assertFalse(HostPluginProtocol.is_default_channel()) - def test_upload_status_blob_unknown_type_assumes_block(self, *args): + @patch("azurelinuxagent.common.protocol.hostplugin.HostPluginProtocol.ensure_initialized") + def test_upload_status_blob_unknown_type_assumes_block(self, _, *args): vmstatus = VMStatus(message="Ready", status="Ready") wire_protocol_client = WireProtocol(wireserver_url).client wire_protocol_client.ext_conf = ExtensionsConfig(None) @@ -302,7 +296,6 @@ def test_get_in_vm_artifacts_profile_response_body_not_valid(self, *args): host_plugin_get_artifact_url_and_headers.assert_called_with(testurl) - def test_get_in_vm_artifacts_profile_default(self, *args): wire_protocol_client = WireProtocol(wireserver_url).client wire_protocol_client.ext_conf = ExtensionsConfig(None) @@ -315,8 +308,7 @@ def test_get_in_vm_artifacts_profile_default(self, *args): self.assertEqual(dict(onHold='true'), in_vm_artifacts_profile.__dict__) self.assertTrue(in_vm_artifacts_profile.is_on_hold()) - @patch("time.sleep") - def test_fetch_manifest_fallback(self, patch_sleep, *args): + def test_fetch_manifest_fallback(self, *args): uri1 = ExtHandlerVersionUri() uri1.uri = 'ext_uri' uris = DataContractList(ExtHandlerVersionUri) diff --git a/tests/test_import.py b/tests/test_import.py index 39a48abd75..c5fd31c062 100644 --- a/tests/test_import.py +++ b/tests/test_import.py @@ -11,6 +11,7 @@ import azurelinuxagent.ga.monitor as monitor import azurelinuxagent.ga.update as update + class TestImportHandler(AgentTestCase): def test_get_handler(self): osutil.get_osutil() diff --git a/tests/utils/test_rest_util.py b/tests/utils/test_rest_util.py index adeb814185..05911f8229 100644 --- a/tests/utils/test_rest_util.py +++ b/tests/utils/test_rest_util.py @@ -387,16 +387,6 @@ def test_http_request_retries_for_safe_minimum_number_when_throttled(self, _http [call(1) for i in range(restutil.THROTTLE_RETRIES-1)], _sleep.call_args_list) - @patch("time.sleep") - @patch("azurelinuxagent.common.utils.restutil._http_request") - def test_http_request_raises_for_bad_request(self, _http_request, _sleep): - _http_request.side_effect = [ - Mock(status=httpclient.BAD_REQUEST) - ] - - self.assertRaises(ResourceGoneError, restutil.http_get, "https://foo.bar") - self.assertEqual(1, _http_request.call_count) - @patch("time.sleep") @patch("azurelinuxagent.common.utils.restutil._http_request") def test_http_request_raises_for_resource_gone(self, _http_request, _sleep):