Skip to content

Commit

Permalink
[services] reliably retry all requests (#13029)
Browse files Browse the repository at this point in the history
`request_retry_transient_errors` is a charlatan. It does not retry
errors that occur in reading the response body. I eliminated it in favor
of the tried and true `retry_transient_errors` and some new helper
methods that initiate the request *and* read the response.
  • Loading branch information
danking authored May 12, 2023
1 parent d42d89e commit e553ebc
Show file tree
Hide file tree
Showing 16 changed files with 137 additions and 159 deletions.
29 changes: 13 additions & 16 deletions batch/batch/cloud/azure/worker/worker_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from gear.cloud_config import get_azure_config
from hailtop import httpx
from hailtop.aiocloud import aioazure
from hailtop.utils import check_exec_output, request_retry_transient_errors, time_msecs
from hailtop.utils import check_exec_output, retry_transient_errors, time_msecs

from ....worker.worker_api import CloudWorkerAPI
from ..instance_config import AzureSlimInstanceConfig
Expand Down Expand Up @@ -137,18 +137,16 @@ class AadAccessToken(LazyShortLivedToken):
async def _fetch(self, session: httpx.ClientSession) -> Tuple[str, int]:
# https://docs.microsoft.com/en-us/azure/active-directory/managed-identities-azure-resources/how-to-use-vm-token#get-a-token-using-http
params = {'api-version': '2018-02-01', 'resource': 'https://management.azure.com/'}
async with await request_retry_transient_errors(
session,
'GET',
resp_json = await retry_transient_errors(
session.get_read_json,
'http://169.254.169.254/metadata/identity/oauth2/token',
headers={'Metadata': 'true'},
params=params,
timeout=aiohttp.ClientTimeout(total=60), # type: ignore
) as resp:
resp_json = await resp.json()
access_token: str = resp_json['access_token']
expiration_time_ms = int(resp_json['expires_on']) * 1000
return access_token, expiration_time_ms
)
access_token: str = resp_json['access_token']
expiration_time_ms = int(resp_json['expires_on']) * 1000
return access_token, expiration_time_ms


class AcrRefreshToken(LazyShortLivedToken):
Expand All @@ -164,14 +162,13 @@ async def _fetch(self, session: httpx.ClientSession) -> Tuple[str, int]:
'service': self.acr_url,
'access_token': await self.aad_access_token.token(session),
}
async with await request_retry_transient_errors(
session,
'POST',
resp_json = await retry_transient_errors(
session.post_read_json,
f'https://{self.acr_url}/oauth2/exchange',
headers={'Content-Type': 'application/x-www-form-urlencoded'},
data=data,
timeout=aiohttp.ClientTimeout(total=60), # type: ignore
) as resp:
refresh_token: str = (await resp.json())['refresh_token']
expiration_time_ms = time_msecs() + 60 * 60 * 1000 # token expires in 3 hours so we refresh after 1 hour
return refresh_token, expiration_time_ms
)
refresh_token: str = resp_json['refresh_token']
expiration_time_ms = time_msecs() + 60 * 60 * 1000 # token expires in 3 hours so we refresh after 1 hour
return refresh_token, expiration_time_ms
13 changes: 6 additions & 7 deletions batch/batch/cloud/gcp/worker/worker_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from hailtop import httpx
from hailtop.aiocloud import aiogoogle
from hailtop.utils import check_exec_output, request_retry_transient_errors
from hailtop.utils import check_exec_output, retry_transient_errors

from ....worker.worker_api import CloudWorkerAPI
from ..instance_config import GCPSlimInstanceConfig
Expand Down Expand Up @@ -53,15 +53,14 @@ def user_credentials(self, credentials: Dict[str, str]) -> GCPUserCredentials:
return GCPUserCredentials(credentials)

async def worker_access_token(self, session: httpx.ClientSession) -> Dict[str, str]:
async with await request_retry_transient_errors(
session,
'POST',
token_dict = await retry_transient_errors(
session.post_read_json,
'http://169.254.169.254/computeMetadata/v1/instance/service-accounts/default/token',
headers={'Metadata-Flavor': 'Google'},
timeout=aiohttp.ClientTimeout(total=60), # type: ignore
) as resp:
access_token = (await resp.json())['access_token']
return {'username': 'oauth2accesstoken', 'password': access_token}
)
access_token = token_dict['access_token']
return {'username': 'oauth2accesstoken', 'password': access_token}

def instance_config_from_config_dict(self, config_dict: Dict[str, str]) -> GCPSlimInstanceConfig:
return GCPSlimInstanceConfig.from_dict(config_dict)
Expand Down
2 changes: 1 addition & 1 deletion batch/batch/driver/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ async def request(session):
# only jobs from CI may use batch's TLS identity
await request(client_session)
else:
async with aiohttp.ClientSession(raise_for_status=True, timeout=aiohttp.ClientTimeout(total=5)) as session:
async with httpx.client_session() as session:
await request(session)
except asyncio.CancelledError:
raise
Expand Down
37 changes: 15 additions & 22 deletions batch/batch/front_end/front_end.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@
dump_all_stacktraces,
humanize_timedelta_msecs,
periodically_call,
request_retry_transient_errors,
retry_long_running,
retry_transient_errors,
run_if_changed,
time_msecs,
time_msecs_str,
Expand Down Expand Up @@ -421,12 +421,10 @@ def attempt_id_from_spec(record) -> Optional[str]:

async def _get_job_container_log_from_worker(client_session, batch_id, job_id, container, ip_address) -> bytes:
try:
async with await request_retry_transient_errors(
client_session,
'GET',
return await retry_transient_errors(
client_session.get_read,
f'http://{ip_address}:5000/api/v1alpha/batches/{batch_id}/jobs/{job_id}/log/{container}',
) as resp:
return await resp.read()
)
except aiohttp.ClientResponseError:
log.exception(f'while getting log for {(batch_id, job_id)}')
return b'ERROR: encountered a problem while fetching the log'
Expand Down Expand Up @@ -489,12 +487,10 @@ async def _get_job_resource_usage(app, batch_id, job_id):

if state == 'Running':
try:
resp = await request_retry_transient_errors(
client_session,
'GET',
data = await retry_transient_errors(
client_session.get_read_json,
f'http://{ip_address}:5000/api/v1alpha/batches/{batch_id}/jobs/{job_id}/resource_usage',
)
data = await resp.json()
return {
task: ResourceUsageMonitor.decode_to_df(base64.b64decode(encoded_df))
for task, encoded_df in data.items()
Expand Down Expand Up @@ -621,10 +617,10 @@ async def _get_full_job_status(app, record):

ip_address = record['ip_address']
try:
resp = await request_retry_transient_errors(
client_session, 'GET', f'http://{ip_address}:5000/api/v1alpha/batches/{batch_id}/jobs/{job_id}/status'
return await retry_transient_errors(
client_session.get_read_json,
f'http://{ip_address}:5000/api/v1alpha/batches/{batch_id}/jobs/{job_id}/status',
)
return await resp.json()
except aiohttp.ClientResponseError as e:
if e.status == 404:
return None
Expand Down Expand Up @@ -1740,9 +1736,8 @@ async def _commit_update(app: web.Application, batch_id: int, update_id: int, us
raise

app['task_manager'].ensure_future(
request_retry_transient_errors(
client_session,
'PATCH',
retry_transient_errors(
client_session.patch,
deploy_config.url('batch-driver', f'/api/v1alpha/batches/{user}/{batch_id}/update'),
headers=app['batch_headers'],
)
Expand Down Expand Up @@ -2879,9 +2874,8 @@ async def index(request, userdata): # pylint: disable=unused-argument

async def cancel_batch_loop_body(app):
client_session: httpx.ClientSession = app['client_session']
await request_retry_transient_errors(
client_session,
'POST',
await retry_transient_errors(
client_session.post,
deploy_config.url('batch-driver', '/api/v1alpha/batches/cancel'),
headers=app['batch_headers'],
)
Expand All @@ -2892,9 +2886,8 @@ async def cancel_batch_loop_body(app):

async def delete_batch_loop_body(app):
client_session: httpx.ClientSession = app['client_session']
await request_retry_transient_errors(
client_session,
'POST',
await retry_transient_errors(
client_session.post,
deploy_config.url('batch-driver', '/api/v1alpha/batches/delete'),
headers=app['batch_headers'],
)
Expand Down
7 changes: 2 additions & 5 deletions batch/batch/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@
is_delayed_warning_error,
parse_docker_image_reference,
periodically_call,
request_retry_transient_errors,
retry_transient_errors,
retry_transient_errors_with_debug_string,
retry_transient_errors_with_delayed_warnings,
Expand Down Expand Up @@ -3209,14 +3208,12 @@ async def post_job_started(self, job):

async def activate(self):
log.info('activating')
resp = await request_retry_transient_errors(
self.client_session,
'POST',
resp_json = await retry_transient_errors(
self.client_session.post_read_json,
deploy_config.url('batch-driver', '/api/v1alpha/instances/activate'),
json={'ip_address': os.environ['IP_ADDRESS']},
headers={'X-Hail-Instance-Name': NAME, 'Authorization': f'Bearer {os.environ["ACTIVATION_TOKEN"]}'},
)
resp_json = await resp.json()
self.headers = {'X-Hail-Instance-Name': NAME, 'Authorization': f'Bearer {resp_json["token"]}'}
self.active = True
self.last_updated = time_msecs()
Expand Down
5 changes: 2 additions & 3 deletions batch/test/test_invariants.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
import aiohttp
import pytest

from hailtop import utils
from hailtop.auth import hail_credentials
from hailtop.config import get_deploy_config
from hailtop.httpx import client_session
from hailtop.utils import retry_transient_errors

pytestmark = pytest.mark.asyncio

Expand All @@ -20,8 +20,7 @@ async def test_invariants():
headers = await hail_credentials().auth_headers()
async with client_session(timeout=aiohttp.ClientTimeout(total=60)) as session:

resp = await utils.request_retry_transient_errors(session, 'GET', url, headers=headers)
data = await resp.json()
data = await retry_transient_errors(session.get_read_json, url, headers=headers)

assert data['check_incremental_error'] is None, data
assert data['check_resource_aggregation_error'] is None, data
7 changes: 3 additions & 4 deletions ci/test/test_ci.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@

import pytest

from hailtop import utils
from hailtop.auth import hail_credentials
from hailtop.config import get_deploy_config
from hailtop.httpx import client_session
from hailtop.utils import retry_transient_errors

logging.basicConfig(level=logging.INFO)
log = logging.getLogger(__name__)
Expand All @@ -24,10 +24,9 @@ async def wait_forever():
deploy_state = None
failure_information = None
while deploy_state is None:
resp = await utils.request_retry_transient_errors(
session, 'GET', f'{ci_deploy_status_url}', headers=headers
deploy_statuses = await retry_transient_errors(
session.get_read_json, ci_deploy_status_url, headers=headers
)
deploy_statuses = await resp.json()
log.info(f'deploy_statuses:\n{json.dumps(deploy_statuses, indent=2)}')
assert len(deploy_statuses) == 1, deploy_statuses
deploy_status = deploy_statuses[0]
Expand Down
5 changes: 2 additions & 3 deletions gear/gear/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from hailtop import httpx
from hailtop.config import get_deploy_config
from hailtop.utils import request_retry_transient_errors
from hailtop.utils import retry_transient_errors

from .time_limited_max_size_cache import TimeLimitedMaxSizeCache

Expand Down Expand Up @@ -129,8 +129,7 @@ async def impersonate_user_and_get_info(session_id: str, client_session: httpx.C
headers = {'Authorization': f'Bearer {session_id}'}
userinfo_url = deploy_config.url('auth', '/api/v1alpha/userinfo')
try:
resp = await request_retry_transient_errors(client_session, 'GET', userinfo_url, headers=headers)
return await resp.json()
return await retry_transient_errors(client_session.get_read_json, userinfo_url, headers=headers)
except aiohttp.ClientResponseError as err:
if err.status == 401:
return None
Expand Down
63 changes: 33 additions & 30 deletions hail/python/hailtop/aiocloud/aiogoogle/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import socket
from urllib.parse import urlencode
import jwt
from hailtop.utils import request_retry_transient_errors
from hailtop.utils import retry_transient_errors
import hailtop.httpx
from ..common.credentials import AnonymousCloudCredentials, CloudCredentials

Expand Down Expand Up @@ -108,19 +108,20 @@ def __str__(self):
return 'ApplicationDefaultCredentials'

async def _get_access_token(self) -> GoogleExpiringAccessToken:
async with await request_retry_transient_errors(
self._http_session, 'POST',
'https://www.googleapis.com/oauth2/v4/token',
headers={
'content-type': 'application/x-www-form-urlencoded'
},
data=urlencode({
'grant_type': 'refresh_token',
'client_id': self.credentials['client_id'],
'client_secret': self.credentials['client_secret'],
'refresh_token': self.credentials['refresh_token']
})) as resp:
return GoogleExpiringAccessToken.from_dict(await resp.json())
token_dict = await retry_transient_errors(
self._http_session.post_read_json,
'https://www.googleapis.com/oauth2/v4/token',
headers={
'content-type': 'application/x-www-form-urlencoded'
},
data=urlencode({
'grant_type': 'refresh_token',
'client_id': self.credentials['client_id'],
'client_secret': self.credentials['client_secret'],
'refresh_token': self.credentials['refresh_token']
})
)
return GoogleExpiringAccessToken.from_dict(token_dict)


# protocol documented here:
Expand All @@ -145,27 +146,29 @@ async def _get_access_token(self) -> GoogleExpiringAccessToken:
"iss": self.key['client_email']
}
encoded_assertion = jwt.encode(assertion, self.key['private_key'], algorithm='RS256')
async with await request_retry_transient_errors(
self._http_session, 'POST',
'https://www.googleapis.com/oauth2/v4/token',
headers={
'content-type': 'application/x-www-form-urlencoded'
},
data=urlencode({
'grant_type': 'urn:ietf:params:oauth:grant-type:jwt-bearer',
'assertion': encoded_assertion
})) as resp:
return GoogleExpiringAccessToken.from_dict(await resp.json())
token_dict = await retry_transient_errors(
self._http_session.post_read_json,
'https://www.googleapis.com/oauth2/v4/token',
headers={
'content-type': 'application/x-www-form-urlencoded'
},
data=urlencode({
'grant_type': 'urn:ietf:params:oauth:grant-type:jwt-bearer',
'assertion': encoded_assertion
})
)
return GoogleExpiringAccessToken.from_dict(token_dict)


# https://cloud.google.com/compute/docs/access/create-enable-service-accounts-for-instances#applications
class GoogleInstanceMetadataCredentials(GoogleCredentials):
async def _get_access_token(self) -> GoogleExpiringAccessToken:
async with await request_retry_transient_errors(
self._http_session, 'GET',
'http://metadata.google.internal/computeMetadata/v1/instance/service-accounts/default/token',
headers={'Metadata-Flavor': 'Google'}) as resp:
return GoogleExpiringAccessToken.from_dict(await resp.json())
token_dict = await retry_transient_errors(
self._http_session.get_read_json,
'http://metadata.google.internal/computeMetadata/v1/instance/service-accounts/default/token',
headers={'Metadata-Flavor': 'Google'}
)
return GoogleExpiringAccessToken.from_dict(token_dict)

@staticmethod
def available():
Expand Down
4 changes: 2 additions & 2 deletions hail/python/hailtop/aiocloud/common/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import aiohttp
import abc
from hailtop import httpx
from hailtop.utils import request_retry_transient_errors, RateLimit, RateLimiter
from hailtop.utils import retry_transient_errors, RateLimit, RateLimiter
from .credentials import CloudCredentials

SessionType = TypeVar('SessionType', bound='BaseSession')
Expand Down Expand Up @@ -104,7 +104,7 @@ async def request(self, method: str, url: str, **kwargs) -> aiohttp.ClientRespon
# retry by default
retry = kwargs.pop('retry', True)
if retry:
return await request_retry_transient_errors(self._http_session, method, url, **kwargs)
return await retry_transient_errors(self._http_session.request, method, url, **kwargs)
return await self._http_session.request(method, url, **kwargs)

async def close(self) -> None:
Expand Down
Loading

0 comments on commit e553ebc

Please sign in to comment.