-
Notifications
You must be signed in to change notification settings - Fork 372
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve resilience of extension downloads #1463
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -75,6 +75,7 @@ | |
|
||
AGENT_STATUS_FILE = "waagent_status.json" | ||
|
||
NUMBER_OF_DOWNLOAD_RETRIES = 5 | ||
|
||
def get_traceback(e): | ||
if sys.version_info[0] == 3: | ||
|
@@ -450,6 +451,7 @@ def handle_enable(self, ext_handler_i): | |
if handler_state == ExtHandlerState.NotInstalled: | ||
ext_handler_i.set_handler_state(ExtHandlerState.NotInstalled) | ||
ext_handler_i.download() | ||
ext_handler_i.initialize() | ||
ext_handler_i.update_settings() | ||
if old_ext_handler_i is None: | ||
ext_handler_i.install() | ||
|
@@ -743,54 +745,84 @@ def report_event(self, message="", is_success=True, duration=0, log_event=True): | |
add_event(name=self.ext_handler.name, version=ext_handler_version, message=message, | ||
op=self.operation, is_success=is_success, duration=duration, log_event=log_event) | ||
|
||
def _download_extension_package(self, source_uri, target_file): | ||
self.logger.info("Downloading extension package: {0}", source_uri) | ||
try: | ||
if not self.protocol.download_ext_handler_pkg(source_uri, target_file): | ||
raise Exception("Failed to download extension package - no error information is available") | ||
except Exception as exception: | ||
self.logger.info("Error downloading extension package: {0}", ustr(exception)) | ||
if os.path.exists(target_file): | ||
os.remove(target_file) | ||
return False | ||
return True | ||
|
||
def _unzip_extension_package(self, source_file, target_directory): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just a clarification - Would the issue of corrupt zip arise for Agent as well? and we need similar logic in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yup, and I intend to fix that as well. However, this is a hotfix release, so I do not want to add any other changes to it. Thanks for pointing it out. |
||
self.logger.info("Unzipping extension package: {0}", source_file) | ||
try: | ||
zipfile.ZipFile(source_file).extractall(target_directory) | ||
except Exception as exception: | ||
logger.info("Error while unzipping extension package: {0}", ustr(exception)) | ||
os.remove(source_file) | ||
if os.path.exists(target_directory): | ||
shutil.rmtree(target_directory) | ||
return False | ||
return True | ||
|
||
def download(self): | ||
begin_utc = datetime.datetime.utcnow() | ||
self.logger.verbose("Download extension package") | ||
self.set_operation(WALAEventOperation.Download) | ||
|
||
if self.pkg is None: | ||
if self.pkg is None or self.pkg.uris is None or len(self.pkg.uris) == 0: | ||
raise ExtensionDownloadError("No package uri found") | ||
|
||
uris_shuffled = self.pkg.uris | ||
random.shuffle(uris_shuffled) | ||
file_downloaded = False | ||
destination = os.path.join(conf.get_lib_dir(), os.path.basename(self.pkg.uris[0].uri) + ".zip") | ||
|
||
for uri in uris_shuffled: | ||
try: | ||
destination = os.path.join(conf.get_lib_dir(), os.path.basename(uri.uri) + ".zip") | ||
package_exists = False | ||
if os.path.exists(destination): | ||
self.logger.info("Using existing extension package: {0}", destination) | ||
if self._unzip_extension_package(destination, self.get_base_dir()): | ||
package_exists = True | ||
else: | ||
self.logger.info("The existing extension package is invalid, will ignore it.") | ||
|
||
if os.path.exists(destination): | ||
file_downloaded = True | ||
self.pkg_file = destination | ||
break | ||
else: | ||
file_downloaded = self.protocol.download_ext_handler_pkg(uri.uri, destination) | ||
if not package_exists: | ||
downloaded = False | ||
i = 0 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: a better name for this variable might be "retry_count"? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks. The loop is short and very simple; I think in this case the longer variable name wouldn't add much. |
||
while i < NUMBER_OF_DOWNLOAD_RETRIES: | ||
uris_shuffled = self.pkg.uris | ||
random.shuffle(uris_shuffled) | ||
|
||
for uri in uris_shuffled: | ||
if not self._download_extension_package(uri.uri, destination): | ||
continue | ||
|
||
if file_downloaded and os.path.exists(destination): | ||
self.pkg_file = destination | ||
if self._unzip_extension_package(destination, self.get_base_dir()): | ||
downloaded = True | ||
break | ||
|
||
except Exception as e: | ||
logger.warn("Error while downloading extension: {0}", ustr(e)) | ||
if downloaded: | ||
break | ||
|
||
if not file_downloaded: | ||
raise ExtensionDownloadError("Failed to download extension", code=1001) | ||
self.logger.info("Failed to download the extension package from all uris, will retry after a minute") | ||
time.sleep(60) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good call! 👍 |
||
i += 1 | ||
|
||
self.logger.verbose("Unzip extension package") | ||
try: | ||
zipfile.ZipFile(self.pkg_file).extractall(self.get_base_dir()) | ||
except IOError as e: | ||
fileutil.clean_ioerror(e, paths=[self.get_base_dir(), self.pkg_file]) | ||
raise ExtensionError(u"Failed to unzip extension package", e, code=1001) | ||
if not downloaded: | ||
raise ExtensionDownloadError("Failed to download extension", code=1001) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'll open a task. We need to document these magical numbers. (such as 1001). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok |
||
|
||
# Add user execute permission to all files under the base dir | ||
for file in fileutil.get_all_files(self.get_base_dir()): | ||
fileutil.chmod(file, os.stat(file).st_mode | stat.S_IXUSR) | ||
self.pkg_file = destination | ||
|
||
duration = elapsed_milliseconds(begin_utc) | ||
self.report_event(message="Download succeeded", duration=duration) | ||
|
||
def initialize(self): | ||
self.logger.info("Initialize extension directory") | ||
|
||
# Add user execute permission to all files under the base dir | ||
for file in fileutil.get_all_files(self.get_base_dir()): | ||
fileutil.chmod(file, os.stat(file).st_mode | stat.S_IXUSR) | ||
|
||
# Save HandlerManifest.json | ||
man_file = fileutil.search_file(self.get_base_dir(), 'HandlerManifest.json') | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,203 @@ | ||
# Copyright (c) Microsoft Corporation. All rights reserved. | ||
# Licensed under the Apache License. | ||
|
||
import zipfile, time | ||
|
||
from azurelinuxagent.common.protocol.restapi import ExtHandler, ExtHandlerProperties, ExtHandlerPackage, ExtHandlerVersionUri | ||
from azurelinuxagent.common.protocol.wire import WireProtocol | ||
from azurelinuxagent.ga.exthandlers import ExtHandlerInstance, NUMBER_OF_DOWNLOAD_RETRIES | ||
from azurelinuxagent.common.exception import ExtensionDownloadError | ||
from tests.tools import * | ||
|
||
class DownloadExtensionTestCase(AgentTestCase): | ||
""" | ||
Test cases for launch_command | ||
""" | ||
@classmethod | ||
def setUpClass(cls): | ||
AgentTestCase.setUpClass() | ||
cls.mock_cgroups = patch("azurelinuxagent.ga.exthandlers.CGroups") | ||
cls.mock_cgroups.start() | ||
|
||
cls.mock_cgroups_telemetry = patch("azurelinuxagent.ga.exthandlers.CGroupsTelemetry") | ||
cls.mock_cgroups_telemetry.start() | ||
|
||
@classmethod | ||
def tearDownClass(cls): | ||
cls.mock_cgroups_telemetry.stop() | ||
cls.mock_cgroups.stop() | ||
|
||
AgentTestCase.tearDownClass() | ||
|
||
def setUp(self): | ||
AgentTestCase.setUp(self) | ||
|
||
ext_handler_properties = ExtHandlerProperties() | ||
ext_handler_properties.version = "1.0.0" | ||
ext_handler = ExtHandler(name='Microsoft.CPlat.Core.RunCommandLinux') | ||
ext_handler.properties = ext_handler_properties | ||
|
||
protocol = WireProtocol("http://Microsoft.CPlat.Core.RunCommandLinux/foo-bar") | ||
|
||
self.pkg = ExtHandlerPackage() | ||
self.pkg.uris = [ ExtHandlerVersionUri(), ExtHandlerVersionUri(), ExtHandlerVersionUri(), ExtHandlerVersionUri(), ExtHandlerVersionUri() ] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If you change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ExtHandlerVersionUri is meant to be used to deserialize from xml, etc. The class has other members other than uri. I think it would not be worth it to add all of those parameters to init just for the sake of this test. |
||
self.pkg.uris[0].uri = 'https://zrdfepirv2cy4prdstr00a.blob.core.windows.net/f72653efd9e349ed9842c8b99e4c1712-foobar/Microsoft.CPlat.Core__RunCommandLinux__1.0.0' | ||
self.pkg.uris[1].uri = 'https://zrdfepirv2cy4prdstr01a.blob.core.windows.net/f72653efd9e349ed9842c8b99e4c1712-foobar/Microsoft.CPlat.Core__RunCommandLinux__1.0.0' | ||
self.pkg.uris[2].uri = 'https://zrdfepirv2cy4prdstr02a.blob.core.windows.net/f72653efd9e349ed9842c8b99e4c1712-foobar/Microsoft.CPlat.Core__RunCommandLinux__1.0.0' | ||
self.pkg.uris[3].uri = 'https://zrdfepirv2cy4prdstr03a.blob.core.windows.net/f72653efd9e349ed9842c8b99e4c1712-foobar/Microsoft.CPlat.Core__RunCommandLinux__1.0.0' | ||
self.pkg.uris[4].uri = 'https://zrdfepirv2cy4prdstr04a.blob.core.windows.net/f72653efd9e349ed9842c8b99e4c1712-foobar/Microsoft.CPlat.Core__RunCommandLinux__1.0.0' | ||
|
||
self.ext_handler_instance = ExtHandlerInstance(ext_handler=ext_handler, protocol=protocol) | ||
self.ext_handler_instance.pkg = self.pkg | ||
|
||
self.extension_dir = os.path.join(self.tmp_dir, "Microsoft.CPlat.Core.RunCommandLinux-1.0.0") | ||
self.mock_get_base_dir = patch("azurelinuxagent.ga.exthandlers.ExtHandlerInstance.get_base_dir", return_value=self.extension_dir) | ||
self.mock_get_base_dir.start() | ||
|
||
self.mock_get_log_dir = patch("azurelinuxagent.ga.exthandlers.ExtHandlerInstance.get_log_dir", return_value=self.tmp_dir) | ||
self.mock_get_log_dir.start() | ||
|
||
self.agent_dir = self.tmp_dir | ||
self.mock_get_lib_dir = patch("azurelinuxagent.ga.exthandlers.conf.get_lib_dir", return_value=self.agent_dir) | ||
self.mock_get_lib_dir.start() | ||
|
||
def tearDown(self): | ||
self.mock_get_lib_dir.stop() | ||
self.mock_get_log_dir.stop() | ||
self.mock_get_base_dir.stop() | ||
|
||
AgentTestCase.tearDown(self) | ||
|
||
_extension_command = "RunCommandLinux.sh" | ||
|
||
@staticmethod | ||
def _create_zip_file(filename): | ||
file = None | ||
try: | ||
file = zipfile.ZipFile(filename, "w") | ||
info = zipfile.ZipInfo(DownloadExtensionTestCase._extension_command) | ||
info.date_time = time.localtime(time.time())[:6] | ||
info.compress_type = zipfile.ZIP_DEFLATED | ||
file.writestr(info, "#!/bin/sh\necho 'RunCommandLinux executed successfully'\n") | ||
finally: | ||
if file is not None: | ||
file.close() | ||
|
||
@staticmethod | ||
def _create_invalid_zip_file(filename): | ||
with open(filename, "w") as file: | ||
file.write("An invalid ZIP file\n") | ||
|
||
def _get_extension_package_file(self): | ||
return os.path.join(self.agent_dir, os.path.basename(self.pkg.uris[0].uri) + ".zip") | ||
|
||
def _get_extension_command_file(self): | ||
return os.path.join(self.extension_dir, DownloadExtensionTestCase._extension_command) | ||
|
||
def _assert_download_and_expand_succeeded(self): | ||
self.assertTrue(os.path.exists(self._get_extension_package_file()), "The extension package was not downloaded to the expected location") | ||
self.assertTrue(os.path.exists(self._get_extension_command_file()), "The extension package was not expanded to the expected location") | ||
|
||
def test_it_should_download_and_expand_extension_package(self): | ||
def download_ext_handler_pkg(_uri, destination): | ||
DownloadExtensionTestCase._create_zip_file(destination) | ||
return True | ||
|
||
with patch("azurelinuxagent.common.protocol.wire.WireProtocol.download_ext_handler_pkg", side_effect=download_ext_handler_pkg) as mock_download_ext_handler_pkg: | ||
self.ext_handler_instance.download() | ||
|
||
# first download attempt should succeed | ||
mock_download_ext_handler_pkg.assert_called_once() | ||
|
||
self._assert_download_and_expand_succeeded() | ||
|
||
def test_it_should_use_existing_extension_package_when_already_downloaded(self): | ||
DownloadExtensionTestCase._create_zip_file(self._get_extension_package_file()) | ||
|
||
with patch("azurelinuxagent.common.protocol.wire.WireProtocol.download_ext_handler_pkg") as mock_download_ext_handler_pkg: | ||
self.ext_handler_instance.download() | ||
|
||
mock_download_ext_handler_pkg.assert_not_called() | ||
|
||
self.assertTrue(os.path.exists(self._get_extension_command_file()), "The extension package was not expanded to the expected location") | ||
|
||
def test_it_should_ignore_existing_extension_package_when_it_is_invalid(self): | ||
def download_ext_handler_pkg(_uri, destination): | ||
DownloadExtensionTestCase._create_zip_file(destination) | ||
return True | ||
|
||
DownloadExtensionTestCase._create_invalid_zip_file(self._get_extension_package_file()) | ||
|
||
with patch("azurelinuxagent.common.protocol.wire.WireProtocol.download_ext_handler_pkg", side_effect=download_ext_handler_pkg) as mock_download_ext_handler_pkg: | ||
self.ext_handler_instance.download() | ||
|
||
mock_download_ext_handler_pkg.assert_called_once() | ||
|
||
self._assert_download_and_expand_succeeded() | ||
|
||
def test_it_should_use_alternate_uris_when_download_fails(self): | ||
self.download_failures = 0 | ||
|
||
def download_ext_handler_pkg(_uri, destination): | ||
if self.download_failures < 3: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we use NUMBER_OF_DOWNLOAD_RETRIES - 2 instead of 3? It seems like an arbitrary number otherwise. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Or is this related to the number of uris you initialize in the test case set up method? In that case, I would also refer to len(self.pkg.uris). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is an arbitrary number :) - Read it as "fail a few times, then succeed". There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. added comment to code |
||
self.download_failures += 1 | ||
return False | ||
DownloadExtensionTestCase._create_zip_file(destination) | ||
return True | ||
|
||
with patch("azurelinuxagent.common.protocol.wire.WireProtocol.download_ext_handler_pkg", side_effect=download_ext_handler_pkg) as mock_download_ext_handler_pkg: | ||
self.ext_handler_instance.download() | ||
|
||
self.assertEquals(mock_download_ext_handler_pkg.call_count, self.download_failures + 1) | ||
|
||
self._assert_download_and_expand_succeeded() | ||
|
||
def test_it_should_use_alternate_uris_when_download_raises_an_exception(self): | ||
self.download_failures = 0 | ||
|
||
def download_ext_handler_pkg(_uri, destination): | ||
if self.download_failures < 3: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same as above. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. idem |
||
self.download_failures += 1 | ||
raise Exception("Download failed") | ||
DownloadExtensionTestCase._create_zip_file(destination) | ||
return True | ||
|
||
with patch("azurelinuxagent.common.protocol.wire.WireProtocol.download_ext_handler_pkg", side_effect=download_ext_handler_pkg) as mock_download_ext_handler_pkg: | ||
self.ext_handler_instance.download() | ||
|
||
self.assertEquals(mock_download_ext_handler_pkg.call_count, self.download_failures + 1) | ||
|
||
self._assert_download_and_expand_succeeded() | ||
|
||
def test_it_should_use_alternate_uris_when_it_downloads_an_invalid_package(self): | ||
self.download_failures = 0 | ||
|
||
def download_ext_handler_pkg(_uri, destination): | ||
if self.download_failures < 3: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same as above. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. idem |
||
self.download_failures += 1 | ||
DownloadExtensionTestCase._create_invalid_zip_file(destination) | ||
else: | ||
DownloadExtensionTestCase._create_zip_file(destination) | ||
return True | ||
|
||
with patch("azurelinuxagent.common.protocol.wire.WireProtocol.download_ext_handler_pkg", side_effect=download_ext_handler_pkg) as mock_download_ext_handler_pkg: | ||
self.ext_handler_instance.download() | ||
|
||
self.assertEquals(mock_download_ext_handler_pkg.call_count, self.download_failures + 1) | ||
|
||
self._assert_download_and_expand_succeeded() | ||
|
||
def test_it_should_raise_an_exception_when_all_downloads_fail(self): | ||
def download_ext_handler_pkg(_uri, destination): | ||
return False | ||
|
||
with patch("time.sleep", lambda *_: None): | ||
with patch("azurelinuxagent.common.protocol.wire.WireProtocol.download_ext_handler_pkg", side_effect=download_ext_handler_pkg) as mock_download_ext_handler_pkg: | ||
with self.assertRaises(ExtensionDownloadError) as context_manager: | ||
self.ext_handler_instance.download() | ||
|
||
self.assertEquals(mock_download_ext_handler_pkg.call_count, NUMBER_OF_DOWNLOAD_RETRIES * len(self.pkg.uris)) | ||
|
||
self.assertRegex(str(context_manager.exception), "Failed to download extension") | ||
self.assertEquals(context_manager.exception.code, 1001) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
fileutil.clean_ioerror
already does this partially. We should change (overload) this as util and use that in both the place (756, 766 and 768).There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That function seems to be meant to handle very specific errors and cannot handle this one in particular - it has a check for the exception type. Since that function is used in many places in the code, I decided to leave it alone for the moment.