diff --git a/cloudinit/sources/DataSourceEc2.py b/cloudinit/sources/DataSourceEc2.py index 412f4d709b97..b52d3d4079b1 100644 --- a/cloudinit/sources/DataSourceEc2.py +++ b/cloudinit/sources/DataSourceEc2.py @@ -14,7 +14,7 @@ import time import uuid from contextlib import suppress -from typing import Dict, List +from typing import Dict, List, Literal from cloudinit import dmi, net, sources from cloudinit import url_helper as uhelp @@ -310,7 +310,7 @@ def _maybe_fetch_api_token(self, mdurls): timeout=url_params.timeout_seconds, status_cb=LOG.warning, headers_cb=self._get_headers, - exception_cb=self._imds_exception_cb, + exception_cb=self._token_exception_cb, request_method=request_method, headers_redact=self.imdsv2_token_redact, connect_synchronously=False, @@ -622,25 +622,27 @@ def _refresh_api_token(self, seconds=None): return None return response.contents - def _skip_or_refresh_stale_aws_token_cb(self, msg, exception): + def _skip_or_refresh_stale_aws_token_cb( + self, exception: uhelp.UrlError + ) -> bool: """Callback will not retry on SKIP_USERDATA_CODES or if no token is available.""" - retry = ec2.skip_retry_on_codes( - ec2.SKIP_USERDATA_CODES, msg, exception - ) + retry = ec2.skip_retry_on_codes(ec2.SKIP_USERDATA_CODES, exception) if not retry: return False # False raises exception - return self._refresh_stale_aws_token_cb(msg, exception) + return self._refresh_stale_aws_token_cb(exception) - def _refresh_stale_aws_token_cb(self, msg, exception): + def _refresh_stale_aws_token_cb( + self, exception: uhelp.UrlError + ) -> Literal[True]: """Exception handler for Ec2 to refresh token if token is stale.""" - if isinstance(exception, uhelp.UrlError) and exception.code == 401: + if exception.code == 401: # With _api_token as None, _get_headers will _refresh_api_token. LOG.debug("Clearing cached Ec2 API token due to expiry") self._api_token = None return True # always retry - def _imds_exception_cb(self, msg, exception=None): + def _token_exception_cb(self, exception: uhelp.UrlError) -> bool: """Fail quickly on proper AWS if IMDSv2 rejects API token request Guidance from Amazon is that if IMDSv2 had disabled token requests @@ -653,26 +655,23 @@ def _imds_exception_cb(self, msg, exception=None): temporarily unroutable or unavailable will still retry due to the callsite wait_for_url. """ - if isinstance(exception, uhelp.UrlError): + if exception.code: # requests.ConnectionError will have exception.code == None - if exception.code: - if exception.code == 403: - LOG.warning( - "Ec2 IMDS endpoint returned a 403 error. " - "HTTP endpoint is disabled. Aborting." - ) - raise exception - elif exception.code == 503: - LOG.warning( - "Ec2 IMDS endpoint returned a 503 error. " - "HTTP endpoint is overloaded. Retrying." - ) - return - elif exception.code >= 400: - LOG.warning( - "Fatal error while requesting Ec2 IMDSv2 API tokens" - ) - raise exception + if exception.code == 403: + LOG.warning( + "Ec2 IMDS endpoint returned a 403 error. " + "HTTP endpoint is disabled. Aborting." + ) + return False + elif exception.code == 503: + # Let the global handler deal with it + return False + elif exception.code >= 400: + LOG.warning( + "Fatal error while requesting Ec2 IMDSv2 API tokens" + ) + return False + return True def _get_headers(self, url=""): """Return a dict of headers for accessing a url. diff --git a/cloudinit/sources/DataSourceScaleway.py b/cloudinit/sources/DataSourceScaleway.py index 589ef6151413..43b4b8bd6fa3 100644 --- a/cloudinit/sources/DataSourceScaleway.py +++ b/cloudinit/sources/DataSourceScaleway.py @@ -84,7 +84,7 @@ def query_data_api_once(api_address, timeout, requests_session): session=requests_session, # If the error is a HTTP/404 or a ConnectionError, go into raise # block below and don't bother retrying. - exception_cb=lambda _, exc: exc.code != 404 + exception_cb=lambda exc: exc.code != 404 and ( not isinstance(exc.cause, requests.exceptions.ConnectionError) ), diff --git a/cloudinit/sources/azure/imds.py b/cloudinit/sources/azure/imds.py index 4f9ec2339355..cb14189e2b89 100644 --- a/cloudinit/sources/azure/imds.py +++ b/cloudinit/sources/azure/imds.py @@ -55,7 +55,7 @@ def __init__( self._request_count = 0 self._last_error: Union[None, Type, int] = None - def exception_callback(self, req_args, exception) -> bool: + def exception_callback(self, exception) -> bool: self._request_count += 1 if not isinstance(exception, UrlError): report_diagnostic_event( diff --git a/cloudinit/sources/helpers/ec2.py b/cloudinit/sources/helpers/ec2.py index a3590a6e4b2d..55c6205721c2 100644 --- a/cloudinit/sources/helpers/ec2.py +++ b/cloudinit/sources/helpers/ec2.py @@ -135,7 +135,7 @@ def _materialize(self, blob, base_url): return joined -def skip_retry_on_codes(status_codes, _request_args, cause): +def skip_retry_on_codes(status_codes, cause): """Returns False if cause.code is in status_codes.""" return cause.code not in status_codes @@ -143,6 +143,7 @@ def skip_retry_on_codes(status_codes, _request_args, cause): def get_instance_userdata( api_version="latest", metadata_address="http://169.254.169.254", + *, ssl_details=None, timeout=5, retries=5, diff --git a/cloudinit/sources/helpers/openstack.py b/cloudinit/sources/helpers/openstack.py index 97ec18faf983..629185995520 100644 --- a/cloudinit/sources/helpers/openstack.py +++ b/cloudinit/sources/helpers/openstack.py @@ -493,7 +493,7 @@ def _fetch_available_versions(self): return self._versions def _path_read(self, path, decode=False): - def should_retry_cb(_request_args, cause): + def should_retry_cb(cause): try: code = int(cause.code) if code >= 400: diff --git a/cloudinit/url_helper.py b/cloudinit/url_helper.py index e8577c93776f..3bf7370d0846 100644 --- a/cloudinit/url_helper.py +++ b/cloudinit/url_helper.py @@ -22,7 +22,17 @@ from http.client import NOT_FOUND from itertools import count from ssl import create_default_context -from typing import Any, Callable, Iterator, List, Optional, Tuple, Union +from typing import ( + Any, + Callable, + Iterator, + List, + MutableMapping, + NamedTuple, + Optional, + Tuple, + Union, +) from urllib.parse import quote, urlparse, urlsplit, urlunparse import requests @@ -33,6 +43,7 @@ LOG = logging.getLogger(__name__) REDACTED = "REDACTED" +ExceptionCallback = Optional[Callable[["UrlError"], bool]] def _cleanurl(url): @@ -334,13 +345,17 @@ def iter_content( class UrlError(IOError): - def __init__(self, cause, code=None, headers=None, url=None): + def __init__( + self, + cause: Any, # This SHOULD be an exception to wrap, but can be anything + code: Optional[int] = None, + headers: Optional[MutableMapping] = None, + url: Optional[str] = None, + ): IOError.__init__(self, str(cause)) self.cause = cause self.code = code - self.headers = headers - if self.headers is None: - self.headers = {} + self.headers: MutableMapping = {} if headers is None else headers self.url = url @@ -362,6 +377,76 @@ def _get_ssl_args(url, ssl_details): return ssl_args +def _get_retry_after(retry_after: str) -> float: + """Parse a Retry-After header value into an integer. + + : param retry_after: The value of the Retry-After header. + https://www.rfc-editor.org/rfc/rfc9110.html#section-10.2.3 + https://www.rfc-editor.org/rfc/rfc2616#section-3.3 + : return: The number of seconds to wait before retrying the request. + """ + try: + to_wait = float(retry_after) + except ValueError: + # Translate a date such as "Fri, 31 Dec 1999 23:59:59 GMT" + # into seconds to wait + try: + to_wait = float( + time.mktime( + time.strptime(retry_after, "%a, %d %b %Y %H:%M:%S %Z") + ) + - time.time() + ) + except ValueError: + LOG.info( + "Failed to parse Retry-After header value: %s. " + "Waiting 1 second instead.", + retry_after, + ) + to_wait = 1 + if to_wait < 0: + LOG.info( + "Retry-After header value is in the past. " + "Waiting 1 second instead." + ) + to_wait = 1 + return to_wait + + +def _handle_error( + error: UrlError, + *, + exception_cb: ExceptionCallback = None, +) -> Optional[float]: + """Handle exceptions raised during request processing. + + If we have no exception callback or the callback handled the error or we + got a 503, return with an optional timeout so the request can be retried. + Otherwise, raise the error. + + :param error: The exception raised during the request. + :param response: The response object. + :param exception_cb: Callable to handle the exception. + + :return: Optional time to wait before retrying the request. + """ + if exception_cb and exception_cb(error): + return None + if error.code and error.code == 503: + LOG.warning( + "Ec2 IMDS endpoint returned a 503 error. " + "HTTP endpoint is overloaded. Retrying." + ) + if error.headers: + return _get_retry_after(error.headers.get("Retry-After", "1")) + LOG.info("Unable to introspect response header. Waiting 1 second.") + return 1 + if not exception_cb: + return None + # If exception_cb returned False and there's no 503 + raise error + + def readurl( url, *, @@ -375,7 +460,7 @@ def readurl( ssl_details=None, check_status=True, allow_redirects=True, - exception_cb=None, + exception_cb: ExceptionCallback = None, session=None, infinite=False, log_req_resp=True, @@ -404,8 +489,8 @@ def readurl( occurs. Default: True. :param allow_redirects: Optional boolean passed straight to Session.request as 'allow_redirects'. Default: True. - :param exception_cb: Optional callable which accepts the params - msg and exception and returns a boolean True if retries are permitted. + :param exception_cb: Optional callable to handle exception and returns + True if retries are permitted. :param session: Optional exiting requests.Session instance to reuse. :param infinite: Bool, set True to retry indefinitely. Default: False. :param log_req_resp: Set False to turn off verbose debug messages. @@ -472,6 +557,7 @@ def readurl( filtered_req_args[k][key] = REDACTED else: filtered_req_args[k] = v + raised_exception: Exception try: if log_req_resp: LOG.debug( @@ -482,59 +568,57 @@ def readurl( filtered_req_args, ) - r = session.request(**req_args) + response = session.request(**req_args) if check_status: - r.raise_for_status() + response.raise_for_status() LOG.debug( "Read from %s (%s, %sb) after %s attempts", url, - r.status_code, - len(r.content), + response.status_code, + len(response.content), (i + 1), ) # Doesn't seem like we can make it use a different # subclass for responses, so add our own backward-compat # attrs - return UrlResponse(r) + return UrlResponse(response) except exceptions.SSLError as e: # ssl exceptions are not going to get fixed by waiting a # few seconds raise UrlError(e, url=url) from e + except exceptions.HTTPError as e: + url_error = UrlError( + e, + code=e.response.status_code, + headers=e.response.headers, + url=url, + ) + raised_exception = e except exceptions.RequestException as e: - if ( - isinstance(e, (exceptions.HTTPError)) - and hasattr(e, "response") - and hasattr( # This appeared in v 0.10.8 - e.response, "status_code" - ) - ): - url_error = UrlError( - e, - code=e.response.status_code, - headers=e.response.headers, - url=url, - ) - else: - url_error = UrlError(e, url=url) - - if exception_cb and not exception_cb(req_args.copy(), url_error): - # if an exception callback was given, it should return True - # to continue retrying and False to break and re-raise the - # exception - raise url_error from e + url_error = UrlError(e, url=url) + raised_exception = e + response = None + response_sleep_time = _handle_error( + url_error, + exception_cb=exception_cb, + ) + # If our response tells us to wait, then wait even if we're + # past the max tries + if not response_sleep_time: will_retry = infinite or (i + 1 < manual_tries) if not will_retry: - raise url_error from e + raise url_error from raised_exception + sleep_time = response_sleep_time or sec_between - if sec_between > 0: - if log_req_resp: - LOG.debug( - "Please wait %s seconds while we wait to try again", - sec_between, - ) - time.sleep(sec_between) + if sec_between > 0: + if log_req_resp: + LOG.debug( + "Please wait %s seconds while we wait to try again", + sec_between, + ) + time.sleep(sleep_time) raise RuntimeError("This path should be unreachable...") @@ -562,7 +646,7 @@ def dual_stack( addresses: List[str], stagger_delay: float = 0.150, timeout: int = 10, -) -> Tuple: +) -> Tuple[Optional[str], Optional[UrlResponse]]: """execute multiple callbacks in parallel Run blocking func against two different addresses staggered with a @@ -640,6 +724,15 @@ def dual_stack( return (returned_address, return_result) +class HandledResponse(NamedTuple): + # Set when we have a response to return + url: Optional[str] + response: Optional[UrlResponse] + + # Possibly set if we need to try again + wait_time: Optional[float] + + def wait_for_url( urls, *, @@ -649,53 +742,42 @@ def wait_for_url( headers_cb: Optional[Callable] = None, headers_redact=None, sleep_time: Optional[float] = None, - exception_cb: Optional[Callable] = None, + exception_cb: ExceptionCallback = None, sleep_time_cb: Optional[Callable[[Any, float], float]] = None, request_method: str = "", connect_synchronously: bool = True, async_delay: float = 0.150, ): - """ - urls: a list of urls to try - max_wait: roughly the maximum time to wait before giving up - The max time is *actually* len(urls)*timeout as each url will - be tried once and given the timeout provided. - a number <= 0 will always result in only one try - timeout: the timeout provided to urlopen - status_cb: call method with string message when a url is not available - headers_cb: call method with single argument of url to get headers - for request. - headers_redact: a list of header names to redact from the log - sleep_time: Amount of time to sleep between retries. If this and - sleep_time_cb are None, the default sleep time - defaults to 1 second and increases by 1 seconds every 5 - tries. Cannot be specified along with `sleep_time_cb`. - exception_cb: call method with 2 arguments 'msg' (per status_cb) and - 'exception', the exception that occurred. - sleep_time_cb: call method with 2 arguments (response, loop_n) that - generates the next sleep time. Cannot be specified - along with 'sleep_time`. - request_method: indicate the type of HTTP request, GET, PUT, or POST - connect_synchronously: if false, enables executing requests in parallel - async_delay: delay before parallel metadata requests, see RFC 6555 - returns: tuple of (url, response contents), on failure, (False, None) - - the idea of this routine is to wait for the EC2 metadata service to - come up. On both Eucalyptus and EC2 we have seen the case where - the instance hit the MD before the MD service was up. EC2 seems - to have permanently fixed this, though. - - In openstack, the metadata service might be painfully slow, and - unable to avoid hitting a timeout of even up to 10 seconds or more - (LP: #894279) for a simple GET. - - Offset those needs with the need to not hang forever (and block boot) - on a system where cloud-init is configured to look for EC2 Metadata - service but is not going to find one. It is possible that the instance - data host (169.254.169.254) may be firewalled off Entirely for a system, - meaning that the connection will block forever unless a timeout is set. - - The default value for max_wait will retry indefinitely. + """Wait for a response from one of the urls provided. + + :param urls: List of urls to try + :param max_wait: Roughly the maximum time to wait before giving up + The max time is *actually* len(urls)*timeout as each url will + be tried once and given the timeout provided. + a number <= 0 will always result in only one try + :param timeout: Timeout provided to urlopen + :param status_cb: Callable with string message when a url is not available + :param headers_cb: Callable with single argument of url to get headers + for request. + :param headers_redact: List of header names to redact from the log + :param sleep_time: Amount of time to sleep between retries. If this and + sleep_time_cb are None, the default sleep time defaults to 1 second + and increases by 1 seconds every 5 tries. Cannot be specified along + with `sleep_time_cb`. + :param exception_cb: Callable to handle exception and returns True if + retries are permitted. + :param sleep_time_cb: Callable with 2 arguments (response, loop_n) that + generates the next sleep time. Cannot be specified + along with 'sleep_time`. + :param request_method: Indicates the type of HTTP request: + GET, PUT, or POST + :param connect_synchronously: If false, enables executing requests + in parallel + :param async_delay: Delay before parallel metadata requests, see RFC 6555 + + :return: tuple of (url, response contents), on failure, (False, None) + + :raises: UrlError on unrecoverable error """ def default_sleep_time(_, loop_number: int) -> float: @@ -709,8 +791,28 @@ def timeup(max_wait: float, start_time: float, sleep_time: float = 0): time.monotonic() - start_time + sleep_time > max_wait ) - def handle_url_response(response, url) -> Tuple[Optional[Exception], str]: + def handle_url_response( + response: Optional[UrlResponse], url: Optional[str] + ) -> Tuple[Optional[UrlError], str]: """Map requests response code/contents to internal "UrlError" type""" + reason = "" + url_exc = None + if not (response and url): + reason = "Request timed out" + url_exc = UrlError(ValueError(reason)) + return url_exc, reason + try: + # Do this first because it can provide more context for the + # exception than what comes later + response._response.raise_for_status() + except requests.exceptions.HTTPError as e: + url_exc = UrlError( + e, + code=e.response.status_code, + headers=e.response.headers, + url=url, + ) + return url_exc, str(e) if not response.contents: reason = "empty response [%s]" % (response.code) url_exc = UrlError( @@ -720,6 +822,7 @@ def handle_url_response(response, url) -> Tuple[Optional[Exception], str]: url=url, ) elif not response.ok(): + # 3xx "errors" wouldn't be covered by the raise_for_status above reason = "bad status code [%s]" % (response.code) url_exc = UrlError( ValueError(reason), @@ -727,22 +830,26 @@ def handle_url_response(response, url) -> Tuple[Optional[Exception], str]: headers=response.headers, url=url, ) - else: - reason = "" - url_exc = None return (url_exc, reason) def read_url_handle_exceptions( - url_reader_cb, urls, start_time, exc_cb, log_cb - ) -> Tuple[str, Union[Exception, UrlResponse]]: + url_reader_cb: Callable[ + [Any], Tuple[Optional[str], Optional[UrlResponse]] + ], + urls: Union[str, List[str]], + start_time: int, + exc_cb: ExceptionCallback, + log_cb: Callable, + ) -> HandledResponse: """Execute request, handle response, optionally log exception""" reason = "" - url = "" + url = None + url_exc: Optional[Exception] try: url, response = url_reader_cb(urls) url_exc, reason = handle_url_response(response, url) if not url_exc: - return (url, response) + return HandledResponse(url, response, wait_time=None) except UrlError as e: reason = "request error [%s]" % e url_exc = e @@ -758,12 +865,16 @@ def read_url_handle_exceptions( reason, ) log_cb(status_msg) - if exc_cb: - # This can be used to alter the headers that will be sent - # in the future, for example this is what the MAAS datasource - # does. - exc_cb(msg=status_msg, exception=url_exc) - return url, url_exc + + return HandledResponse( + url=None, + response=None, + wait_time=( + _handle_error(url_exc, exception_cb=exc_cb) + if isinstance(url_exc, UrlError) + else None + ), + ) def read_url_cb(url: str, timeout: int) -> UrlResponse: return readurl( @@ -777,7 +888,7 @@ def read_url_cb(url: str, timeout: int) -> UrlResponse: def read_url_serial( start_time, timeout, exc_cb, log_cb - ) -> Optional[Tuple[str, Union[Exception, UrlResponse]]]: + ) -> HandledResponse: """iterate over list of urls, request each one and handle responses and thrown exceptions individually per url """ @@ -785,11 +896,14 @@ def read_url_serial( def url_reader_serial(url: str): return (url, read_url_cb(url, timeout)) + wait_times = [] for url in urls: now = time.monotonic() if loop_n != 0: if timeup(max_wait, start_time): - return None + return HandledResponse( + url=None, response=None, wait_time=None + ) if ( max_wait is not None and timeout @@ -801,13 +915,16 @@ def url_reader_serial(url: str): out = read_url_handle_exceptions( url_reader_serial, url, start_time, exc_cb, log_cb ) - if out: + if out.response: return out - return None + elif out.wait_time: + wait_times.append(out.wait_time) + wait_time = max(wait_times) if wait_times else None + return HandledResponse(url=None, response=None, wait_time=wait_time) def read_url_parallel( start_time, timeout, exc_cb, log_cb - ) -> Optional[Tuple[str, Union[Exception, UrlResponse]]]: + ) -> HandledResponse: """pass list of urls to dual_stack which sends requests in parallel handle response and exceptions of the first endpoint to respond """ @@ -817,11 +934,9 @@ def read_url_parallel( stagger_delay=async_delay, timeout=timeout, ) - out = read_url_handle_exceptions( + return read_url_handle_exceptions( url_reader_parallel, urls, start_time, exc_cb, log_cb ) - if out: - return out start_time = time.monotonic() if sleep_time and sleep_time_cb: @@ -841,14 +956,15 @@ def read_url_parallel( loop_n: int = 0 response = None while True: - current_sleep_time = calculate_sleep_time(response, loop_n) - - url = do_read_url(start_time, timeout, exception_cb, status_cb) - if url: - address, response = url - if isinstance(response, UrlResponse): - return (address, response.contents) + resp = do_read_url(start_time, timeout, exception_cb, status_cb) + if resp.response: + return resp.url, resp.response.contents + elif resp.wait_time: + time.sleep(resp.wait_time) + loop_n = loop_n + 1 + continue + current_sleep_time = calculate_sleep_time(response, loop_n) if timeup(max_wait, start_time, current_sleep_time): break diff --git a/tests/unittests/sources/helpers/test_ec2.py b/tests/unittests/sources/helpers/test_ec2.py index 2e6ec4cc2f3a..aa250893a2a6 100644 --- a/tests/unittests/sources/helpers/test_ec2.py +++ b/tests/unittests/sources/helpers/test_ec2.py @@ -276,7 +276,7 @@ def test_metadata_no_security_credentials(self): def test_metadata_children_with_invalid_character(self): def _skip_tags(exception): if isinstance(exception, uh.UrlError) and exception.code == 404: - if "meta-data/tags/" in exception.url: + if exception.url and "meta-data/tags/" in exception.url: print(exception.url) return True return False diff --git a/tests/unittests/sources/test_ec2.py b/tests/unittests/sources/test_ec2.py index b28afc52fe0e..12cbd28fb9bc 100644 --- a/tests/unittests/sources/test_ec2.py +++ b/tests/unittests/sources/test_ec2.py @@ -755,6 +755,33 @@ def test_aws_token_403_fails_without_retries(self, caplog, mocker, tmpdir): for log in expected_logs: assert log in caplog.record_tuples + @responses.activate + def test_aws_token_503_success_after_retries(self, mocker, tmpdir): + """Verify that 503s fetching AWS tokens are retried. + + GH-5577: Cloud-init fails on AWS if IMDSv2 returns a 503 error. + """ + ds = self._setup_ds( + platform_data=self.valid_platform_data, + sys_cfg={ + "datasource": { + "Ec2": { + "strict_id": False, + } + } + }, + md=None, + mocker=mocker, + tmpdir=tmpdir, + ) + + token_url = self.data_url("latest", data_item="api/token") + responses.add(responses.PUT, token_url, status=503) + responses.add(responses.PUT, token_url, status=503) + responses.add(responses.PUT, token_url, status=200, body="response") + assert ds.wait_for_metadata_service() is True + assert 3 == len(responses.calls) + @responses.activate def test_aws_token_redacted(self, caplog, mocker, tmpdir): """Verify that aws tokens are redacted when logged.""" diff --git a/tests/unittests/test_data.py b/tests/unittests/test_data.py index 7621c5f6c806..fa1aedf7d851 100644 --- a/tests/unittests/test_data.py +++ b/tests/unittests/test_data.py @@ -14,7 +14,6 @@ from unittest import mock import pytest -import requests import responses from cloudinit import handlers @@ -742,9 +741,7 @@ def test_include_bad_url_no_fail( responses.add( responses.GET, bad_url, - body=requests.HTTPError( - f"403 Client Error: Forbidden for url: {bad_url}" - ), + body="forbidden", status=403, ) diff --git a/tests/unittests/test_url_helper.py b/tests/unittests/test_url_helper.py index b05f544b13d8..a4e5b669ff16 100644 --- a/tests/unittests/test_url_helper.py +++ b/tests/unittests/test_url_helper.py @@ -17,6 +17,7 @@ REDACTED, UrlError, UrlResponse, + _handle_error, dual_stack, oauth_headers, read_file_or_url, @@ -156,7 +157,7 @@ def test_wb_read_url_defaults_honored_by_read_file_or_url_callers(self): class FakeSessionRaisesHttpError(requests.Session): @classmethod def request(cls, **kwargs): - raise requests.exceptions.HTTPError("broke") + raise requests.exceptions.RequestException("broke") class FakeSession(requests.Session): @classmethod @@ -347,6 +348,63 @@ def request(cls, **kwargs): assert response._response == m_response + def test_error_no_cb(self, mocker): + response = requests.Response() + response.status_code = 500 + m_request = mocker.patch("requests.Session.request", autospec=True) + m_request.return_value = response + + with pytest.raises(UrlError) as e: + readurl("http://some/path") + assert e.value.code == 500 + + def test_error_cb_true(self, mocker): + mocker.patch("time.sleep") + + bad_response = requests.Response() + bad_response.status_code = 500 + bad_response._content = b"oh noes!" + good_response = requests.Response() + good_response.status_code = 200 + good_response._content = b"yay" + + m_request = mocker.patch("requests.Session.request", autospec=True) + m_request.side_effect = (bad_response, good_response) + + readurl("http://some/path", retries=1, exception_cb=lambda _: True) + assert m_request.call_count == 2 + + def test_error_cb_false(self, mocker): + mocker.patch("time.sleep") + + bad_response = requests.Response() + bad_response.status_code = 500 + bad_response._content = b"oh noes!" + + m_request = mocker.patch("requests.Session.request", autospec=True) + m_request.return_value = bad_response + + with pytest.raises(UrlError): + readurl( + "http://some/path", retries=1, exception_cb=lambda _: False + ) + assert m_request.call_count == 1 + + def test_exception_503(self, mocker): + mocker.patch("time.sleep") + + retry_response = requests.Response() + retry_response.status_code = 503 + retry_response._content = b"try again" + good_response = requests.Response() + good_response.status_code = 200 + good_response._content = b"good" + m_request = mocker.patch("requests.Session.request", autospec=True) + m_request.side_effect = (retry_response, retry_response, good_response) + + readurl("http://some/path") + assert m_request.call_count == 3 + event = Event() @@ -593,7 +651,7 @@ def identity_of_first_arg(x, _): SLEEP2 = "https://sleep2/" -class TestUrlHelper: +class TestWaitForUrl: success = "SUCCESS" fail = "FAIL" event = Event() @@ -774,3 +832,69 @@ def readurl_side_effect(self, *args, **kwargs): if "timeout" in kwargs: self.mock_time_value += kwargs["timeout"] + 0.0000001 raise UrlError("test") + + +class TestHandleError: + def test_handle_error_no_cb(self): + """Test no callback.""" + assert _handle_error(UrlError("test")) is None + + def test_handle_error_cb_false(self): + """Test callback returning False.""" + with pytest.raises(UrlError) as e: + _handle_error(UrlError("test"), exception_cb=lambda _: False) + assert str(e.value) == "test" + + def test_handle_error_cb_true(self): + """Test callback returning True.""" + assert ( + _handle_error(UrlError("test"), exception_cb=lambda _: True) + ) is None + + def test_handle_503(self, caplog): + """Test 503 with no callback.""" + assert _handle_error(UrlError("test", code=503)) == 1 + assert "Unable to introspect response header" in caplog.text + + def test_handle_503_with_retry_header(self): + """Test 503 with a retry integer value.""" + assert ( + _handle_error( + UrlError("test", code=503, headers={"Retry-After": 5}) + ) + == 5 + ) + + def test_handle_503_with_retry_header_in_past(self, caplog): + """Test 503 with date in the past.""" + assert ( + _handle_error( + UrlError( + "test", + code=503, + headers={"Retry-After": "Fri, 31 Dec 1999 23:59:59 GMT"}, + ) + ) + == 1 + ) + assert "Retry-After header value is in the past" in caplog.text + + def test_handle_503_cb_true(self): + """Test 503 with a callback returning True.""" + assert ( + _handle_error( + UrlError("test", code=503), + exception_cb=lambda _: True, + ) + is None + ) + + def test_handle_503_cb_false(self): + """Test 503 with a callback returning False.""" + assert ( + _handle_error( + UrlError("test", code=503), + exception_cb=lambda _: False, + ) + == 1 + )