From 9f03cf42650db6105b16df9ae4c08f3255d7fbf1 Mon Sep 17 00:00:00 2001 From: Daniel Goldstein Date: Fri, 27 Jan 2023 14:41:10 -0500 Subject: [PATCH] [batch] Dont use Batch identity on workers (#12611) * [batch] Dont use Batch identity on workers * no longer create one fs per JVM container * make the worker own its own task manager --- batch/batch/cloud/azure/worker/disk.py | 3 - batch/batch/cloud/azure/worker/worker_api.py | 16 +- batch/batch/cloud/gcp/worker/disk.py | 18 +- batch/batch/cloud/gcp/worker/worker_api.py | 24 +- batch/batch/driver/main.py | 5 +- batch/batch/front_end/front_end.py | 2 +- batch/batch/worker/disk.py | 5 - batch/batch/worker/worker.py | 208 +++++++++--------- batch/batch/worker/worker_api.py | 10 +- gear/gear/clients.py | 17 +- .../hailtop/aiocloud/aiogoogle/credentials.py | 4 +- 11 files changed, 164 insertions(+), 148 deletions(-) diff --git a/batch/batch/cloud/azure/worker/disk.py b/batch/batch/cloud/azure/worker/disk.py index 194748ccdee..8a632be8cca 100644 --- a/batch/batch/cloud/azure/worker/disk.py +++ b/batch/batch/cloud/azure/worker/disk.py @@ -22,6 +22,3 @@ async def create(self, labels=None): async def delete(self): raise NotImplementedError - - async def close(self): - raise NotImplementedError diff --git a/batch/batch/cloud/azure/worker/worker_api.py b/batch/batch/cloud/azure/worker/worker_api.py index ff858459213..7a7badd081a 100644 --- a/batch/batch/cloud/azure/worker/worker_api.py +++ b/batch/batch/cloud/azure/worker/worker_api.py @@ -4,7 +4,9 @@ import aiohttp +from gear.cloud_config import get_azure_config from hailtop import httpx +from hailtop.aiocloud import aioazure from hailtop.utils import request_retry_transient_errors, time_msecs from ....worker.worker_api import CloudWorkerAPI @@ -14,6 +16,8 @@ class AzureWorkerAPI(CloudWorkerAPI): + nameserver_ip = '168.63.129.16' + @staticmethod def from_env(): subscription_id = os.environ['SUBSCRIPTION_ID'] @@ -26,14 +30,18 @@ def __init__(self, subscription_id: str, resource_group: str, acr_url: str): self.subscription_id = subscription_id self.resource_group = resource_group self.acr_refresh_token = AcrRefreshToken(acr_url, AadAccessToken()) - - @property - def nameserver_ip(self): - return '168.63.129.16' + self.azure_credentials = aioazure.AzureCredentials.default_credentials() def create_disk(self, instance_name: str, disk_name: str, size_in_gb: int, mount_path: str) -> AzureDisk: return AzureDisk(disk_name, instance_name, size_in_gb, mount_path) + def get_cloud_async_fs(self) -> aioazure.AzureAsyncFS: + return aioazure.AzureAsyncFS(credentials=self.azure_credentials) + + def get_compute_client(self) -> aioazure.AzureComputeClient: + azure_config = get_azure_config() + return aioazure.AzureComputeClient(azure_config.subscription_id, azure_config.resource_group) + def user_credentials(self, credentials: Dict[str, bytes]) -> AzureUserCredentials: return AzureUserCredentials(credentials) diff --git a/batch/batch/cloud/gcp/worker/disk.py b/batch/batch/cloud/gcp/worker/disk.py index 963f4e1c898..6b3a648ee85 100644 --- a/batch/batch/cloud/gcp/worker/disk.py +++ b/batch/batch/cloud/gcp/worker/disk.py @@ -11,22 +11,29 @@ class GCPDisk(CloudDisk): - def __init__(self, name: str, zone: str, project: str, instance_name: str, size_in_gb: int, mount_path: str): + def __init__( + self, + name: str, + zone: str, + project: str, + instance_name: str, + size_in_gb: int, + mount_path: str, + compute_client: aiogoogle.GoogleComputeClient, # BORROWED + ): assert size_in_gb >= 10 # disk name must be 63 characters or less # https://cloud.google.com/compute/docs/reference/rest/v1/disks#resource:-disk # under the information for the name field assert len(name) <= 63 - self.compute_client = aiogoogle.GoogleComputeClient( - project, credentials=aiogoogle.GoogleCredentials.from_file('/worker-key.json') - ) self.name = name self.zone = zone self.project = project self.instance_name = instance_name self.size_in_gb = size_in_gb self.mount_path = mount_path + self.compute_client = compute_client self._created = False self._attached = False @@ -49,9 +56,6 @@ async def delete(self): finally: await self._delete() - async def close(self): - await self.compute_client.close() - async def _unmount(self): if self._attached: await retry_all_errors_n_times( diff --git a/batch/batch/cloud/gcp/worker/worker_api.py b/batch/batch/cloud/gcp/worker/worker_api.py index ce7789c57a7..adb64022c61 100644 --- a/batch/batch/cloud/gcp/worker/worker_api.py +++ b/batch/batch/cloud/gcp/worker/worker_api.py @@ -4,6 +4,7 @@ import aiohttp from hailtop import httpx +from hailtop.aiocloud import aiogoogle from hailtop.utils import request_retry_transient_errors from ....worker.worker_api import CloudWorkerAPI @@ -13,19 +14,21 @@ class GCPWorkerAPI(CloudWorkerAPI): + nameserver_ip = '169.254.169.254' + + # async because GoogleSession must be created inside a running event loop @staticmethod - def from_env(): + async def from_env() -> 'GCPWorkerAPI': project = os.environ['PROJECT'] zone = os.environ['ZONE'].rsplit('/', 1)[1] - return GCPWorkerAPI(project, zone) + session = aiogoogle.GoogleSession() + return GCPWorkerAPI(project, zone, session) - def __init__(self, project: str, zone: str): + def __init__(self, project: str, zone: str, session: aiogoogle.GoogleSession): self.project = project self.zone = zone - - @property - def nameserver_ip(self): - return '169.254.169.254' + self._google_session = session + self._compute_client = aiogoogle.GoogleComputeClient(project, session=session) def create_disk(self, instance_name: str, disk_name: str, size_in_gb: int, mount_path: str) -> GCPDisk: return GCPDisk( @@ -35,8 +38,15 @@ def create_disk(self, instance_name: str, disk_name: str, size_in_gb: int, mount name=disk_name, size_in_gb=size_in_gb, mount_path=mount_path, + compute_client=self._compute_client, ) + def get_cloud_async_fs(self) -> aiogoogle.GoogleStorageAsyncFS: + return aiogoogle.GoogleStorageAsyncFS(session=self._google_session) + + def get_compute_client(self) -> aiogoogle.GoogleComputeClient: + return self._compute_client + def user_credentials(self, credentials: Dict[str, bytes]) -> GCPUserCredentials: return GCPUserCredentials(credentials) diff --git a/batch/batch/driver/main.py b/batch/batch/driver/main.py index d3b3b42dccf..9ffb84c0fb9 100644 --- a/batch/batch/driver/main.py +++ b/batch/batch/driver/main.py @@ -263,6 +263,7 @@ async def get_gsa_key(request, instance): # pylint: disable=unused-argument return await asyncio.shield(get_gsa_key_1(instance)) +# deprecated @routes.get('/api/v1alpha/instances/credentials') @activating_instances_only async def get_credentials(request, instance): # pylint: disable=unused-argument @@ -1325,12 +1326,12 @@ async def on_startup(app): app['cancel_running_state_changed'] = asyncio.Event() app['async_worker_pool'] = AsyncWorkerPool(100, queue_size=100) - credentials_file = '/gsa-key/key.json' - fs = get_cloud_async_fs(credentials_file=credentials_file) + fs = get_cloud_async_fs() app['file_store'] = FileStore(fs, BATCH_STORAGE_URI, instance_id) inst_coll_configs = await InstanceCollectionConfigs.create(db) + credentials_file = '/gsa-key/key.json' app['driver'] = await get_cloud_driver( app, db, MACHINE_NAME_PREFIX, DEFAULT_NAMESPACE, inst_coll_configs, credentials_file, task_manager ) diff --git a/batch/batch/front_end/front_end.py b/batch/batch/front_end/front_end.py index aea8ccdca42..cacf6f7a0e9 100644 --- a/batch/batch/front_end/front_end.py +++ b/batch/batch/front_end/front_end.py @@ -2806,7 +2806,7 @@ async def on_startup(app): assert max(regions.values()) < 64, str(regions) app['regions'] = regions - fs = get_cloud_async_fs(credentials_file='/gsa-key/key.json') + fs = get_cloud_async_fs() app['file_store'] = FileStore(fs, BATCH_STORAGE_URI, instance_id) app['inst_coll_configs'] = await InstanceCollectionConfigs.create(db) diff --git a/batch/batch/worker/disk.py b/batch/batch/worker/disk.py index acf138a693d..4e5940ab15e 100644 --- a/batch/batch/worker/disk.py +++ b/batch/batch/worker/disk.py @@ -13,7 +13,6 @@ async def __aenter__(self, labels=None): async def __aexit__(self, exc_type, exc_val, exc_tb): await self.delete() - await self.close() @abc.abstractmethod async def create(self, labels=None): @@ -22,7 +21,3 @@ async def create(self, labels=None): @abc.abstractmethod async def delete(self): pass - - @abc.abstractmethod - async def close(self): - pass diff --git a/batch/batch/worker/worker.py b/batch/batch/worker/worker.py index 715d176c943..6b6b23518fc 100644 --- a/batch/batch/worker/worker.py +++ b/batch/batch/worker/worker.py @@ -15,20 +15,8 @@ import uuid import warnings from collections import defaultdict -from contextlib import AsyncExitStack, ExitStack, contextmanager -from typing import ( - Any, - Awaitable, - Callable, - ContextManager, - Dict, - Iterator, - List, - MutableMapping, - Optional, - Tuple, - Union, -) +from contextlib import AsyncExitStack, ExitStack +from typing import Any, Awaitable, Callable, ContextManager, Dict, List, MutableMapping, Optional, Tuple, Union import aiodocker # type: ignore import aiodocker.images @@ -41,7 +29,6 @@ from aiohttp import web from sortedcontainers import SortedSet -from gear.clients import get_cloud_async_fs, get_compute_client from hailtop import aiotools, httpx from hailtop.aiotools import AsyncFS, LocalAsyncFS from hailtop.aiotools.router_fs import RouterAsyncFS @@ -79,6 +66,7 @@ ) from ..file_store import FileStore from ..globals import HTTP_CLIENT_MAX_SIZE, RESERVED_STORAGE_GB_PER_CORE, STATUS_FORMAT_VERSION +from ..instance_config import InstanceConfig from ..publicly_available_images import publicly_available_images from ..resource_usage import ResourceUsageMonitor from ..semaphore import FIFOWeightedSemaphore @@ -173,7 +161,7 @@ def compose(auth: Union[MutableMapping, str, bytes], registry_addr: Optional[str ACCEPTABLE_QUERY_JAR_URL_PREFIX = os.environ['ACCEPTABLE_QUERY_JAR_URL_PREFIX'] assert len(ACCEPTABLE_QUERY_JAR_URL_PREFIX) > 3 # x:// where x is one or more characters -CLOUD_WORKER_API: CloudWorkerAPI = GCPWorkerAPI.from_env() if CLOUD == 'gcp' else AzureWorkerAPI.from_env() +CLOUD_WORKER_API: Optional[CloudWorkerAPI] = None log.info(f'CLOUD {CLOUD}') log.info(f'CORES {CORES}') @@ -185,7 +173,6 @@ def compose(auth: Union[MutableMapping, str, bytes], registry_addr: Optional[str log.info(f'INSTANCE_ID {INSTANCE_ID}') log.info(f'DOCKER_PREFIX {DOCKER_PREFIX}') log.info(f'INSTANCE_CONFIG {INSTANCE_CONFIG}') -log.info(f'CLOUD_WORKER_API {CLOUD_WORKER_API}') log.info(f'MAX_IDLE_TIME_MSECS {MAX_IDLE_TIME_MSECS}') log.info(f'BATCH_WORKER_IMAGE {BATCH_WORKER_IMAGE}') log.info(f'BATCH_WORKER_IMAGE_ID {BATCH_WORKER_IMAGE_ID}') @@ -194,9 +181,7 @@ def compose(auth: Union[MutableMapping, str, bytes], registry_addr: Optional[str log.info(f'ACCEPTABLE_QUERY_JAR_URL_PREFIX {ACCEPTABLE_QUERY_JAR_URL_PREFIX}') log.info(f'REGION {REGION}') -instance_config = CLOUD_WORKER_API.instance_config_from_config_dict(INSTANCE_CONFIG) -assert instance_config.cores == CORES -assert instance_config.cloud == CLOUD +instance_config: Optional[InstanceConfig] = None N_SLOTS = 4 * CORES # Jobs are allowed at minimum a quarter core @@ -268,6 +253,7 @@ async def init(self): # resolver. with open(f'/etc/netns/{self.network_ns_name}/resolv.conf', 'w', encoding='utf-8') as resolv: if self.private: + assert CLOUD_WORKER_API resolv.write(f'nameserver {CLOUD_WORKER_API.nameserver_ip}\n') if CLOUD == 'gcp': resolv.write('search c.hail-vdc.internal google.internal\n') @@ -334,7 +320,8 @@ async def cleanup(self): class NetworkAllocator: - def __init__(self): + def __init__(self, task_manager: aiotools.BackgroundTaskManager): + self.task_manager = task_manager self.private_networks: asyncio.Queue[NetworkNamespace] = asyncio.Queue() self.public_networks: asyncio.Queue[NetworkNamespace] = asyncio.Queue() self.internet_interface = INTERNET_INTERFACE @@ -357,7 +344,7 @@ async def allocate_public(self) -> NetworkNamespace: return await self.public_networks.get() def free(self, netns: NetworkNamespace): - asyncio.ensure_future(self._free(netns)) + self.task_manager.ensure_future(self._free(netns)) async def _free(self, netns: NetworkNamespace): await netns.cleanup() @@ -497,6 +484,7 @@ async def _ensure_image_is_pulled(self, auth: Optional[Callable[..., Awaitable[O raise async def _batch_worker_access_token(self) -> Dict[str, str]: + assert CLOUD_WORKER_API return await CLOUD_WORKER_API.worker_access_token(self.client_session) async def _current_user_access_token(self) -> Dict[str, str]: @@ -550,9 +538,15 @@ class StepInterruptedError(Exception): pass -async def run_until_done_or_deleted(event: asyncio.Event, f: Callable[..., Awaitable[Any]], *args, **kwargs): - step = asyncio.ensure_future(f(*args, **kwargs)) - deleted = asyncio.ensure_future(event.wait()) +async def run_until_done_or_deleted( + task_manager: aiotools.BackgroundTaskManager, + event: asyncio.Event, + f: Callable[..., Awaitable[Any]], + *args, + **kwargs, +): + step = task_manager.ensure_future(f(*args, **kwargs)) + deleted = task_manager.ensure_future(event.wait()) try: await asyncio.wait([deleted, step], return_when=asyncio.FIRST_COMPLETED) if deleted.done(): @@ -630,6 +624,7 @@ def user_error(e): class Container: def __init__( self, + task_manager: aiotools.BackgroundTaskManager, fs: AsyncFS, name: str, image: Image, @@ -645,8 +640,8 @@ def __init__( env: Optional[List[str]] = None, stdin: Optional[str] = None, ): + self.task_manager = task_manager self.fs = fs - assert self.fs self.name = name self.image = image @@ -724,7 +719,7 @@ async def create(self): raise ContainerCreateError from e raise - async def start(self): + def start(self): async def _run(): self.state = 'running' try: @@ -757,7 +752,7 @@ async def _run(): raise ContainerStartError from e raise - self._run_fut = asyncio.ensure_future(self._run_until_done_or_deleted(_run)) + self._run_fut = self.task_manager.ensure_future(self._run_until_done_or_deleted(_run)) async def wait(self): assert self._run_fut @@ -775,7 +770,7 @@ async def run( async with self._cleanup_lock: try: await self.create() - await self.start() + self.start() await self.wait() finally: try: @@ -867,7 +862,7 @@ async def remove(self): async def _run_until_done_or_deleted(self, f: Callable[..., Awaitable[Any]], *args, **kwargs): try: - return await run_until_done_or_deleted(self.deleted_event, f, *args, **kwargs) + return await run_until_done_or_deleted(self.task_manager, self.deleted_event, f, *args, **kwargs) except StepInterruptedError as e: raise ContainerDeletedError from e @@ -1260,6 +1255,7 @@ def copy_container( ] return Container( + task_manager=job.task_manager, fs=job.worker.fs, name=job.container_name(task_name), image=Image(BATCH_WORKER_IMAGE, CopyStepCredentials(), client_session, job.pool), @@ -1378,6 +1374,7 @@ def __init__( extra_storage_in_gib = job_spec['resources']['storage_gib'] assert extra_storage_in_gib == 0 or is_valid_storage_request(CLOUD, extra_storage_in_gib) + assert instance_config if instance_config.job_private: self.external_storage_in_gib = 0 self.data_disk_storage_in_gib = extra_storage_in_gib @@ -1592,6 +1589,7 @@ def __init__( assert self.worker.fs containers['main'] = Container( + task_manager=self.task_manager, fs=self.worker.fs, name=self.container_name('main'), image=Image(job_spec['process']['image'], self.credentials, client_session, pool), @@ -1629,12 +1627,14 @@ def container_name(self, task_name: str): return f'batch-{self.batch_id}-job-{self.job_id}-{task_name}' async def setup_io(self): + assert instance_config if not instance_config.job_private: if self.worker.data_disk_space_remaining.value < self.external_storage_in_gib: log.info( f'worker data disk storage is full: {self.external_storage_in_gib}Gi requested and {self.worker.data_disk_space_remaining}Gi remaining' ) + assert CLOUD_WORKER_API # disk name must be 63 characters or less # https://cloud.google.com/compute/docs/reference/rest/v1/disks#resource:-disk # under the information for the name field @@ -1730,6 +1730,7 @@ async def run(self): f'xfs_quota -x -c "project -s -p {self.cloudfuse_base_path()} {self.project_id}" /host/' ) + assert CLOUD_WORKER_API for config in self.cloudfuse: bucket = config['bucket'] assert bucket @@ -1806,8 +1807,6 @@ async def cleanup(self): raise except Exception: log.exception(f'while detaching and deleting disk {self.disk.name} for {self.id}') - finally: - await self.disk.close() else: self.worker.data_disk_space_remaining.value += self.external_storage_in_gib @@ -1819,6 +1818,7 @@ async def cleanup(self): mount_path = self.cloudfuse_data_path(bucket) try: + assert CLOUD_WORKER_API await CLOUD_WORKER_API.unmount_cloudfuse(mount_path) log.info(f'unmounted fuse blob storage {bucket} from {mount_path}') config['mounted'] = False @@ -1906,7 +1906,7 @@ def step(self, name): async def run_until_done_or_deleted(self, f: Callable[..., Awaitable[Any]], *args, **kwargs): try: - return await run_until_done_or_deleted(self.deleted_event, f, *args, **kwargs) + return await run_until_done_or_deleted(self.worker.task_manager, self.deleted_event, f, *args, **kwargs) except StepInterruptedError as e: raise JobDeletedError from e @@ -2104,15 +2104,6 @@ def __str__(self): ) -@contextmanager -def scoped_ensure_future(coro_or_future, *, loop=None) -> Iterator[asyncio.Future]: - fut = asyncio.ensure_future(coro_or_future, loop=loop) - try: - yield fut - finally: - fut.cancel() - - class JVMCreationError(Exception): pass @@ -2134,10 +2125,13 @@ async def create_and_start( root_dir: str, client_session: httpx.ClientSession, pool: concurrent.futures.ThreadPoolExecutor, + fs: AsyncFS, + task_manager: aiotools.BackgroundTaskManager, ): assert os.path.commonpath([socket_file, root_dir]) == root_dir assert os.path.isdir(root_dir) + assert instance_config total_memory_bytes = n_cores * worker_memory_per_core_bytes(CLOUD, instance_config.worker_type()) # We allocate 60% of memory per core to off heap memory @@ -2188,9 +2182,8 @@ async def create_and_start( }, ] - fs = LocalAsyncFS(pool) # worker does not have a fs when initializing JVMs - c = Container( + task_manager=task_manager, fs=fs, name=f'jvm-{index}', image=Image(BATCH_WORKER_IMAGE, JVMUserCredentials(), client_session, pool), @@ -2203,13 +2196,13 @@ async def create_and_start( ) await c.create() - await c.start() + c.start() return JVMContainer(c, fs) - def __init__(self, container: Container, fs: LocalAsyncFS): + def __init__(self, container: Container, fs: AsyncFS): self.container = container - self.fs: Optional[LocalAsyncFS] = fs + self.fs: AsyncFS = fs @property def returncode(self) -> Optional[int]: @@ -2218,9 +2211,6 @@ def returncode(self) -> Optional[int]: return self.container.process.returncode async def remove(self): - if self.fs is not None: - await self.fs.close() - self.fs = None await self.container.remove() @@ -2246,9 +2236,13 @@ async def create_container_and_connect( root_dir: str, client_session: httpx.ClientSession, pool: concurrent.futures.ThreadPoolExecutor, + fs: AsyncFS, + task_manager: aiotools.BackgroundTaskManager, ) -> JVMContainer: try: - container = await JVMContainer.create_and_start(index, n_cores, socket_file, root_dir, client_session, pool) + container = await JVMContainer.create_and_start( + index, n_cores, socket_file, root_dir, client_session, pool, fs, task_manager + ) attempts = 0 delay = 0.25 @@ -2290,7 +2284,14 @@ async def create(cls, index: int, n_cores: int, worker: 'Worker'): should_interrupt = asyncio.Event() await blocking_to_async(worker.pool, os.makedirs, root_dir) container = await cls.create_container_and_connect( - index, n_cores, socket_file, root_dir, worker.client_session, worker.pool + index, + n_cores, + socket_file, + root_dir, + worker.client_session, + worker.pool, + worker.fs, + worker.task_manager, ) return cls( index, @@ -2302,6 +2303,8 @@ async def create(cls, index: int, n_cores: int, worker: 'Worker'): container, worker.client_session, worker.pool, + worker.fs, + worker.task_manager, ) def __init__( @@ -2315,6 +2318,8 @@ def __init__( container: JVMContainer, client_session: httpx.ClientSession, pool: concurrent.futures.ThreadPoolExecutor, + fs: AsyncFS, + task_manager: aiotools.BackgroundTaskManager, ): self.index = index self.n_cores = n_cores @@ -2325,6 +2330,8 @@ def __init__( self.container = container self.client_session = client_session self.pool = pool + self.fs = fs + self.task_manager = task_manager def __str__(self): return f'JVM-{self.index}' @@ -2354,7 +2361,14 @@ async def new_connection(self): await blocking_to_async(self.pool, shutil.rmtree, f'{self.root_dir}/container', ignore_errors=True) container = await self.create_container_and_connect( - self.index, self.n_cores, self.socket_file, self.root_dir, self.client_session, self.pool + self.index, + self.n_cores, + self.socket_file, + self.root_dir, + self.client_session, + self.pool, + self.fs, + self.task_manager, ) self.container = container @@ -2375,9 +2389,9 @@ async def execute(self, classpath: str, scratch_dir: str, log_file: str, jar_url write_str(writer, part) await writer.drain() - wait_for_message_from_container: asyncio.Future = asyncio.ensure_future(read_int(reader)) + wait_for_message_from_container: asyncio.Future = self.task_manager.ensure_future(read_int(reader)) stack.callback(wait_for_message_from_container.cancel) - wait_for_interrupt: asyncio.Future = asyncio.ensure_future(self.should_interrupt.wait()) + wait_for_interrupt: asyncio.Future = self.task_manager.ensure_future(self.should_interrupt.wait()) stack.callback(wait_for_interrupt.cancel) await asyncio.wait( @@ -2437,16 +2451,25 @@ def __init__(self, client_session: httpx.ClientSession): self.image_data: Dict[str, ImageData] = defaultdict(ImageData) self.image_data[BATCH_WORKER_IMAGE_ID] += 1 - # filled in during activation - self.fs: Optional[RouterAsyncFS] = None - self.file_store: Optional[FileStore] = None + assert CLOUD_WORKER_API + fs = CLOUD_WORKER_API.get_cloud_async_fs() + self.fs = RouterAsyncFS( + 'file', + filesystems=[ + LocalAsyncFS(self.pool), + fs, + ], + ) + self.file_store = FileStore(fs, BATCH_LOGS_STORAGE_URI, INSTANCE_ID) + self.compute_client = CLOUD_WORKER_API.get_compute_client() + self.headers: Optional[Dict[str, str]] = None - self.compute_client = None - self._jvm_initializer_task = asyncio.ensure_future(self._initialize_jvms()) + self._jvm_initializer_task = self.task_manager.ensure_future(self._initialize_jvms()) self._jvms: SortedSet[JVM] = SortedSet([], key=lambda jvm: jvm.n_cores) async def _initialize_jvms(self): + assert instance_config if instance_config.worker_type() in ('standard', 'D', 'highmem', 'E'): jvms: List[Awaitable[JVM]] = [] for jvm_cores in (1, 2, 4, 8): @@ -2456,6 +2479,7 @@ async def _initialize_jvms(self): log.info(f'JVMs initialized {self._jvms}') async def borrow_jvm(self, n_cores: int) -> JVM: + assert instance_config if instance_config.worker_type() not in ('standard', 'D', 'highmem', 'E'): raise ValueError(f'no JVMs available on {instance_config.worker_type()}') await self._jvm_initializer_task @@ -2546,6 +2570,7 @@ async def create_job_1(self, request): if not self.active: return web.HTTPServiceUnavailable() + assert CLOUD_WORKER_API credentials = CLOUD_WORKER_API.user_credentials(body['gsa_key']) job = Job.create( @@ -2827,31 +2852,6 @@ async def post_job_started(self, job): async def activate(self): log.info('activating') - resp = await request_retry_transient_errors( - self.client_session, - 'GET', - deploy_config.url('batch-driver', '/api/v1alpha/instances/credentials'), - headers={'X-Hail-Instance-Name': NAME, 'Authorization': f'Bearer {os.environ["ACTIVATION_TOKEN"]}'}, - ) - resp_json = await resp.json() - - credentials_file = '/worker-key.json' - with open(credentials_file, 'w', encoding='utf-8') as f: - f.write(json.dumps(resp_json['key'])) - - self.fs = RouterAsyncFS( - 'file', - filesystems=[ - LocalAsyncFS(self.pool), - get_cloud_async_fs(credentials_file=credentials_file), - ], - ) - - fs = get_cloud_async_fs(credentials_file=credentials_file) - self.file_store = FileStore(fs, BATCH_LOGS_STORAGE_URI, INSTANCE_ID) - - self.compute_client = get_compute_client(credentials_file=credentials_file) - resp = await request_retry_transient_errors( self.client_session, 'POST', @@ -2915,13 +2915,20 @@ async def update(): async def async_main(): - global port_allocator, network_allocator, worker, docker, image_lock + global port_allocator, network_allocator, worker, docker, image_lock, CLOUD_WORKER_API, instance_config image_lock = aiorwlock.RWLock() docker = aiodocker.Docker() + CLOUD_WORKER_API = await GCPWorkerAPI.from_env() if CLOUD == 'gcp' else AzureWorkerAPI.from_env() + instance_config = CLOUD_WORKER_API.instance_config_from_config_dict(INSTANCE_CONFIG) + assert instance_config.cores == CORES + assert instance_config.cloud == CLOUD + port_allocator = PortAllocator() - network_allocator = NetworkAllocator() + + network_allocator_task_manager = aiotools.BackgroundTaskManager() + network_allocator = NetworkAllocator(network_allocator_task_manager) await network_allocator.reserve() worker = Worker(httpx.client_session()) @@ -2933,19 +2940,22 @@ async def async_main(): log.info('worker shutdown', exc_info=True) finally: try: - await docker.close() - log.info('docker closed') + await network_allocator_task_manager.shutdown_and_wait() finally: - asyncio.get_event_loop().set_debug(True) - other_tasks = [t for t in asyncio.all_tasks() if t != asyncio.current_task()] - if other_tasks: - log.warning('Tasks immediately after docker close') - dump_all_stacktraces() - _, pending = await asyncio.wait(other_tasks, timeout=10 * 60, return_when=asyncio.ALL_COMPLETED) - for t in pending: - log.warning('Dangling task:') - t.print_stack() - t.cancel() + try: + await docker.close() + log.info('docker closed') + finally: + asyncio.get_event_loop().set_debug(True) + other_tasks = [t for t in asyncio.all_tasks() if t != asyncio.current_task()] + if other_tasks: + log.warning('Tasks immediately after docker close') + dump_all_stacktraces() + _, pending = await asyncio.wait(other_tasks, timeout=10 * 60, return_when=asyncio.ALL_COMPLETED) + for t in pending: + log.warning('Dangling task:') + t.print_stack() + t.cancel() loop = asyncio.get_event_loop() diff --git a/batch/batch/worker/worker_api.py b/batch/batch/worker/worker_api.py index 3136aef7745..2a459827a54 100644 --- a/batch/batch/worker/worker_api.py +++ b/batch/batch/worker/worker_api.py @@ -2,6 +2,7 @@ from typing import Dict from hailtop import httpx +from hailtop.aiotools.fs import AsyncFS from hailtop.utils import CalledProcessError, check_shell, sleep_and_backoff from ..instance_config import InstanceConfig @@ -10,15 +11,20 @@ class CloudWorkerAPI(abc.ABC): - @property + nameserver_ip: str + @abc.abstractmethod - def nameserver_ip(self): + def get_compute_client(self): raise NotImplementedError @abc.abstractmethod def create_disk(self, instance_name: str, disk_name: str, size_in_gb: int, mount_path: str) -> CloudDisk: raise NotImplementedError + @abc.abstractmethod + def get_cloud_async_fs(self) -> AsyncFS: + raise NotImplementedError + @abc.abstractmethod def user_credentials(self, credentials: Dict[str, bytes]) -> CloudUserCredentials: raise NotImplementedError diff --git a/gear/gear/clients.py b/gear/gear/clients.py index 88a6d01e38e..b938bb244aa 100644 --- a/gear/gear/clients.py +++ b/gear/gear/clients.py @@ -1,6 +1,6 @@ from typing import Optional -from gear.cloud_config import get_azure_config, get_gcp_config, get_global_config +from gear.cloud_config import get_gcp_config, get_global_config from hailtop.aiocloud import aioazure, aiogoogle from hailtop.aiotools.fs import AsyncFS, AsyncFSFactory @@ -23,21 +23,6 @@ def get_identity_client(credentials_file: Optional[str] = None): return aiogoogle.GoogleIAmClient(project, credentials_file=credentials_file) -def get_compute_client(credentials_file: Optional[str] = None): - if credentials_file is None: - credentials_file = '/gsa-key/key.json' - - cloud = get_global_config()['cloud'] - - if cloud == 'azure': - azure_config = get_azure_config() - return aioazure.AzureComputeClient(azure_config.subscription_id, azure_config.resource_group) - - assert cloud == 'gcp', cloud - project = get_gcp_config().project - return aiogoogle.GoogleComputeClient(project, credentials_file=credentials_file) - - def get_cloud_async_fs(credentials_file: Optional[str] = None) -> AsyncFS: if credentials_file is None: credentials_file = '/gsa-key/key.json' diff --git a/hail/python/hailtop/aiocloud/aiogoogle/credentials.py b/hail/python/hailtop/aiocloud/aiogoogle/credentials.py index fca6af3de00..9d34dc7707a 100644 --- a/hail/python/hailtop/aiocloud/aiogoogle/credentials.py +++ b/hail/python/hailtop/aiocloud/aiogoogle/credentials.py @@ -74,9 +74,9 @@ def default_credentials() -> Union['GoogleCredentials', AnonymousCloudCredential log.info(f'using credentials file {credentials_file}: {creds}') return creds - log.warning('Unable to locate Google Cloud credentials file') + log.info('Unable to locate Google Cloud credentials file') if GoogleInstanceMetadataCredentials.available(): - log.warning('Will attempt to use instance metadata server instead') + log.info('Will attempt to use instance metadata server instead') return GoogleInstanceMetadataCredentials() log.warning('Using anonymous credentials. If accessing private data, '