diff --git a/batch/batch/worker/worker.py b/batch/batch/worker/worker.py index cb86e11e60ee..3a512a136bb7 100644 --- a/batch/batch/worker/worker.py +++ b/batch/batch/worker/worker.py @@ -1,3 +1,4 @@ +import abc import asyncio import base64 import concurrent @@ -366,6 +367,67 @@ async def _free(self, netns: NetworkNamespace): self.public_networks.put_nowait(netns) +# Handling read-only cloudfuse mounts that can be shared across jobs + + +class FuseMount: + def __init__(self, path): + self.path = path + self.bind_mounts = set() + + +class ReadOnlyCloudfuseManager: + def __init__(self): + self.cloudfuse_dir = '/cloudfuse/readonly_cache' + self.fuse_mounts: Dict[Tuple[str, str], FuseMount] = {} + self.user_bucket_locks: Dict[Tuple[str, str], asyncio.Lock] = defaultdict(asyncio.Lock) + + async def mount( + self, bucket: str, destination: str, *, user: str, credentials_path: str, tmp_path: str, config: dict + ): + async with self.user_bucket_locks[(user, bucket)]: + if (user, bucket) not in self.fuse_mounts: + local_path = self._new_path() + await self._fuse_mount(local_path, credentials_path=credentials_path, tmp_path=tmp_path, config=config) + self.fuse_mounts[(user, bucket)] = FuseMount(local_path) + mount = self.fuse_mounts[(user, bucket)] + mount.bind_mounts.add(destination) + await self._bind_mount(mount.path, destination) + + async def unmount(self, destination, *, user: str, bucket: str): + async with self.user_bucket_locks[(user, bucket)]: + mount = self.fuse_mounts[(user, bucket)] + await self._bind_unmount(destination) + mount.bind_mounts.remove(destination) + if len(mount.bind_mounts) == 0: + await self._fuse_unmount(mount.path) + del self.fuse_mounts[(user, bucket)] + + async def _fuse_mount(self, destination: str, *, credentials_path: str, tmp_path: str, config: dict): + assert CLOUD_WORKER_API + await CLOUD_WORKER_API.mount_cloudfuse( + credentials_path, + destination, + tmp_path, + config, + ) + + async def _fuse_unmount(self, path: str): + assert CLOUD_WORKER_API + await CLOUD_WORKER_API.unmount_cloudfuse(path) + + async def _bind_mount(self, src, dst): + await check_exec_output('mount', '--bind', src, dst) + + async def _bind_unmount(self, dst): + await check_exec_output('umount', dst) + + def _new_path(self): + path = f'{self.cloudfuse_dir}/{uuid.uuid4().hex}' + os.makedirs(path) + return path + + def docker_call_retry(timeout, name, f, *args, **kwargs): debug_string = f'In docker call to {f.__name__} for {name}' @@ -1019,6 +1081,7 @@ async def container_config(self): }, }, 'linux': { + 'rootfsPropagation': 'slave', 'namespaces': [ {'type': 'pid'}, { @@ -1295,7 +1358,7 @@ def copy_container( ) -class Job: +class Job(abc.ABC): quota_project_id = 100 @staticmethod @@ -1310,23 +1373,17 @@ def secret_host_path(self, secret) -> str: def io_host_path(self) -> str: return f'{self.scratch}/io' + @abc.abstractmethod def cloudfuse_base_path(self): - # Make sure this path isn't in self.scratch to avoid accidental bucket deletions! - path = f'/cloudfuse/{self.token}' - assert os.path.commonpath([path, self.scratch]) == '/' - return path + raise NotImplementedError + @abc.abstractmethod def cloudfuse_data_path(self, bucket: str) -> str: - # Make sure this path isn't in self.scratch to avoid accidental bucket deletions! - path = f'{self.cloudfuse_base_path()}/{bucket}/data' - assert os.path.commonpath([path, self.scratch]) == '/' - return path + raise NotImplementedError + @abc.abstractmethod def cloudfuse_tmp_path(self, bucket: str) -> str: - # Make sure this path isn't in self.scratch to avoid accidental bucket deletions! - path = f'{self.cloudfuse_base_path()}/{bucket}/tmp' - assert os.path.commonpath([path, self.scratch]) == '/' - return path + raise NotImplementedError def cloudfuse_credentials_path(self, bucket: str) -> str: return f'{self.scratch}/cloudfuse/{bucket}' @@ -1432,24 +1489,7 @@ def __init__( self.main_volume_mounts.append(io_volume_mount) self.output_volume_mounts.append(io_volume_mount) - requester_pays_project = job_spec.get('requester_pays_project') - cloudfuse = job_spec.get('cloudfuse') or job_spec.get('gcsfuse') - self.cloudfuse = cloudfuse - if cloudfuse: - for config in cloudfuse: - if requester_pays_project: - config['requester_pays_project'] = requester_pays_project - config['mounted'] = False - bucket = config['bucket'] - assert bucket - self.main_volume_mounts.append( - { - 'source': f'{self.cloudfuse_data_path(bucket)}', - 'destination': config['mount_path'], - 'type': 'none', - 'options': ['rbind', 'rw', 'shared'], - } - ) + self.cloudfuse = job_spec.get('cloudfuse') or job_spec.get('gcsfuse') secrets = job_spec.get('secrets') self.secrets = secrets @@ -1580,6 +1620,22 @@ def __init__( ] self.env += hail_extra_env + if self.cloudfuse: + for config in self.cloudfuse: + if requester_pays_project: + config['requester_pays_project'] = requester_pays_project + config['mounted'] = False + bucket = config['bucket'] + assert bucket + self.main_volume_mounts.append( + { + 'source': f'{self.cloudfuse_data_path(bucket)}', + 'destination': config['mount_path'], + 'type': 'none', + 'options': ['rbind', 'rw', 'shared'], + } + ) + if self.secrets: for secret in self.secrets: volume_mount = { @@ -1885,6 +1941,24 @@ def status(self): status['timing'] = self.timings.to_dict() return status + def cloudfuse_base_path(self): + # Make sure this path isn't in self.scratch to avoid accidental bucket deletions! + path = f'/cloudfuse/{self.token}' + assert os.path.commonpath([path, self.scratch]) == '/' + return path + + def cloudfuse_data_path(self, bucket: str) -> str: + # Make sure this path isn't in self.scratch to avoid accidental bucket deletions! + path = f'{self.cloudfuse_base_path()}/{bucket}/data' + assert os.path.commonpath([path, self.scratch]) == '/' + return path + + def cloudfuse_tmp_path(self, bucket: str) -> str: + # Make sure this path isn't in self.scratch to avoid accidental bucket deletions! + path = f'{self.cloudfuse_base_path()}/{bucket}/tmp' + assert os.path.commonpath([path, self.scratch]) == '/' + return path + def __str__(self): return f'job {self.id}' @@ -1941,6 +2015,26 @@ async def run_until_done_or_deleted(self, f: Callable[..., Coroutine[Any, Any, A def secret_host_path(self, secret): return f'{self.scratch}/secrets/{secret["mount_path"]}' + # This path must already be bind mounted into the JVM + def cloudfuse_base_path(self): + # Make sure this path isn't in self.scratch to avoid accidental bucket deletions! + assert self.jvm + path = self.jvm.cloudfuse_dir + assert os.path.commonpath([path, self.scratch]) == '/' + return path + + def cloudfuse_data_path(self, bucket: str) -> str: + # Make sure this path isn't in self.scratch to avoid accidental bucket deletions! + path = f'{self.cloudfuse_base_path()}/{bucket}' + assert os.path.commonpath([path, self.scratch]) == '/' + return path + + def cloudfuse_tmp_path(self, bucket: str) -> str: + # Make sure this path isn't in self.scratch to avoid accidental bucket deletions! + path = f'{self.cloudfuse_base_path()}/tmp/{bucket}' + assert os.path.commonpath([path, self.scratch]) == '/' + return path + async def download_jar(self): assert self.worker assert self.worker.pool @@ -2002,6 +2096,37 @@ async def run(self): f'xfs_quota -x -c "limit -p bsoft={self.data_disk_storage_in_gib} bhard={self.data_disk_storage_in_gib} {self.project_id}" /host/' ) + with self.step('adding cloudfuse support'): + if self.cloudfuse: + await check_shell_output( + f'xfs_quota -x -c "project -s -p {self.cloudfuse_base_path()} {self.project_id}" /host/' + ) + + assert CLOUD_WORKER_API + for config in self.cloudfuse: + bucket = config['bucket'] + assert bucket + + credentials = self.credentials.cloudfuse_credentials(config) + credentials_path = CLOUD_WORKER_API.write_cloudfuse_credentials( + self.scratch, credentials, bucket + ) + data_path = self.cloudfuse_data_path(bucket) + tmp_path = self.cloudfuse_tmp_path(bucket) + + os.makedirs(data_path, exist_ok=True) + os.makedirs(tmp_path, exist_ok=True) + + await self.jvm.cloudfuse_mount_manager.mount( + bucket, + data_path, + user=self.user, + credentials_path=credentials_path, + tmp_path=tmp_path, + config=config, + ) + config['mounted'] = True + if self.secrets: for secret in self.secrets: populate_secret_host_path(self.secret_host_path(secret), secret['data']) @@ -2053,6 +2178,15 @@ async def cleanup(self): assert self.worker assert self.worker.file_store is not None assert self.worker.fs + assert self.jvm + + if self.cloudfuse: + for config in self.cloudfuse: + if config['mounted']: + bucket = config['bucket'] + assert bucket + mount_path = self.cloudfuse_data_path(bucket) + await self.jvm.cloudfuse_mount_manager.unmount(mount_path, user=self.user, bucket=bucket) if self.jvm is not None: self.worker.return_jvm(self.jvm) @@ -2064,6 +2198,8 @@ async def cleanup(self): self.format_version, self.batch_id, self.job_id, self.attempt_id, 'main', log_contents ) + await check_shell(f'xfs_quota -x -c "limit -p bsoft=0 bhard=0 {self.project_id}" /host') + try: await check_shell(f'xfs_quota -x -c "limit -p bsoft=0 bhard=0 {self.project_id}" /host') await blocking_to_async(self.pool, shutil.rmtree, self.scratch, ignore_errors=True) @@ -2160,6 +2296,7 @@ async def create_and_start( n_cores: int, socket_file: str, root_dir: str, + cloudfuse_dir: str, client_session: httpx.ClientSession, pool: concurrent.futures.ThreadPoolExecutor, fs: AsyncFS, @@ -2217,6 +2354,12 @@ async def create_and_start( 'type': 'none', 'options': ['rbind', 'rw'], }, + { + 'source': cloudfuse_dir, + 'destination': '/cloudfuse', + 'type': 'none', + 'options': ['rbind', 'ro', 'rslave'], + }, ] c = Container( @@ -2271,6 +2414,7 @@ async def create_container_and_connect( n_cores: int, socket_file: str, root_dir: str, + cloudfuse_dir: str, client_session: httpx.ClientSession, pool: concurrent.futures.ThreadPoolExecutor, fs: AsyncFS, @@ -2278,7 +2422,7 @@ async def create_container_and_connect( ) -> JVMContainer: try: container = await JVMContainer.create_and_start( - index, n_cores, socket_file, root_dir, client_session, pool, fs, task_manager + index, n_cores, socket_file, root_dir, cloudfuse_dir, client_session, pool, fs, task_manager ) attempts = 0 @@ -2320,15 +2464,18 @@ async def create_container_and_connect( async def create(cls, index: int, n_cores: int, worker: 'Worker'): token = uuid.uuid4().hex root_dir = f'/host/jvm-{token}' + cloudfuse_dir = f'/cloudfuse/jvm-{index}-{token[:5]}' socket_file = root_dir + '/socket' output_file = root_dir + '/output' should_interrupt = asyncio.Event() await blocking_to_async(worker.pool, os.makedirs, root_dir) + await blocking_to_async(worker.pool, os.makedirs, cloudfuse_dir) container = await cls.create_container_and_connect( index, n_cores, socket_file, root_dir, + cloudfuse_dir, worker.client_session, worker.pool, worker.fs, @@ -2339,6 +2486,7 @@ async def create(cls, index: int, n_cores: int, worker: 'Worker'): n_cores, socket_file, root_dir, + cloudfuse_dir, output_file, should_interrupt, container, @@ -2346,6 +2494,7 @@ async def create(cls, index: int, n_cores: int, worker: 'Worker'): worker.pool, worker.fs, worker.task_manager, + worker.cloudfuse_mount_manager, ) def __init__( @@ -2354,6 +2503,7 @@ def __init__( n_cores: int, socket_file: str, root_dir: str, + cloudfuse_dir: str, output_file: str, should_interrupt: asyncio.Event, container: JVMContainer, @@ -2361,11 +2511,13 @@ def __init__( pool: concurrent.futures.ThreadPoolExecutor, fs: AsyncFS, task_manager: aiotools.BackgroundTaskManager, + cloudfuse_mount_manager: ReadOnlyCloudfuseManager, ): self.index = index self.n_cores = n_cores self.socket_file = socket_file self.root_dir = root_dir + self.cloudfuse_dir = cloudfuse_dir self.output_file = output_file self.should_interrupt = should_interrupt self.container = container @@ -2373,6 +2525,7 @@ def __init__( self.pool = pool self.fs = fs self.task_manager = task_manager + self.cloudfuse_mount_manager = cloudfuse_mount_manager def __str__(self): return f'JVM-{self.index}' @@ -2506,6 +2659,8 @@ def __init__(self, client_session: httpx.ClientSession): self.headers: Optional[Dict[str, str]] = None + self.cloudfuse_mount_manager = ReadOnlyCloudfuseManager() + self._jvm_initializer_task = asyncio.create_task(self._initialize_jvms()) self._jvms: SortedSet[JVM] = SortedSet([], key=lambda jvm: jvm.n_cores)