From 77671413ea02b6750a7575c840ccf5bd21127afc Mon Sep 17 00:00:00 2001 From: Dan King Date: Fri, 20 Oct 2023 01:04:28 -0400 Subject: [PATCH] [batch] maybe reduce average JVMJob "connecting to jvm" time --- batch/batch/worker/worker.py | 67 ++++++++++++++++++++++-------------- 1 file changed, 42 insertions(+), 25 deletions(-) diff --git a/batch/batch/worker/worker.py b/batch/batch/worker/worker.py index 6cd263620aba..e9b378991f2b 100644 --- a/batch/batch/worker/worker.py +++ b/batch/batch/worker/worker.py @@ -2373,6 +2373,7 @@ async def cleanup(self): except asyncio.CancelledError: raise except Exception as e: + self.worker.return_jvm(await self.worker.recreate_jvm(self.jvm)) raise IncompleteJVMCleanupError( f'while unmounting fuse blob storage {bucket} from {mount_path} for {self.jvm_name} for job {self.id}' ) from e @@ -2910,6 +2911,14 @@ async def get_job_resource_usage(self) -> bytes: return await self.container.get_job_resource_usage() +class JVMQueue: + def __init__(self, n_cores): + self.queue: asyncio.Queue[JVM] = asyncio.Queue() + self.total = 0 + self.target = CORES // n_cores + self.n_cores = n_cores + + class Worker: def __init__(self, client_session: httpx.ClientSession): self.active = False @@ -2942,39 +2951,51 @@ def __init__(self, client_session: httpx.ClientSession): self.cloudfuse_mount_manager = ReadOnlyCloudfuseManager() + self._jvms_by_cores: Dict[int, JVMQueue] = { + n_cores: JVMQueue(n_cores) for n_cores in (1, 2, 4, 8) + } + self._jvm_waiters: asyncio.Queue[int] = asyncio.Queue() self._jvm_initializer_task = asyncio.create_task(self._initialize_jvms()) - self._jvms = SortedSet([], key=lambda jvm: jvm.n_cores) async def _initialize_jvms(self): assert instance_config if instance_config.worker_type() in ('standard', 'D', 'highmem', 'E'): - jvms: List[Awaitable[JVM]] = [] - for jvm_cores in (1, 2, 4, 8): - for _ in range(CORES // jvm_cores): - jvms.append(JVM.create(len(jvms), jvm_cores, self)) - assert len(jvms) == N_JVM_CONTAINERS - self._jvms.update(await asyncio.gather(*jvms)) - log.info(f'JVMs initialized {self._jvms}') + global_jvm_index = 0 + while True: + try: + self._jvm_waiters.get_nowait() + except asyncio.QueueEmpty: + for n_cores, jvmqueue in self._jvms_by_cores.items(): + if jvmqueue.target != jvmqueue.total: + jvmqueue.queue.put_nowait(await JVM.create(global_jvm_index, n_cores, self)) + jvmqueue.total += 1 + global_jvm_index += 1 + continue + break + assert self._jvm_waiters.empty() + log.info(f'JVMs initialized {self._jvms_by_cores}') async def borrow_jvm(self, n_cores: int) -> JVM: assert instance_config if instance_config.worker_type() not in ('standard', 'D', 'highmem', 'E'): raise ValueError(f'no JVMs available on {instance_config.worker_type()}') - await self._jvm_initializer_task - assert self._jvms - index = self._jvms.bisect_key_left(n_cores) - assert index < len(self._jvms), index - return self._jvms.pop(index) + + jvmqueue = self._jvms_by_cores[n_cores] + try: + return jvmqueue.queue.get_nowait() + except asyncio.QueueEmpty: + assert not self._jvm_initializer_task.done() + self._jvm_waiters.put_nowait(n_cores) + return await jvmqueue.queue.get() def return_jvm(self, jvm: JVM): jvm.reset() - self._jvms.add(jvm) + self._jvms_by_cores[jvm.n_cores].queue.put_nowait(jvm) async def recreate_jvm(self, jvm: JVM): - self._jvms.remove(jvm) - log.info(f'quarantined {jvm} and recreated a new jvm') - new_jvm = await JVM.create(jvm.index, jvm.n_cores, self) - self._jvms.add(new_jvm) + await jvm.kill() # is this OK to do? Seems like we ought to, no? + log.info(f'killed {jvm} and recreated a new jvm') + return await JVM.create(jvm.index, jvm.n_cores, self) @property def headers(self) -> Dict[str, str]: @@ -2984,8 +3005,9 @@ async def shutdown(self): log.info('Worker.shutdown') self._jvm_initializer_task.cancel() async with AsyncExitStack() as cleanup: - for jvm in self._jvms: - cleanup.push_async_callback(jvm.kill) + for jvmqueue in self._jvms_by_cores.values(): + while not jvmqueue.queue.empty(): + cleanup.push_async_callback(jvmqueue.queue.get_nowait().kill) cleanup.push_async_callback(self.task_manager.shutdown_and_wait) if self.file_store: cleanup.push_async_callback(self.file_store.close) @@ -3000,11 +3022,6 @@ async def run_job(self, job): raise except JVMCreationError: self.stop_event.set() - except IncompleteJVMCleanupError: - assert isinstance(job, JVMJob) - assert job.jvm is not None - await self.recreate_jvm(job.jvm) - log.exception(f'while running {job}, ignoring') except Exception as e: if not user_error(e): log.exception(f'while running {job}, ignoring')