Skip to content

Commit

Permalink
[batch] Create read-only fuse mounts shared between concurrent jvm jobs
Browse files Browse the repository at this point in the history
  • Loading branch information
daniel-goldstein committed Mar 14, 2023
1 parent 6f30189 commit dcc7466
Showing 1 changed file with 187 additions and 32 deletions.
219 changes: 187 additions & 32 deletions batch/batch/worker/worker.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import abc
import asyncio
import base64
import concurrent
Expand Down Expand Up @@ -366,6 +367,67 @@ async def _free(self, netns: NetworkNamespace):
self.public_networks.put_nowait(netns)


# Handling read-only cloudfuse mounts that can be shared across jobs


class FuseMount:
def __init__(self, path):
self.path = path
self.bind_mounts = set()


class ReadOnlyCloudfuseManager:
def __init__(self):
self.cloudfuse_dir = '/cloudfuse/readonly_cache'
self.fuse_mounts: Dict[Tuple[str, str], FuseMount] = {}
self.user_bucket_locks: Dict[Tuple[str, str], asyncio.Lock] = defaultdict(asyncio.Lock)

async def mount(
self, bucket: str, destination: str, *, user: str, credentials_path: str, tmp_path: str, config: dict
):
async with self.user_bucket_locks[(user, bucket)]:
if (user, bucket) not in self.fuse_mounts:
local_path = self._new_path()
await self._fuse_mount(local_path, credentials_path=credentials_path, tmp_path=tmp_path, config=config)
self.fuse_mounts[(user, bucket)] = FuseMount(local_path)
mount = self.fuse_mounts[(user, bucket)]
mount.bind_mounts.add(destination)
await self._bind_mount(mount.path, destination)

async def unmount(self, destination, *, user: str, bucket: str):
async with self.user_bucket_locks[(user, bucket)]:
mount = self.fuse_mounts[(user, bucket)]
await self._bind_unmount(destination)
mount.bind_mounts.remove(destination)
if len(mount.bind_mounts) == 0:
await self._fuse_unmount(mount.path)
del self.fuse_mounts[(user, bucket)]

async def _fuse_mount(self, destination: str, *, credentials_path: str, tmp_path: str, config: dict):
assert CLOUD_WORKER_API
await CLOUD_WORKER_API.mount_cloudfuse(
credentials_path,
destination,
tmp_path,
config,
)

async def _fuse_unmount(self, path: str):
assert CLOUD_WORKER_API
await CLOUD_WORKER_API.unmount_cloudfuse(path)

async def _bind_mount(self, src, dst):
await check_exec_output('mount', '--bind', src, dst)

async def _bind_unmount(self, dst):
await check_exec_output('umount', dst)

def _new_path(self):
path = f'{self.cloudfuse_dir}/{uuid.uuid4().hex}'
os.makedirs(path)
return path


def docker_call_retry(timeout, name, f, *args, **kwargs):
debug_string = f'In docker call to {f.__name__} for {name}'

Expand Down Expand Up @@ -1019,6 +1081,7 @@ async def container_config(self):
},
},
'linux': {
'rootfsPropagation': 'slave',
'namespaces': [
{'type': 'pid'},
{
Expand Down Expand Up @@ -1295,7 +1358,7 @@ def copy_container(
)


class Job:
class Job(abc.ABC):
quota_project_id = 100

@staticmethod
Expand All @@ -1310,23 +1373,17 @@ def secret_host_path(self, secret) -> str:
def io_host_path(self) -> str:
return f'{self.scratch}/io'

@abc.abstractmethod
def cloudfuse_base_path(self):
# Make sure this path isn't in self.scratch to avoid accidental bucket deletions!
path = f'/cloudfuse/{self.token}'
assert os.path.commonpath([path, self.scratch]) == '/'
return path
raise NotImplementedError

@abc.abstractmethod
def cloudfuse_data_path(self, bucket: str) -> str:
# Make sure this path isn't in self.scratch to avoid accidental bucket deletions!
path = f'{self.cloudfuse_base_path()}/{bucket}/data'
assert os.path.commonpath([path, self.scratch]) == '/'
return path
raise NotImplementedError

@abc.abstractmethod
def cloudfuse_tmp_path(self, bucket: str) -> str:
# Make sure this path isn't in self.scratch to avoid accidental bucket deletions!
path = f'{self.cloudfuse_base_path()}/{bucket}/tmp'
assert os.path.commonpath([path, self.scratch]) == '/'
return path
raise NotImplementedError

def cloudfuse_credentials_path(self, bucket: str) -> str:
return f'{self.scratch}/cloudfuse/{bucket}'
Expand Down Expand Up @@ -1432,24 +1489,7 @@ def __init__(
self.main_volume_mounts.append(io_volume_mount)
self.output_volume_mounts.append(io_volume_mount)

requester_pays_project = job_spec.get('requester_pays_project')
cloudfuse = job_spec.get('cloudfuse') or job_spec.get('gcsfuse')
self.cloudfuse = cloudfuse
if cloudfuse:
for config in cloudfuse:
if requester_pays_project:
config['requester_pays_project'] = requester_pays_project
config['mounted'] = False
bucket = config['bucket']
assert bucket
self.main_volume_mounts.append(
{
'source': f'{self.cloudfuse_data_path(bucket)}',
'destination': config['mount_path'],
'type': 'none',
'options': ['rbind', 'rw', 'shared'],
}
)
self.cloudfuse = job_spec.get('cloudfuse') or job_spec.get('gcsfuse')

secrets = job_spec.get('secrets')
self.secrets = secrets
Expand Down Expand Up @@ -1580,6 +1620,22 @@ def __init__(
]
self.env += hail_extra_env

if self.cloudfuse:
for config in self.cloudfuse:
if requester_pays_project:
config['requester_pays_project'] = requester_pays_project
config['mounted'] = False
bucket = config['bucket']
assert bucket
self.main_volume_mounts.append(
{
'source': f'{self.cloudfuse_data_path(bucket)}',
'destination': config['mount_path'],
'type': 'none',
'options': ['rbind', 'rw', 'shared'],
}
)

if self.secrets:
for secret in self.secrets:
volume_mount = {
Expand Down Expand Up @@ -1885,6 +1941,24 @@ def status(self):
status['timing'] = self.timings.to_dict()
return status

def cloudfuse_base_path(self):
# Make sure this path isn't in self.scratch to avoid accidental bucket deletions!
path = f'/cloudfuse/{self.token}'
assert os.path.commonpath([path, self.scratch]) == '/'
return path

def cloudfuse_data_path(self, bucket: str) -> str:
# Make sure this path isn't in self.scratch to avoid accidental bucket deletions!
path = f'{self.cloudfuse_base_path()}/{bucket}/data'
assert os.path.commonpath([path, self.scratch]) == '/'
return path

def cloudfuse_tmp_path(self, bucket: str) -> str:
# Make sure this path isn't in self.scratch to avoid accidental bucket deletions!
path = f'{self.cloudfuse_base_path()}/{bucket}/tmp'
assert os.path.commonpath([path, self.scratch]) == '/'
return path

def __str__(self):
return f'job {self.id}'

Expand Down Expand Up @@ -1941,6 +2015,26 @@ async def run_until_done_or_deleted(self, f: Callable[..., Coroutine[Any, Any, A
def secret_host_path(self, secret):
return f'{self.scratch}/secrets/{secret["mount_path"]}'

# This path must already be bind mounted into the JVM
def cloudfuse_base_path(self):
# Make sure this path isn't in self.scratch to avoid accidental bucket deletions!
assert self.jvm
path = self.jvm.cloudfuse_dir
assert os.path.commonpath([path, self.scratch]) == '/'
return path

def cloudfuse_data_path(self, bucket: str) -> str:
# Make sure this path isn't in self.scratch to avoid accidental bucket deletions!
path = f'{self.cloudfuse_base_path()}/{bucket}'
assert os.path.commonpath([path, self.scratch]) == '/'
return path

def cloudfuse_tmp_path(self, bucket: str) -> str:
# Make sure this path isn't in self.scratch to avoid accidental bucket deletions!
path = f'{self.cloudfuse_base_path()}/tmp/{bucket}'
assert os.path.commonpath([path, self.scratch]) == '/'
return path

async def download_jar(self):
assert self.worker
assert self.worker.pool
Expand Down Expand Up @@ -2002,6 +2096,37 @@ async def run(self):
f'xfs_quota -x -c "limit -p bsoft={self.data_disk_storage_in_gib} bhard={self.data_disk_storage_in_gib} {self.project_id}" /host/'
)

with self.step('adding cloudfuse support'):
if self.cloudfuse:
await check_shell_output(
f'xfs_quota -x -c "project -s -p {self.cloudfuse_base_path()} {self.project_id}" /host/'
)

assert CLOUD_WORKER_API
for config in self.cloudfuse:
bucket = config['bucket']
assert bucket

credentials = self.credentials.cloudfuse_credentials(config)
credentials_path = CLOUD_WORKER_API.write_cloudfuse_credentials(
self.scratch, credentials, bucket
)
data_path = self.cloudfuse_data_path(bucket)
tmp_path = self.cloudfuse_tmp_path(bucket)

os.makedirs(data_path, exist_ok=True)
os.makedirs(tmp_path, exist_ok=True)

await self.jvm.cloudfuse_mount_manager.mount(
bucket,
data_path,
user=self.user,
credentials_path=credentials_path,
tmp_path=tmp_path,
config=config,
)
config['mounted'] = True

if self.secrets:
for secret in self.secrets:
populate_secret_host_path(self.secret_host_path(secret), secret['data'])
Expand Down Expand Up @@ -2053,6 +2178,15 @@ async def cleanup(self):
assert self.worker
assert self.worker.file_store is not None
assert self.worker.fs
assert self.jvm

if self.cloudfuse:
for config in self.cloudfuse:
if config['mounted']:
bucket = config['bucket']
assert bucket
mount_path = self.cloudfuse_data_path(bucket)
await self.jvm.cloudfuse_mount_manager.unmount(mount_path, user=self.user, bucket=bucket)

if self.jvm is not None:
self.worker.return_jvm(self.jvm)
Expand All @@ -2064,6 +2198,8 @@ async def cleanup(self):
self.format_version, self.batch_id, self.job_id, self.attempt_id, 'main', log_contents
)

await check_shell(f'xfs_quota -x -c "limit -p bsoft=0 bhard=0 {self.project_id}" /host')

try:
await check_shell(f'xfs_quota -x -c "limit -p bsoft=0 bhard=0 {self.project_id}" /host')
await blocking_to_async(self.pool, shutil.rmtree, self.scratch, ignore_errors=True)
Expand Down Expand Up @@ -2160,6 +2296,7 @@ async def create_and_start(
n_cores: int,
socket_file: str,
root_dir: str,
cloudfuse_dir: str,
client_session: httpx.ClientSession,
pool: concurrent.futures.ThreadPoolExecutor,
fs: AsyncFS,
Expand Down Expand Up @@ -2217,6 +2354,12 @@ async def create_and_start(
'type': 'none',
'options': ['rbind', 'rw'],
},
{
'source': cloudfuse_dir,
'destination': '/cloudfuse',
'type': 'none',
'options': ['rbind', 'ro', 'rslave'],
},
]

c = Container(
Expand Down Expand Up @@ -2271,14 +2414,15 @@ async def create_container_and_connect(
n_cores: int,
socket_file: str,
root_dir: str,
cloudfuse_dir: str,
client_session: httpx.ClientSession,
pool: concurrent.futures.ThreadPoolExecutor,
fs: AsyncFS,
task_manager: aiotools.BackgroundTaskManager,
) -> JVMContainer:
try:
container = await JVMContainer.create_and_start(
index, n_cores, socket_file, root_dir, client_session, pool, fs, task_manager
index, n_cores, socket_file, root_dir, cloudfuse_dir, client_session, pool, fs, task_manager
)

attempts = 0
Expand Down Expand Up @@ -2320,15 +2464,18 @@ async def create_container_and_connect(
async def create(cls, index: int, n_cores: int, worker: 'Worker'):
token = uuid.uuid4().hex
root_dir = f'/host/jvm-{token}'
cloudfuse_dir = f'/cloudfuse/jvm-{index}-{token[:5]}'
socket_file = root_dir + '/socket'
output_file = root_dir + '/output'
should_interrupt = asyncio.Event()
await blocking_to_async(worker.pool, os.makedirs, root_dir)
await blocking_to_async(worker.pool, os.makedirs, cloudfuse_dir)
container = await cls.create_container_and_connect(
index,
n_cores,
socket_file,
root_dir,
cloudfuse_dir,
worker.client_session,
worker.pool,
worker.fs,
Expand All @@ -2339,13 +2486,15 @@ async def create(cls, index: int, n_cores: int, worker: 'Worker'):
n_cores,
socket_file,
root_dir,
cloudfuse_dir,
output_file,
should_interrupt,
container,
worker.client_session,
worker.pool,
worker.fs,
worker.task_manager,
worker.cloudfuse_mount_manager,
)

def __init__(
Expand All @@ -2354,25 +2503,29 @@ def __init__(
n_cores: int,
socket_file: str,
root_dir: str,
cloudfuse_dir: str,
output_file: str,
should_interrupt: asyncio.Event,
container: JVMContainer,
client_session: httpx.ClientSession,
pool: concurrent.futures.ThreadPoolExecutor,
fs: AsyncFS,
task_manager: aiotools.BackgroundTaskManager,
cloudfuse_mount_manager: ReadOnlyCloudfuseManager,
):
self.index = index
self.n_cores = n_cores
self.socket_file = socket_file
self.root_dir = root_dir
self.cloudfuse_dir = cloudfuse_dir
self.output_file = output_file
self.should_interrupt = should_interrupt
self.container = container
self.client_session = client_session
self.pool = pool
self.fs = fs
self.task_manager = task_manager
self.cloudfuse_mount_manager = cloudfuse_mount_manager

def __str__(self):
return f'JVM-{self.index}'
Expand Down Expand Up @@ -2506,6 +2659,8 @@ def __init__(self, client_session: httpx.ClientSession):

self.headers: Optional[Dict[str, str]] = None

self.cloudfuse_mount_manager = ReadOnlyCloudfuseManager()

self._jvm_initializer_task = asyncio.create_task(self._initialize_jvms())
self._jvms: SortedSet[JVM] = SortedSet([], key=lambda jvm: jvm.n_cores)

Expand Down

0 comments on commit dcc7466

Please sign in to comment.