From d9bb8668b25fe3af146162c2cc45f0351e14506f Mon Sep 17 00:00:00 2001 From: jigold Date: Mon, 27 Sep 2021 13:23:53 -0400 Subject: [PATCH] [batch] Replace GCS with GoogleStorageAsyncFS (attempt #2) (#10814) * [batch] replace GCS with GoogleStorageAsyncFS in LogStore * delint * addr comments * fix s3 exception handling * fix catching nonexistent blob in azurefs * flip version number of instance --- batch/batch/driver/create_instance.py | 9 +- batch/batch/driver/job.py | 6 +- batch/batch/driver/main.py | 24 +-- batch/batch/{log_store.py => file_store.py} | 58 +++--- batch/batch/front_end/front_end.py | 31 ++- batch/batch/globals.py | 2 +- batch/batch/spec_writer.py | 6 +- batch/batch/worker/worker.py | 21 +- benchmark-service/benchmark/benchmark.py | 51 ++--- benchmark-service/benchmark/utils.py | 50 +---- benchmark-service/test/test_update_commits.py | 2 - build.yaml | 8 +- .../hailtop/aiogoogle/auth/credentials.py | 3 + .../aiogoogle/client/storage_client.py | 11 +- hail/python/hailtop/aiotools/azurefs.py | 10 +- hail/python/hailtop/aiotools/fs.py | 9 + hail/python/hailtop/aiotools/s3asyncfs.py | 25 ++- .../hailtop/batch/batch_pool_executor.py | 29 ++- hail/python/hailtop/google_storage.py | 190 ------------------ hail/python/test/hailtop/test_fs.py | 30 +++ memory/memory/client.py | 12 +- memory/memory/memory.py | 39 ++-- memory/test/test_memory.py | 25 ++- 23 files changed, 249 insertions(+), 402 deletions(-) rename batch/batch/{log_store.py => file_store.py} (59%) delete mode 100644 hail/python/hailtop/google_storage.py diff --git a/batch/batch/driver/create_instance.py b/batch/batch/driver/create_instance.py index 810e5f353f1..28f1dae7130 100644 --- a/batch/batch/driver/create_instance.py +++ b/batch/batch/driver/create_instance.py @@ -10,7 +10,7 @@ from ..batch_configuration import PROJECT, DOCKER_ROOT_IMAGE, DOCKER_PREFIX, DEFAULT_NAMESPACE from ..inst_coll_config import machine_type_to_dict from ..worker_config import WorkerConfig -from ..log_store import LogStore +from ..file_store import FileStore from ..utils import unreserved_worker_data_disk_size_gib log = logging.getLogger('create_instance') @@ -34,7 +34,8 @@ def create_instance_config( preemptible, job_private, ) -> Tuple[Dict[str, Any], WorkerConfig]: - log_store: LogStore = app['log_store'] + file_store: FileStore = app['file_store'] + cores = int(machine_type_to_dict(machine_type)['cores']) if worker_local_ssd_data_disk: @@ -324,8 +325,8 @@ def create_instance_config( {'key': 'docker_prefix', 'value': DOCKER_PREFIX}, {'key': 'namespace', 'value': DEFAULT_NAMESPACE}, {'key': 'internal_ip', 'value': INTERNAL_GATEWAY_IP}, - {'key': 'batch_logs_bucket_name', 'value': log_store.batch_logs_bucket_name}, - {'key': 'instance_id', 'value': log_store.instance_id}, + {'key': 'batch_logs_bucket_name', 'value': file_store.batch_logs_bucket_name}, + {'key': 'instance_id', 'value': file_store.instance_id}, {'key': 'max_idle_time_msecs', 'value': max_idle_time_msecs}, ] }, diff --git a/batch/batch/driver/job.py b/batch/batch/driver/job.py index e2854c5a372..2d62b8c983a 100644 --- a/batch/batch/driver/job.py +++ b/batch/batch/driver/job.py @@ -15,7 +15,7 @@ from ..batch_configuration import KUBERNETES_TIMEOUT_IN_SECONDS, KUBERNETES_SERVER_URL from ..batch_format_version import BatchFormatVersion from ..spec_writer import SpecWriter -from ..log_store import LogStore +from ..file_store import FileStore from .k8s_cache import K8sCache @@ -379,7 +379,7 @@ async def job_config(app, record, attempt_id): async def schedule_job(app, record, instance): assert instance.state == 'active' - log_store: LogStore = app['log_store'] + file_store: FileStore = app['file_store'] db: Database = app['db'] batch_id = record['batch_id'] @@ -407,7 +407,7 @@ async def schedule_job(app, record, instance): } if format_version.has_full_status_in_gcs(): - await log_store.write_status_file(batch_id, job_id, attempt_id, json.dumps(status)) + await file_store.write_status_file(batch_id, job_id, attempt_id, json.dumps(status)) db_status = format_version.db_status(status) resources = [] diff --git a/batch/batch/driver/main.py b/batch/batch/driver/main.py index 9cfe9715343..0efbffc4ddb 100644 --- a/batch/batch/driver/main.py +++ b/batch/batch/driver/main.py @@ -3,7 +3,6 @@ from typing import Dict from functools import wraps from collections import namedtuple, defaultdict -import concurrent import copy import asyncio import signal @@ -11,7 +10,6 @@ from aiohttp import web import aiohttp_session import kubernetes_asyncio as kube -import google.oauth2.service_account from prometheus_async.aio.web import server_stats import prometheus_client as pc # type: ignore from gear import ( @@ -40,7 +38,7 @@ import googlecloudprofiler import uvloop -from ..log_store import LogStore +from ..file_store import FileStore from ..batch import cancel_batch_in_db from ..batch_configuration import ( REFRESH_INTERVAL_IN_SECONDS, @@ -1042,8 +1040,6 @@ async def scheduling_cancelling_bump(app): async def on_startup(app): app['task_manager'] = aiotools.BackgroundTaskManager() - pool = concurrent.futures.ThreadPoolExecutor() - app['blocking_pool'] = pool kube.config.load_incluster_config() k8s_client = kube.client.CoreV1Api() @@ -1105,9 +1101,9 @@ async def on_startup(app): async_worker_pool = AsyncWorkerPool(100, queue_size=100) app['async_worker_pool'] = async_worker_pool - credentials = google.oauth2.service_account.Credentials.from_service_account_file('/gsa-key/key.json') - log_store = LogStore(BATCH_BUCKET_NAME, instance_id, pool, credentials=credentials) - app['log_store'] = log_store + credentials = aiogoogle.auth.credentials.Credentials.from_file('/gsa-key/key.json') + fs = aiogoogle.GoogleStorageAsyncFS(credentials=credentials) + app['file_store'] = FileStore(fs, BATCH_BUCKET_NAME, instance_id) zone_monitor = ZoneMonitor(app) app['zone_monitor'] = zone_monitor @@ -1146,22 +1142,22 @@ async def on_startup(app): async def on_cleanup(app): try: - app['blocking_pool'].shutdown() + await app['db'].async_close() finally: try: - await app['db'].async_close() + app['zone_monitor'].shutdown() finally: try: - app['zone_monitor'].shutdown() + app['inst_coll_manager'].shutdown() finally: try: - app['inst_coll_manager'].shutdown() + app['canceller'].shutdown() finally: try: - app['canceller'].shutdown() + app['gce_event_monitor'].shutdown() finally: try: - app['gce_event_monitor'].shutdown() + await app['file_store'].close() finally: try: app['task_manager'].shutdown() diff --git a/batch/batch/log_store.py b/batch/batch/file_store.py similarity index 59% rename from batch/batch/log_store.py rename to batch/batch/file_store.py index ab6afb0425d..46b6875f565 100644 --- a/batch/batch/log_store.py +++ b/batch/batch/file_store.py @@ -1,7 +1,7 @@ import logging import asyncio -from hailtop.google_storage import GCS +from hailtop.aiotools.fs import AsyncFS from .spec_writer import SpecWriter from .globals import BATCH_FORMAT_VERSION @@ -10,12 +10,13 @@ log = logging.getLogger('logstore') -class LogStore: - def __init__(self, batch_logs_bucket_name, instance_id, blocking_pool, *, project=None, credentials=None): +class FileStore: + def __init__(self, fs: AsyncFS, batch_logs_bucket_name, instance_id): + self.fs = fs self.batch_logs_bucket_name = batch_logs_bucket_name self.instance_id = instance_id + self.batch_logs_root = f'gs://{batch_logs_bucket_name}/batch/logs/{instance_id}/batch' - self.gcs = GCS(blocking_pool, project=project, credentials=credentials) log.info(f'BATCH_LOGS_ROOT {self.batch_logs_root}') format_version = BatchFormatVersion(BATCH_FORMAT_VERSION) @@ -27,34 +28,36 @@ def batch_log_dir(self, batch_id): def log_path(self, format_version, batch_id, job_id, attempt_id, task): if not format_version.has_attempt_in_log_path(): return f'{self.batch_log_dir(batch_id)}/{job_id}/{task}/log' - return f'{self.batch_log_dir(batch_id)}/{job_id}/{attempt_id}/{task}/log' async def read_log_file(self, format_version, batch_id, job_id, attempt_id, task): - path = self.log_path(format_version, batch_id, job_id, attempt_id, task) - return await self.gcs.read_gs_file(path) + url = self.log_path(format_version, batch_id, job_id, attempt_id, task) + data = await self.fs.read(url) + return data.decode('utf-8') async def write_log_file(self, format_version, batch_id, job_id, attempt_id, task, data): - path = self.log_path(format_version, batch_id, job_id, attempt_id, task) - return await self.gcs.write_gs_file_from_string(path, data) + url = self.log_path(format_version, batch_id, job_id, attempt_id, task) + await self.fs.write(url, data.encode('utf-8')) async def delete_batch_logs(self, batch_id): - await self.gcs.delete_gs_files(self.batch_log_dir(batch_id)) + url = self.batch_log_dir(batch_id) + await self.fs.rmtree(None, url) def status_path(self, batch_id, job_id, attempt_id): return f'{self.batch_log_dir(batch_id)}/{job_id}/{attempt_id}/status.json' async def read_status_file(self, batch_id, job_id, attempt_id): - path = self.status_path(batch_id, job_id, attempt_id) - return await self.gcs.read_gs_file(path) + url = self.status_path(batch_id, job_id, attempt_id) + data = await self.fs.read(url) + return data.decode('utf-8') async def write_status_file(self, batch_id, job_id, attempt_id, status): - path = self.status_path(batch_id, job_id, attempt_id) - return await self.gcs.write_gs_file_from_string(path, status) + url = self.status_path(batch_id, job_id, attempt_id) + await self.fs.write(url, status.encode('utf-8')) async def delete_status_file(self, batch_id, job_id, attempt_id): - path = self.status_path(batch_id, job_id, attempt_id) - return await self.gcs.delete_gs_file(path) + url = self.status_path(batch_id, job_id, attempt_id) + return await self.fs.remove(url) def specs_dir(self, batch_id, token): return f'{self.batch_logs_root}/{batch_id}/bunch/{token}' @@ -66,22 +69,27 @@ def specs_index_path(self, batch_id, token): return f'{self.specs_dir(batch_id, token)}/specs.idx' async def read_spec_file(self, batch_id, token, start_job_id, job_id): - idx_path = self.specs_index_path(batch_id, token) + idx_url = self.specs_index_path(batch_id, token) idx_start, idx_end = SpecWriter.get_index_file_offsets(job_id, start_job_id) - offsets = await self.gcs.read_binary_gs_file(idx_path, start=idx_start, end=idx_end) + offsets = await self.fs.read_range(idx_url, idx_start, idx_end) - spec_path = self.specs_path(batch_id, token) + spec_url = self.specs_path(batch_id, token) spec_start, spec_end = SpecWriter.get_spec_file_offsets(offsets) - return await self.gcs.read_gs_file(spec_path, start=spec_start, end=spec_end) + data = await self.fs.read_range(spec_url, spec_start, spec_end) + return data.decode('utf-8') async def write_spec_file(self, batch_id, token, data_bytes, offsets_bytes): - idx_path = self.specs_index_path(batch_id, token) - write1 = self.gcs.write_gs_file_from_string(idx_path, offsets_bytes, content_type='application/octet-stream') + idx_url = self.specs_index_path(batch_id, token) + write1 = self.fs.write(idx_url, offsets_bytes) - specs_path = self.specs_path(batch_id, token) - write2 = self.gcs.write_gs_file_from_string(specs_path, data_bytes) + specs_url = self.specs_path(batch_id, token) + write2 = self.fs.write(specs_url, data_bytes) await asyncio.gather(write1, write2) async def delete_spec_file(self, batch_id, token): - await self.gcs.delete_gs_files(self.specs_dir(batch_id, token)) + url = self.specs_dir(batch_id, token) + await self.fs.rmtree(None, url) + + async def close(self): + await self.fs.close() diff --git a/batch/batch/front_end/front_end.py b/batch/batch/front_end/front_end.py index 9315824f8fe..eede3dcba67 100644 --- a/batch/batch/front_end/front_end.py +++ b/batch/batch/front_end/front_end.py @@ -1,6 +1,5 @@ from numbers import Number import os -import concurrent import logging import json import random @@ -34,6 +33,7 @@ periodically_call, ) from hailtop.batch_client.parse import parse_cpu_in_mcpu, parse_memory_in_bytes, parse_storage_in_bytes +import hailtop.aiogoogle as aiogoogle from hailtop.config import get_deploy_config from hailtop.tls import internal_server_ssl_context from hailtop.httpx import client_session @@ -69,7 +69,7 @@ BatchOperationAlreadyCompletedError, ) from ..inst_coll_config import InstanceCollectionConfigs -from ..log_store import LogStore +from ..file_store import FileStore from ..database import CallError, check_call_procedure from ..batch_configuration import BATCH_BUCKET_NAME, DEFAULT_NAMESPACE, SCOPE from ..globals import HTTP_CLIENT_MAX_SIZE, BATCH_FORMAT_VERSION, memory_to_worker_type @@ -335,12 +335,12 @@ async def _get_job_log_from_record(app, batch_id, job_id, record): raise if state in ('Error', 'Failed', 'Success'): - log_store: LogStore = app['log_store'] + file_store: FileStore = app['file_store'] batch_format_version = BatchFormatVersion(record['format_version']) async def _read_log_from_gcs(task): try: - data = await log_store.read_log_file(batch_format_version, batch_id, job_id, record['attempt_id'], task) + data = await file_store.read_log_file(batch_format_version, batch_id, job_id, record['attempt_id'], task) except google.api_core.exceptions.NotFound: id = (batch_id, job_id) log.exception(f'missing log file for {id} and task {task}') @@ -411,7 +411,7 @@ async def _get_attributes(app, record): async def _get_full_job_spec(app, record): db: Database = app['db'] - log_store: LogStore = app['log_store'] + file_store: FileStore = app['file_store'] batch_id = record['batch_id'] job_id = record['job_id'] @@ -423,7 +423,7 @@ async def _get_full_job_spec(app, record): token, start_job_id = await SpecWriter.get_token_start_id(db, batch_id, job_id) try: - spec = await log_store.read_spec_file(batch_id, token, start_job_id, job_id) + spec = await file_store.read_spec_file(batch_id, token, start_job_id, job_id) return json.loads(spec) except google.api_core.exceptions.NotFound: id = (batch_id, job_id) @@ -432,7 +432,7 @@ async def _get_full_job_spec(app, record): async def _get_full_job_status(app, record): - log_store: LogStore = app['log_store'] + file_store: FileStore = app['file_store'] batch_id = record['batch_id'] job_id = record['job_id'] @@ -448,7 +448,7 @@ async def _get_full_job_status(app, record): return json.loads(record['status']) try: - status = await log_store.read_status_file(batch_id, job_id, attempt_id) + status = await file_store.read_status_file(batch_id, job_id, attempt_id) return json.loads(status) except google.api_core.exceptions.NotFound: id = (batch_id, job_id) @@ -616,7 +616,7 @@ def check_service_account_permissions(user, sa): async def create_jobs(request, userdata): app = request.app db: Database = app['db'] - log_store: LogStore = app['log_store'] + file_store: FileStore = app['file_store'] if app['frozen']: log.info('ignoring batch create request; batch is frozen') @@ -660,7 +660,7 @@ async def create_jobs(request, userdata): raise web.HTTPBadRequest(reason=e.reason) async with timer.step('build db args'): - spec_writer = SpecWriter(log_store, batch_id) + spec_writer = SpecWriter(file_store, batch_id) jobs_args = [] job_parents_args = [] @@ -2123,8 +2123,6 @@ async def delete_batch_loop_body(app): async def on_startup(app): app['task_manager'] = aiotools.BackgroundTaskManager() - pool = concurrent.futures.ThreadPoolExecutor() - app['blocking_pool'] = pool db = Database() await db.async_init() @@ -2148,8 +2146,9 @@ async def on_startup(app): app['frozen'] = row['frozen'] - credentials = google.oauth2.service_account.Credentials.from_service_account_file('/gsa-key/key.json') - app['log_store'] = LogStore(BATCH_BUCKET_NAME, instance_id, pool, credentials=credentials) + credentials = aiogoogle.auth.credentials.Credentials.from_file('/gsa-key/key.json') + fs = aiogoogle.GoogleStorageAsyncFS(credentials=credentials) + app['file_store'] = FileStore(fs, BATCH_BUCKET_NAME, instance_id) inst_coll_configs = InstanceCollectionConfigs(app) app['inst_coll_configs'] = inst_coll_configs @@ -2176,9 +2175,9 @@ async def on_startup(app): async def on_cleanup(app): try: - app['blocking_pool'].shutdown() - finally: app['task_manager'].shutdown() + finally: + await app['file_store'].close() def run(): diff --git a/batch/batch/globals.py b/batch/batch/globals.py index e9c64bd1b31..59eb9dc98d3 100644 --- a/batch/batch/globals.py +++ b/batch/batch/globals.py @@ -30,7 +30,7 @@ BATCH_FORMAT_VERSION = 6 STATUS_FORMAT_VERSION = 5 -INSTANCE_VERSION = 20 +INSTANCE_VERSION = 21 WORKER_CONFIG_VERSION = 3 MAX_PERSISTENT_SSD_SIZE_GIB = 64 * 1024 diff --git a/batch/batch/spec_writer.py b/batch/batch/spec_writer.py index 64bc29a9246..1546ced2c8d 100644 --- a/batch/batch/spec_writer.py +++ b/batch/batch/spec_writer.py @@ -41,8 +41,8 @@ async def get_token_start_id(db, batch_id, job_id): start_job_id = bunch_record['start_job_id'] return (token, start_job_id) - def __init__(self, log_store, batch_id): - self.log_store = log_store + def __init__(self, file_store, batch_id): + self.file_store = file_store self.batch_id = batch_id self.token = secret_alnum_string(16) @@ -63,7 +63,7 @@ async def write(self): end = len(self._data_bytes) self._offsets_bytes.extend(end.to_bytes(8, byteorder=SpecWriter.byteorder, signed=SpecWriter.signed)) - await self.log_store.write_spec_file( + await self.file_store.write_spec_file( self.batch_id, self.token, bytes(self._data_bytes), bytes(self._offsets_bytes) ) return self.token diff --git a/batch/batch/worker/worker.py b/batch/batch/worker/worker.py index dec79d273a8..cb2f72f284f 100644 --- a/batch/batch/worker/worker.py +++ b/batch/batch/worker/worker.py @@ -21,7 +21,6 @@ from collections import defaultdict import psutil from aiodocker.exceptions import DockerError # type: ignore -import google.oauth2.service_account # type: ignore from hailtop.utils import ( time_msecs, time_msecs_str, @@ -60,7 +59,7 @@ cores_mcpu_to_storage_bytes, ) from ..semaphore import FIFOWeightedSemaphore -from ..log_store import LogStore +from ..file_store import FileStore from ..globals import ( HTTP_CLIENT_MAX_SIZE, STATUS_FORMAT_VERSION, @@ -923,7 +922,7 @@ def container_finished(self): return self.process is not None and self.process.returncode is not None async def upload_log(self): - await worker.log_store.write_log_file( + await worker.file_store.write_log_file( self.job.format_version, self.job.batch_id, self.job.job_id, @@ -1615,7 +1614,7 @@ async def run(self, worker): log.info(f'finished {self} with return code {self.process.returncode}') - await worker.log_store.write_log_file( + await worker.file_store.write_log_file( self.format_version, self.batch_id, self.job_id, self.attempt_id, 'main', self.logbuffer.decode() ) @@ -1732,7 +1731,7 @@ def __init__(self): self.image_data[BATCH_WORKER_IMAGE_ID] += 1 # filled in during activation - self.log_store = None + self.file_store = None self.headers = None self.compute_client = None @@ -1763,7 +1762,7 @@ async def create_job_1(self, request): start_job_id = body['start_job_id'] addtl_spec = body['job_spec'] - job_spec = await self.log_store.read_spec_file(batch_id, token, start_job_id, job_id) + job_spec = await self.file_store.read_spec_file(batch_id, token, start_job_id, job_id) job_spec = json.loads(job_spec) job_spec['attempt_id'] = addtl_spec['attempt_id'] @@ -1889,6 +1888,7 @@ async def run(self): finally: self.active = False log.info('shutting down') + await self.file_store.close() await site.stop() log.info('stopped site') await app_runner.cleanup() @@ -1920,7 +1920,7 @@ async def post_job_complete_1(self, job): if job.format_version.has_full_status_in_gcs(): await retry_all_errors(f'error while writing status file to gcs for {job}')( - self.log_store.write_status_file, job.batch_id, job.job_id, job.attempt_id, json.dumps(full_status) + self.file_store.write_status_file, job.batch_id, job.job_id, job.attempt_id, json.dumps(full_status) ) db_status = job.format_version.db_status(full_status) @@ -2026,10 +2026,9 @@ async def activate(self): with open('/worker-key.json', 'w') as f: f.write(json.dumps(resp_json['key'])) - credentials = google.oauth2.service_account.Credentials.from_service_account_file('/worker-key.json') - self.log_store = LogStore( - BATCH_LOGS_BUCKET_NAME, INSTANCE_ID, self.pool, project=PROJECT, credentials=credentials - ) + credentials = aiogoogle.auth.credentials.Credentials.from_file('/worker-key.json') + fs = aiogoogle.GoogleStorageAsyncFS(credentials=credentials) + self.file_store = FileStore(fs, BATCH_LOGS_BUCKET_NAME, INSTANCE_ID) credentials = aiogoogle.Credentials.from_file('/worker-key.json') self.compute_client = aiogoogle.ComputeClient(PROJECT, credentials=credentials) diff --git a/benchmark-service/benchmark/benchmark.py b/benchmark-service/benchmark/benchmark.py index d1719147d75..08bb335d2f0 100644 --- a/benchmark-service/benchmark/benchmark.py +++ b/benchmark-service/benchmark/benchmark.py @@ -9,10 +9,10 @@ from hailtop.hail_logging import AccessLogger, configure_logging from hailtop.utils import retry_long_running, collect_agen, humanize_timedelta_msecs from hailtop import aiotools +import hailtop.aiogoogle as aiogoogle import hailtop.batch_client.aioclient as bc from web_common import setup_aiohttp_jinja2, setup_common_static_routes, render_template from benchmark.utils import ( - ReadGoogleStorage, get_geometric_mean, parse_file_path, enumerate_list_of_trials, @@ -30,7 +30,6 @@ import gidgethub import gidgethub.aiohttp from .config import START_POINT, BENCHMARK_RESULTS_PATH -import google configure_logging() router = web.RouteTableDef() @@ -53,13 +52,13 @@ oauth_token = f.read().strip() -def get_benchmarks(app, file_path): +async def get_benchmarks(app, file_path): log.info(f'get_benchmarks file_path={file_path}') - gs_reader = app['gs_reader'] + fs: aiotools.AsyncFS = app['fs'] try: - json_data = gs_reader.get_data_as_string(file_path) + json_data = (await fs.read(file_path)).decode('utf-8') pre_data = json.loads(json_data) - except google.api_core.exceptions.NotFound: + except FileNotFoundError: message = f'could not find file, {file_path}' log.info('could not get blob: ' + message, exc_info=True) return None @@ -167,7 +166,7 @@ async def healthcheck(request: web.Request) -> web.Response: # pylint: disable= @web_authenticated_developers_only(redirect=False) async def show_name(request: web.Request, userdata) -> web.Response: # pylint: disable=unused-argument file_path = request.query.get('file') - benchmarks = get_benchmarks(request.app, file_path) + benchmarks = await get_benchmarks(request.app, file_path) name_data = benchmarks['data'][str(request.match_info['name'])] try: @@ -217,11 +216,11 @@ async def lookup(request, userdata): # pylint: disable=unused-argument if file is None: benchmarks_context = None else: - benchmarks_context = get_benchmarks(request.app, file) + benchmarks_context = await get_benchmarks(request.app, file) context = { 'file': file, 'benchmarks': benchmarks_context, - 'benchmark_file_list': list_benchmark_files(app['gs_reader']), + 'benchmark_file_list': await list_benchmark_files(app['fs']), } return await render_template('benchmark', request, userdata, 'lookup.html', context) @@ -238,8 +237,8 @@ async def compare(request, userdata): # pylint: disable=unused-argument benchmarks_context2 = None comparisons = None else: - benchmarks_context1 = get_benchmarks(app, file1) - benchmarks_context2 = get_benchmarks(app, file2) + benchmarks_context1 = await get_benchmarks(app, file1) + benchmarks_context2 = await get_benchmarks(app, file2) comparisons = final_comparisons(get_comparisons(benchmarks_context1, benchmarks_context2, metric)) context = { 'file1': file1, @@ -248,7 +247,7 @@ async def compare(request, userdata): # pylint: disable=unused-argument 'benchmarks1': benchmarks_context1, 'benchmarks2': benchmarks_context2, 'comparisons': comparisons, - 'benchmark_file_list': list_benchmark_files(app['gs_reader']), + 'benchmark_file_list': await list_benchmark_files(app['fs']), } return await render_template('benchmark', request, userdata, 'compare.html', context) @@ -305,7 +304,7 @@ async def get_commit(app, sha): # pylint: disable=unused-argument log.info(f'get_commit sha={sha}') github_client = app['github_client'] batch_client = app['batch_client'] - gs_reader = app['gs_reader'] + fs: aiotools.AsyncFS = app['fs'] file_path = f'{BENCHMARK_RESULTS_PATH}/0-{sha}.json' request_string = f'/repos/hail-is/hail/commits/{sha}' @@ -317,7 +316,7 @@ async def get_commit(app, sha): # pylint: disable=unused-argument pr_id = message_dict['pr_id'] title = message_dict['title'] - has_results_file = gs_reader.file_exists(file_path) + has_results_file = await fs.exists(file_path) batch_statuses = [b._last_known_status async for b in batch_client.list_batches(q=f'sha={sha} user:benchmark')] complete_batch_statuses = [bs for bs in batch_statuses if bs['complete']] running_batch_statuses = [bs for bs in batch_statuses if not bs['complete']] @@ -353,7 +352,7 @@ async def get_commit(app, sha): # pylint: disable=unused-argument async def update_commit(app, sha): # pylint: disable=unused-argument log.info('in update_commit') global benchmark_data - gs_reader = app['gs_reader'] + fs: aiotools.AsyncFS = app['fs'] commit = await get_commit(app, sha) file_path = f'{BENCHMARK_RESULTS_PATH}/0-{sha}.json' @@ -367,9 +366,9 @@ async def update_commit(app, sha): # pylint: disable=unused-argument benchmark_data['commits'][sha] = commit return commit - has_results_file = gs_reader.file_exists(file_path) + has_results_file = await fs.exists(file_path) if has_results_file and sha in benchmark_data['commits']: - benchmarks = get_benchmarks(app, file_path) + benchmarks = await get_benchmarks(app, file_path) commit['geo_mean'] = benchmarks['geometric_mean'] geo_mean = commit['geo_mean'] log.info(f'geo mean is {geo_mean}') @@ -394,13 +393,13 @@ async def get_status(request): # pylint: disable=unused-argument async def delete_commit(request): # pylint: disable=unused-argument global benchmark_data app = request.app - gs_reader = app['gs_reader'] + fs: aiotools.AsyncFS = app['fs'] batch_client = app['batch_client'] sha = str(request.match_info['sha']) file_path = f'{BENCHMARK_RESULTS_PATH}/0-{sha}.json' - if gs_reader.file_exists(file_path): - gs_reader.delete_file(file_path) + if await fs.exists(file_path): + await fs.remove(file_path) log.info(f'deleted file for sha {sha}') async for b in batch_client.list_batches(q=f'sha={sha} user:benchmark'): @@ -431,7 +430,8 @@ async def github_polling_loop(app): async def on_startup(app): - app['gs_reader'] = ReadGoogleStorage(service_account_key_file='/benchmark-gsa-key/key.json') + credentials = aiogoogle.auth.Credentials.from_file('/benchmark-gsa-key/key.json') + app['fs'] = aiogoogle.GoogleStorageAsyncFS(credentials=credentials) app['gh_client_session'] = aiohttp.ClientSession() app['github_client'] = gidgethub.aiohttp.GitHubAPI( app['gh_client_session'], 'hail-is/hail', oauth_token=oauth_token @@ -442,8 +442,13 @@ async def on_startup(app): async def on_cleanup(app): - await app['gh_client_session'].close() - app['task_manager'].shutdown() + try: + await app['gh_client_session'].close() + finally: + try: + await app['fs'].close() + finally: + app['task_manager'].shutdown() def run(): diff --git a/benchmark-service/benchmark/utils.py b/benchmark-service/benchmark/utils.py index 7b5a6c25c58..36eba18265a 100644 --- a/benchmark-service/benchmark/utils.py +++ b/benchmark-service/benchmark/utils.py @@ -1,8 +1,9 @@ -from google.cloud import storage import re import logging + +import hailtop.aiogoogle as aiogoogle + from .config import BENCHMARK_RESULTS_PATH -import google log = logging.getLogger('benchmark') @@ -40,10 +41,11 @@ def enumerate_list_of_trials(list_of_trials): return res_dict -def list_benchmark_files(read_gs): +async def list_benchmark_files(fs: aiogoogle.GoogleStorageAsyncFS): list_of_files = [] for bucket in BENCHMARK_BUCKETS: - list_of_files.extend(read_gs.list_files(bucket_name=bucket)) + files = await fs.listfiles(f'gs://{bucket}/', recursive=True) + list_of_files.extend(files) return list_of_files @@ -61,43 +63,3 @@ async def submit_test_batch(batch_client, sha): await batch.submit(disable_progress_bar=True) log.info(f'submitting batch for commit {sha}') return job.batch_id - - -class ReadGoogleStorage: - def __init__(self, service_account_key_file=None): - self.storage_client = storage.Client.from_service_account_json(service_account_key_file) - - def get_data_as_string(self, file_path): - file_info = parse_file_path(FILE_PATH_REGEX, file_path) - bucket = self.storage_client.get_bucket(file_info['bucket']) - path = file_info['path'] - try: - blob = bucket.blob(path) - data = blob.download_as_string() - except google.api_core.exceptions.NotFound as e: - log.exception(f'error while reading file {file_path}: {e}') - data = None - return data - - def list_files(self, bucket_name): - list_of_files = [] - bucket = self.storage_client.get_bucket(bucket_name) - for blob in bucket.list_blobs(): - list_of_files.append('gs://' + bucket_name + '/' + blob.name) - return list_of_files - - def file_exists(self, file_path): - file_info = parse_file_path(FILE_PATH_REGEX, file_path) - bucket_name = file_info['bucket'] - bucket = self.storage_client.bucket(bucket_name) - path = file_info['path'] - exists = storage.Blob(bucket=bucket, name=path).exists() - log.info(f'file {bucket_name}/{path} in bucket {bucket_name} exists? {exists}') - return exists - - def delete_file(self, file_path): - file_info = parse_file_path(FILE_PATH_REGEX, file_path) - bucket_name = file_info['bucket'] - bucket = self.storage_client.bucket(bucket_name) - path = file_info['path'] - storage.Blob(bucket=bucket, name=path).delete() diff --git a/benchmark-service/test/test_update_commits.py b/benchmark-service/test/test_update_commits.py index 7e3957c70e5..5d2bbd3fb28 100644 --- a/benchmark-service/test/test_update_commits.py +++ b/benchmark-service/test/test_update_commits.py @@ -1,8 +1,6 @@ -import json import logging import asyncio import pytest -import aiohttp from hailtop.config import get_deploy_config from hailtop.auth import service_auth_headers diff --git a/build.yaml b/build.yaml index 29542340fa7..c818a9fa79b 100644 --- a/build.yaml +++ b/build.yaml @@ -2588,7 +2588,7 @@ steps: memory: 3.75Gi cpu: '1' script: | - export HAIL_GSA_KEY_FILE=/test-gsa-key/key.json + export GOOGLE_APPLICATION_CREDENTIALS=/test-gsa-key/key.json export PROJECT={{ global.project }} hailctl config set batch/bucket {{ global.hail_test_gcs_bucket }} python3 -m pytest --log-cli-level=INFO -s -vv --instafail --durations=50 /io/test/ @@ -3155,7 +3155,6 @@ steps: script: | cd /io/hailtop set -ex - export HAIL_GSA_KEY_FILE=/test-gsa-key/key.json export GOOGLE_APPLICATION_CREDENTIALS=/test-gsa-key/key.json export PYTEST_SPLITS=5 export PYTEST_SPLIT_INDEX=0 @@ -3209,7 +3208,6 @@ steps: script: | cd /io/hailtop set -ex - export HAIL_GSA_KEY_FILE=/test-gsa-key/key.json export GOOGLE_APPLICATION_CREDENTIALS=/test-gsa-key/key.json export PYTEST_SPLITS=5 export PYTEST_SPLIT_INDEX=1 @@ -3263,7 +3261,6 @@ steps: script: | cd /io/hailtop set -ex - export HAIL_GSA_KEY_FILE=/test-gsa-key/key.json export GOOGLE_APPLICATION_CREDENTIALS=/test-gsa-key/key.json export PYTEST_SPLITS=5 export PYTEST_SPLIT_INDEX=2 @@ -3317,7 +3314,6 @@ steps: script: | cd /io/hailtop set -ex - export HAIL_GSA_KEY_FILE=/test-gsa-key/key.json export GOOGLE_APPLICATION_CREDENTIALS=/test-gsa-key/key.json export PYTEST_SPLITS=5 export PYTEST_SPLIT_INDEX=3 @@ -3371,7 +3367,6 @@ steps: script: | cd /io/hailtop set -ex - export HAIL_GSA_KEY_FILE=/test-gsa-key/key.json export GOOGLE_APPLICATION_CREDENTIALS=/test-gsa-key/key.json export PYTEST_SPLITS=5 export PYTEST_SPLIT_INDEX=4 @@ -3424,7 +3419,6 @@ steps: valueFrom: service_base_image.image script: | set -ex - export HAIL_GSA_KEY_FILE=/test-gsa-key/key.json export GOOGLE_APPLICATION_CREDENTIALS=/test-gsa-key/key.json cd /io/hailtop/batch hailctl config set batch/billing_project test diff --git a/hail/python/hailtop/aiogoogle/auth/credentials.py b/hail/python/hailtop/aiogoogle/auth/credentials.py index 9bd4336689d..c93ea26389d 100644 --- a/hail/python/hailtop/aiogoogle/auth/credentials.py +++ b/hail/python/hailtop/aiogoogle/auth/credentials.py @@ -15,7 +15,10 @@ class Credentials(abc.ABC): def from_file(credentials_file): with open(credentials_file) as f: credentials = json.load(f) + return Credentials.from_credentials_data(credentials) + @staticmethod + def from_credentials_data(credentials): credentials_type = credentials['type'] if credentials_type == 'service_account': return ServiceAccountCredentials(credentials) diff --git a/hail/python/hailtop/aiogoogle/client/storage_client.py b/hail/python/hailtop/aiogoogle/client/storage_client.py index d2e474fdf3c..c17ad1c439c 100644 --- a/hail/python/hailtop/aiogoogle/client/storage_client.py +++ b/hail/python/hailtop/aiogoogle/client/storage_client.py @@ -322,9 +322,14 @@ async def get_object(self, bucket: str, name: str, **kwargs) -> GetObjectStream: assert 'alt' not in params params['alt'] = 'media' - resp = await self._session.get( - f'https://storage.googleapis.com/storage/v1/b/{bucket}/o/{urllib.parse.quote(name, safe="")}', **kwargs) - return GetObjectStream(resp) + try: + resp = await self._session.get( + f'https://storage.googleapis.com/storage/v1/b/{bucket}/o/{urllib.parse.quote(name, safe="")}', **kwargs) + return GetObjectStream(resp) + except aiohttp.ClientResponseError as e: + if e.status == 404: + raise FileNotFoundError from e + raise async def get_object_metadata(self, bucket: str, name: str, **kwargs) -> Dict[str, str]: assert name diff --git a/hail/python/hailtop/aiotools/azurefs.py b/hail/python/hailtop/aiotools/azurefs.py index 156ec0e07be..1b6f07c4f7a 100644 --- a/hail/python/hailtop/aiotools/azurefs.py +++ b/hail/python/hailtop/aiotools/azurefs.py @@ -153,13 +153,19 @@ async def read(self, n: int = -1) -> bytes: return b'' if n == -1: - downloader = await self._client.download_blob(offset=self._offset) + try: + downloader = await self._client.download_blob(offset=self._offset) + except azure.core.exceptions.ResourceNotFoundError as e: + raise FileNotFoundError from e data = await downloader.readall() self._eof = True return data if self._downloader is None: - self._downloader = await self._client.download_blob(offset=self._offset) + try: + self._downloader = await self._client.download_blob(offset=self._offset) + except azure.core.exceptions.ResourceNotFoundError as e: + raise FileNotFoundError from e if self._chunk_it is None: self._chunk_it = self._downloader.chunks() diff --git a/hail/python/hailtop/aiotools/fs.py b/hail/python/hailtop/aiotools/fs.py index c8c55a1cb22..a1056e470d1 100644 --- a/hail/python/hailtop/aiotools/fs.py +++ b/hail/python/hailtop/aiotools/fs.py @@ -199,8 +199,17 @@ async def write(self, url: str, data: bytes) -> None: async def _write() -> None: async with await self.create(url, retry_writes=False) as f: await f.write(data) + await retry_transient_errors(_write) + async def exists(self, url: str) -> bool: + try: + await self.statfile(url) + except FileNotFoundError: + return False + else: + return True + async def close(self) -> None: pass diff --git a/hail/python/hailtop/aiotools/s3asyncfs.py b/hail/python/hailtop/aiotools/s3asyncfs.py index 76774dff82c..caf16899a65 100644 --- a/hail/python/hailtop/aiotools/s3asyncfs.py +++ b/hail/python/hailtop/aiotools/s3asyncfs.py @@ -7,6 +7,7 @@ import threading import asyncio import logging + import botocore.exceptions import boto3 from hailtop.utils import blocking_to_async @@ -272,18 +273,24 @@ def _get_bucket_name(url: str) -> Tuple[str, str]: async def open(self, url: str) -> ReadableStream: bucket, name = self._get_bucket_name(url) - resp = await blocking_to_async(self._thread_pool, self._s3.get_object, - Bucket=bucket, - Key=name) - return blocking_readable_stream_to_async(self._thread_pool, cast(BinaryIO, resp['Body'])) + try: + resp = await blocking_to_async(self._thread_pool, self._s3.get_object, + Bucket=bucket, + Key=name) + return blocking_readable_stream_to_async(self._thread_pool, cast(BinaryIO, resp['Body'])) + except self._s3.exceptions.NoSuchKey as e: + raise FileNotFoundError(url) from e async def open_from(self, url: str, start: int) -> ReadableStream: bucket, name = self._get_bucket_name(url) - resp = await blocking_to_async(self._thread_pool, self._s3.get_object, - Bucket=bucket, - Key=name, - Range=f'bytes={start}-') - return blocking_readable_stream_to_async(self._thread_pool, cast(BinaryIO, resp['Body'])) + try: + resp = await blocking_to_async(self._thread_pool, self._s3.get_object, + Bucket=bucket, + Key=name, + Range=f'bytes={start}-') + return blocking_readable_stream_to_async(self._thread_pool, cast(BinaryIO, resp['Body'])) + except self._s3.exceptions.NoSuchKey as e: + raise FileNotFoundError(url) from e async def create(self, url: str, *, retry_writes: bool = True) -> S3CreateManager: # pylint: disable=unused-argument # It may be possible to write a more efficient version of this diff --git a/hail/python/hailtop/batch/batch_pool_executor.py b/hail/python/hailtop/batch/batch_pool_executor.py index 4676afc8e4c..0368e09e336 100644 --- a/hail/python/hailtop/batch/batch_pool_executor.py +++ b/hail/python/hailtop/batch/batch_pool_executor.py @@ -10,10 +10,11 @@ from hailtop.utils import secret_alnum_string, partition import hailtop.batch_client.aioclient as low_level_batch_client from hailtop.batch_client.parse import parse_cpu_in_mcpu +import hailtop.aiogoogle as aiogoogle from .batch import Batch from .backend import ServiceBackend -from ..google_storage import GCS + if sys.version_info < (3, 7): def create_task(coro, *, name=None): # pylint: disable=unused-argument @@ -131,8 +132,7 @@ def __init__(self, *, self.directory = self.backend.remote_tmpdir + f'batch-pool-executor/{self.name}/' self.inputs = self.directory + 'inputs/' self.outputs = self.directory + 'outputs/' - self.gcs = GCS(blocking_pool=concurrent.futures.ThreadPoolExecutor(), - project=project) + self.fs = aiogoogle.GoogleStorageAsyncFS(project=project) self.futures: List[BatchPoolFuture] = [] self.finished_future_count = 0 self._shutdown = False @@ -354,9 +354,9 @@ async def async_submit(self, pipe = BytesIO() dill.dump(functools.partial(unapplied, *args, **kwargs), pipe, recurse=True) pipe.seek(0) - pickledfun_gcs = self.inputs + f'{name}/pickledfun' - await self.gcs.write_gs_file_from_file_like_object(pickledfun_gcs, pipe) - pickledfun_local = batch.read_input(pickledfun_gcs) + pickledfun_remote = self.inputs + f'{name}/pickledfun' + await self.fs.write(pickledfun_remote, pipe.getvalue()) + pickledfun_local = batch.read_input(pickledfun_remote) thread_limit = "1" if self.cpus_per_job: @@ -408,7 +408,7 @@ def _add_future(self, f): def _finish_future(self): self.finished_future_count += 1 if self._shutdown and self.finished_future_count == len(self.futures): - self._cleanup(False) + self._cleanup() def shutdown(self, wait: bool = True): """Allow temporary resources to be cleaned up. @@ -432,14 +432,13 @@ async def ignore_exceptions(f): async_to_blocking( asyncio.gather(*[ignore_exceptions(f) for f in self.futures])) if self.finished_future_count == len(self.futures): - self._cleanup(False) + self._cleanup() self._shutdown = True - def _cleanup(self, wait): + def _cleanup(self): if self.cleanup_bucket: - async_to_blocking( - self.gcs.delete_gs_files(self.directory)) - self.gcs.shutdown(wait) + async_to_blocking(self.fs.rmtree(None, self.directory)) + async_to_blocking(self.fs.close()) self.backend.close() @@ -448,11 +447,11 @@ def __init__(self, executor: BatchPoolExecutor, batch: low_level_batch_client.Batch, job: low_level_batch_client.Job, - output_gcs: str): + output_file: str): self.executor = executor self.batch = batch self.job = job - self.output_gcs = output_gcs + self.output_file = output_file self.fetch_coro = asyncio.ensure_future(self._async_fetch_result()) executor._add_future(self) @@ -548,7 +547,7 @@ async def _async_fetch_result(self): raise ValueError( f"submitted job failed:\n{main_container_status['error']}") value, traceback = dill.loads( - await self.executor.gcs.read_binary_gs_file(self.output_gcs)) + await self.executor.fs.read(self.output_file)) if traceback is None: return value assert isinstance(value, BaseException) diff --git a/hail/python/hailtop/google_storage.py b/hail/python/hailtop/google_storage.py deleted file mode 100644 index 645a21a9f5e..00000000000 --- a/hail/python/hailtop/google_storage.py +++ /dev/null @@ -1,190 +0,0 @@ -from typing import Optional, IO, Callable, List -import os -import logging -import concurrent.futures -from functools import wraps - -import google.api_core.exceptions -import google.oauth2.service_account -import google.cloud.storage -from google.cloud.storage.blob import Blob -from hailtop.utils import blocking_to_async, retry_transient_errors - - -logging.getLogger("google").setLevel(logging.WARNING) - - -class GCS: - @staticmethod - def _parse_uri(uri: str): - assert uri.startswith('gs://'), uri - uri_parts = uri[5:].split('/') - bucket = uri_parts[0] - path = '/'.join(uri_parts[1:]) - return bucket, path - - def __init__(self, - blocking_pool: concurrent.futures.Executor, - *, - project: Optional[str] = None, - key: Optional[str] = None, - credentials: Optional[google.oauth2.service_account.Credentials] = None): - self.blocking_pool = blocking_pool - # project=None doesn't mean default, it means no project: - # https://github.com/googleapis/google-cloud-python/blob/master/storage/google/cloud/storage/client.py#L86 - if credentials is None: - if key is not None: - credentials = google.oauth2.service_account.Credentials.from_service_account_info(key) - elif 'HAIL_GSA_KEY_FILE' in os.environ: - key_file = os.environ['HAIL_GSA_KEY_FILE'] - credentials = google.oauth2.service_account.Credentials.from_service_account_file(key_file) - - if project: - self.gcs_client = google.cloud.storage.Client( - project=project, credentials=credentials) - else: - self.gcs_client = google.cloud.storage.Client( - credentials=credentials) - self._wrapped_write_gs_file_from_string = self._wrap_network_call(GCS._write_gs_file_from_string) - self._wrapped_write_gs_file_from_file_like_object = self._wrap_network_call(GCS._write_gs_file_from_file_like_object) - self._wrapped_read_gs_file = self._wrap_network_call(GCS._read_gs_file) - self._wrapped_read_binary_gs_file = self._wrap_network_call(GCS._read_binary_gs_file) - self._wrapped_read_gs_file_to_file = self._wrap_network_call(GCS._read_gs_file_to_file) - self._wrapped_delete_gs_file = self._wrap_network_call(GCS._delete_gs_file) - self._wrapped_delete_gs_files = self._wrap_network_call(GCS._delete_gs_files) - self._wrapped_copy_gs_file = self._wrap_network_call(GCS._copy_gs_file) - self._wrapped_list_all_blobs_with_prefix = self._wrap_network_call(GCS._list_all_blobs_with_prefix) - self._wrapped_compose_gs_file = self._wrap_network_call(GCS._compose_gs_file) - self._wrapped_get_blob = self._wrap_network_call(GCS._get_blob) - - def shutdown(self, wait: bool = True): - self.blocking_pool.shutdown(wait) - - async def get_etag(self, uri: str): - return await retry_transient_errors(self._wrap_network_call(GCS._get_etag), self, uri) - - async def write_gs_file_from_string(self, uri: str, string: str, *args, **kwargs): - return await retry_transient_errors(self._wrapped_write_gs_file_from_string, - self, uri, string, *args, **kwargs) - - async def write_gs_file_from_file_like_object(self, uri: str, file: IO, *args, start=None, end=None, **kwargs): - return await retry_transient_errors(self._wrapped_write_gs_file_from_file_like_object, - self, uri, file, start, end, *args, **kwargs) - - async def write_gs_file_from_file(self, uri: str, file_name: str, *args, start=None, end=None, **kwargs): - with open(file_name, 'r') as file: - await self.write_gs_file_from_file_like_object(uri, file, *args, start=start, end=end, **kwargs) - - async def read_gs_file(self, uri: str, *args, **kwargs): - return await retry_transient_errors(self._wrapped_read_gs_file, - self, uri, *args, **kwargs) - - async def read_binary_gs_file(self, uri: str, *args, **kwargs): - return await retry_transient_errors(self._wrapped_read_binary_gs_file, - self, uri, *args, **kwargs) - - async def read_gs_file_to_file(self, uri: str, file_name, offset, *args, **kwargs): - return await retry_transient_errors(self._wrapped_read_gs_file_to_file, - self, uri, file_name, offset, *args, **kwargs) - - async def delete_gs_file(self, uri: str): - return await retry_transient_errors(self._wrapped_delete_gs_file, - self, uri) - - async def delete_gs_files(self, uri_prefix: str): - return await retry_transient_errors(self._wrapped_delete_gs_files, - self, uri_prefix) - - async def copy_gs_file(self, src: str, dest: str, *args, **kwargs): - return await retry_transient_errors(self._wrapped_copy_gs_file, - self, src, dest, *args, **kwargs) - - async def compose_gs_file(self, sources: str, dest: str, *args, **kwargs): - return await retry_transient_errors(self._wrapped_compose_gs_file, - self, sources, dest, *args, **kwargs) - - async def list_all_blobs_with_prefix(self, uri: str, max_results: Optional[int] = None): - return await retry_transient_errors(self._wrapped_list_all_blobs_with_prefix, - self, uri, max_results=max_results) - - async def get_blob(self, uri: str): - return await retry_transient_errors(self._wrapped_get_blob, - self, uri) - - def _wrap_network_call(self, fun: Callable) -> Callable: - @wraps(fun) - async def wrapped(*args, **kwargs): - return await blocking_to_async(self.blocking_pool, - fun, - *args, - **kwargs) - return wrapped - - def _get_etag(self, uri: str): - b = self._get_blob(uri) - b.reload() - return b.etag - - def _write_gs_file_from_string(self, uri: str, string: str, *args, **kwargs): - b = self._get_blob(uri) - b.metadata = {'Cache-Control': 'no-cache'} - b.upload_from_string(string, *args, **kwargs) - - def _write_gs_file_from_file_like_object(self, uri: str, file: IO, *args, **kwargs): - b = self._get_blob(uri) - b.metadata = {'Cache-Control': 'no-cache'} - b.upload_from_file(file, *args, **kwargs) - - def _read_gs_file(self, uri: str, *args, **kwargs): - content = self._read_binary_gs_file(uri, *args, **kwargs) - return content.decode('utf-8') - - def _read_binary_gs_file(self, uri: str, *args, **kwargs): - b = self._get_blob(uri) - b.metadata = {'Cache-Control': 'no-cache'} - content = b.download_as_string(*args, **kwargs) - return content - - def _read_gs_file_to_file(self, uri: str, file_name: str, offset: int, *args, **kwargs): - with open(file_name, 'r+b') as file: - file.seek(offset) - b = self._get_blob(uri) - b.metadata = {'Cache-Control': 'no-cache'} - b.download_to_file(file, *args, **kwargs) - - def _delete_gs_files(self, uri: str): - for blob in self._list_all_blobs_with_prefix(uri): - try: - blob.delete() - except google.api_core.exceptions.NotFound: - continue - - def _delete_gs_file(self, uri: str): - b = self._get_blob(uri) - try: - b.delete() - except google.api_core.exceptions.NotFound: - return - - def _copy_gs_file(self, src: str, dest: str, *args, **kwargs): - src_bucket, src_path = GCS._parse_uri(src) - src_bucket = self.gcs_client.bucket(src_bucket) - dest_bucket, dest_path = GCS._parse_uri(dest) - dest_bucket = self.gcs_client.bucket(dest_bucket) - src_blob = src_bucket.blob(src_path) - src_bucket.copy_blob(src_blob, dest_bucket, new_name=dest_path, *args, **kwargs) - - def _list_all_blobs_with_prefix(self, uri: str, max_results: Optional[int] = None): - b = self._get_blob(uri) - return iter(b.bucket.list_blobs(prefix=b.name, max_results=max_results)) - - def _compose_gs_file(self, sources: List[str], dest: str, *args, **kwargs): - assert sources - sources = [self._get_blob(src) for src in sources] - dest_blob = self._get_blob(dest) - dest_blob.compose(sources, *args, **kwargs) - - def _get_blob(self, uri: str) -> Blob: - bucket, path = GCS._parse_uri(uri) - bucket = self.gcs_client.bucket(bucket) - return bucket.blob(path) diff --git a/hail/python/test/hailtop/test_fs.py b/hail/python/test/hailtop/test_fs.py index c0ca915acec..bc6bd2dc91e 100644 --- a/hail/python/test/hailtop/test_fs.py +++ b/hail/python/test/hailtop/test_fs.py @@ -111,6 +111,36 @@ async def test_open_from(filesystem): assert r == b'cde' +@pytest.mark.asyncio +async def test_open_nonexistent_file(filesystem): + sema, fs, base = filesystem + + file = f'{base}foo' + + try: + async with await fs.open(file) as f: + await f.read() + except FileNotFoundError: + pass + else: + assert False + + +@pytest.mark.asyncio +async def test_open_from_nonexistent_file(filesystem): + sema, fs, base = filesystem + + file = f'{base}foo' + + try: + async with await fs.open_from(file, 2) as f: + await f.read() + except FileNotFoundError: + pass + else: + assert False + + @pytest.mark.asyncio async def test_read_from(filesystem): sema, fs, base = filesystem diff --git a/memory/memory/client.py b/memory/memory/client.py index f2b4f2201ff..2a110e19db3 100644 --- a/memory/memory/client.py +++ b/memory/memory/client.py @@ -1,9 +1,9 @@ import aiohttp -import concurrent + +from hailtop.aiogoogle.client.storage_client import GoogleStorageAsyncFS from hailtop.auth import service_auth_headers from hailtop.config import get_deploy_config -from hailtop.google_storage import GCS from hailtop.httpx import client_session from hailtop.utils import request_retry_transient_errors @@ -18,9 +18,11 @@ def __init__(self, gcs_project=None, fs=None, deploy_config=None, session=None, self.url = self._deploy_config.base_url('memory') self.objects_url = f'{self.url}/api/v1alpha/objects' self._session = session + if fs is None: - fs = GCS(blocking_pool=concurrent.futures.ThreadPoolExecutor(), project=gcs_project) + fs = GoogleStorageAsyncFS(project=gcs_project) self._fs = fs + self._headers = {} if headers: self._headers.update(headers) @@ -49,7 +51,7 @@ async def read_file(self, filename): data = await self._get_file_if_exists(filename) if data is not None: return data - return await self._fs.read_binary_gs_file(filename) + return await self._fs.read(filename) async def write_file(self, filename, data): params = {'q': filename} @@ -61,3 +63,5 @@ async def write_file(self, filename, data): async def close(self): await self._session.close() self._session = None + await self._fs.close() + self._fs = None diff --git a/memory/memory/memory.py b/memory/memory/memory.py index 123d0f7f4f2..d09fa856bce 100644 --- a/memory/memory/memory.py +++ b/memory/memory/memory.py @@ -1,9 +1,8 @@ import aioredis import asyncio import base64 -import concurrent -import json import logging +import json import os import uvloop import signal @@ -12,8 +11,10 @@ from prometheus_async.aio.web import server_stats # type: ignore from typing import Set +from hailtop.aiotools import AsyncFS +from hailtop.aiogoogle.client.storage_client import GoogleStorageAsyncFS +from hailtop.aiogoogle.auth.credentials import Credentials from hailtop.config import get_deploy_config -from hailtop.google_storage import GCS from hailtop.hail_logging import AccessLogger from hailtop.tls import internal_server_ssl_context from hailtop.utils import AsyncWorkerPool, retry_transient_errors, dump_all_stacktraces @@ -59,7 +60,7 @@ async def write_object(request, userdata): files = request.app['files_in_progress'] files.add(file_key) - await persist_in_gcs(userinfo['fs'], files, file_key, filepath, data) + await persist(userinfo['fs'], files, file_key, filepath, data) await cache_file(request.app['redis_pool'], files, file_key, filepath, data) return web.Response(status=200) @@ -72,8 +73,9 @@ async def get_or_add_user(app, userdata): gsa_key_secret = await retry_transient_errors( k8s_client.read_namespaced_secret, userdata['gsa_key_secret_name'], DEFAULT_NAMESPACE, _request_timeout=5.0 ) - gsa_key = base64.b64decode(gsa_key_secret.data['key.json']).decode() - users[username] = {'fs': GCS(blocking_pool=app['thread_pool'], key=json.loads(gsa_key))} + gsa_key = json.loads(base64.b64decode(gsa_key_secret.data['key.json']).decode()) + credentials = Credentials.from_credentials_data(gsa_key) + users[username] = {'fs': GoogleStorageAsyncFS(credentials=credentials)} return users[username] @@ -81,7 +83,7 @@ def make_redis_key(username, filepath): return f'{ username }_{ filepath }' -async def get_file_or_none(app, username, fs, filepath): +async def get_file_or_none(app, username, fs: AsyncFS, filepath): file_key = make_redis_key(username, filepath) redis_pool: aioredis.ConnectionsPool = app['redis_pool'] @@ -101,10 +103,10 @@ async def get_file_or_none(app, username, fs, filepath): return None -async def load_file(redis, files, file_key, fs, filepath): +async def load_file(redis, files, file_key, fs: AsyncFS, filepath): try: log.info(f"memory: {file_key}: reading.") - data = await fs.read_binary_gs_file(filepath) + data = await fs.read(filepath) log.info(f"memory: {file_key}: read {filepath}") except Exception as e: files.remove(file_key) @@ -113,17 +115,17 @@ async def load_file(redis, files, file_key, fs, filepath): await cache_file(redis, files, file_key, filepath, data) -async def persist_in_gcs(fs: GCS, files: Set[str], file_key: str, filepath: str, data: str): +async def persist(fs: AsyncFS, files: Set[str], file_key: str, filepath: str, data: bytes): try: log.info(f"memory: {file_key}: persisting.") - await fs.write_gs_file_from_string(filepath, data) + await fs.write(filepath, data) log.info(f"memory: {file_key}: persisted {filepath}") except Exception as e: files.remove(file_key) raise e -async def cache_file(redis: aioredis.ConnectionsPool, files: Set[str], file_key: str, filepath: str, data: str): +async def cache_file(redis: aioredis.ConnectionsPool, files: Set[str], file_key: str, filepath: str, data: bytes): try: await redis.execute('HMSET', file_key, 'body', data) log.info(f"memory: {file_key}: stored {filepath}") @@ -132,7 +134,6 @@ async def cache_file(redis: aioredis.ConnectionsPool, files: Set[str], file_key: async def on_startup(app): - app['thread_pool'] = concurrent.futures.ThreadPoolExecutor() app['worker_pool'] = AsyncWorkerPool(parallelism=100, queue_size=10) app['files_in_progress'] = set() app['users'] = {} @@ -144,16 +145,20 @@ async def on_startup(app): async def on_cleanup(app): try: - app['thread_pool'].shutdown() + app['worker_pool'].shutdown() finally: try: - app['worker_pool'].shutdown() + app['redis_pool'].close() finally: try: - app['redis_pool'].close() - finally: del app['k8s_client'] await asyncio.gather(*(t for t in asyncio.all_tasks() if t is not asyncio.current_task())) + finally: + for items in app['users'].values(): + try: + await items['fs'].close() + except: + pass def run(): diff --git a/memory/test/test_memory.py b/memory/test/test_memory.py index 2c5efe613de..457d869ee1f 100644 --- a/memory/test/test_memory.py +++ b/memory/test/test_memory.py @@ -1,11 +1,11 @@ -import concurrent -import os +import asyncio import unittest +import os import uuid from memory.client import MemoryClient +from hailtop.aiogoogle.client.storage_client import GoogleStorageAsyncFS from hailtop.config import get_user_config -from hailtop.google_storage import GCS from hailtop.utils import async_to_blocking @@ -33,17 +33,20 @@ def setUp(self): token = uuid.uuid4() self.test_path = f'gs://{bucket_name}/memory-tests/{token}' - self.fs = GCS(concurrent.futures.ThreadPoolExecutor(), project=os.environ['PROJECT']) + self.fs = GoogleStorageAsyncFS(project=os.environ['PROJECT']) self.client = BlockingMemoryClient(fs=self.fs) self.temp_files = set() def tearDown(self): - async_to_blocking(self.fs.delete_gs_files(self.test_path)) + async_to_blocking(self.fs.rmtree(None, self.test_path)) self.client.close() - def add_temp_file_from_string(self, name: str, str_value: str): + async def add_temp_file_from_string(self, name: str, str_value: bytes): handle = f'{self.test_path}/{name}' - self.fs._write_gs_file_from_string(handle, str_value) + + async with await self.fs.create(handle) as f: + await f.write(str_value) + return handle def test_non_existent(self): @@ -51,10 +54,14 @@ def test_non_existent(self): self.assertIsNone(self.client._get_file_if_exists(f'{self.test_path}/nonexistent')) def test_small_write_around(self): + async def read(url): + async with await self.fs.open(url) as f: + return await f.read() + cases = [('empty_file', b''), ('null', b'\0'), ('small', b'hello world')] for file, data in cases: - handle = self.add_temp_file_from_string(file, data) - expected = self.fs._read_binary_gs_file(handle) + handle = async_to_blocking(self.add_temp_file_from_string(file, data)) + expected = async_to_blocking(read(handle)) self.assertEqual(expected, data) i = 0 cached = self.client._get_file_if_exists(handle)