Skip to content

Commit

Permalink
[batch] maybe reduce average JVMJob "connecting to jvm" time (#13870)
Browse files Browse the repository at this point in the history
  • Loading branch information
danking authored Oct 25, 2023
1 parent e739a95 commit 2f69f8a
Showing 1 changed file with 78 additions and 28 deletions.
106 changes: 78 additions & 28 deletions batch/batch/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@
import orjson
from aiodocker.exceptions import DockerError # type: ignore
from aiohttp import web
from sortedcontainers import SortedSet

from gear import json_request, json_response
from hailtop import aiotools, httpx
Expand Down Expand Up @@ -2373,6 +2372,7 @@ async def cleanup(self):
except asyncio.CancelledError:
raise
except Exception as e:
await self.worker.return_broken_jvm(self.jvm)
raise IncompleteJVMCleanupError(
f'while unmounting fuse blob storage {bucket} from {mount_path} for {self.jvm_name} for job {self.id}'
) from e
Expand Down Expand Up @@ -2910,6 +2910,47 @@ async def get_job_resource_usage(self) -> bytes:
return await self.container.get_job_resource_usage()


class JVMPool:
global_jvm_index = 0

def __init__(self, n_cores: int, worker: 'Worker'):
self.queue: asyncio.Queue[JVM] = asyncio.Queue()
self.total_jvms_including_borrowed = 0
self.max_jvms = CORES // n_cores
self.n_cores = n_cores
self.worker = worker

def borrow_jvm_nowait(self) -> JVM:
return self.queue.get_nowait()

async def borrow_jvm(self) -> JVM:
return await self.queue.get()

def return_jvm(self, jvm: JVM):
assert self.n_cores == jvm.n_cores
assert self.queue.qsize() < self.max_jvms
self.queue.put_nowait(jvm)

async def return_broken_jvm(self, jvm: JVM):
await jvm.kill()
self.total_jvms_including_borrowed -= 1
await self.create_jvm()
log.info(f'killed {jvm} and recreated a new jvm')

async def create_jvm(self):
assert self.queue.qsize() < self.max_jvms
assert self.total_jvms_including_borrowed < self.max_jvms
self.queue.put_nowait(await JVM.create(JVMPool.global_jvm_index, self.n_cores, self.worker))
self.total_jvms_including_borrowed += 1
JVMPool.global_jvm_index += 1

def full(self) -> bool:
return self.total_jvms_including_borrowed == self.max_jvms

def __repr__(self):
return f'JVMPool({self.queue!r}, {self.total_jvms_including_borrowed!r}, {self.max_jvms!r}, {self.n_cores!r})'


class Worker:
def __init__(self, client_session: httpx.ClientSession):
self.active = False
Expand Down Expand Up @@ -2942,39 +2983,52 @@ def __init__(self, client_session: httpx.ClientSession):

self.cloudfuse_mount_manager = ReadOnlyCloudfuseManager()

self._jvmpools_by_cores: Dict[int, JVMPool] = {n_cores: JVMPool(n_cores, self) for n_cores in (1, 2, 4, 8)}
self._waiting_for_jvm_with_n_cores: asyncio.Queue[int] = asyncio.Queue()
self._jvm_initializer_task = asyncio.create_task(self._initialize_jvms())
self._jvms = SortedSet([], key=lambda jvm: jvm.n_cores)

async def _initialize_jvms(self):
assert instance_config
if instance_config.worker_type() in ('standard', 'D', 'highmem', 'E'):
jvms: List[Awaitable[JVM]] = []
for jvm_cores in (1, 2, 4, 8):
for _ in range(CORES // jvm_cores):
jvms.append(JVM.create(len(jvms), jvm_cores, self))
assert len(jvms) == N_JVM_CONTAINERS
self._jvms.update(await asyncio.gather(*jvms))
log.info(f'JVMs initialized {self._jvms}')
if instance_config.worker_type() not in ('standard', 'D', 'highmem', 'E'):
log.info('no JVMs initialized')

while True:
try:
requested_n_cores = self._waiting_for_jvm_with_n_cores.get_nowait()
await self._jvmpools_by_cores[requested_n_cores].create_jvm()
except asyncio.QueueEmpty:
next_unfull_jvmpool = None
for jvmpool in self._jvmpools_by_cores.values():
if not jvmpool.full():
next_unfull_jvmpool = jvmpool
break

if next_unfull_jvmpool is None:
break
await next_unfull_jvmpool.create_jvm()

assert self._waiting_for_jvm_with_n_cores.empty()
assert all(jvmpool.full() for jvmpool in self._jvmpools_by_cores.values())
log.info(f'JVMs initialized {self._jvmpools_by_cores}')

async def borrow_jvm(self, n_cores: int) -> JVM:
assert instance_config
if instance_config.worker_type() not in ('standard', 'D', 'highmem', 'E'):
raise ValueError(f'no JVMs available on {instance_config.worker_type()}')
await self._jvm_initializer_task
assert self._jvms
index = self._jvms.bisect_key_left(n_cores)
assert index < len(self._jvms), index
return self._jvms.pop(index)

jvmpool = self._jvmpools_by_cores[n_cores]
try:
return jvmpool.borrow_jvm_nowait()
except asyncio.QueueEmpty:
self._waiting_for_jvm_with_n_cores.put_nowait(n_cores)
return await jvmpool.borrow_jvm()

def return_jvm(self, jvm: JVM):
jvm.reset()
self._jvms.add(jvm)
self._jvmpools_by_cores[jvm.n_cores].return_jvm(jvm)

async def recreate_jvm(self, jvm: JVM):
self._jvms.remove(jvm)
log.info(f'quarantined {jvm} and recreated a new jvm')
new_jvm = await JVM.create(jvm.index, jvm.n_cores, self)
self._jvms.add(new_jvm)
async def return_broken_jvm(self, jvm: JVM):
return await self._jvmpools_by_cores[jvm.n_cores].return_broken_jvm(jvm)

@property
def headers(self) -> Dict[str, str]:
Expand All @@ -2984,8 +3038,9 @@ async def shutdown(self):
log.info('Worker.shutdown')
self._jvm_initializer_task.cancel()
async with AsyncExitStack() as cleanup:
for jvm in self._jvms:
cleanup.push_async_callback(jvm.kill)
for jvmqueue in self._jvmpools_by_cores.values():
while not jvmqueue.queue.empty():
cleanup.push_async_callback(jvmqueue.queue.get_nowait().kill)
cleanup.push_async_callback(self.task_manager.shutdown_and_wait)
if self.file_store:
cleanup.push_async_callback(self.file_store.close)
Expand All @@ -3000,11 +3055,6 @@ async def run_job(self, job):
raise
except JVMCreationError:
self.stop_event.set()
except IncompleteJVMCleanupError:
assert isinstance(job, JVMJob)
assert job.jvm is not None
await self.recreate_jvm(job.jvm)
log.exception(f'while running {job}, ignoring')
except Exception as e:
if not user_error(e):
log.exception(f'while running {job}, ignoring')
Expand Down

0 comments on commit 2f69f8a

Please sign in to comment.