Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[batch] Fix async exit stacks #13969

Merged
merged 2 commits into from
Nov 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 13 additions & 6 deletions batch/batch/cloud/azure/driver/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ async def create(
machine_name_prefix: str,
namespace: str,
inst_coll_configs: InstanceCollectionConfigs,
task_manager: aiotools.BackgroundTaskManager, # BORROWED
) -> 'AzureDriver':
azure_config = get_azure_config()
subscription_id = azure_config.subscription_id
Expand Down Expand Up @@ -68,6 +67,8 @@ async def create(
app, subscription_id, resource_group, ssh_public_key, arm_client, compute_client, billing_manager
)

task_manager = aiotools.BackgroundTaskManager()

create_pools_coros = [
Pool.create(
app,
Expand Down Expand Up @@ -110,6 +111,7 @@ async def create(
inst_coll_manager,
jpim,
billing_manager,
task_manager,
)

task_manager.ensure_future(periodically_call(60, driver.delete_orphaned_nics))
Expand All @@ -135,6 +137,7 @@ def __init__(
inst_coll_manager: InstanceCollectionManager,
job_private_inst_manager: JobPrivateInstanceManager,
billing_manager: AzureBillingManager,
task_manager: aiotools.BackgroundTaskManager,
):
self.db = db
self.machine_name_prefix = machine_name_prefix
Expand All @@ -150,6 +153,7 @@ def __init__(
self.job_private_inst_manager = job_private_inst_manager
self._billing_manager = billing_manager
self._inst_coll_manager = inst_coll_manager
self._task_manager = task_manager

@property
def billing_manager(self) -> AzureBillingManager:
Expand All @@ -161,18 +165,21 @@ def inst_coll_manager(self) -> InstanceCollectionManager:

async def shutdown(self) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needn't be this PR, but this should use an exit stack as well.

try:
await self.arm_client.close()
await self._task_manager.shutdown_and_wait()
finally:
try:
await self.compute_client.close()
await self.arm_client.close()
finally:
try:
await self.resources_client.close()
await self.compute_client.close()
finally:
try:
await self.network_client.close()
await self.resources_client.close()
finally:
await self.pricing_client.close()
try:
await self.network_client.close()
finally:
await self.pricing_client.close()

def _resource_is_orphaned(self, resource_name: str) -> bool:
instance_name = resource_name.rsplit('-', maxsplit=1)[0]
Expand Down
6 changes: 2 additions & 4 deletions batch/batch/cloud/driver.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from gear import Database
from gear.cloud_config import get_global_config
from hailtop import aiotools

from ..driver.driver import CloudDriver
from ..inst_coll_config import InstanceCollectionConfigs
Expand All @@ -14,12 +13,11 @@ async def get_cloud_driver(
machine_name_prefix: str,
namespace: str,
inst_coll_configs: InstanceCollectionConfigs,
task_manager: aiotools.BackgroundTaskManager,
) -> CloudDriver:
cloud = get_global_config()['cloud']

if cloud == 'azure':
return await AzureDriver.create(app, db, machine_name_prefix, namespace, inst_coll_configs, task_manager)
return await AzureDriver.create(app, db, machine_name_prefix, namespace, inst_coll_configs)

assert cloud == 'gcp', cloud
return await GCPDriver.create(app, db, machine_name_prefix, namespace, inst_coll_configs, task_manager)
return await GCPDriver.create(app, db, machine_name_prefix, namespace, inst_coll_configs)
15 changes: 11 additions & 4 deletions batch/batch/cloud/gcp/driver/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ async def create(
machine_name_prefix: str,
namespace: str,
inst_coll_configs: InstanceCollectionConfigs,
task_manager: aiotools.BackgroundTaskManager, # BORROWED
) -> 'GCPDriver':
gcp_config = get_gcp_config()
project = gcp_config.project
Expand Down Expand Up @@ -67,6 +66,8 @@ async def create(
inst_coll_manager = InstanceCollectionManager(db, machine_name_prefix, zone_monitor, region, regions)
resource_manager = GCPResourceManager(project, compute_client, billing_manager)

task_manager = aiotools.BackgroundTaskManager()

create_pools_coros = [
Pool.create(
app,
Expand Down Expand Up @@ -105,6 +106,7 @@ async def create(
inst_coll_manager,
jpim,
billing_manager,
task_manager,
)

task_manager.ensure_future(periodically_call(15, driver.process_activity_logs))
Expand All @@ -126,6 +128,7 @@ def __init__(
inst_coll_manager: InstanceCollectionManager,
job_private_inst_manager: JobPrivateInstanceManager,
billing_manager: GCPBillingManager,
task_manager: aiotools.BackgroundTaskManager,
):
self.db = db
self.machine_name_prefix = machine_name_prefix
Expand All @@ -137,6 +140,7 @@ def __init__(
self.job_private_inst_manager = job_private_inst_manager
self._billing_manager = billing_manager
self._inst_coll_manager = inst_coll_manager
self._task_manager = task_manager

@property
def billing_manager(self) -> GCPBillingManager:
Expand All @@ -148,12 +152,15 @@ def inst_coll_manager(self) -> InstanceCollectionManager:

async def shutdown(self) -> None:
try:
await self.compute_client.close()
await self._task_manager.shutdown_and_wait()
finally:
try:
await self.activity_logs_client.close()
await self.compute_client.close()
finally:
await self._billing_manager.close()
try:
await self.activity_logs_client.close()
finally:
await self._billing_manager.close()

async def process_activity_logs(self) -> None:
async def _process_activity_log_events_since(mark):
Expand Down
6 changes: 3 additions & 3 deletions batch/batch/driver/canceller.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,11 @@ def __init__(self, app):

self.task_manager = aiotools.BackgroundTaskManager()

def shutdown(self):
async def shutdown_and_wait(self):
try:
self.task_manager.shutdown()
await self.task_manager.shutdown_and_wait()
finally:
self.async_worker_pool.shutdown()
await self.async_worker_pool.shutdown_and_wait()

async def cancel_cancelled_ready_jobs_loop_body(self):
records = self.db.select_and_fetchall(
Expand Down
50 changes: 25 additions & 25 deletions batch/batch/driver/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -1558,18 +1558,25 @@ def log(self, request, response, time):


async def on_startup(app):
task_manager = aiotools.BackgroundTaskManager()
app['task_manager'] = task_manager

app['client_session'] = httpx.client_session()
exit_stack = AsyncExitStack()
app['exit_stack'] = exit_stack

kubernetes_asyncio.config.load_incluster_config()
app['k8s_client'] = kubernetes_asyncio.client.CoreV1Api()
app['k8s_cache'] = K8sCache(app['k8s_client'])

async def close_and_wait():
# - Following warning mitigation described here: https://github.com/aio-libs/aiohttp/pull/2045
# - Fixed in aiohttp 4.0.0: https://github.com/aio-libs/aiohttp/issues/1925
await app['k8s_client'].api_client.close()
await asyncio.sleep(0.250)

exit_stack.push_async_callback(close_and_wait)

db = Database()
await db.async_init(maxsize=50)
app['db'] = db
exit_stack.push_async_callback(app['db'].async_close)

row = await db.select_and_fetchone(
'''
Expand All @@ -1590,18 +1597,28 @@ async def on_startup(app):
app['cancel_ready_state_changed'] = asyncio.Event()
app['cancel_creating_state_changed'] = asyncio.Event()
app['cancel_running_state_changed'] = asyncio.Event()

app['async_worker_pool'] = AsyncWorkerPool(100, queue_size=100)
exit_stack.push_async_callback(app['async_worker_pool'].shutdown_and_wait)

fs = get_cloud_async_fs()
app['file_store'] = FileStore(fs, BATCH_STORAGE_URI, instance_id)
exit_stack.push_async_callback(app['file_store'].close)

inst_coll_configs = await InstanceCollectionConfigs.create(db)

app['driver'] = await get_cloud_driver(
app, db, MACHINE_NAME_PREFIX, DEFAULT_NAMESPACE, inst_coll_configs, task_manager
)
app['client_session'] = httpx.client_session()
exit_stack.push_async_callback(app['client_session'].close)

app['driver'] = await get_cloud_driver(app, db, MACHINE_NAME_PREFIX, DEFAULT_NAMESPACE, inst_coll_configs)
exit_stack.push_async_callback(app['driver'].shutdown)

app['canceller'] = await Canceller.create(app)
exit_stack.push_async_callback(app['canceller'].shutdown_and_wait)

task_manager = aiotools.BackgroundTaskManager()
app['task_manager'] = task_manager
exit_stack.push_async_callback(app['task_manager'].shutdown_and_wait)

task_manager.ensure_future(periodically_call(10, monitor_billing_limits, app))
task_manager.ensure_future(periodically_call(10, cancel_fast_failing_batches, app))
Expand All @@ -1614,24 +1631,7 @@ async def on_startup(app):

async def on_cleanup(app):
try:
async with AsyncExitStack() as cleanup:
cleanup.callback(app['canceller'].shutdown)
cleanup.callback(app['task_manager'].shutdown)
cleanup.push_async_callback(app['driver'].shutdown)
cleanup.push_async_callback(app['file_store'].shutdown)
cleanup.push_async_callback(app['client_session'].close)
cleanup.callback(app['async_worker_pool'].shutdown)
cleanup.push_async_callback(app['db'].async_close)

k8s: kubernetes_asyncio.client.CoreV1Api = app['k8s_client']

async def close_and_wait():
# - Following warning mitigation described here: https://github.com/aio-libs/aiohttp/pull/2045
# - Fixed in aiohttp 4.0.0: https://github.com/aio-libs/aiohttp/issues/1925
await k8s.api_client.close()
await asyncio.sleep(0.250)

cleanup.push_async_callback(close_and_wait)
await app['exit_stack'].aclose()
finally:
await asyncio.gather(*(t for t in asyncio.all_tasks() if t is not asyncio.current_task()))

Expand Down
19 changes: 12 additions & 7 deletions batch/batch/front_end/front_end.py
Original file line number Diff line number Diff line change
Expand Up @@ -2903,12 +2903,16 @@ def log(self, request, response, time):


async def on_startup(app):
app['task_manager'] = aiotools.BackgroundTaskManager()
exit_stack = AsyncExitStack()
app['exit_stack'] = exit_stack

app['client_session'] = httpx.client_session()
exit_stack.push_async_callback(app['client_session'].close)

db = Database()
await db.async_init()
app['db'] = db
exit_stack.push_async_callback(app['db'].async_close)

row = await db.select_and_fetchone(
'''
Expand All @@ -2923,6 +2927,7 @@ async def on_startup(app):
app['instance_id'] = instance_id

app['hail_credentials'] = hail_credentials()
exit_stack.push_async_callback(app['hail_credentials'].close)

app['frozen'] = row['frozen']

Expand All @@ -2937,8 +2942,13 @@ async def on_startup(app):

fs = get_cloud_async_fs()
app['file_store'] = FileStore(fs, BATCH_STORAGE_URI, instance_id)
exit_stack.push_async_callback(app['file_store'].close)

app['task_manager'] = aiotools.BackgroundTaskManager()
exit_stack.callback(app['task_manager'].shutdown)

app['inst_coll_configs'] = await InstanceCollectionConfigs.create(db)
exit_stack.push_async_callback(app['file_store'].close)

cancel_batch_state_changed = asyncio.Event()
app['cancel_batch_state_changed'] = cancel_batch_state_changed
Expand All @@ -2958,12 +2968,7 @@ async def on_startup(app):


async def on_cleanup(app):
async with AsyncExitStack() as stack:
stack.callback(app['task_manager'].shutdown)
stack.push_async_callback(app['hail_credentials'].close)
stack.push_async_callback(app['client_session'].close)
stack.push_async_callback(app['file_store'].close)
stack.push_async_callback(app['db'].async_close)
await app['exit_stack'].aclose()


def run():
Expand Down
17 changes: 8 additions & 9 deletions batch/batch/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3040,15 +3040,15 @@ async def shutdown(self):
log.info('Worker.shutdown')
self._jvm_initializer_task.cancel()
async with AsyncExitStack() as cleanup:
cleanup.push_async_callback(self.client_session.close)
if self.fs:
cleanup.push_async_callback(self.fs.close)
if self.file_store:
cleanup.push_async_callback(self.file_store.close)
for jvmqueue in self._jvmpools_by_cores.values():
while not jvmqueue.queue.empty():
cleanup.push_async_callback(jvmqueue.queue.get_nowait().kill)
cleanup.push_async_callback(self.task_manager.shutdown_and_wait)
if self.file_store:
cleanup.push_async_callback(self.file_store.close)
if self.fs:
cleanup.push_async_callback(self.fs.close)
cleanup.push_async_callback(self.client_session.close)

async def run_job(self, job):
try:
Expand Down Expand Up @@ -3476,11 +3476,10 @@ async def async_main():
with aiomonitor.start_monitor(asyncio.get_event_loop(), locals=locals()):
try:
async with AsyncExitStack() as cleanup:
cleanup.push_async_callback(worker.shutdown)
cleanup.push_async_callback(CLOUD_WORKER_API.close)
cleanup.push_async_callback(network_allocator_task_manager.shutdown_and_wait)
cleanup.push_async_callback(docker.close)

cleanup.push_async_callback(network_allocator_task_manager.shutdown_and_wait)
cleanup.push_async_callback(CLOUD_WORKER_API.close)
cleanup.push_async_callback(worker.shutdown)
await worker.run()
finally:
asyncio.get_event_loop().set_debug(True)
Expand Down
3 changes: 2 additions & 1 deletion hail/python/hailtop/aiotools/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,5 @@ def shutdown(self):

async def shutdown_and_wait(self):
self.shutdown()
await asyncio.wait(self.tasks, return_when=asyncio.ALL_COMPLETED)
if self.tasks:
await asyncio.wait(self.tasks, return_when=asyncio.ALL_COMPLETED)
4 changes: 4 additions & 0 deletions hail/python/hailtop/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,10 @@ def shutdown(self):
except Exception:
pass

async def shutdown_and_wait(self):
self.shutdown()
await asyncio.gather(*self.workers, return_exceptions=True)


class WaitableSharedPool:
def __init__(self, worker_pool: AsyncWorkerPool):
Expand Down