From e553ebcc37bc32cea48df69afd90e912dc6f8a1f Mon Sep 17 00:00:00 2001 From: Dan King Date: Thu, 11 May 2023 20:56:12 -0400 Subject: [PATCH] [services] reliably retry all requests (#13029) `request_retry_transient_errors` is a charlatan. It does not retry errors that occur in reading the response body. I eliminated it in favor of the tried and true `retry_transient_errors` and some new helper methods that initiate the request *and* read the response. --- batch/batch/cloud/azure/worker/worker_api.py | 29 ++++----- batch/batch/cloud/gcp/worker/worker_api.py | 13 ++-- batch/batch/driver/job.py | 2 +- batch/batch/front_end/front_end.py | 37 +++++------ batch/batch/worker/worker.py | 7 +-- batch/test/test_invariants.py | 5 +- ci/test/test_ci.py | 7 +-- gear/gear/auth.py | 5 +- .../hailtop/aiocloud/aiogoogle/credentials.py | 63 ++++++++++--------- .../python/hailtop/aiocloud/common/session.py | 4 +- hail/python/hailtop/auth/auth.py | 44 ++++++------- hail/python/hailtop/httpx.py | 24 +++++++ hail/python/hailtop/utils/__init__.py | 5 +- hail/python/hailtop/utils/utils.py | 27 -------- memory/memory/client.py | 17 +++-- monitoring/test/test_monitoring.py | 7 +-- 16 files changed, 137 insertions(+), 159 deletions(-) diff --git a/batch/batch/cloud/azure/worker/worker_api.py b/batch/batch/cloud/azure/worker/worker_api.py index b77efa8ccb8..db9967ce5a0 100644 --- a/batch/batch/cloud/azure/worker/worker_api.py +++ b/batch/batch/cloud/azure/worker/worker_api.py @@ -8,7 +8,7 @@ from gear.cloud_config import get_azure_config from hailtop import httpx from hailtop.aiocloud import aioazure -from hailtop.utils import check_exec_output, request_retry_transient_errors, time_msecs +from hailtop.utils import check_exec_output, retry_transient_errors, time_msecs from ....worker.worker_api import CloudWorkerAPI from ..instance_config import AzureSlimInstanceConfig @@ -137,18 +137,16 @@ class AadAccessToken(LazyShortLivedToken): async def _fetch(self, session: httpx.ClientSession) -> Tuple[str, int]: # https://docs.microsoft.com/en-us/azure/active-directory/managed-identities-azure-resources/how-to-use-vm-token#get-a-token-using-http params = {'api-version': '2018-02-01', 'resource': 'https://management.azure.com/'} - async with await request_retry_transient_errors( - session, - 'GET', + resp_json = await retry_transient_errors( + session.get_read_json, 'http://169.254.169.254/metadata/identity/oauth2/token', headers={'Metadata': 'true'}, params=params, timeout=aiohttp.ClientTimeout(total=60), # type: ignore - ) as resp: - resp_json = await resp.json() - access_token: str = resp_json['access_token'] - expiration_time_ms = int(resp_json['expires_on']) * 1000 - return access_token, expiration_time_ms + ) + access_token: str = resp_json['access_token'] + expiration_time_ms = int(resp_json['expires_on']) * 1000 + return access_token, expiration_time_ms class AcrRefreshToken(LazyShortLivedToken): @@ -164,14 +162,13 @@ async def _fetch(self, session: httpx.ClientSession) -> Tuple[str, int]: 'service': self.acr_url, 'access_token': await self.aad_access_token.token(session), } - async with await request_retry_transient_errors( - session, - 'POST', + resp_json = await retry_transient_errors( + session.post_read_json, f'https://{self.acr_url}/oauth2/exchange', headers={'Content-Type': 'application/x-www-form-urlencoded'}, data=data, timeout=aiohttp.ClientTimeout(total=60), # type: ignore - ) as resp: - refresh_token: str = (await resp.json())['refresh_token'] - expiration_time_ms = time_msecs() + 60 * 60 * 1000 # token expires in 3 hours so we refresh after 1 hour - return refresh_token, expiration_time_ms + ) + refresh_token: str = resp_json['refresh_token'] + expiration_time_ms = time_msecs() + 60 * 60 * 1000 # token expires in 3 hours so we refresh after 1 hour + return refresh_token, expiration_time_ms diff --git a/batch/batch/cloud/gcp/worker/worker_api.py b/batch/batch/cloud/gcp/worker/worker_api.py index 1bf04fca64c..013637ce446 100644 --- a/batch/batch/cloud/gcp/worker/worker_api.py +++ b/batch/batch/cloud/gcp/worker/worker_api.py @@ -6,7 +6,7 @@ from hailtop import httpx from hailtop.aiocloud import aiogoogle -from hailtop.utils import check_exec_output, request_retry_transient_errors +from hailtop.utils import check_exec_output, retry_transient_errors from ....worker.worker_api import CloudWorkerAPI from ..instance_config import GCPSlimInstanceConfig @@ -53,15 +53,14 @@ def user_credentials(self, credentials: Dict[str, str]) -> GCPUserCredentials: return GCPUserCredentials(credentials) async def worker_access_token(self, session: httpx.ClientSession) -> Dict[str, str]: - async with await request_retry_transient_errors( - session, - 'POST', + token_dict = await retry_transient_errors( + session.post_read_json, 'http://169.254.169.254/computeMetadata/v1/instance/service-accounts/default/token', headers={'Metadata-Flavor': 'Google'}, timeout=aiohttp.ClientTimeout(total=60), # type: ignore - ) as resp: - access_token = (await resp.json())['access_token'] - return {'username': 'oauth2accesstoken', 'password': access_token} + ) + access_token = token_dict['access_token'] + return {'username': 'oauth2accesstoken', 'password': access_token} def instance_config_from_config_dict(self, config_dict: Dict[str, str]) -> GCPSlimInstanceConfig: return GCPSlimInstanceConfig.from_dict(config_dict) diff --git a/batch/batch/driver/job.py b/batch/batch/driver/job.py index 6bac023c30d..9ee2b913cd8 100644 --- a/batch/batch/driver/job.py +++ b/batch/batch/driver/job.py @@ -70,7 +70,7 @@ async def request(session): # only jobs from CI may use batch's TLS identity await request(client_session) else: - async with aiohttp.ClientSession(raise_for_status=True, timeout=aiohttp.ClientTimeout(total=5)) as session: + async with httpx.client_session() as session: await request(session) except asyncio.CancelledError: raise diff --git a/batch/batch/front_end/front_end.py b/batch/batch/front_end/front_end.py index fb832dae9b4..1c79ce19c83 100644 --- a/batch/batch/front_end/front_end.py +++ b/batch/batch/front_end/front_end.py @@ -50,8 +50,8 @@ dump_all_stacktraces, humanize_timedelta_msecs, periodically_call, - request_retry_transient_errors, retry_long_running, + retry_transient_errors, run_if_changed, time_msecs, time_msecs_str, @@ -421,12 +421,10 @@ def attempt_id_from_spec(record) -> Optional[str]: async def _get_job_container_log_from_worker(client_session, batch_id, job_id, container, ip_address) -> bytes: try: - async with await request_retry_transient_errors( - client_session, - 'GET', + return await retry_transient_errors( + client_session.get_read, f'http://{ip_address}:5000/api/v1alpha/batches/{batch_id}/jobs/{job_id}/log/{container}', - ) as resp: - return await resp.read() + ) except aiohttp.ClientResponseError: log.exception(f'while getting log for {(batch_id, job_id)}') return b'ERROR: encountered a problem while fetching the log' @@ -489,12 +487,10 @@ async def _get_job_resource_usage(app, batch_id, job_id): if state == 'Running': try: - resp = await request_retry_transient_errors( - client_session, - 'GET', + data = await retry_transient_errors( + client_session.get_read_json, f'http://{ip_address}:5000/api/v1alpha/batches/{batch_id}/jobs/{job_id}/resource_usage', ) - data = await resp.json() return { task: ResourceUsageMonitor.decode_to_df(base64.b64decode(encoded_df)) for task, encoded_df in data.items() @@ -621,10 +617,10 @@ async def _get_full_job_status(app, record): ip_address = record['ip_address'] try: - resp = await request_retry_transient_errors( - client_session, 'GET', f'http://{ip_address}:5000/api/v1alpha/batches/{batch_id}/jobs/{job_id}/status' + return await retry_transient_errors( + client_session.get_read_json, + f'http://{ip_address}:5000/api/v1alpha/batches/{batch_id}/jobs/{job_id}/status', ) - return await resp.json() except aiohttp.ClientResponseError as e: if e.status == 404: return None @@ -1740,9 +1736,8 @@ async def _commit_update(app: web.Application, batch_id: int, update_id: int, us raise app['task_manager'].ensure_future( - request_retry_transient_errors( - client_session, - 'PATCH', + retry_transient_errors( + client_session.patch, deploy_config.url('batch-driver', f'/api/v1alpha/batches/{user}/{batch_id}/update'), headers=app['batch_headers'], ) @@ -2879,9 +2874,8 @@ async def index(request, userdata): # pylint: disable=unused-argument async def cancel_batch_loop_body(app): client_session: httpx.ClientSession = app['client_session'] - await request_retry_transient_errors( - client_session, - 'POST', + await retry_transient_errors( + client_session.post, deploy_config.url('batch-driver', '/api/v1alpha/batches/cancel'), headers=app['batch_headers'], ) @@ -2892,9 +2886,8 @@ async def cancel_batch_loop_body(app): async def delete_batch_loop_body(app): client_session: httpx.ClientSession = app['client_session'] - await request_retry_transient_errors( - client_session, - 'POST', + await retry_transient_errors( + client_session.post, deploy_config.url('batch-driver', '/api/v1alpha/batches/delete'), headers=app['batch_headers'], ) diff --git a/batch/batch/worker/worker.py b/batch/batch/worker/worker.py index 70be821bdda..d48754d6dcb 100644 --- a/batch/batch/worker/worker.py +++ b/batch/batch/worker/worker.py @@ -61,7 +61,6 @@ is_delayed_warning_error, parse_docker_image_reference, periodically_call, - request_retry_transient_errors, retry_transient_errors, retry_transient_errors_with_debug_string, retry_transient_errors_with_delayed_warnings, @@ -3209,14 +3208,12 @@ async def post_job_started(self, job): async def activate(self): log.info('activating') - resp = await request_retry_transient_errors( - self.client_session, - 'POST', + resp_json = await retry_transient_errors( + self.client_session.post_read_json, deploy_config.url('batch-driver', '/api/v1alpha/instances/activate'), json={'ip_address': os.environ['IP_ADDRESS']}, headers={'X-Hail-Instance-Name': NAME, 'Authorization': f'Bearer {os.environ["ACTIVATION_TOKEN"]}'}, ) - resp_json = await resp.json() self.headers = {'X-Hail-Instance-Name': NAME, 'Authorization': f'Bearer {resp_json["token"]}'} self.active = True self.last_updated = time_msecs() diff --git a/batch/test/test_invariants.py b/batch/test/test_invariants.py index def03eb6915..545319b8983 100644 --- a/batch/test/test_invariants.py +++ b/batch/test/test_invariants.py @@ -3,10 +3,10 @@ import aiohttp import pytest -from hailtop import utils from hailtop.auth import hail_credentials from hailtop.config import get_deploy_config from hailtop.httpx import client_session +from hailtop.utils import retry_transient_errors pytestmark = pytest.mark.asyncio @@ -20,8 +20,7 @@ async def test_invariants(): headers = await hail_credentials().auth_headers() async with client_session(timeout=aiohttp.ClientTimeout(total=60)) as session: - resp = await utils.request_retry_transient_errors(session, 'GET', url, headers=headers) - data = await resp.json() + data = await retry_transient_errors(session.get_read_json, url, headers=headers) assert data['check_incremental_error'] is None, data assert data['check_resource_aggregation_error'] is None, data diff --git a/ci/test/test_ci.py b/ci/test/test_ci.py index 28e91c7bcaf..112afecef64 100644 --- a/ci/test/test_ci.py +++ b/ci/test/test_ci.py @@ -4,10 +4,10 @@ import pytest -from hailtop import utils from hailtop.auth import hail_credentials from hailtop.config import get_deploy_config from hailtop.httpx import client_session +from hailtop.utils import retry_transient_errors logging.basicConfig(level=logging.INFO) log = logging.getLogger(__name__) @@ -24,10 +24,9 @@ async def wait_forever(): deploy_state = None failure_information = None while deploy_state is None: - resp = await utils.request_retry_transient_errors( - session, 'GET', f'{ci_deploy_status_url}', headers=headers + deploy_statuses = await retry_transient_errors( + session.get_read_json, ci_deploy_status_url, headers=headers ) - deploy_statuses = await resp.json() log.info(f'deploy_statuses:\n{json.dumps(deploy_statuses, indent=2)}') assert len(deploy_statuses) == 1, deploy_statuses deploy_status = deploy_statuses[0] diff --git a/gear/gear/auth.py b/gear/gear/auth.py index 053d6131670..3d23706d56c 100644 --- a/gear/gear/auth.py +++ b/gear/gear/auth.py @@ -10,7 +10,7 @@ from hailtop import httpx from hailtop.config import get_deploy_config -from hailtop.utils import request_retry_transient_errors +from hailtop.utils import retry_transient_errors from .time_limited_max_size_cache import TimeLimitedMaxSizeCache @@ -129,8 +129,7 @@ async def impersonate_user_and_get_info(session_id: str, client_session: httpx.C headers = {'Authorization': f'Bearer {session_id}'} userinfo_url = deploy_config.url('auth', '/api/v1alpha/userinfo') try: - resp = await request_retry_transient_errors(client_session, 'GET', userinfo_url, headers=headers) - return await resp.json() + return await retry_transient_errors(client_session.get_read_json, userinfo_url, headers=headers) except aiohttp.ClientResponseError as err: if err.status == 401: return None diff --git a/hail/python/hailtop/aiocloud/aiogoogle/credentials.py b/hail/python/hailtop/aiocloud/aiogoogle/credentials.py index 62112e9cef2..1f182c64b7d 100644 --- a/hail/python/hailtop/aiocloud/aiogoogle/credentials.py +++ b/hail/python/hailtop/aiocloud/aiogoogle/credentials.py @@ -6,7 +6,7 @@ import socket from urllib.parse import urlencode import jwt -from hailtop.utils import request_retry_transient_errors +from hailtop.utils import retry_transient_errors import hailtop.httpx from ..common.credentials import AnonymousCloudCredentials, CloudCredentials @@ -108,19 +108,20 @@ def __str__(self): return 'ApplicationDefaultCredentials' async def _get_access_token(self) -> GoogleExpiringAccessToken: - async with await request_retry_transient_errors( - self._http_session, 'POST', - 'https://www.googleapis.com/oauth2/v4/token', - headers={ - 'content-type': 'application/x-www-form-urlencoded' - }, - data=urlencode({ - 'grant_type': 'refresh_token', - 'client_id': self.credentials['client_id'], - 'client_secret': self.credentials['client_secret'], - 'refresh_token': self.credentials['refresh_token'] - })) as resp: - return GoogleExpiringAccessToken.from_dict(await resp.json()) + token_dict = await retry_transient_errors( + self._http_session.post_read_json, + 'https://www.googleapis.com/oauth2/v4/token', + headers={ + 'content-type': 'application/x-www-form-urlencoded' + }, + data=urlencode({ + 'grant_type': 'refresh_token', + 'client_id': self.credentials['client_id'], + 'client_secret': self.credentials['client_secret'], + 'refresh_token': self.credentials['refresh_token'] + }) + ) + return GoogleExpiringAccessToken.from_dict(token_dict) # protocol documented here: @@ -145,27 +146,29 @@ async def _get_access_token(self) -> GoogleExpiringAccessToken: "iss": self.key['client_email'] } encoded_assertion = jwt.encode(assertion, self.key['private_key'], algorithm='RS256') - async with await request_retry_transient_errors( - self._http_session, 'POST', - 'https://www.googleapis.com/oauth2/v4/token', - headers={ - 'content-type': 'application/x-www-form-urlencoded' - }, - data=urlencode({ - 'grant_type': 'urn:ietf:params:oauth:grant-type:jwt-bearer', - 'assertion': encoded_assertion - })) as resp: - return GoogleExpiringAccessToken.from_dict(await resp.json()) + token_dict = await retry_transient_errors( + self._http_session.post_read_json, + 'https://www.googleapis.com/oauth2/v4/token', + headers={ + 'content-type': 'application/x-www-form-urlencoded' + }, + data=urlencode({ + 'grant_type': 'urn:ietf:params:oauth:grant-type:jwt-bearer', + 'assertion': encoded_assertion + }) + ) + return GoogleExpiringAccessToken.from_dict(token_dict) # https://cloud.google.com/compute/docs/access/create-enable-service-accounts-for-instances#applications class GoogleInstanceMetadataCredentials(GoogleCredentials): async def _get_access_token(self) -> GoogleExpiringAccessToken: - async with await request_retry_transient_errors( - self._http_session, 'GET', - 'http://metadata.google.internal/computeMetadata/v1/instance/service-accounts/default/token', - headers={'Metadata-Flavor': 'Google'}) as resp: - return GoogleExpiringAccessToken.from_dict(await resp.json()) + token_dict = await retry_transient_errors( + self._http_session.get_read_json, + 'http://metadata.google.internal/computeMetadata/v1/instance/service-accounts/default/token', + headers={'Metadata-Flavor': 'Google'} + ) + return GoogleExpiringAccessToken.from_dict(token_dict) @staticmethod def available(): diff --git a/hail/python/hailtop/aiocloud/common/session.py b/hail/python/hailtop/aiocloud/common/session.py index a363f2556cf..dcea86811fc 100644 --- a/hail/python/hailtop/aiocloud/common/session.py +++ b/hail/python/hailtop/aiocloud/common/session.py @@ -3,7 +3,7 @@ import aiohttp import abc from hailtop import httpx -from hailtop.utils import request_retry_transient_errors, RateLimit, RateLimiter +from hailtop.utils import retry_transient_errors, RateLimit, RateLimiter from .credentials import CloudCredentials SessionType = TypeVar('SessionType', bound='BaseSession') @@ -104,7 +104,7 @@ async def request(self, method: str, url: str, **kwargs) -> aiohttp.ClientRespon # retry by default retry = kwargs.pop('retry', True) if retry: - return await request_retry_transient_errors(self._http_session, method, url, **kwargs) + return await retry_transient_errors(self._http_session.request, method, url, **kwargs) return await self._http_session.request(method, url, **kwargs) async def close(self) -> None: diff --git a/hail/python/hailtop/auth/auth.py b/hail/python/hailtop/auth/auth.py index a6ff071d738..52b617f5bec 100644 --- a/hail/python/hailtop/auth/auth.py +++ b/hail/python/hailtop/auth/auth.py @@ -2,10 +2,11 @@ import os import aiohttp +from hailtop import httpx from hailtop.aiocloud.common.credentials import CloudCredentials from hailtop.aiocloud.common import Session from hailtop.config import get_deploy_config, DeployConfig -from hailtop.utils import async_to_blocking, request_retry_transient_errors +from hailtop.utils import async_to_blocking, retry_transient_errors from .tokens import Tokens, get_tokens @@ -97,14 +98,12 @@ def copy_paste_login(copy_paste_token: str, namespace: Optional[str] = None): async def async_copy_paste_login(copy_paste_token: str, namespace: Optional[str] = None): deploy_config, headers, namespace = deploy_config_and_headers_from_namespace(namespace, authorize_target=False) - async with aiohttp.ClientSession( - raise_for_status=True, - timeout=aiohttp.ClientTimeout(total=5), - headers=headers) as session: - async with await request_retry_transient_errors( - session, 'POST', deploy_config.url('auth', '/api/v1alpha/copy-paste-login'), - params={'copy_paste_token': copy_paste_token}) as resp: - data = await resp.json() + async with httpx.client_session(headers=headers) as session: + data = await retry_transient_errors( + session.post_read_json, + deploy_config.url('auth', '/api/v1alpha/copy-paste-login'), + params={'copy_paste_token': copy_paste_token} + ) token = data['token'] username = data['username'] @@ -125,13 +124,13 @@ def get_user(username: str, namespace: Optional[str] = None) -> dict: async def async_get_user(username: str, namespace: Optional[str] = None) -> dict: deploy_config, headers, _ = deploy_config_and_headers_from_namespace(namespace) - async with aiohttp.ClientSession( - raise_for_status=True, + async with httpx.client_session( timeout=aiohttp.ClientTimeout(total=30), headers=headers) as session: - async with await request_retry_transient_errors( - session, 'GET', deploy_config.url('auth', f'/api/v1alpha/users/{username}')) as resp: - return await resp.json() + return await retry_transient_errors( + session.get_read_json, + deploy_config.url('auth', f'/api/v1alpha/users/{username}') + ) def create_user(username: str, login_id: str, is_developer: bool, is_service_account: bool, namespace: Optional[str] = None): @@ -147,12 +146,13 @@ async def async_create_user(username: str, login_id: str, is_developer: bool, is 'is_service_account': is_service_account, } - async with aiohttp.ClientSession( - raise_for_status=True, + async with httpx.client_session( timeout=aiohttp.ClientTimeout(total=30), headers=headers) as session: - await request_retry_transient_errors( - session, 'POST', deploy_config.url('auth', f'/api/v1alpha/users/{username}/create'), json=body + await retry_transient_errors( + session.post, + deploy_config.url('auth', f'/api/v1alpha/users/{username}/create'), + json=body ) @@ -162,10 +162,10 @@ def delete_user(username: str, namespace: Optional[str] = None): async def async_delete_user(username: str, namespace: Optional[str] = None): deploy_config, headers, _ = deploy_config_and_headers_from_namespace(namespace) - async with aiohttp.ClientSession( - raise_for_status=True, + async with httpx.client_session( timeout=aiohttp.ClientTimeout(total=300), headers=headers) as session: - await request_retry_transient_errors( - session, 'DELETE', deploy_config.url('auth', f'/api/v1alpha/users/{username}') + await retry_transient_errors( + session.delete, + deploy_config.url('auth', f'/api/v1alpha/users/{username}') ) diff --git a/hail/python/hailtop/httpx.py b/hail/python/hailtop/httpx.py index 15f5ad34f4d..9ede8ea6db6 100644 --- a/hail/python/hailtop/httpx.py +++ b/hail/python/hailtop/httpx.py @@ -160,6 +160,18 @@ def get( ) -> aiohttp.client._RequestContextManager: return self.request('GET', url, allow_redirects=allow_redirects, **kwargs) + async def get_read_json( + self, *args, **kwargs + ) -> Any: + async with self.get(*args, **kwargs) as resp: + return await resp.json() + + async def get_read( + self, *args, **kwargs + ) -> bytes: + async with self.get(*args, **kwargs) as resp: + return await resp.read() + def options( self, url: aiohttp.client.StrOrURL, *, allow_redirects: bool = True, **kwargs: Any ) -> aiohttp.client._RequestContextManager: @@ -175,6 +187,18 @@ def post( ) -> aiohttp.client._RequestContextManager: return self.request('POST', url, data=data, **kwargs) + async def post_read_json( + self, *args, **kwargs + ) -> Any: + async with self.post(*args, **kwargs) as resp: + return await resp.json() + + async def post_read( + self, *args, **kwargs + ) -> bytes: + async with self.post(*args, **kwargs) as resp: + return await resp.read() + def put( self, url: aiohttp.client.StrOrURL, *, data: Any = None, **kwargs: Any ) -> aiohttp.client._RequestContextManager: diff --git a/hail/python/hailtop/utils/__init__.py b/hail/python/hailtop/utils/__init__.py index d32ca79d108..2d7080f85f2 100644 --- a/hail/python/hailtop/utils/__init__.py +++ b/hail/python/hailtop/utils/__init__.py @@ -3,8 +3,7 @@ time_ns) from .utils import (unzip, async_to_blocking, blocking_to_async, AsyncWorkerPool, bounded_gather, grouped, sync_sleep_and_backoff, sleep_and_backoff, is_transient_error, - request_retry_transient_errors, request_raise_transient_errors, collect_agen, - retry_all_errors, retry_transient_errors, + collect_agen, retry_all_errors, retry_transient_errors, retry_transient_errors_with_debug_string, retry_long_running, run_if_changed, run_if_changed_idempotent, LoggingTimer, WaitableSharedPool, RETRY_FUNCTION_SCRIPT, sync_retry_transient_errors, @@ -59,8 +58,6 @@ 'run_if_changed_idempotent', 'LoggingTimer', 'WaitableSharedPool', - 'request_retry_transient_errors', - 'request_raise_transient_errors', 'collect_agen', 'RETRY_FUNCTION_SCRIPT', 'sync_retry_transient_errors', diff --git a/hail/python/hailtop/utils/utils.py b/hail/python/hailtop/utils/utils.py index 4d449518679..114db1d9866 100644 --- a/hail/python/hailtop/utils/utils.py +++ b/hail/python/hailtop/utils/utils.py @@ -14,7 +14,6 @@ import logging import asyncio import aiohttp -from aiohttp import web import urllib import urllib3 import secrets @@ -838,32 +837,6 @@ def sync_retry_transient_errors(f, *args, **kwargs): delay = sync_sleep_and_backoff(delay) -async def request_retry_transient_errors( - session, # : Union[httpx.ClientSession, aiohttp.ClientSession] - method: str, - url, - **kwargs -) -> aiohttp.ClientResponse: - return await retry_transient_errors(session.request, method, url, **kwargs) - - -async def request_raise_transient_errors( - session, # : Union[httpx.ClientSession, aiohttp.ClientSession] - method: str, - url, - **kwargs -) -> aiohttp.ClientResponse: - try: - return await session.request(method, url, **kwargs) - except KeyboardInterrupt: - raise - except Exception as e: - if is_transient_error(e): - log.exception('request failed with transient exception: {method} {url}') - raise web.HTTPServiceUnavailable() - raise - - def retry_response_returning_functions(fun, *args, **kwargs): delay = 0.1 errors = 0 diff --git a/memory/memory/client.py b/memory/memory/client.py index 1677515bc5b..ef8f1fe7159 100644 --- a/memory/memory/client.py +++ b/memory/memory/client.py @@ -4,7 +4,7 @@ from hailtop.auth import hail_credentials from hailtop.config import get_deploy_config from hailtop.httpx import client_session -from hailtop.utils import request_retry_transient_errors +from hailtop.utils import retry_transient_errors class MemoryClient: @@ -34,10 +34,9 @@ async def async_init(self): async def _get_file_if_exists(self, filename): params = {'q': filename} try: - async with await request_retry_transient_errors( - self._session, 'get', self.objects_url, params=params, headers=self._headers - ) as response: - return await response.read() + return await retry_transient_errors( + self._session.get_read, self.objects_url, params=params, headers=self._headers + ) except aiohttp.ClientResponseError as e: if e.status == 404: return None @@ -51,10 +50,10 @@ async def read_file(self, filename): async def write_file(self, filename, data): params = {'q': filename} - async with await request_retry_transient_errors( - self._session, 'post', self.objects_url, params=params, headers=self._headers, data=data - ) as response: - assert response.status == 200 + response = await retry_transient_errors( + self._session.post, self.objects_url, params=params, headers=self._headers, data=data + ) + assert response.status == 200 async def close(self): await self._session.close() diff --git a/monitoring/test/test_monitoring.py b/monitoring/test/test_monitoring.py index f1840465414..207b87921e6 100644 --- a/monitoring/test/test_monitoring.py +++ b/monitoring/test/test_monitoring.py @@ -3,10 +3,10 @@ import pytest -from hailtop import utils from hailtop.auth import hail_credentials from hailtop.config import get_deploy_config from hailtop.httpx import client_session +from hailtop.utils import retry_transient_errors logging.basicConfig(level=logging.INFO) log = logging.getLogger(__name__) @@ -22,10 +22,9 @@ async def test_billing_monitoring(): async def wait_forever(): data = None while data is None: - resp = await utils.request_retry_transient_errors( - session, 'GET', f'{monitoring_deploy_config_url}', headers=headers + data = await retry_transient_errors( + session.get_read_json, monitoring_deploy_config_url, headers=headers ) - data = await resp.json() await asyncio.sleep(5) return data