Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve resilience of extension downloads #1463

Merged
merged 3 commits into from
Feb 14, 2019
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 61 additions & 29 deletions azurelinuxagent/ga/exthandlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@

AGENT_STATUS_FILE = "waagent_status.json"

NUMBER_OF_DOWNLOAD_RETRIES = 5

def get_traceback(e):
if sys.version_info[0] == 3:
Expand Down Expand Up @@ -450,6 +451,7 @@ def handle_enable(self, ext_handler_i):
if handler_state == ExtHandlerState.NotInstalled:
ext_handler_i.set_handler_state(ExtHandlerState.NotInstalled)
ext_handler_i.download()
ext_handler_i.initialize()
ext_handler_i.update_settings()
if old_ext_handler_i is None:
ext_handler_i.install()
Expand Down Expand Up @@ -743,54 +745,84 @@ def report_event(self, message="", is_success=True, duration=0, log_event=True):
add_event(name=self.ext_handler.name, version=ext_handler_version, message=message,
op=self.operation, is_success=is_success, duration=duration, log_event=log_event)

def _download_extension_package(self, source_uri, target_file):
self.logger.info("Downloading extension package: {0}", source_uri)
try:
if not self.protocol.download_ext_handler_pkg(source_uri, target_file):
raise Exception("Failed to download extension package - no error information is available")
except Exception as exception:
self.logger.info("Error downloading extension package: {0}", ustr(exception))
if os.path.exists(target_file):
os.remove(target_file)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fileutil.clean_ioerror already does this partially. We should change (overload) this as util and use that in both the place (756, 766 and 768).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That function seems to be meant to handle very specific errors and cannot handle this one in particular - it has a check for the exception type. Since that function is used in many places in the code, I decided to leave it alone for the moment.

return False
return True

def _unzip_extension_package(self, source_file, target_directory):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a clarification - Would the issue of corrupt zip arise for Agent as well? and we need similar logic in azurelinuxagent.ga.update.GuestAgent#_unpack?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, and I intend to fix that as well. However, this is a hotfix release, so I do not want to add any other changes to it. Thanks for pointing it out.

self.logger.info("Unzipping extension package: {0}", source_file)
try:
zipfile.ZipFile(source_file).extractall(target_directory)
except Exception as exception:
logger.info("Error while unzipping extension package: {0}", ustr(exception))
os.remove(source_file)
if os.path.exists(target_directory):
shutil.rmtree(target_directory)
return False
return True

def download(self):
begin_utc = datetime.datetime.utcnow()
self.logger.verbose("Download extension package")
self.set_operation(WALAEventOperation.Download)

if self.pkg is None:
if self.pkg is None or self.pkg.uris is None or len(self.pkg.uris) == 0:
raise ExtensionDownloadError("No package uri found")

uris_shuffled = self.pkg.uris
random.shuffle(uris_shuffled)
file_downloaded = False
destination = os.path.join(conf.get_lib_dir(), os.path.basename(self.pkg.uris[0].uri) + ".zip")

for uri in uris_shuffled:
try:
destination = os.path.join(conf.get_lib_dir(), os.path.basename(uri.uri) + ".zip")
package_exists = False
if os.path.exists(destination):
self.logger.info("Using existing extension package: {0}", destination)
if self._unzip_extension_package(destination, self.get_base_dir()):
package_exists = True
else:
self.logger.info("The existing extension package is invalid, will ignore it.")

if os.path.exists(destination):
file_downloaded = True
self.pkg_file = destination
break
else:
file_downloaded = self.protocol.download_ext_handler_pkg(uri.uri, destination)
if not package_exists:
downloaded = False
i = 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: a better name for this variable might be "retry_count"?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. The loop is short and very simple; I think in this case the longer variable name wouldn't add much.

while i < NUMBER_OF_DOWNLOAD_RETRIES:
uris_shuffled = self.pkg.uris
random.shuffle(uris_shuffled)

for uri in uris_shuffled:
if not self._download_extension_package(uri.uri, destination):
continue

if file_downloaded and os.path.exists(destination):
self.pkg_file = destination
if self._unzip_extension_package(destination, self.get_base_dir()):
downloaded = True
break

except Exception as e:
logger.warn("Error while downloading extension: {0}", ustr(e))
if downloaded:
break

if not file_downloaded:
raise ExtensionDownloadError("Failed to download extension", code=1001)
self.logger.info("Failed to download the extension package from all uris, will retry after a minute")
time.sleep(60)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good call! 👍

i += 1

self.logger.verbose("Unzip extension package")
try:
zipfile.ZipFile(self.pkg_file).extractall(self.get_base_dir())
except IOError as e:
fileutil.clean_ioerror(e, paths=[self.get_base_dir(), self.pkg_file])
raise ExtensionError(u"Failed to unzip extension package", e, code=1001)
if not downloaded:
raise ExtensionDownloadError("Failed to download extension", code=1001)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll open a task. We need to document these magical numbers. (such as 1001).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok


# Add user execute permission to all files under the base dir
for file in fileutil.get_all_files(self.get_base_dir()):
fileutil.chmod(file, os.stat(file).st_mode | stat.S_IXUSR)
self.pkg_file = destination

duration = elapsed_milliseconds(begin_utc)
self.report_event(message="Download succeeded", duration=duration)

def initialize(self):
self.logger.info("Initialize extension directory")

# Add user execute permission to all files under the base dir
for file in fileutil.get_all_files(self.get_base_dir()):
fileutil.chmod(file, os.stat(file).st_mode | stat.S_IXUSR)

# Save HandlerManifest.json
man_file = fileutil.search_file(self.get_base_dir(), 'HandlerManifest.json')

Expand Down
203 changes: 203 additions & 0 deletions tests/ga/test_exthandlers_download_extension.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the Apache License.

import zipfile, time

from azurelinuxagent.common.protocol.restapi import ExtHandler, ExtHandlerProperties, ExtHandlerPackage, ExtHandlerVersionUri
from azurelinuxagent.common.protocol.wire import WireProtocol
from azurelinuxagent.ga.exthandlers import ExtHandlerInstance, NUMBER_OF_DOWNLOAD_RETRIES
from azurelinuxagent.common.exception import ExtensionDownloadError
from tests.tools import *

class DownloadExtensionTestCase(AgentTestCase):
"""
Test cases for launch_command
"""
@classmethod
def setUpClass(cls):
AgentTestCase.setUpClass()
cls.mock_cgroups = patch("azurelinuxagent.ga.exthandlers.CGroups")
cls.mock_cgroups.start()

cls.mock_cgroups_telemetry = patch("azurelinuxagent.ga.exthandlers.CGroupsTelemetry")
cls.mock_cgroups_telemetry.start()

@classmethod
def tearDownClass(cls):
cls.mock_cgroups_telemetry.stop()
cls.mock_cgroups.stop()

AgentTestCase.tearDownClass()

def setUp(self):
AgentTestCase.setUp(self)

ext_handler_properties = ExtHandlerProperties()
ext_handler_properties.version = "1.0.0"
ext_handler = ExtHandler(name='Microsoft.CPlat.Core.RunCommandLinux')
ext_handler.properties = ext_handler_properties

protocol = WireProtocol("http://Microsoft.CPlat.Core.RunCommandLinux/foo-bar")

self.pkg = ExtHandlerPackage()
self.pkg.uris = [ ExtHandlerVersionUri(), ExtHandlerVersionUri(), ExtHandlerVersionUri(), ExtHandlerVersionUri(), ExtHandlerVersionUri() ]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you change ExtHandlerVersionUri to the following, It would make both the class as well as the test a little cleaner:

class ExtHandlerVersionUri(DataContract):
    def __init__(self, uri=None):
        self.uri = uri
self.pkg.uris = [ 
ExtHandlerVersionUri('https://zrdfepirv2cy4prdstr00a.blob.core.windows.net/f72653efd9e349ed9842c8b99e4c1712-foobar/Microsoft.CPlat.Core__RunCommandLinux__1.0.0'), 
ExtHandlerVersionUri('https://zrdfepirv2cy4prdstr01a.blob.core.windows.net/f72653efd9e349ed9842c8b99e4c1712-foobar/Microsoft.CPlat.Core__RunCommandLinux__1.0.0'), 
ExtHandlerVersionUri('https://zrdfepirv2cy4prdstr02a.blob.core.windows.net/f72653efd9e349ed9842c8b99e4c1712-foobar/Microsoft.CPlat.Core__RunCommandLinux__1.0.0'), 
ExtHandlerVersionUri('https://zrdfepirv2cy4prdstr02a.blob.core.windows.net/f72653efd9e349ed9842c8b99e4c1712-foobar/Microsoft.CPlat.Core__RunCommandLinux__1.0.0'), 
ExtHandlerVersionUri('https://zrdfepirv2cy4prdstr02a.blob.core.windows.net/f72653efd9e349ed9842c8b99e4c1712-foobar/Microsoft.CPlat.Core__RunCommandLinux__1.0.0')
]

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ExtHandlerVersionUri is meant to be used to deserialize from xml, etc. The class has other members other than uri. I think it would not be worth it to add all of those parameters to init just for the sake of this test.

self.pkg.uris[0].uri = 'https://zrdfepirv2cy4prdstr00a.blob.core.windows.net/f72653efd9e349ed9842c8b99e4c1712-foobar/Microsoft.CPlat.Core__RunCommandLinux__1.0.0'
self.pkg.uris[1].uri = 'https://zrdfepirv2cy4prdstr01a.blob.core.windows.net/f72653efd9e349ed9842c8b99e4c1712-foobar/Microsoft.CPlat.Core__RunCommandLinux__1.0.0'
self.pkg.uris[2].uri = 'https://zrdfepirv2cy4prdstr02a.blob.core.windows.net/f72653efd9e349ed9842c8b99e4c1712-foobar/Microsoft.CPlat.Core__RunCommandLinux__1.0.0'
self.pkg.uris[3].uri = 'https://zrdfepirv2cy4prdstr03a.blob.core.windows.net/f72653efd9e349ed9842c8b99e4c1712-foobar/Microsoft.CPlat.Core__RunCommandLinux__1.0.0'
self.pkg.uris[4].uri = 'https://zrdfepirv2cy4prdstr04a.blob.core.windows.net/f72653efd9e349ed9842c8b99e4c1712-foobar/Microsoft.CPlat.Core__RunCommandLinux__1.0.0'

self.ext_handler_instance = ExtHandlerInstance(ext_handler=ext_handler, protocol=protocol)
self.ext_handler_instance.pkg = self.pkg

self.extension_dir = os.path.join(self.tmp_dir, "Microsoft.CPlat.Core.RunCommandLinux-1.0.0")
self.mock_get_base_dir = patch("azurelinuxagent.ga.exthandlers.ExtHandlerInstance.get_base_dir", return_value=self.extension_dir)
self.mock_get_base_dir.start()

self.mock_get_log_dir = patch("azurelinuxagent.ga.exthandlers.ExtHandlerInstance.get_log_dir", return_value=self.tmp_dir)
self.mock_get_log_dir.start()

self.agent_dir = self.tmp_dir
self.mock_get_lib_dir = patch("azurelinuxagent.ga.exthandlers.conf.get_lib_dir", return_value=self.agent_dir)
self.mock_get_lib_dir.start()

def tearDown(self):
self.mock_get_lib_dir.stop()
self.mock_get_log_dir.stop()
self.mock_get_base_dir.stop()

AgentTestCase.tearDown(self)

_extension_command = "RunCommandLinux.sh"

@staticmethod
def _create_zip_file(filename):
file = None
try:
file = zipfile.ZipFile(filename, "w")
info = zipfile.ZipInfo(DownloadExtensionTestCase._extension_command)
info.date_time = time.localtime(time.time())[:6]
info.compress_type = zipfile.ZIP_DEFLATED
file.writestr(info, "#!/bin/sh\necho 'RunCommandLinux executed successfully'\n")
finally:
if file is not None:
file.close()

@staticmethod
def _create_invalid_zip_file(filename):
with open(filename, "w") as file:
file.write("An invalid ZIP file\n")

def _get_extension_package_file(self):
return os.path.join(self.agent_dir, os.path.basename(self.pkg.uris[0].uri) + ".zip")

def _get_extension_command_file(self):
return os.path.join(self.extension_dir, DownloadExtensionTestCase._extension_command)

def _assert_download_and_expand_succeeded(self):
self.assertTrue(os.path.exists(self._get_extension_package_file()), "The extension package was not downloaded to the expected location")
self.assertTrue(os.path.exists(self._get_extension_command_file()), "The extension package was not expanded to the expected location")

def test_it_should_download_and_expand_extension_package(self):
def download_ext_handler_pkg(_uri, destination):
DownloadExtensionTestCase._create_zip_file(destination)
return True

with patch("azurelinuxagent.common.protocol.wire.WireProtocol.download_ext_handler_pkg", side_effect=download_ext_handler_pkg) as mock_download_ext_handler_pkg:
self.ext_handler_instance.download()

# first download attempt should succeed
mock_download_ext_handler_pkg.assert_called_once()

self._assert_download_and_expand_succeeded()

def test_it_should_use_existing_extension_package_when_already_downloaded(self):
DownloadExtensionTestCase._create_zip_file(self._get_extension_package_file())

with patch("azurelinuxagent.common.protocol.wire.WireProtocol.download_ext_handler_pkg") as mock_download_ext_handler_pkg:
self.ext_handler_instance.download()

mock_download_ext_handler_pkg.assert_not_called()

self.assertTrue(os.path.exists(self._get_extension_command_file()), "The extension package was not expanded to the expected location")

def test_it_should_ignore_existing_extension_package_when_it_is_invalid(self):
def download_ext_handler_pkg(_uri, destination):
DownloadExtensionTestCase._create_zip_file(destination)
return True

DownloadExtensionTestCase._create_invalid_zip_file(self._get_extension_package_file())

with patch("azurelinuxagent.common.protocol.wire.WireProtocol.download_ext_handler_pkg", side_effect=download_ext_handler_pkg) as mock_download_ext_handler_pkg:
self.ext_handler_instance.download()

mock_download_ext_handler_pkg.assert_called_once()

self._assert_download_and_expand_succeeded()

def test_it_should_use_alternate_uris_when_download_fails(self):
self.download_failures = 0

def download_ext_handler_pkg(_uri, destination):
if self.download_failures < 3:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use NUMBER_OF_DOWNLOAD_RETRIES - 2 instead of 3? It seems like an arbitrary number otherwise.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or is this related to the number of uris you initialize in the test case set up method? In that case, I would also refer to len(self.pkg.uris).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is an arbitrary number :) - Read it as "fail a few times, then succeed".

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added comment to code

self.download_failures += 1
return False
DownloadExtensionTestCase._create_zip_file(destination)
return True

with patch("azurelinuxagent.common.protocol.wire.WireProtocol.download_ext_handler_pkg", side_effect=download_ext_handler_pkg) as mock_download_ext_handler_pkg:
self.ext_handler_instance.download()

self.assertEquals(mock_download_ext_handler_pkg.call_count, self.download_failures + 1)

self._assert_download_and_expand_succeeded()

def test_it_should_use_alternate_uris_when_download_raises_an_exception(self):
self.download_failures = 0

def download_ext_handler_pkg(_uri, destination):
if self.download_failures < 3:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

idem

self.download_failures += 1
raise Exception("Download failed")
DownloadExtensionTestCase._create_zip_file(destination)
return True

with patch("azurelinuxagent.common.protocol.wire.WireProtocol.download_ext_handler_pkg", side_effect=download_ext_handler_pkg) as mock_download_ext_handler_pkg:
self.ext_handler_instance.download()

self.assertEquals(mock_download_ext_handler_pkg.call_count, self.download_failures + 1)

self._assert_download_and_expand_succeeded()

def test_it_should_use_alternate_uris_when_it_downloads_an_invalid_package(self):
self.download_failures = 0

def download_ext_handler_pkg(_uri, destination):
if self.download_failures < 3:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

idem

self.download_failures += 1
DownloadExtensionTestCase._create_invalid_zip_file(destination)
else:
DownloadExtensionTestCase._create_zip_file(destination)
return True

with patch("azurelinuxagent.common.protocol.wire.WireProtocol.download_ext_handler_pkg", side_effect=download_ext_handler_pkg) as mock_download_ext_handler_pkg:
self.ext_handler_instance.download()

self.assertEquals(mock_download_ext_handler_pkg.call_count, self.download_failures + 1)

self._assert_download_and_expand_succeeded()

def test_it_should_raise_an_exception_when_all_downloads_fail(self):
def download_ext_handler_pkg(_uri, destination):
return False

with patch("time.sleep", lambda *_: None):
with patch("azurelinuxagent.common.protocol.wire.WireProtocol.download_ext_handler_pkg", side_effect=download_ext_handler_pkg) as mock_download_ext_handler_pkg:
with self.assertRaises(ExtensionDownloadError) as context_manager:
self.ext_handler_instance.download()

self.assertEquals(mock_download_ext_handler_pkg.call_count, NUMBER_OF_DOWNLOAD_RETRIES * len(self.pkg.uris))

self.assertRegex(str(context_manager.exception), "Failed to download extension")
self.assertEquals(context_manager.exception.code, 1001)