Skip to content

Commit

Permalink
[batch] Replace GCS with GoogleStorageAsyncFS (attempt #2) (hail-is#1…
Browse files Browse the repository at this point in the history
…0814)

* [batch] replace GCS with GoogleStorageAsyncFS in LogStore

* delint

* addr comments

* fix s3 exception handling

* fix catching nonexistent blob in azurefs

* flip version number of instance
  • Loading branch information
jigold authored Sep 27, 2021
1 parent 674844b commit d9bb866
Show file tree
Hide file tree
Showing 23 changed files with 249 additions and 402 deletions.
9 changes: 5 additions & 4 deletions batch/batch/driver/create_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from ..batch_configuration import PROJECT, DOCKER_ROOT_IMAGE, DOCKER_PREFIX, DEFAULT_NAMESPACE
from ..inst_coll_config import machine_type_to_dict
from ..worker_config import WorkerConfig
from ..log_store import LogStore
from ..file_store import FileStore
from ..utils import unreserved_worker_data_disk_size_gib

log = logging.getLogger('create_instance')
Expand All @@ -34,7 +34,8 @@ def create_instance_config(
preemptible,
job_private,
) -> Tuple[Dict[str, Any], WorkerConfig]:
log_store: LogStore = app['log_store']
file_store: FileStore = app['file_store']

cores = int(machine_type_to_dict(machine_type)['cores'])

if worker_local_ssd_data_disk:
Expand Down Expand Up @@ -324,8 +325,8 @@ def create_instance_config(
{'key': 'docker_prefix', 'value': DOCKER_PREFIX},
{'key': 'namespace', 'value': DEFAULT_NAMESPACE},
{'key': 'internal_ip', 'value': INTERNAL_GATEWAY_IP},
{'key': 'batch_logs_bucket_name', 'value': log_store.batch_logs_bucket_name},
{'key': 'instance_id', 'value': log_store.instance_id},
{'key': 'batch_logs_bucket_name', 'value': file_store.batch_logs_bucket_name},
{'key': 'instance_id', 'value': file_store.instance_id},
{'key': 'max_idle_time_msecs', 'value': max_idle_time_msecs},
]
},
Expand Down
6 changes: 3 additions & 3 deletions batch/batch/driver/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from ..batch_configuration import KUBERNETES_TIMEOUT_IN_SECONDS, KUBERNETES_SERVER_URL
from ..batch_format_version import BatchFormatVersion
from ..spec_writer import SpecWriter
from ..log_store import LogStore
from ..file_store import FileStore

from .k8s_cache import K8sCache

Expand Down Expand Up @@ -379,7 +379,7 @@ async def job_config(app, record, attempt_id):
async def schedule_job(app, record, instance):
assert instance.state == 'active'

log_store: LogStore = app['log_store']
file_store: FileStore = app['file_store']
db: Database = app['db']

batch_id = record['batch_id']
Expand Down Expand Up @@ -407,7 +407,7 @@ async def schedule_job(app, record, instance):
}

if format_version.has_full_status_in_gcs():
await log_store.write_status_file(batch_id, job_id, attempt_id, json.dumps(status))
await file_store.write_status_file(batch_id, job_id, attempt_id, json.dumps(status))

db_status = format_version.db_status(status)
resources = []
Expand Down
24 changes: 10 additions & 14 deletions batch/batch/driver/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,13 @@
from typing import Dict
from functools import wraps
from collections import namedtuple, defaultdict
import concurrent
import copy
import asyncio
import signal
import dictdiffer
from aiohttp import web
import aiohttp_session
import kubernetes_asyncio as kube
import google.oauth2.service_account
from prometheus_async.aio.web import server_stats
import prometheus_client as pc # type: ignore
from gear import (
Expand Down Expand Up @@ -40,7 +38,7 @@
import googlecloudprofiler
import uvloop

from ..log_store import LogStore
from ..file_store import FileStore
from ..batch import cancel_batch_in_db
from ..batch_configuration import (
REFRESH_INTERVAL_IN_SECONDS,
Expand Down Expand Up @@ -1042,8 +1040,6 @@ async def scheduling_cancelling_bump(app):

async def on_startup(app):
app['task_manager'] = aiotools.BackgroundTaskManager()
pool = concurrent.futures.ThreadPoolExecutor()
app['blocking_pool'] = pool

kube.config.load_incluster_config()
k8s_client = kube.client.CoreV1Api()
Expand Down Expand Up @@ -1105,9 +1101,9 @@ async def on_startup(app):
async_worker_pool = AsyncWorkerPool(100, queue_size=100)
app['async_worker_pool'] = async_worker_pool

credentials = google.oauth2.service_account.Credentials.from_service_account_file('/gsa-key/key.json')
log_store = LogStore(BATCH_BUCKET_NAME, instance_id, pool, credentials=credentials)
app['log_store'] = log_store
credentials = aiogoogle.auth.credentials.Credentials.from_file('/gsa-key/key.json')
fs = aiogoogle.GoogleStorageAsyncFS(credentials=credentials)
app['file_store'] = FileStore(fs, BATCH_BUCKET_NAME, instance_id)

zone_monitor = ZoneMonitor(app)
app['zone_monitor'] = zone_monitor
Expand Down Expand Up @@ -1146,22 +1142,22 @@ async def on_startup(app):

async def on_cleanup(app):
try:
app['blocking_pool'].shutdown()
await app['db'].async_close()
finally:
try:
await app['db'].async_close()
app['zone_monitor'].shutdown()
finally:
try:
app['zone_monitor'].shutdown()
app['inst_coll_manager'].shutdown()
finally:
try:
app['inst_coll_manager'].shutdown()
app['canceller'].shutdown()
finally:
try:
app['canceller'].shutdown()
app['gce_event_monitor'].shutdown()
finally:
try:
app['gce_event_monitor'].shutdown()
await app['file_store'].close()
finally:
try:
app['task_manager'].shutdown()
Expand Down
58 changes: 33 additions & 25 deletions batch/batch/log_store.py → batch/batch/file_store.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
import asyncio

from hailtop.google_storage import GCS
from hailtop.aiotools.fs import AsyncFS

from .spec_writer import SpecWriter
from .globals import BATCH_FORMAT_VERSION
Expand All @@ -10,12 +10,13 @@
log = logging.getLogger('logstore')


class LogStore:
def __init__(self, batch_logs_bucket_name, instance_id, blocking_pool, *, project=None, credentials=None):
class FileStore:
def __init__(self, fs: AsyncFS, batch_logs_bucket_name, instance_id):
self.fs = fs
self.batch_logs_bucket_name = batch_logs_bucket_name
self.instance_id = instance_id

self.batch_logs_root = f'gs://{batch_logs_bucket_name}/batch/logs/{instance_id}/batch'
self.gcs = GCS(blocking_pool, project=project, credentials=credentials)

log.info(f'BATCH_LOGS_ROOT {self.batch_logs_root}')
format_version = BatchFormatVersion(BATCH_FORMAT_VERSION)
Expand All @@ -27,34 +28,36 @@ def batch_log_dir(self, batch_id):
def log_path(self, format_version, batch_id, job_id, attempt_id, task):
if not format_version.has_attempt_in_log_path():
return f'{self.batch_log_dir(batch_id)}/{job_id}/{task}/log'

return f'{self.batch_log_dir(batch_id)}/{job_id}/{attempt_id}/{task}/log'

async def read_log_file(self, format_version, batch_id, job_id, attempt_id, task):
path = self.log_path(format_version, batch_id, job_id, attempt_id, task)
return await self.gcs.read_gs_file(path)
url = self.log_path(format_version, batch_id, job_id, attempt_id, task)
data = await self.fs.read(url)
return data.decode('utf-8')

async def write_log_file(self, format_version, batch_id, job_id, attempt_id, task, data):
path = self.log_path(format_version, batch_id, job_id, attempt_id, task)
return await self.gcs.write_gs_file_from_string(path, data)
url = self.log_path(format_version, batch_id, job_id, attempt_id, task)
await self.fs.write(url, data.encode('utf-8'))

async def delete_batch_logs(self, batch_id):
await self.gcs.delete_gs_files(self.batch_log_dir(batch_id))
url = self.batch_log_dir(batch_id)
await self.fs.rmtree(None, url)

def status_path(self, batch_id, job_id, attempt_id):
return f'{self.batch_log_dir(batch_id)}/{job_id}/{attempt_id}/status.json'

async def read_status_file(self, batch_id, job_id, attempt_id):
path = self.status_path(batch_id, job_id, attempt_id)
return await self.gcs.read_gs_file(path)
url = self.status_path(batch_id, job_id, attempt_id)
data = await self.fs.read(url)
return data.decode('utf-8')

async def write_status_file(self, batch_id, job_id, attempt_id, status):
path = self.status_path(batch_id, job_id, attempt_id)
return await self.gcs.write_gs_file_from_string(path, status)
url = self.status_path(batch_id, job_id, attempt_id)
await self.fs.write(url, status.encode('utf-8'))

async def delete_status_file(self, batch_id, job_id, attempt_id):
path = self.status_path(batch_id, job_id, attempt_id)
return await self.gcs.delete_gs_file(path)
url = self.status_path(batch_id, job_id, attempt_id)
return await self.fs.remove(url)

def specs_dir(self, batch_id, token):
return f'{self.batch_logs_root}/{batch_id}/bunch/{token}'
Expand All @@ -66,22 +69,27 @@ def specs_index_path(self, batch_id, token):
return f'{self.specs_dir(batch_id, token)}/specs.idx'

async def read_spec_file(self, batch_id, token, start_job_id, job_id):
idx_path = self.specs_index_path(batch_id, token)
idx_url = self.specs_index_path(batch_id, token)
idx_start, idx_end = SpecWriter.get_index_file_offsets(job_id, start_job_id)
offsets = await self.gcs.read_binary_gs_file(idx_path, start=idx_start, end=idx_end)
offsets = await self.fs.read_range(idx_url, idx_start, idx_end)

spec_path = self.specs_path(batch_id, token)
spec_url = self.specs_path(batch_id, token)
spec_start, spec_end = SpecWriter.get_spec_file_offsets(offsets)
return await self.gcs.read_gs_file(spec_path, start=spec_start, end=spec_end)
data = await self.fs.read_range(spec_url, spec_start, spec_end)
return data.decode('utf-8')

async def write_spec_file(self, batch_id, token, data_bytes, offsets_bytes):
idx_path = self.specs_index_path(batch_id, token)
write1 = self.gcs.write_gs_file_from_string(idx_path, offsets_bytes, content_type='application/octet-stream')
idx_url = self.specs_index_path(batch_id, token)
write1 = self.fs.write(idx_url, offsets_bytes)

specs_path = self.specs_path(batch_id, token)
write2 = self.gcs.write_gs_file_from_string(specs_path, data_bytes)
specs_url = self.specs_path(batch_id, token)
write2 = self.fs.write(specs_url, data_bytes)

await asyncio.gather(write1, write2)

async def delete_spec_file(self, batch_id, token):
await self.gcs.delete_gs_files(self.specs_dir(batch_id, token))
url = self.specs_dir(batch_id, token)
await self.fs.rmtree(None, url)

async def close(self):
await self.fs.close()
31 changes: 15 additions & 16 deletions batch/batch/front_end/front_end.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from numbers import Number
import os
import concurrent
import logging
import json
import random
Expand Down Expand Up @@ -34,6 +33,7 @@
periodically_call,
)
from hailtop.batch_client.parse import parse_cpu_in_mcpu, parse_memory_in_bytes, parse_storage_in_bytes
import hailtop.aiogoogle as aiogoogle
from hailtop.config import get_deploy_config
from hailtop.tls import internal_server_ssl_context
from hailtop.httpx import client_session
Expand Down Expand Up @@ -69,7 +69,7 @@
BatchOperationAlreadyCompletedError,
)
from ..inst_coll_config import InstanceCollectionConfigs
from ..log_store import LogStore
from ..file_store import FileStore
from ..database import CallError, check_call_procedure
from ..batch_configuration import BATCH_BUCKET_NAME, DEFAULT_NAMESPACE, SCOPE
from ..globals import HTTP_CLIENT_MAX_SIZE, BATCH_FORMAT_VERSION, memory_to_worker_type
Expand Down Expand Up @@ -335,12 +335,12 @@ async def _get_job_log_from_record(app, batch_id, job_id, record):
raise

if state in ('Error', 'Failed', 'Success'):
log_store: LogStore = app['log_store']
file_store: FileStore = app['file_store']
batch_format_version = BatchFormatVersion(record['format_version'])

async def _read_log_from_gcs(task):
try:
data = await log_store.read_log_file(batch_format_version, batch_id, job_id, record['attempt_id'], task)
data = await file_store.read_log_file(batch_format_version, batch_id, job_id, record['attempt_id'], task)
except google.api_core.exceptions.NotFound:
id = (batch_id, job_id)
log.exception(f'missing log file for {id} and task {task}')
Expand Down Expand Up @@ -411,7 +411,7 @@ async def _get_attributes(app, record):

async def _get_full_job_spec(app, record):
db: Database = app['db']
log_store: LogStore = app['log_store']
file_store: FileStore = app['file_store']

batch_id = record['batch_id']
job_id = record['job_id']
Expand All @@ -423,7 +423,7 @@ async def _get_full_job_spec(app, record):
token, start_job_id = await SpecWriter.get_token_start_id(db, batch_id, job_id)

try:
spec = await log_store.read_spec_file(batch_id, token, start_job_id, job_id)
spec = await file_store.read_spec_file(batch_id, token, start_job_id, job_id)
return json.loads(spec)
except google.api_core.exceptions.NotFound:
id = (batch_id, job_id)
Expand All @@ -432,7 +432,7 @@ async def _get_full_job_spec(app, record):


async def _get_full_job_status(app, record):
log_store: LogStore = app['log_store']
file_store: FileStore = app['file_store']

batch_id = record['batch_id']
job_id = record['job_id']
Expand All @@ -448,7 +448,7 @@ async def _get_full_job_status(app, record):
return json.loads(record['status'])

try:
status = await log_store.read_status_file(batch_id, job_id, attempt_id)
status = await file_store.read_status_file(batch_id, job_id, attempt_id)
return json.loads(status)
except google.api_core.exceptions.NotFound:
id = (batch_id, job_id)
Expand Down Expand Up @@ -616,7 +616,7 @@ def check_service_account_permissions(user, sa):
async def create_jobs(request, userdata):
app = request.app
db: Database = app['db']
log_store: LogStore = app['log_store']
file_store: FileStore = app['file_store']

if app['frozen']:
log.info('ignoring batch create request; batch is frozen')
Expand Down Expand Up @@ -660,7 +660,7 @@ async def create_jobs(request, userdata):
raise web.HTTPBadRequest(reason=e.reason)

async with timer.step('build db args'):
spec_writer = SpecWriter(log_store, batch_id)
spec_writer = SpecWriter(file_store, batch_id)

jobs_args = []
job_parents_args = []
Expand Down Expand Up @@ -2123,8 +2123,6 @@ async def delete_batch_loop_body(app):

async def on_startup(app):
app['task_manager'] = aiotools.BackgroundTaskManager()
pool = concurrent.futures.ThreadPoolExecutor()
app['blocking_pool'] = pool

db = Database()
await db.async_init()
Expand All @@ -2148,8 +2146,9 @@ async def on_startup(app):

app['frozen'] = row['frozen']

credentials = google.oauth2.service_account.Credentials.from_service_account_file('/gsa-key/key.json')
app['log_store'] = LogStore(BATCH_BUCKET_NAME, instance_id, pool, credentials=credentials)
credentials = aiogoogle.auth.credentials.Credentials.from_file('/gsa-key/key.json')
fs = aiogoogle.GoogleStorageAsyncFS(credentials=credentials)
app['file_store'] = FileStore(fs, BATCH_BUCKET_NAME, instance_id)

inst_coll_configs = InstanceCollectionConfigs(app)
app['inst_coll_configs'] = inst_coll_configs
Expand All @@ -2176,9 +2175,9 @@ async def on_startup(app):

async def on_cleanup(app):
try:
app['blocking_pool'].shutdown()
finally:
app['task_manager'].shutdown()
finally:
await app['file_store'].close()


def run():
Expand Down
2 changes: 1 addition & 1 deletion batch/batch/globals.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

BATCH_FORMAT_VERSION = 6
STATUS_FORMAT_VERSION = 5
INSTANCE_VERSION = 20
INSTANCE_VERSION = 21
WORKER_CONFIG_VERSION = 3

MAX_PERSISTENT_SSD_SIZE_GIB = 64 * 1024
Expand Down
6 changes: 3 additions & 3 deletions batch/batch/spec_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ async def get_token_start_id(db, batch_id, job_id):
start_job_id = bunch_record['start_job_id']
return (token, start_job_id)

def __init__(self, log_store, batch_id):
self.log_store = log_store
def __init__(self, file_store, batch_id):
self.file_store = file_store
self.batch_id = batch_id
self.token = secret_alnum_string(16)

Expand All @@ -63,7 +63,7 @@ async def write(self):
end = len(self._data_bytes)
self._offsets_bytes.extend(end.to_bytes(8, byteorder=SpecWriter.byteorder, signed=SpecWriter.signed))

await self.log_store.write_spec_file(
await self.file_store.write_spec_file(
self.batch_id, self.token, bytes(self._data_bytes), bytes(self._offsets_bytes)
)
return self.token
Loading

0 comments on commit d9bb866

Please sign in to comment.