diff --git a/pkgpanda/exceptions.py b/pkgpanda/exceptions.py index 4f7c951205..78a506f3bd 100644 --- a/pkgpanda/exceptions.py +++ b/pkgpanda/exceptions.py @@ -16,6 +16,22 @@ def __str__(self): return msg +class IncompleteDownloadError(Exception): + + def __init__(self, url, total_bytes_read, content_length): + self.url = url + self.total_bytes_read = total_bytes_read + self.content_length = content_length + + def __str__(self): + msg = "Problem fetching {} - bytes read {} does not match content-length {}".format( + self.url, + self.total_bytes_read, + self.content_length) + + return msg + + class InstallError(Exception): pass diff --git a/pkgpanda/test_util.py b/pkgpanda/test_util.py index 3656389f48..97cdf6c41f 100644 --- a/pkgpanda/test_util.py +++ b/pkgpanda/test_util.py @@ -1,8 +1,11 @@ import os import tempfile +from http.server import BaseHTTPRequestHandler, HTTPServer from subprocess import CalledProcessError +from threading import Thread import pytest +import requests import pkgpanda.util from pkgpanda import UserManagement @@ -412,3 +415,49 @@ def test_write_string(tmpdir): st_mode = os.stat(filename).st_mode expected_permission = 0o777 assert (st_mode & 0o777) == expected_permission + + +class MockDownloadServerRequestHandler(BaseHTTPRequestHandler): + def do_GET(self): # noqa: N802 + body = b'foobar' + + self.send_response(requests.codes.ok) + self.send_header('Content-Length', '6') + self.send_header('Content-Type', 'text/plain') + self.end_headers() + + if self.server.requests_received == 0: + # Don't send the last byte of the response body. + self.wfile.write(body[:len(body) - 1]) + else: + self.wfile.write(body) + self.server.requests_received += 1 + + return + + +class MockHTTPDownloadServer(HTTPServer): + requests_received = 0 + + +def test_stream_remote_file_with_retries(tmpdir): + mock_server = MockHTTPDownloadServer(('localhost', 0), MockDownloadServerRequestHandler) + mock_server_port = mock_server.server_port + + mock_server_thread = Thread( + target=mock_server.serve_forever, + daemon=True) + mock_server_thread.start() + + url = 'http://localhost:{port}/foobar.txt'.format(port=mock_server_port) + + out_file = os.path.join(str(tmpdir), 'foobar.txt') + response = pkgpanda.util._download_remote_file(out_file, url) + + response_is_ok = response.ok + assert response_is_ok + + assert mock_server.requests_received == 2 + + with open(out_file, 'rb') as f: + assert f.read() == b'foobar' diff --git a/pkgpanda/util.py b/pkgpanda/util.py index bd0c77102b..49a411de33 100644 --- a/pkgpanda/util.py +++ b/pkgpanda/util.py @@ -19,13 +19,14 @@ from typing import List import requests +import retrying import teamcity import yaml from requests.adapters import HTTPAdapter from requests.packages.urllib3.util.retry import Retry from teamcity.messages import TeamcityServiceMessages -from pkgpanda.exceptions import FetchError, ValidationError +from pkgpanda.exceptions import FetchError, IncompleteDownloadError, ValidationError is_windows = platform.system() == "Windows" @@ -150,6 +151,33 @@ def get_requests_retry_session(max_retries=4, backoff_factor=1, status_forcelist return session +def _is_incomplete_download_error(exception): + return isinstance(exception, IncompleteDownloadError) + + +@retrying.retry( + stop_max_attempt_number=3, + wait_random_min=1000, + wait_random_max=2000, + retry_on_exception=_is_incomplete_download_error) +def _download_remote_file(out_filename, url, retries=4): + with open(out_filename, "wb") as f: + r = get_requests_retry_session().get(url, stream=True) + r.raise_for_status() + + content_length = int(r.headers['content-length']) + + total_bytes_read = 0 + for chunk in r.iter_content(chunk_size=4096): + f.write(chunk) + total_bytes_read += len(chunk) + + if total_bytes_read != content_length: + raise IncompleteDownloadError(url, total_bytes_read, content_length) + + return r + + def download(out_filename, url, work_dir, rm_on_error=True): assert os.path.isabs(out_filename) assert os.path.isabs(work_dir) @@ -167,14 +195,7 @@ def download(out_filename, url, work_dir, rm_on_error=True): src_filename = work_dir + '/' + src_filename shutil.copyfile(src_filename, out_filename) else: - # Download the file. - with open(out_filename, "w+b") as f: - r = get_requests_retry_session().get(url, stream=True) - if r.status_code == 301: - raise Exception("got a 301") - r.raise_for_status() - for chunk in r.iter_content(chunk_size=4096): - f.write(chunk) + _download_remote_file(out_filename, url) except Exception as fetch_exception: if rm_on_error: rm_passed = False