diff --git a/azurelinuxagent/ga/agent_update.py b/azurelinuxagent/ga/agent_update.py index 3728e57ed8..85c9ce53a6 100644 --- a/azurelinuxagent/ga/agent_update.py +++ b/azurelinuxagent/ga/agent_update.py @@ -20,6 +20,27 @@ def get_agent_update_handler(protocol): return AgentUpdateHandler(protocol) +class AgentUpgradeType(object): + """ + Enum for different modes of Agent Upgrade + """ + Hotfix = "Hotfix" + Normal = "Normal" + + +class AgentUpdateHandlerUpdateState(object): + """ + This class is primarily used to maintain the in-memory persistent state for the agent updates. + This state will be persisted throughout the current service run. + """ + def __init__(self): + self.last_attempted_requested_version_update_time = datetime.datetime.min + self.last_attempted_hotfix_update_time = datetime.datetime.min + self.last_attempted_normal_update_time = datetime.datetime.min + self.last_warning = "" + self.last_warning_time = datetime.datetime.min + + class AgentUpdateHandler(object): def __init__(self, protocol): @@ -27,27 +48,73 @@ def __init__(self, protocol): self._ga_family = conf.get_autoupdate_gafamily() self._autoupdate_enabled = conf.get_autoupdate_enabled() self._gs_id = self._protocol.get_goal_state().extensions_goal_state.id - self._last_attempted_update_time = datetime.datetime.min - self._last_attempted_update_version = FlexibleVersion("0.0.0.0") - self._last_warning = "" - self._last_warning_time = datetime.datetime.min + self._is_requested_version_update = True # This is to track the current update type(requested version or self update) + self.persistent_data = AgentUpdateHandlerUpdateState() def __should_update_agent(self, requested_version): """ - check to see if update is allowed once per (as specified in the conf.get_autoupdate_frequency()) - return false when we don't allow updates. + requested version update: + update is allowed once per (as specified in the conf.get_autoupdate_frequency()) + return false when we don't allow updates. + largest version update(self-update): + update is allowed once per (as specified in the conf.get_hotfix_upgrade_frequency() or conf.get_normal_upgrade_frequency()) + return false when we don't allow updates. """ now = datetime.datetime.now() - if self._last_attempted_update_time != datetime.datetime.min and self._last_attempted_update_version == requested_version: - next_attempt_time = self._last_attempted_update_time + datetime.timedelta(seconds=conf.get_autoupdate_frequency()) + if self._is_requested_version_update: + if self.persistent_data.last_attempted_requested_version_update_time != datetime.datetime.min: + next_attempt_time = self.persistent_data.last_attempted_requested_version_update_time + datetime.timedelta(seconds=conf.get_autoupdate_frequency()) + else: + next_attempt_time = now + + if next_attempt_time > now: + return False + # The time limit elapsed for us to allow updates. + return True else: - next_attempt_time = now + next_hotfix_time, next_normal_time = self.__get_next_upgrade_times(now) + upgrade_type = self.__get_agent_upgrade_type(requested_version) - if next_attempt_time > now: + if next_hotfix_time > now and next_normal_time > now: + return False + + if (upgrade_type == AgentUpgradeType.Hotfix and next_hotfix_time <= now) or ( + upgrade_type == AgentUpgradeType.Normal and next_normal_time <= now): + return True return False - # The time limit elapsed for us to allow updates. - return True + + def __update_last_attempt_update_times(self): + now = datetime.datetime.now() + if self._is_requested_version_update: + self.persistent_data.last_attempted_requested_version_update_time = now + else: + self.persistent_data.last_attempted_normal_update_time = now + self.persistent_data.last_attempted_hotfix_update_time = now + + @staticmethod + def __get_agent_upgrade_type(requested_version): + # We follow semantic versioning for the agent, if .. is same, then has changed. + # In this case, we consider it as a Hotfix upgrade. Else we consider it a Normal upgrade. + if requested_version.major == CURRENT_VERSION.major and requested_version.minor == CURRENT_VERSION.minor and requested_version.patch == CURRENT_VERSION.patch: + return AgentUpgradeType.Hotfix + return AgentUpgradeType.Normal + + def __get_next_upgrade_times(self, now): + """ + Get the next upgrade times + return: Next Hotfix Upgrade Time, Next Normal Upgrade Time + """ + + def get_next_process_time(last_val, frequency): + return now if last_val == datetime.datetime.min else last_val + datetime.timedelta(seconds=frequency) + + next_hotfix_time = get_next_process_time(self.persistent_data.last_attempted_hotfix_update_time, + conf.get_hotfix_upgrade_frequency()) + next_normal_time = get_next_process_time(self.persistent_data.last_attempted_normal_update_time, + conf.get_normal_upgrade_frequency()) + + return next_hotfix_time, next_normal_time def __get_agent_family_from_last_gs(self, goal_state): """ @@ -180,16 +247,17 @@ def __log_event(self, level, msg_, success_=True): msg_ += "[NOTE: Will not log the same error for the next 6 hours]" # Incarnation may change if we get new goal state that would make whole string unique every time. So comparing only the substring until Incarnation if Incarnation included in msg # Example msg "Unable to update Agent: No manifest links found for agent family: Prod for incarnation: incarnation_1, skipping agent update" + now = datetime.datetime.now() prefix_msg = msg_.split("incarnation", 1)[0] - prefix_last_warning_msg = self._last_warning.split("incarnation", 1)[0] - if prefix_msg != prefix_last_warning_msg or self._last_warning_time == datetime.datetime.min or datetime.datetime.now() >= self._last_warning_time + datetime.timedelta(hours=6): + prefix_last_warning_msg = self.persistent_data.last_warning.split("incarnation", 1)[0] + if prefix_msg != prefix_last_warning_msg or self.persistent_data.last_warning_time == datetime.datetime.min or now >= self.persistent_data.last_warning_time + datetime.timedelta(hours=6): if level == LogLevel.WARNING: logger.warn(msg_) elif level == LogLevel.ERROR: logger.error(msg_) add_event(op=WALAEventOperation.AgentUpgrade, is_success=success_, message=msg_, log_event=False) - self._last_warning_time = datetime.datetime.now() - self._last_warning = msg_ + self.persistent_data.last_warning_time = now + self.persistent_data.last_warning = msg_ def run(self, goal_state): try: @@ -209,7 +277,9 @@ def run(self, goal_state): GAUpdateReportState.report_error_msg = warn_msg agent_manifest = goal_state.fetch_agent_manifest(agent_family.name, agent_family.uris) requested_version = self.__get_largest_version(agent_manifest) + self._is_requested_version_update = False else: + self._is_requested_version_update = True # Save the requested version to report back GAUpdateReportState.report_expected_version = requested_version # Remove the missing requested version warning once requested version becomes available @@ -241,8 +311,7 @@ def run(self, goal_state): self.__proceed_with_update(requested_version) finally: - self._last_attempted_update_time = datetime.datetime.now() - self._last_attempted_update_version = requested_version + self.__update_last_attempt_update_times() except Exception as err: if isinstance(err, AgentUpgradeExitException): diff --git a/tests/data/wire/ga_manifest_no_uris.xml b/tests/data/wire/ga_manifest_no_uris.xml new file mode 100644 index 0000000000..89573ad63b --- /dev/null +++ b/tests/data/wire/ga_manifest_no_uris.xml @@ -0,0 +1,39 @@ + + + + + 1.0.0 + + http://mock-goal-state/ga-manifests/OSTCExtensions.WALinuxAgent__1.0.0 + + + + 1.1.0 + + http://mock-goal-state/ga-manifests/OSTCExtensions.WALinuxAgent__1.1.0 + + + + 1.2.0 + + http://mock-goal-state/ga-manifests/OSTCExtensions.WALinuxAgent__1.2.0 + + + + 2.0.0http://mock-goal-state/ga-manifests/OSTCExtensions.WALinuxAgent__2.0.0 + + + 2.1.0http://mock-goal-state/ga-manifests/OSTCExtensions.WALinuxAgent__2.1.0 + + + 9.9.9.10 + + http://mock-goal-state/ga-manifests/OSTCExtensions.WALinuxAgent__99999.0.0.0 + + + + 99999.0.0.0 + + + + diff --git a/tests/ga/test_agent_update.py b/tests/ga/test_agent_update.py index 73339d7c3c..dbdf8dab5a 100644 --- a/tests/ga/test_agent_update.py +++ b/tests/ga/test_agent_update.py @@ -108,6 +108,41 @@ def test_it_should_update_to_largest_version_if_ga_versioning_disabled(self): self.__assert_agent_directories_exist_and_others_dont_exist(versions=[str(CURRENT_VERSION), "99999.0.0.0"]) self.assertIn("Agent update found, Exiting current process", ustr(context.exception.reason)) + def test_it_should_update_to_largest_version_if_time_window_not_elapsed(self): + self.prepare_agents(count=1) + + data_file = DATA_FILE.copy() + data_file["ga_manifest"] = "wire/ga_manifest_no_uris.xml" + with self.__get_agent_update_handler(test_data=data_file) as (agent_update_handler, _): + agent_update_handler.run(agent_update_handler._protocol.get_goal_state()) + self.assertFalse(os.path.exists(self.agent_dir("99999.0.0.0")), + "New agent directory should not be found") + agent_update_handler._protocol.mock_wire_data.set_ga_manifest("wire/ga_manifest.xml") + agent_update_handler._protocol.mock_wire_data.set_incarnation(2) + agent_update_handler._protocol.client.update_goal_state() + agent_update_handler.run(agent_update_handler._protocol.get_goal_state()) + self.assertFalse(os.path.exists(self.agent_dir("99999.0.0.0")), + "New agent directory should not be found") + + def test_it_should_update_to_largest_version_if_time_window_elapsed(self): + self.prepare_agents(count=1) + + data_file = DATA_FILE.copy() + data_file["ga_manifest"] = "wire/ga_manifest_no_uris.xml" + with patch("azurelinuxagent.common.conf.get_hotfix_upgrade_frequency", return_value=0.001): + with patch("azurelinuxagent.common.conf.get_normal_upgrade_frequency", return_value=0.001): + with self.__get_agent_update_handler(test_data=data_file) as (agent_update_handler, mock_telemetry): + with self.assertRaises(AgentUpgradeExitException) as context: + agent_update_handler.run(agent_update_handler._protocol.get_goal_state()) + self.assertFalse(os.path.exists(self.agent_dir("99999.0.0.0")), + "New agent directory should not be found") + agent_update_handler._protocol.mock_wire_data.set_ga_manifest("wire/ga_manifest.xml") + agent_update_handler._protocol.mock_wire_data.set_incarnation(2) + agent_update_handler._protocol.client.update_goal_state() + agent_update_handler.run(agent_update_handler._protocol.get_goal_state()) + self.__assert_agent_requested_version_in_goal_state(mock_telemetry, inc=2, version="99999.0.0.0") + self.__assert_agent_directories_exist_and_others_dont_exist(versions=[str(CURRENT_VERSION), "99999.0.0.0"]) + self.assertIn("Agent update found, Exiting current process", ustr(context.exception.reason)) def test_it_should_not_agent_update_if_last_attempted_update_time_not_elapsed(self): self.prepare_agents(count=1) diff --git a/tests/ga/test_update.py b/tests/ga/test_update.py index 22ec8dbef3..42df9f6472 100644 --- a/tests/ga/test_update.py +++ b/tests/ga/test_update.py @@ -1428,16 +1428,18 @@ def test_run_emits_restart_event(self): class TestAgentUpgrade(UpdateTestCase): @contextlib.contextmanager - def create_conf_mocks(self, autoupdate_frequency): + def create_conf_mocks(self, autoupdate_frequency, hotfix_frequency, normal_frequency): # Disabling extension processing to speed up tests as this class deals with testing agent upgrades with patch("azurelinuxagent.common.conf.get_extensions_enabled", return_value=False): with patch("azurelinuxagent.common.conf.get_autoupdate_frequency", return_value=autoupdate_frequency): - with patch("azurelinuxagent.common.conf.get_autoupdate_gafamily", return_value="Prod"): - yield + with patch("azurelinuxagent.common.conf.get_hotfix_upgrade_frequency", return_value=hotfix_frequency): + with patch("azurelinuxagent.common.conf.get_normal_upgrade_frequency", return_value=normal_frequency): + with patch("azurelinuxagent.common.conf.get_autoupdate_gafamily", return_value="Prod"): + yield @contextlib.contextmanager def __get_update_handler(self, iterations=1, test_data=None, - reload_conf=None, autoupdate_frequency=0.001): + reload_conf=None, autoupdate_frequency=0.001, hotfix_frequency=1.0, normal_frequency=2.0): test_data = DATA_FILE if test_data is None else test_data @@ -1463,7 +1465,7 @@ def put_handler(url, *args, **_): return MockHttpResponse(status=201) protocol.set_http_handlers(http_get_handler=get_handler, http_put_handler=put_handler) - with self.create_conf_mocks(autoupdate_frequency): + with self.create_conf_mocks(autoupdate_frequency, hotfix_frequency, normal_frequency): with patch("azurelinuxagent.common.event.EventLogger.add_event") as mock_telemetry: update_handler._protocol = protocol yield update_handler, mock_telemetry @@ -1680,6 +1682,85 @@ def reload_conf(url, protocol): self.__assert_upgrade_telemetry_emitted(mock_telemetry, version="99999.0.0.0") self.__assert_agent_directories_exist_and_others_dont_exist(versions=["99999.0.0.0", str(CURRENT_VERSION)]) + def test_it_should_not_update_largest_version_if_time_window_not_elapsed(self): + no_of_iterations = 20 + + # Set the test environment by adding 20 random agents to the agent directory + self.prepare_agents() + self.assertEqual(20, self.agent_count(), "Agent directories not set properly") + + def reload_conf(url, protocol): + mock_wire_data = protocol.mock_wire_data + + # This function reloads the conf mid-run to mimic an actual customer scenario + if HttpRequestPredicates.is_goal_state_request(url) and mock_wire_data.call_counts[ + "goalstate"] >= 5: + reload_conf.call_count += 1 + + self.__assert_agent_directories_available(versions=[str(CURRENT_VERSION)]) + + # Update the ga_manifest and incarnation to send largest version manifest + mock_wire_data.data_files["ga_manifest"] = "wire/ga_manifest.xml" + mock_wire_data.reload() + self._add_write_permission_to_goal_state_files() + reload_conf.incarnation += 1 + mock_wire_data.set_incarnation(reload_conf.incarnation) + + reload_conf.call_count = 0 + reload_conf.incarnation = 2 + + data_file = mockwiredata.DATA_FILE.copy() + # This is to fail the agent update at first attempt so that agent doesn't go through update + data_file["ga_manifest"] = "wire/ga_manifest_no_uris.xml" + with self.__get_update_handler(iterations=no_of_iterations, test_data=data_file, reload_conf=reload_conf, + hotfix_frequency=10, normal_frequency=10) as (update_handler, _): + update_handler._protocol.mock_wire_data.set_incarnation(2) + update_handler.run(debug=True) + + self.assertGreater(reload_conf.call_count, 0, "Reload conf not updated") + self.__assert_exit_code_successful(update_handler) + self.assertFalse(os.path.exists(self.agent_dir("99999.0.0.0")), + "New agent directory should not be found") + + def test_it_should_update_largest_version_if_time_window_elapsed(self): + no_of_iterations = 20 + + # Set the test environment by adding 20 random agents to the agent directory + self.prepare_agents() + self.assertEqual(20, self.agent_count(), "Agent directories not set properly") + + def reload_conf(url, protocol): + mock_wire_data = protocol.mock_wire_data + + # This function reloads the conf mid-run to mimic an actual customer scenario + if HttpRequestPredicates.is_goal_state_request(url) and mock_wire_data.call_counts[ + "goalstate"] >= 5: + reload_conf.call_count += 1 + + self.__assert_agent_directories_available(versions=[str(CURRENT_VERSION)]) + + # Update the ga_manifest and incarnation to send largest version manifest + mock_wire_data.data_files["ga_manifest"] = "wire/ga_manifest.xml" + mock_wire_data.reload() + self._add_write_permission_to_goal_state_files() + reload_conf.incarnation += 1 + mock_wire_data.set_incarnation(reload_conf.incarnation) + + reload_conf.call_count = 0 + reload_conf.incarnation = 2 + + data_file = mockwiredata.DATA_FILE.copy() + data_file["ga_manifest"] = "wire/ga_manifest_no_uris.xml" + with self.__get_update_handler(iterations=no_of_iterations, test_data=data_file, reload_conf=reload_conf, + hotfix_frequency=0.001, normal_frequency=0.001) as (update_handler, mock_telemetry): + update_handler._protocol.mock_wire_data.set_incarnation(2) + update_handler.run(debug=True) + + self.assertGreater(reload_conf.call_count, 0, "Reload conf not updated") + self.__assert_exit_code_successful(update_handler) + self.__assert_upgrade_telemetry_emitted(mock_telemetry, version="99999.0.0.0") + self.__assert_agent_directories_exist_and_others_dont_exist(versions=["99999.0.0.0", str(CURRENT_VERSION)]) + def test_it_should_not_download_anything_if_requested_version_is_current_version(self): data_file = mockwiredata.DATA_FILE.copy() data_file["ext_conf"] = "wire/ext_conf_requested_version.xml" diff --git a/tests/protocol/mockwiredata.py b/tests/protocol/mockwiredata.py index 936533e97b..c3beabf566 100644 --- a/tests/protocol/mockwiredata.py +++ b/tests/protocol/mockwiredata.py @@ -460,6 +460,9 @@ def set_manifest_version(self, version): def set_extension_config(self, ext_conf_file): self.ext_conf = load_data(ext_conf_file) + def set_ga_manifest(self, ga_manifest): + self.ga_manifest = load_data(ga_manifest) + def set_extension_config_requested_version(self, version): self.ext_conf = WireProtocolData.replace_xml_element_value(self.ext_conf, "Version", version)