Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[services] reliably retry all requests #13029

Merged
merged 7 commits into from
May 12, 2023
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 13 additions & 16 deletions batch/batch/cloud/azure/worker/worker_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from gear.cloud_config import get_azure_config
from hailtop import httpx
from hailtop.aiocloud import aioazure
from hailtop.utils import request_retry_transient_errors, time_msecs
from hailtop.utils import retry_transient_errors, time_msecs

from ....worker.worker_api import CloudWorkerAPI
from ..instance_config import AzureSlimInstanceConfig
Expand Down Expand Up @@ -116,18 +116,16 @@ class AadAccessToken(LazyShortLivedToken):
async def _fetch(self, session: httpx.ClientSession) -> Tuple[str, int]:
# https://docs.microsoft.com/en-us/azure/active-directory/managed-identities-azure-resources/how-to-use-vm-token#get-a-token-using-http
params = {'api-version': '2018-02-01', 'resource': 'https://management.azure.com/'}
async with await request_retry_transient_errors(
session,
'GET',
resp_json = await retry_transient_errors(
session.get_read_json,
'http://169.254.169.254/metadata/identity/oauth2/token',
headers={'Metadata': 'true'},
params=params,
timeout=aiohttp.ClientTimeout(total=60), # type: ignore
) as resp:
resp_json = await resp.json()
access_token: str = resp_json['access_token']
expiration_time_ms = int(resp_json['expires_on']) * 1000
return access_token, expiration_time_ms
)
access_token: str = resp_json['access_token']
expiration_time_ms = int(resp_json['expires_on']) * 1000
return access_token, expiration_time_ms


class AcrRefreshToken(LazyShortLivedToken):
Expand All @@ -143,14 +141,13 @@ async def _fetch(self, session: httpx.ClientSession) -> Tuple[str, int]:
'service': self.acr_url,
'access_token': await self.aad_access_token.token(session),
}
async with await request_retry_transient_errors(
session,
'POST',
resp_json = await retry_transient_errors(
session.post_read_json,
f'https://{self.acr_url}/oauth2/exchange',
headers={'Content-Type': 'application/x-www-form-urlencoded'},
data=data,
timeout=aiohttp.ClientTimeout(total=60), # type: ignore
) as resp:
refresh_token: str = (await resp.json())['refresh_token']
expiration_time_ms = time_msecs() + 60 * 60 * 1000 # token expires in 3 hours so we refresh after 1 hour
return refresh_token, expiration_time_ms
)
refresh_token: str = resp_json['refresh_token']
expiration_time_ms = time_msecs() + 60 * 60 * 1000 # token expires in 3 hours so we refresh after 1 hour
return refresh_token, expiration_time_ms
13 changes: 6 additions & 7 deletions batch/batch/cloud/gcp/worker/worker_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from hailtop import httpx
from hailtop.aiocloud import aiogoogle
from hailtop.utils import request_retry_transient_errors
from hailtop.utils import retry_transient_errors

from ....worker.worker_api import CloudWorkerAPI
from ..instance_config import GCPSlimInstanceConfig
Expand Down Expand Up @@ -51,15 +51,14 @@ def user_credentials(self, credentials: Dict[str, str]) -> GCPUserCredentials:
return GCPUserCredentials(credentials)

async def worker_access_token(self, session: httpx.ClientSession) -> Dict[str, str]:
async with await request_retry_transient_errors(
session,
'POST',
token_dict = await retry_transient_errors(
session.post_read_json,
'http://169.254.169.254/computeMetadata/v1/instance/service-accounts/default/token',
headers={'Metadata-Flavor': 'Google'},
timeout=aiohttp.ClientTimeout(total=60), # type: ignore
) as resp:
access_token = (await resp.json())['access_token']
return {'username': 'oauth2accesstoken', 'password': access_token}
)
access_token = token_dict['access_token']
return {'username': 'oauth2accesstoken', 'password': access_token}

def instance_config_from_config_dict(self, config_dict: Dict[str, str]) -> GCPSlimInstanceConfig:
return GCPSlimInstanceConfig.from_dict(config_dict)
Expand Down
2 changes: 1 addition & 1 deletion batch/batch/driver/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ async def request(session):
# only jobs from CI may use batch's TLS identity
await request(client_session)
else:
async with aiohttp.ClientSession(raise_for_status=True, timeout=aiohttp.ClientTimeout(total=5)) as session:
async with httpx.client_session() as session:
await request(session)
except asyncio.CancelledError:
raise
Expand Down
37 changes: 15 additions & 22 deletions batch/batch/front_end/front_end.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@
dump_all_stacktraces,
humanize_timedelta_msecs,
periodically_call,
request_retry_transient_errors,
retry_long_running,
retry_transient_errors,
run_if_changed,
time_msecs,
time_msecs_str,
Expand Down Expand Up @@ -419,12 +419,10 @@ def attempt_id_from_spec(record) -> Optional[str]:

async def _get_job_container_log_from_worker(client_session, batch_id, job_id, container, ip_address) -> bytes:
try:
async with await request_retry_transient_errors(
client_session,
'GET',
return await retry_transient_errors(
client_session.get_read,
f'http://{ip_address}:5000/api/v1alpha/batches/{batch_id}/jobs/{job_id}/log/{container}',
) as resp:
return await resp.read()
)
except aiohttp.ClientResponseError:
log.exception(f'while getting log for {(batch_id, job_id)}')
return b'ERROR: encountered a problem while fetching the log'
Expand Down Expand Up @@ -487,12 +485,10 @@ async def _get_job_resource_usage(app, batch_id, job_id):

if state == 'Running':
try:
resp = await request_retry_transient_errors(
client_session,
'GET',
data = await retry_transient_errors(
client_session.get_read_json,
f'http://{ip_address}:5000/api/v1alpha/batches/{batch_id}/jobs/{job_id}/resource_usage',
)
data = await resp.json()
return {
task: ResourceUsageMonitor.decode_to_df(base64.b64decode(encoded_df))
for task, encoded_df in data.items()
Expand Down Expand Up @@ -595,10 +591,10 @@ async def _get_full_job_status(app, record):

ip_address = record['ip_address']
try:
resp = await request_retry_transient_errors(
client_session, 'GET', f'http://{ip_address}:5000/api/v1alpha/batches/{batch_id}/jobs/{job_id}/status'
return await retry_transient_errors(
client_session.get_read_json,
f'http://{ip_address}:5000/api/v1alpha/batches/{batch_id}/jobs/{job_id}/status',
)
return await resp.json()
except aiohttp.ClientResponseError as e:
if e.status == 404:
return None
Expand Down Expand Up @@ -1714,9 +1710,8 @@ async def _commit_update(app: web.Application, batch_id: int, update_id: int, us
raise

app['task_manager'].ensure_future(
request_retry_transient_errors(
client_session,
'PATCH',
retry_transient_errors(
client_session.patch,
deploy_config.url('batch-driver', f'/api/v1alpha/batches/{user}/{batch_id}/update'),
headers=app['batch_headers'],
)
Expand Down Expand Up @@ -2837,9 +2832,8 @@ async def index(request, userdata): # pylint: disable=unused-argument

async def cancel_batch_loop_body(app):
client_session: httpx.ClientSession = app['client_session']
await request_retry_transient_errors(
client_session,
'POST',
await retry_transient_errors(
client_session.post,
deploy_config.url('batch-driver', '/api/v1alpha/batches/cancel'),
headers=app['batch_headers'],
)
Expand All @@ -2850,9 +2844,8 @@ async def cancel_batch_loop_body(app):

async def delete_batch_loop_body(app):
client_session: httpx.ClientSession = app['client_session']
await request_retry_transient_errors(
client_session,
'POST',
await retry_transient_errors(
client_session.post,
deploy_config.url('batch-driver', '/api/v1alpha/batches/delete'),
headers=app['batch_headers'],
)
Expand Down
7 changes: 2 additions & 5 deletions batch/batch/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@
is_delayed_warning_error,
parse_docker_image_reference,
periodically_call,
request_retry_transient_errors,
retry_transient_errors,
retry_transient_errors_with_debug_string,
retry_transient_errors_with_delayed_warnings,
Expand Down Expand Up @@ -3130,14 +3129,12 @@ async def post_job_started(self, job):

async def activate(self):
log.info('activating')
resp = await request_retry_transient_errors(
self.client_session,
'POST',
resp_json = await retry_transient_errors(
self.client_session.post_read_json,
deploy_config.url('batch-driver', '/api/v1alpha/instances/activate'),
json={'ip_address': os.environ['IP_ADDRESS']},
headers={'X-Hail-Instance-Name': NAME, 'Authorization': f'Bearer {os.environ["ACTIVATION_TOKEN"]}'},
)
resp_json = await resp.json()
self.headers = {'X-Hail-Instance-Name': NAME, 'Authorization': f'Bearer {resp_json["token"]}'}
self.active = True
self.last_updated = time_msecs()
Expand Down
5 changes: 2 additions & 3 deletions batch/test/test_invariants.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
import aiohttp
import pytest

from hailtop import utils
from hailtop.auth import hail_credentials
from hailtop.config import get_deploy_config
from hailtop.httpx import client_session
from hailtop.utils import retry_transient_errors

pytestmark = pytest.mark.asyncio

Expand All @@ -20,8 +20,7 @@ async def test_invariants():
headers = await hail_credentials().auth_headers()
async with client_session(timeout=aiohttp.ClientTimeout(total=60)) as session:

resp = await utils.request_retry_transient_errors(session, 'GET', url, headers=headers)
data = await resp.json()
data = await retry_transient_errors(session.get_read_json, url, headers=headers)

assert data['check_incremental_error'] is None, data
assert data['check_resource_aggregation_error'] is None, data
7 changes: 3 additions & 4 deletions ci/test/test_ci.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@

import pytest

from hailtop import utils
from hailtop.auth import hail_credentials
from hailtop.config import get_deploy_config
from hailtop.httpx import client_session
from hailtop.utils import retry_transient_errors

logging.basicConfig(level=logging.INFO)
log = logging.getLogger(__name__)
Expand All @@ -24,10 +24,9 @@ async def wait_forever():
deploy_state = None
failure_information = None
while deploy_state is None:
resp = await utils.request_retry_transient_errors(
session, 'GET', f'{ci_deploy_status_url}', headers=headers
deploy_statuses = await retry_transient_errors(
session.get_read_json, ci_deploy_status_url, headers=headers
)
deploy_statuses = await resp.json()
log.info(f'deploy_statuses:\n{json.dumps(deploy_statuses, indent=2)}')
assert len(deploy_statuses) == 1, deploy_statuses
deploy_status = deploy_statuses[0]
Expand Down
5 changes: 2 additions & 3 deletions gear/gear/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from hailtop import httpx
from hailtop.config import get_deploy_config
from hailtop.utils import request_retry_transient_errors
from hailtop.utils import retry_transient_errors

from .time_limited_max_size_cache import TimeLimitedMaxSizeCache

Expand Down Expand Up @@ -129,8 +129,7 @@ async def impersonate_user_and_get_info(session_id: str, client_session: httpx.C
headers = {'Authorization': f'Bearer {session_id}'}
userinfo_url = deploy_config.url('auth', '/api/v1alpha/userinfo')
try:
resp = await request_retry_transient_errors(client_session, 'GET', userinfo_url, headers=headers)
return await resp.json()
return await retry_transient_errors(client_session.get_read_json, userinfo_url, headers=headers)
except aiohttp.ClientResponseError as err:
if err.status == 401:
return None
Expand Down
63 changes: 33 additions & 30 deletions hail/python/hailtop/aiocloud/aiogoogle/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import socket
from urllib.parse import urlencode
import jwt
from hailtop.utils import request_retry_transient_errors
from hailtop.utils import retry_transient_errors
import hailtop.httpx
from ..common.credentials import AnonymousCloudCredentials, CloudCredentials

Expand Down Expand Up @@ -108,19 +108,20 @@ def __str__(self):
return 'ApplicationDefaultCredentials'

async def _get_access_token(self) -> GoogleExpiringAccessToken:
async with await request_retry_transient_errors(
self._http_session, 'POST',
'https://www.googleapis.com/oauth2/v4/token',
headers={
'content-type': 'application/x-www-form-urlencoded'
},
data=urlencode({
'grant_type': 'refresh_token',
'client_id': self.credentials['client_id'],
'client_secret': self.credentials['client_secret'],
'refresh_token': self.credentials['refresh_token']
})) as resp:
return GoogleExpiringAccessToken.from_dict(await resp.json())
token_dict = await retry_transient_errors(
self._http_session.post_read_json,
'https://www.googleapis.com/oauth2/v4/token',
headers={
'content-type': 'application/x-www-form-urlencoded'
},
data=urlencode({
'grant_type': 'refresh_token',
'client_id': self.credentials['client_id'],
'client_secret': self.credentials['client_secret'],
'refresh_token': self.credentials['refresh_token']
})
)
return GoogleExpiringAccessToken.from_dict(token_dict)


# protocol documented here:
Expand All @@ -145,27 +146,29 @@ async def _get_access_token(self) -> GoogleExpiringAccessToken:
"iss": self.key['client_email']
}
encoded_assertion = jwt.encode(assertion, self.key['private_key'], algorithm='RS256')
async with await request_retry_transient_errors(
self._http_session, 'POST',
'https://www.googleapis.com/oauth2/v4/token',
headers={
'content-type': 'application/x-www-form-urlencoded'
},
data=urlencode({
'grant_type': 'urn:ietf:params:oauth:grant-type:jwt-bearer',
'assertion': encoded_assertion
})) as resp:
return GoogleExpiringAccessToken.from_dict(await resp.json())
token_dict = await retry_transient_errors(
self._http_session.post_read_json,
'https://www.googleapis.com/oauth2/v4/token',
headers={
'content-type': 'application/x-www-form-urlencoded'
},
data=urlencode({
'grant_type': 'urn:ietf:params:oauth:grant-type:jwt-bearer',
'assertion': encoded_assertion
})
)
return GoogleExpiringAccessToken.from_dict(token_dict)


# https://cloud.google.com/compute/docs/access/create-enable-service-accounts-for-instances#applications
class GoogleInstanceMetadataCredentials(GoogleCredentials):
async def _get_access_token(self) -> GoogleExpiringAccessToken:
async with await request_retry_transient_errors(
self._http_session, 'GET',
'http://metadata.google.internal/computeMetadata/v1/instance/service-accounts/default/token',
headers={'Metadata-Flavor': 'Google'}) as resp:
return GoogleExpiringAccessToken.from_dict(await resp.json())
token_dict = await retry_transient_errors(
self._http_session.get_read_json,
'http://metadata.google.internal/computeMetadata/v1/instance/service-accounts/default/token',
headers={'Metadata-Flavor': 'Google'}
)
return GoogleExpiringAccessToken.from_dict(token_dict)

@staticmethod
def available():
Expand Down
4 changes: 2 additions & 2 deletions hail/python/hailtop/aiocloud/common/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import aiohttp
import abc
from hailtop import httpx
from hailtop.utils import request_retry_transient_errors, RateLimit, RateLimiter
from hailtop.utils import retry_transient_errors, RateLimit, RateLimiter
from .credentials import CloudCredentials

SessionType = TypeVar('SessionType', bound='BaseSession')
Expand Down Expand Up @@ -104,7 +104,7 @@ async def request(self, method: str, url: str, **kwargs) -> aiohttp.ClientRespon
# retry by default
retry = kwargs.pop('retry', True)
if retry:
return await request_retry_transient_errors(self._http_session, method, url, **kwargs)
return await retry_transient_errors(self._http_session.request, method, url, **kwargs)
return await self._http_session.request(method, url, **kwargs)

async def close(self) -> None:
Expand Down
Loading