diff --git a/batch/batch/front_end/front_end.py b/batch/batch/front_end/front_end.py index dc943128b44..91f05e36f40 100644 --- a/batch/batch/front_end/front_end.py +++ b/batch/batch/front_end/front_end.py @@ -11,7 +11,7 @@ import traceback from functools import wraps from numbers import Number -from typing import Dict, Optional, Tuple, Union +from typing import Awaitable, Callable, Dict, Optional, Tuple, TypeVar, Union import aiohttp import aiohttp_session @@ -25,6 +25,7 @@ from aiohttp import web from plotly.subplots import make_subplots from prometheus_async.aio.web import server_stats # type: ignore +from typing_extensions import ParamSpec from gear import ( AuthClient, @@ -113,6 +114,10 @@ BATCH_JOB_DEFAULT_PREEMPTIBLE = True +T = TypeVar('T') +P = ParamSpec('P') + + def rest_authenticated_developers_or_auth_only(fun): @auth.rest_authenticated_users_only @wraps(fun) @@ -189,6 +194,12 @@ async def wrapped(request, userdata, *args, **kwargs): return wrap +def cast_query_param_to_int(param: Optional[str]) -> Optional[int]: + if param is not None: + return int(param) + return None + + @routes.get('/healthcheck') async def get_healthcheck(request): # pylint: disable=W0613 return web.Response() @@ -210,7 +221,9 @@ async def rest_get_supported_regions(request, userdata): # pylint: disable=unus return json_response(list(request.app['regions'].keys())) -async def _handle_ui_error(session, f, *args, **kwargs): +async def _handle_ui_error( + session: aiohttp_session.Session, f: Callable[P, Awaitable[T]], *args: P.args, **kwargs: P.kwargs +) -> T: try: return await f(*args, **kwargs) except KeyError as e: @@ -227,17 +240,17 @@ async def _handle_ui_error(session, f, *args, **kwargs): raise -async def _handle_api_error(f, *args, **kwargs): +async def _handle_api_error(f: Callable[P, Awaitable[T]], *args: P.args, **kwargs: P.kwargs) -> Optional[T]: try: return await f(*args, **kwargs) except BatchOperationAlreadyCompletedError as e: log.info(e.message) - return + return None except BatchUserError as e: raise e.http_response() -async def _query_batch_jobs(request, batch_id: int, version: int, q: str, last_job_id: Optional[int]): +async def _query_batch_jobs(request: web.Request, batch_id: int, version: int, q: str, last_job_id: Optional[int]): db: Database = request.app['db'] if version == 1: sql, sql_args = parse_batch_jobs_query_v1(batch_id, q, last_job_id) @@ -279,24 +292,22 @@ async def _get_jobs(request, batch_id: int, version: int, q: str, last_job_id: O @routes.get('/api/v1alpha/batches/{batch_id}/jobs') @rest_billing_project_users_only @add_metadata_to_request -async def get_jobs_v1(request, userdata, batch_id): # pylint: disable=unused-argument +async def get_jobs_v1(request: web.Request, userdata: dict, batch_id: int): # pylint: disable=unused-argument q = request.query.get('q', '') - last_job_id = request.query.get('last_job_id') - if last_job_id is not None: - last_job_id = int(last_job_id) + last_job_id = cast_query_param_to_int(request.query.get('last_job_id')) resp = await _handle_api_error(_get_jobs, request, batch_id, 1, q, last_job_id) + assert resp is not None return json_response(resp) @routes.get('/api/v2alpha/batches/{batch_id}/jobs') @rest_billing_project_users_only @add_metadata_to_request -async def get_jobs_v2(request, userdata, batch_id): # pylint: disable=unused-argument +async def get_jobs_v2(request: web.Request, userdata: dict, batch_id: int): # pylint: disable=unused-argument q = request.query.get('q', '') - last_job_id = request.query.get('last_job_id') - if last_job_id is not None: - last_job_id = int(last_job_id) + last_job_id = cast_query_param_to_int(request.query.get('last_job_id')) resp = await _handle_api_error(_get_jobs, request, batch_id, 2, q, last_job_id) + assert resp is not None return json_response(resp) @@ -634,8 +645,10 @@ async def _query_batches(request, user: str, q: str, version: int, last_batch_id async def get_batches_v1(request, userdata): # pylint: disable=unused-argument user = userdata['username'] q = request.query.get('q', f'user:{user}') - last_batch_id = request.query.get('last_batch_id') - batches, last_batch_id = await _handle_api_error(_query_batches, request, user, q, 1, last_batch_id) + last_batch_id = cast_query_param_to_int(request.query.get('last_batch_id')) + result = await _handle_api_error(_query_batches, request, user, q, 1, last_batch_id) + assert result is not None + batches, last_batch_id = result body = {'batches': batches} if last_batch_id is not None: body['last_batch_id'] = last_batch_id @@ -648,8 +661,10 @@ async def get_batches_v1(request, userdata): # pylint: disable=unused-argument async def get_batches_v2(request, userdata): # pylint: disable=unused-argument user = userdata['username'] q = request.query.get('q', f'user = {user}') - last_batch_id = request.query.get('last_batch_id') - batches, last_batch_id = await _handle_api_error(_query_batches, request, user, q, 2, last_batch_id) + last_batch_id = cast_query_param_to_int(request.query.get('last_batch_id')) + result = await _handle_api_error(_query_batches, request, user, q, 2, last_batch_id) + assert result is not None + batches, last_batch_id = result body = {'batches': batches} if last_batch_id is not None: body['last_batch_id'] = last_batch_id @@ -1631,9 +1646,7 @@ async def ui_batch(request, userdata, batch_id): batch = await _get_batch(app, batch_id) q = request.query.get('q', '') - last_job_id = request.query.get('last_job_id') - if last_job_id is not None: - last_job_id = int(last_job_id) + last_job_id = cast_query_param_to_int(request.query.get('last_job_id')) try: jobs, last_job_id = await _query_batch_jobs(request, batch_id, CURRENT_QUERY_VERSION, q, last_job_id) @@ -1697,17 +1710,15 @@ async def ui_delete_batch(request, userdata, batch_id): # pylint: disable=unuse @routes.get('/batches', name='batches') @auth.web_authenticated_users_only() @catch_ui_error_in_dev -async def ui_batches(request, userdata): +async def ui_batches(request: web.Request, userdata: dict): session = await aiohttp_session.get_session(request) user = userdata['username'] q = request.query.get('q', f'user:{user}') - last_batch_id = request.query.get('last_batch_id') - if last_batch_id is not None: - last_batch_id = int(last_batch_id) + last_batch_id = cast_query_param_to_int(request.query.get('last_batch_id')) try: - batches, last_batch_id = await _handle_ui_error( - session, _query_batches, request, user, q, CURRENT_QUERY_VERSION, last_batch_id - ) + result = await _handle_ui_error(session, _query_batches, request, user, q, CURRENT_QUERY_VERSION, last_batch_id) + assert result is not None + batches, last_batch_id = result except asyncio.CancelledError: raise except Exception: diff --git a/batch/test/test_batch.py b/batch/test/test_batch.py index c32362ee76a..54fe6e97ad9 100644 --- a/batch/test/test_batch.py +++ b/batch/test/test_batch.py @@ -481,24 +481,50 @@ def assert_batch_ids(expected: Set[int], q=None): assert_batch_ids( {b1.id, b2.id}, - ''' + f''' start_time >= 2023-02-24T17:15:25Z end_time < 3000-02-24T17:15:25Z +tag = {tag} ''', ) assert_batch_ids( set(), - ''' + f''' start_time >= 2023-02-24T17:15:25Z end_time == 2023-02-24T17:15:25Z +tag = {tag} ''', ) - assert_batch_ids(set(), 'duration > 50000') - assert_batch_ids(set(), 'cost > 1000') - assert_batch_ids({b1.id}, f'batch_id = {b1.id}') - assert_batch_ids({b1.id}, f'batch_id == {b1.id}') + assert_batch_ids( + set(), + f''' +duration > 50000 +tag = {tag} +''', + ) + assert_batch_ids( + set(), + f''' +cost > 1000 +tag = {tag} +''', + ) + assert_batch_ids( + {b1.id}, + f''' +batch_id = {b1.id} +tag = {tag} +''', + ) + assert_batch_ids( + {b1.id}, + f''' +batch_id == {b1.id} +tag = {tag} +''', + ) with pytest.raises(httpx.ClientResponseError, match='could not parse term'): assert_batch_ids(batch_id_test_universe, 'batch_id >= 1 abcde')