Skip to content

Commit

Permalink
Improve resilience of extension downloads (#1463)
Browse files Browse the repository at this point in the history
* Improve resilience of extension downloads

* Fix log messages

* Test improvements
  • Loading branch information
narrieta authored Feb 14, 2019
1 parent 743e0d1 commit 0f7f375
Show file tree
Hide file tree
Showing 2 changed files with 271 additions and 29 deletions.
90 changes: 61 additions & 29 deletions azurelinuxagent/ga/exthandlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@

AGENT_STATUS_FILE = "waagent_status.json"

NUMBER_OF_DOWNLOAD_RETRIES = 5

def get_traceback(e):
if sys.version_info[0] == 3:
Expand Down Expand Up @@ -450,6 +451,7 @@ def handle_enable(self, ext_handler_i):
if handler_state == ExtHandlerState.NotInstalled:
ext_handler_i.set_handler_state(ExtHandlerState.NotInstalled)
ext_handler_i.download()
ext_handler_i.initialize()
ext_handler_i.update_settings()
if old_ext_handler_i is None:
ext_handler_i.install()
Expand Down Expand Up @@ -743,54 +745,84 @@ def report_event(self, message="", is_success=True, duration=0, log_event=True):
add_event(name=self.ext_handler.name, version=ext_handler_version, message=message,
op=self.operation, is_success=is_success, duration=duration, log_event=log_event)

def _download_extension_package(self, source_uri, target_file):
self.logger.info("Downloading extension package: {0}", source_uri)
try:
if not self.protocol.download_ext_handler_pkg(source_uri, target_file):
raise Exception("Failed to download extension package - no error information is available")
except Exception as exception:
self.logger.info("Error downloading extension package: {0}", ustr(exception))
if os.path.exists(target_file):
os.remove(target_file)
return False
return True

def _unzip_extension_package(self, source_file, target_directory):
self.logger.info("Unzipping extension package: {0}", source_file)
try:
zipfile.ZipFile(source_file).extractall(target_directory)
except Exception as exception:
logger.info("Error while unzipping extension package: {0}", ustr(exception))
os.remove(source_file)
if os.path.exists(target_directory):
shutil.rmtree(target_directory)
return False
return True

def download(self):
begin_utc = datetime.datetime.utcnow()
self.logger.verbose("Download extension package")
self.set_operation(WALAEventOperation.Download)

if self.pkg is None:
if self.pkg is None or self.pkg.uris is None or len(self.pkg.uris) == 0:
raise ExtensionDownloadError("No package uri found")

uris_shuffled = self.pkg.uris
random.shuffle(uris_shuffled)
file_downloaded = False
destination = os.path.join(conf.get_lib_dir(), os.path.basename(self.pkg.uris[0].uri) + ".zip")

for uri in uris_shuffled:
try:
destination = os.path.join(conf.get_lib_dir(), os.path.basename(uri.uri) + ".zip")
package_exists = False
if os.path.exists(destination):
self.logger.info("Using existing extension package: {0}", destination)
if self._unzip_extension_package(destination, self.get_base_dir()):
package_exists = True
else:
self.logger.info("The existing extension package is invalid, will ignore it.")

if os.path.exists(destination):
file_downloaded = True
self.pkg_file = destination
break
else:
file_downloaded = self.protocol.download_ext_handler_pkg(uri.uri, destination)
if not package_exists:
downloaded = False
i = 0
while i < NUMBER_OF_DOWNLOAD_RETRIES:
uris_shuffled = self.pkg.uris
random.shuffle(uris_shuffled)

for uri in uris_shuffled:
if not self._download_extension_package(uri.uri, destination):
continue

if file_downloaded and os.path.exists(destination):
self.pkg_file = destination
if self._unzip_extension_package(destination, self.get_base_dir()):
downloaded = True
break

except Exception as e:
logger.warn("Error while downloading extension: {0}", ustr(e))
if downloaded:
break

if not file_downloaded:
raise ExtensionDownloadError("Failed to download extension", code=1001)
self.logger.info("Failed to download the extension package from all uris, will retry after a minute")
time.sleep(60)
i += 1

self.logger.verbose("Unzip extension package")
try:
zipfile.ZipFile(self.pkg_file).extractall(self.get_base_dir())
except IOError as e:
fileutil.clean_ioerror(e, paths=[self.get_base_dir(), self.pkg_file])
raise ExtensionError(u"Failed to unzip extension package", e, code=1001)
if not downloaded:
raise ExtensionDownloadError("Failed to download extension", code=1001)

# Add user execute permission to all files under the base dir
for file in fileutil.get_all_files(self.get_base_dir()):
fileutil.chmod(file, os.stat(file).st_mode | stat.S_IXUSR)
self.pkg_file = destination

duration = elapsed_milliseconds(begin_utc)
self.report_event(message="Download succeeded", duration=duration)

def initialize(self):
self.logger.info("Initialize extension directory")

# Add user execute permission to all files under the base dir
for file in fileutil.get_all_files(self.get_base_dir()):
fileutil.chmod(file, os.stat(file).st_mode | stat.S_IXUSR)

# Save HandlerManifest.json
man_file = fileutil.search_file(self.get_base_dir(), 'HandlerManifest.json')

Expand Down
210 changes: 210 additions & 0 deletions tests/ga/test_exthandlers_download_extension.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the Apache License.

import zipfile, time

from azurelinuxagent.common.protocol.restapi import ExtHandler, ExtHandlerProperties, ExtHandlerPackage, ExtHandlerVersionUri
from azurelinuxagent.common.protocol.wire import WireProtocol
from azurelinuxagent.ga.exthandlers import ExtHandlerInstance, NUMBER_OF_DOWNLOAD_RETRIES
from azurelinuxagent.common.exception import ExtensionDownloadError
from tests.tools import *

class DownloadExtensionTestCase(AgentTestCase):
"""
Test cases for launch_command
"""
@classmethod
def setUpClass(cls):
AgentTestCase.setUpClass()
cls.mock_cgroups = patch("azurelinuxagent.ga.exthandlers.CGroups")
cls.mock_cgroups.start()

cls.mock_cgroups_telemetry = patch("azurelinuxagent.ga.exthandlers.CGroupsTelemetry")
cls.mock_cgroups_telemetry.start()

@classmethod
def tearDownClass(cls):
cls.mock_cgroups_telemetry.stop()
cls.mock_cgroups.stop()

AgentTestCase.tearDownClass()

def setUp(self):
AgentTestCase.setUp(self)

ext_handler_properties = ExtHandlerProperties()
ext_handler_properties.version = "1.0.0"
ext_handler = ExtHandler(name='Microsoft.CPlat.Core.RunCommandLinux')
ext_handler.properties = ext_handler_properties

protocol = WireProtocol("http://Microsoft.CPlat.Core.RunCommandLinux/foo-bar")

self.pkg = ExtHandlerPackage()
self.pkg.uris = [ ExtHandlerVersionUri(), ExtHandlerVersionUri(), ExtHandlerVersionUri(), ExtHandlerVersionUri(), ExtHandlerVersionUri() ]
self.pkg.uris[0].uri = 'https://zrdfepirv2cy4prdstr00a.blob.core.windows.net/f72653efd9e349ed9842c8b99e4c1712-foobar/Microsoft.CPlat.Core__RunCommandLinux__1.0.0'
self.pkg.uris[1].uri = 'https://zrdfepirv2cy4prdstr01a.blob.core.windows.net/f72653efd9e349ed9842c8b99e4c1712-foobar/Microsoft.CPlat.Core__RunCommandLinux__1.0.0'
self.pkg.uris[2].uri = 'https://zrdfepirv2cy4prdstr02a.blob.core.windows.net/f72653efd9e349ed9842c8b99e4c1712-foobar/Microsoft.CPlat.Core__RunCommandLinux__1.0.0'
self.pkg.uris[3].uri = 'https://zrdfepirv2cy4prdstr03a.blob.core.windows.net/f72653efd9e349ed9842c8b99e4c1712-foobar/Microsoft.CPlat.Core__RunCommandLinux__1.0.0'
self.pkg.uris[4].uri = 'https://zrdfepirv2cy4prdstr04a.blob.core.windows.net/f72653efd9e349ed9842c8b99e4c1712-foobar/Microsoft.CPlat.Core__RunCommandLinux__1.0.0'

self.ext_handler_instance = ExtHandlerInstance(ext_handler=ext_handler, protocol=protocol)
self.ext_handler_instance.pkg = self.pkg

self.extension_dir = os.path.join(self.tmp_dir, "Microsoft.CPlat.Core.RunCommandLinux-1.0.0")
self.mock_get_base_dir = patch("azurelinuxagent.ga.exthandlers.ExtHandlerInstance.get_base_dir", return_value=self.extension_dir)
self.mock_get_base_dir.start()

self.mock_get_log_dir = patch("azurelinuxagent.ga.exthandlers.ExtHandlerInstance.get_log_dir", return_value=self.tmp_dir)
self.mock_get_log_dir.start()

self.agent_dir = self.tmp_dir
self.mock_get_lib_dir = patch("azurelinuxagent.ga.exthandlers.conf.get_lib_dir", return_value=self.agent_dir)
self.mock_get_lib_dir.start()

def tearDown(self):
self.mock_get_lib_dir.stop()
self.mock_get_log_dir.stop()
self.mock_get_base_dir.stop()

AgentTestCase.tearDown(self)

_extension_command = "RunCommandLinux.sh"

@staticmethod
def _create_zip_file(filename):
file = None
try:
file = zipfile.ZipFile(filename, "w")
info = zipfile.ZipInfo(DownloadExtensionTestCase._extension_command)
info.date_time = time.localtime(time.time())[:6]
info.compress_type = zipfile.ZIP_DEFLATED
file.writestr(info, "#!/bin/sh\necho 'RunCommandLinux executed successfully'\n")
finally:
if file is not None:
file.close()

@staticmethod
def _create_invalid_zip_file(filename):
with open(filename, "w") as file:
file.write("An invalid ZIP file\n")

def _get_extension_package_file(self):
return os.path.join(self.agent_dir, os.path.basename(self.pkg.uris[0].uri) + ".zip")

def _get_extension_command_file(self):
return os.path.join(self.extension_dir, DownloadExtensionTestCase._extension_command)

def _assert_download_and_expand_succeeded(self):
self.assertTrue(os.path.exists(self._get_extension_package_file()), "The extension package was not downloaded to the expected location")
self.assertTrue(os.path.exists(self._get_extension_command_file()), "The extension package was not expanded to the expected location")

def test_it_should_download_and_expand_extension_package(self):
def download_ext_handler_pkg(_uri, destination):
DownloadExtensionTestCase._create_zip_file(destination)
return True

with patch("azurelinuxagent.common.protocol.wire.WireProtocol.download_ext_handler_pkg", side_effect=download_ext_handler_pkg) as mock_download_ext_handler_pkg:
self.ext_handler_instance.download()

# first download attempt should succeed
mock_download_ext_handler_pkg.assert_called_once()

self._assert_download_and_expand_succeeded()

def test_it_should_use_existing_extension_package_when_already_downloaded(self):
DownloadExtensionTestCase._create_zip_file(self._get_extension_package_file())

with patch("azurelinuxagent.common.protocol.wire.WireProtocol.download_ext_handler_pkg") as mock_download_ext_handler_pkg:
self.ext_handler_instance.download()

mock_download_ext_handler_pkg.assert_not_called()

self.assertTrue(os.path.exists(self._get_extension_command_file()), "The extension package was not expanded to the expected location")

def test_it_should_ignore_existing_extension_package_when_it_is_invalid(self):
def download_ext_handler_pkg(_uri, destination):
DownloadExtensionTestCase._create_zip_file(destination)
return True

DownloadExtensionTestCase._create_invalid_zip_file(self._get_extension_package_file())

with patch("azurelinuxagent.common.protocol.wire.WireProtocol.download_ext_handler_pkg", side_effect=download_ext_handler_pkg) as mock_download_ext_handler_pkg:
self.ext_handler_instance.download()

mock_download_ext_handler_pkg.assert_called_once()

self._assert_download_and_expand_succeeded()

def test_it_should_use_alternate_uris_when_download_fails(self):
self.download_failures = 0

def download_ext_handler_pkg(_uri, destination):
# fail a few times, then succeed
if self.download_failures < 3:
self.download_failures += 1
return False
DownloadExtensionTestCase._create_zip_file(destination)
return True

with patch("azurelinuxagent.common.protocol.wire.WireProtocol.download_ext_handler_pkg", side_effect=download_ext_handler_pkg) as mock_download_ext_handler_pkg:
self.ext_handler_instance.download()

self.assertEquals(mock_download_ext_handler_pkg.call_count, self.download_failures + 1)

self._assert_download_and_expand_succeeded()

def test_it_should_use_alternate_uris_when_download_raises_an_exception(self):
self.download_failures = 0

def download_ext_handler_pkg(_uri, destination):
# fail a few times, then succeed
if self.download_failures < 3:
self.download_failures += 1
raise Exception("Download failed")
DownloadExtensionTestCase._create_zip_file(destination)
return True

with patch("azurelinuxagent.common.protocol.wire.WireProtocol.download_ext_handler_pkg", side_effect=download_ext_handler_pkg) as mock_download_ext_handler_pkg:
self.ext_handler_instance.download()

self.assertEquals(mock_download_ext_handler_pkg.call_count, self.download_failures + 1)

self._assert_download_and_expand_succeeded()

def test_it_should_use_alternate_uris_when_it_downloads_an_invalid_package(self):
self.download_failures = 0

def download_ext_handler_pkg(_uri, destination):
# fail a few times, then succeed
if self.download_failures < 3:
self.download_failures += 1
DownloadExtensionTestCase._create_invalid_zip_file(destination)
else:
DownloadExtensionTestCase._create_zip_file(destination)
return True

with patch("azurelinuxagent.common.protocol.wire.WireProtocol.download_ext_handler_pkg", side_effect=download_ext_handler_pkg) as mock_download_ext_handler_pkg:
self.ext_handler_instance.download()

self.assertEquals(mock_download_ext_handler_pkg.call_count, self.download_failures + 1)

self._assert_download_and_expand_succeeded()

def test_it_should_raise_an_exception_when_all_downloads_fail(self):
def download_ext_handler_pkg(_uri, _destination):
DownloadExtensionTestCase._create_invalid_zip_file(self._get_extension_package_file())
return True

with patch("time.sleep", lambda *_: None):
with patch("azurelinuxagent.common.protocol.wire.WireProtocol.download_ext_handler_pkg", side_effect=download_ext_handler_pkg) as mock_download_ext_handler_pkg:
with self.assertRaises(ExtensionDownloadError) as context_manager:
self.ext_handler_instance.download()

self.assertEquals(mock_download_ext_handler_pkg.call_count, NUMBER_OF_DOWNLOAD_RETRIES * len(self.pkg.uris))

self.assertRegex(str(context_manager.exception), "Failed to download extension")
self.assertEquals(context_manager.exception.code, 1001)

self.assertFalse(os.path.exists(self.extension_dir), "The extension directory was not removed")
self.assertFalse(os.path.exists(self._get_extension_package_file()), "The extension package was not removed")

0 comments on commit 0f7f375

Please sign in to comment.