diff --git a/azurelinuxagent/ga/exthandlers.py b/azurelinuxagent/ga/exthandlers.py index 0c85e6240f..ac95f1e1e7 100644 --- a/azurelinuxagent/ga/exthandlers.py +++ b/azurelinuxagent/ga/exthandlers.py @@ -75,6 +75,7 @@ AGENT_STATUS_FILE = "waagent_status.json" +NUMBER_OF_DOWNLOAD_RETRIES = 5 def get_traceback(e): if sys.version_info[0] == 3: @@ -450,6 +451,7 @@ def handle_enable(self, ext_handler_i): if handler_state == ExtHandlerState.NotInstalled: ext_handler_i.set_handler_state(ExtHandlerState.NotInstalled) ext_handler_i.download() + ext_handler_i.initialize() ext_handler_i.update_settings() if old_ext_handler_i is None: ext_handler_i.install() @@ -743,54 +745,84 @@ def report_event(self, message="", is_success=True, duration=0, log_event=True): add_event(name=self.ext_handler.name, version=ext_handler_version, message=message, op=self.operation, is_success=is_success, duration=duration, log_event=log_event) + def _download_extension_package(self, source_uri, target_file): + self.logger.info("Downloading extension package: {0}", source_uri) + try: + if not self.protocol.download_ext_handler_pkg(source_uri, target_file): + raise Exception("Failed to download extension package - no error information is available") + except Exception as exception: + self.logger.info("Error downloading extension package: {0}", ustr(exception)) + if os.path.exists(target_file): + os.remove(target_file) + return False + return True + + def _unzip_extension_package(self, source_file, target_directory): + self.logger.info("Unzipping extension package: {0}", source_file) + try: + zipfile.ZipFile(source_file).extractall(target_directory) + except Exception as exception: + logger.info("Error while unzipping extension package: {0}", ustr(exception)) + os.remove(source_file) + if os.path.exists(target_directory): + shutil.rmtree(target_directory) + return False + return True + def download(self): begin_utc = datetime.datetime.utcnow() - self.logger.verbose("Download extension package") self.set_operation(WALAEventOperation.Download) - if self.pkg is None: + if self.pkg is None or self.pkg.uris is None or len(self.pkg.uris) == 0: raise ExtensionDownloadError("No package uri found") - uris_shuffled = self.pkg.uris - random.shuffle(uris_shuffled) - file_downloaded = False + destination = os.path.join(conf.get_lib_dir(), os.path.basename(self.pkg.uris[0].uri) + ".zip") - for uri in uris_shuffled: - try: - destination = os.path.join(conf.get_lib_dir(), os.path.basename(uri.uri) + ".zip") + package_exists = False + if os.path.exists(destination): + self.logger.info("Using existing extension package: {0}", destination) + if self._unzip_extension_package(destination, self.get_base_dir()): + package_exists = True + else: + self.logger.info("The existing extension package is invalid, will ignore it.") - if os.path.exists(destination): - file_downloaded = True - self.pkg_file = destination - break - else: - file_downloaded = self.protocol.download_ext_handler_pkg(uri.uri, destination) + if not package_exists: + downloaded = False + i = 0 + while i < NUMBER_OF_DOWNLOAD_RETRIES: + uris_shuffled = self.pkg.uris + random.shuffle(uris_shuffled) + + for uri in uris_shuffled: + if not self._download_extension_package(uri.uri, destination): + continue - if file_downloaded and os.path.exists(destination): - self.pkg_file = destination + if self._unzip_extension_package(destination, self.get_base_dir()): + downloaded = True break - except Exception as e: - logger.warn("Error while downloading extension: {0}", ustr(e)) + if downloaded: + break - if not file_downloaded: - raise ExtensionDownloadError("Failed to download extension", code=1001) + self.logger.info("Failed to download the extension package from all uris, will retry after a minute") + time.sleep(60) + i += 1 - self.logger.verbose("Unzip extension package") - try: - zipfile.ZipFile(self.pkg_file).extractall(self.get_base_dir()) - except IOError as e: - fileutil.clean_ioerror(e, paths=[self.get_base_dir(), self.pkg_file]) - raise ExtensionError(u"Failed to unzip extension package", e, code=1001) + if not downloaded: + raise ExtensionDownloadError("Failed to download extension", code=1001) - # Add user execute permission to all files under the base dir - for file in fileutil.get_all_files(self.get_base_dir()): - fileutil.chmod(file, os.stat(file).st_mode | stat.S_IXUSR) + self.pkg_file = destination duration = elapsed_milliseconds(begin_utc) self.report_event(message="Download succeeded", duration=duration) + def initialize(self): self.logger.info("Initialize extension directory") + + # Add user execute permission to all files under the base dir + for file in fileutil.get_all_files(self.get_base_dir()): + fileutil.chmod(file, os.stat(file).st_mode | stat.S_IXUSR) + # Save HandlerManifest.json man_file = fileutil.search_file(self.get_base_dir(), 'HandlerManifest.json') diff --git a/tests/ga/test_exthandlers_download_extension.py b/tests/ga/test_exthandlers_download_extension.py new file mode 100644 index 0000000000..725649aee0 --- /dev/null +++ b/tests/ga/test_exthandlers_download_extension.py @@ -0,0 +1,210 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the Apache License. + +import zipfile, time + +from azurelinuxagent.common.protocol.restapi import ExtHandler, ExtHandlerProperties, ExtHandlerPackage, ExtHandlerVersionUri +from azurelinuxagent.common.protocol.wire import WireProtocol +from azurelinuxagent.ga.exthandlers import ExtHandlerInstance, NUMBER_OF_DOWNLOAD_RETRIES +from azurelinuxagent.common.exception import ExtensionDownloadError +from tests.tools import * + +class DownloadExtensionTestCase(AgentTestCase): + """ + Test cases for launch_command + """ + @classmethod + def setUpClass(cls): + AgentTestCase.setUpClass() + cls.mock_cgroups = patch("azurelinuxagent.ga.exthandlers.CGroups") + cls.mock_cgroups.start() + + cls.mock_cgroups_telemetry = patch("azurelinuxagent.ga.exthandlers.CGroupsTelemetry") + cls.mock_cgroups_telemetry.start() + + @classmethod + def tearDownClass(cls): + cls.mock_cgroups_telemetry.stop() + cls.mock_cgroups.stop() + + AgentTestCase.tearDownClass() + + def setUp(self): + AgentTestCase.setUp(self) + + ext_handler_properties = ExtHandlerProperties() + ext_handler_properties.version = "1.0.0" + ext_handler = ExtHandler(name='Microsoft.CPlat.Core.RunCommandLinux') + ext_handler.properties = ext_handler_properties + + protocol = WireProtocol("http://Microsoft.CPlat.Core.RunCommandLinux/foo-bar") + + self.pkg = ExtHandlerPackage() + self.pkg.uris = [ ExtHandlerVersionUri(), ExtHandlerVersionUri(), ExtHandlerVersionUri(), ExtHandlerVersionUri(), ExtHandlerVersionUri() ] + self.pkg.uris[0].uri = 'https://zrdfepirv2cy4prdstr00a.blob.core.windows.net/f72653efd9e349ed9842c8b99e4c1712-foobar/Microsoft.CPlat.Core__RunCommandLinux__1.0.0' + self.pkg.uris[1].uri = 'https://zrdfepirv2cy4prdstr01a.blob.core.windows.net/f72653efd9e349ed9842c8b99e4c1712-foobar/Microsoft.CPlat.Core__RunCommandLinux__1.0.0' + self.pkg.uris[2].uri = 'https://zrdfepirv2cy4prdstr02a.blob.core.windows.net/f72653efd9e349ed9842c8b99e4c1712-foobar/Microsoft.CPlat.Core__RunCommandLinux__1.0.0' + self.pkg.uris[3].uri = 'https://zrdfepirv2cy4prdstr03a.blob.core.windows.net/f72653efd9e349ed9842c8b99e4c1712-foobar/Microsoft.CPlat.Core__RunCommandLinux__1.0.0' + self.pkg.uris[4].uri = 'https://zrdfepirv2cy4prdstr04a.blob.core.windows.net/f72653efd9e349ed9842c8b99e4c1712-foobar/Microsoft.CPlat.Core__RunCommandLinux__1.0.0' + + self.ext_handler_instance = ExtHandlerInstance(ext_handler=ext_handler, protocol=protocol) + self.ext_handler_instance.pkg = self.pkg + + self.extension_dir = os.path.join(self.tmp_dir, "Microsoft.CPlat.Core.RunCommandLinux-1.0.0") + self.mock_get_base_dir = patch("azurelinuxagent.ga.exthandlers.ExtHandlerInstance.get_base_dir", return_value=self.extension_dir) + self.mock_get_base_dir.start() + + self.mock_get_log_dir = patch("azurelinuxagent.ga.exthandlers.ExtHandlerInstance.get_log_dir", return_value=self.tmp_dir) + self.mock_get_log_dir.start() + + self.agent_dir = self.tmp_dir + self.mock_get_lib_dir = patch("azurelinuxagent.ga.exthandlers.conf.get_lib_dir", return_value=self.agent_dir) + self.mock_get_lib_dir.start() + + def tearDown(self): + self.mock_get_lib_dir.stop() + self.mock_get_log_dir.stop() + self.mock_get_base_dir.stop() + + AgentTestCase.tearDown(self) + + _extension_command = "RunCommandLinux.sh" + + @staticmethod + def _create_zip_file(filename): + file = None + try: + file = zipfile.ZipFile(filename, "w") + info = zipfile.ZipInfo(DownloadExtensionTestCase._extension_command) + info.date_time = time.localtime(time.time())[:6] + info.compress_type = zipfile.ZIP_DEFLATED + file.writestr(info, "#!/bin/sh\necho 'RunCommandLinux executed successfully'\n") + finally: + if file is not None: + file.close() + + @staticmethod + def _create_invalid_zip_file(filename): + with open(filename, "w") as file: + file.write("An invalid ZIP file\n") + + def _get_extension_package_file(self): + return os.path.join(self.agent_dir, os.path.basename(self.pkg.uris[0].uri) + ".zip") + + def _get_extension_command_file(self): + return os.path.join(self.extension_dir, DownloadExtensionTestCase._extension_command) + + def _assert_download_and_expand_succeeded(self): + self.assertTrue(os.path.exists(self._get_extension_package_file()), "The extension package was not downloaded to the expected location") + self.assertTrue(os.path.exists(self._get_extension_command_file()), "The extension package was not expanded to the expected location") + + def test_it_should_download_and_expand_extension_package(self): + def download_ext_handler_pkg(_uri, destination): + DownloadExtensionTestCase._create_zip_file(destination) + return True + + with patch("azurelinuxagent.common.protocol.wire.WireProtocol.download_ext_handler_pkg", side_effect=download_ext_handler_pkg) as mock_download_ext_handler_pkg: + self.ext_handler_instance.download() + + # first download attempt should succeed + mock_download_ext_handler_pkg.assert_called_once() + + self._assert_download_and_expand_succeeded() + + def test_it_should_use_existing_extension_package_when_already_downloaded(self): + DownloadExtensionTestCase._create_zip_file(self._get_extension_package_file()) + + with patch("azurelinuxagent.common.protocol.wire.WireProtocol.download_ext_handler_pkg") as mock_download_ext_handler_pkg: + self.ext_handler_instance.download() + + mock_download_ext_handler_pkg.assert_not_called() + + self.assertTrue(os.path.exists(self._get_extension_command_file()), "The extension package was not expanded to the expected location") + + def test_it_should_ignore_existing_extension_package_when_it_is_invalid(self): + def download_ext_handler_pkg(_uri, destination): + DownloadExtensionTestCase._create_zip_file(destination) + return True + + DownloadExtensionTestCase._create_invalid_zip_file(self._get_extension_package_file()) + + with patch("azurelinuxagent.common.protocol.wire.WireProtocol.download_ext_handler_pkg", side_effect=download_ext_handler_pkg) as mock_download_ext_handler_pkg: + self.ext_handler_instance.download() + + mock_download_ext_handler_pkg.assert_called_once() + + self._assert_download_and_expand_succeeded() + + def test_it_should_use_alternate_uris_when_download_fails(self): + self.download_failures = 0 + + def download_ext_handler_pkg(_uri, destination): + # fail a few times, then succeed + if self.download_failures < 3: + self.download_failures += 1 + return False + DownloadExtensionTestCase._create_zip_file(destination) + return True + + with patch("azurelinuxagent.common.protocol.wire.WireProtocol.download_ext_handler_pkg", side_effect=download_ext_handler_pkg) as mock_download_ext_handler_pkg: + self.ext_handler_instance.download() + + self.assertEquals(mock_download_ext_handler_pkg.call_count, self.download_failures + 1) + + self._assert_download_and_expand_succeeded() + + def test_it_should_use_alternate_uris_when_download_raises_an_exception(self): + self.download_failures = 0 + + def download_ext_handler_pkg(_uri, destination): + # fail a few times, then succeed + if self.download_failures < 3: + self.download_failures += 1 + raise Exception("Download failed") + DownloadExtensionTestCase._create_zip_file(destination) + return True + + with patch("azurelinuxagent.common.protocol.wire.WireProtocol.download_ext_handler_pkg", side_effect=download_ext_handler_pkg) as mock_download_ext_handler_pkg: + self.ext_handler_instance.download() + + self.assertEquals(mock_download_ext_handler_pkg.call_count, self.download_failures + 1) + + self._assert_download_and_expand_succeeded() + + def test_it_should_use_alternate_uris_when_it_downloads_an_invalid_package(self): + self.download_failures = 0 + + def download_ext_handler_pkg(_uri, destination): + # fail a few times, then succeed + if self.download_failures < 3: + self.download_failures += 1 + DownloadExtensionTestCase._create_invalid_zip_file(destination) + else: + DownloadExtensionTestCase._create_zip_file(destination) + return True + + with patch("azurelinuxagent.common.protocol.wire.WireProtocol.download_ext_handler_pkg", side_effect=download_ext_handler_pkg) as mock_download_ext_handler_pkg: + self.ext_handler_instance.download() + + self.assertEquals(mock_download_ext_handler_pkg.call_count, self.download_failures + 1) + + self._assert_download_and_expand_succeeded() + + def test_it_should_raise_an_exception_when_all_downloads_fail(self): + def download_ext_handler_pkg(_uri, _destination): + DownloadExtensionTestCase._create_invalid_zip_file(self._get_extension_package_file()) + return True + + with patch("time.sleep", lambda *_: None): + with patch("azurelinuxagent.common.protocol.wire.WireProtocol.download_ext_handler_pkg", side_effect=download_ext_handler_pkg) as mock_download_ext_handler_pkg: + with self.assertRaises(ExtensionDownloadError) as context_manager: + self.ext_handler_instance.download() + + self.assertEquals(mock_download_ext_handler_pkg.call_count, NUMBER_OF_DOWNLOAD_RETRIES * len(self.pkg.uris)) + + self.assertRegex(str(context_manager.exception), "Failed to download extension") + self.assertEquals(context_manager.exception.code, 1001) + + self.assertFalse(os.path.exists(self.extension_dir), "The extension directory was not removed") + self.assertFalse(os.path.exists(self._get_extension_package_file()), "The extension package was not removed") +