Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[batch] Fix list batches query and test #13237

Merged
merged 3 commits into from
Jul 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 38 additions & 27 deletions batch/batch/front_end/front_end.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import traceback
from functools import wraps
from numbers import Number
from typing import Dict, Optional, Tuple, Union
from typing import Awaitable, Callable, Dict, Optional, Tuple, TypeVar, Union

import aiohttp
import aiohttp_session
Expand All @@ -25,6 +25,7 @@
from aiohttp import web
from plotly.subplots import make_subplots
from prometheus_async.aio.web import server_stats # type: ignore
from typing_extensions import ParamSpec

from gear import (
AuthClient,
Expand Down Expand Up @@ -113,6 +114,10 @@
BATCH_JOB_DEFAULT_PREEMPTIBLE = True


T = TypeVar('T')
P = ParamSpec('P')


def rest_authenticated_developers_or_auth_only(fun):
@auth.rest_authenticated_users_only
@wraps(fun)
Expand Down Expand Up @@ -189,6 +194,12 @@ async def wrapped(request, userdata, *args, **kwargs):
return wrap


def cast_query_param_to_int(param: Optional[str]) -> Optional[int]:
if param is not None:
return int(param)
return None


@routes.get('/healthcheck')
async def get_healthcheck(request): # pylint: disable=W0613
return web.Response()
Expand All @@ -210,7 +221,9 @@ async def rest_get_supported_regions(request, userdata): # pylint: disable=unus
return json_response(list(request.app['regions'].keys()))


async def _handle_ui_error(session, f, *args, **kwargs):
async def _handle_ui_error(
session: aiohttp_session.Session, f: Callable[P, Awaitable[T]], *args: P.args, **kwargs: P.kwargs
) -> T:
try:
return await f(*args, **kwargs)
except KeyError as e:
Expand All @@ -227,17 +240,17 @@ async def _handle_ui_error(session, f, *args, **kwargs):
raise


async def _handle_api_error(f, *args, **kwargs):
async def _handle_api_error(f: Callable[P, Awaitable[T]], *args: P.args, **kwargs: P.kwargs) -> Optional[T]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be a separate PR, but these new types and all those assert ... is not None reveal another issue.

This function conflates two distinct pieces of functionality:

  1. Make functions returning no values idempotent
  2. Convert BatchUserError into an HTTPResponse.

I'm a bit unsure if BatchOperationAlreadyCompletedError is the right thing to use. It seems a bit confusing if I am at the delete billing project form, then press submit, then I lose my WiFi, then refresh the page (clicking OK when it says "do you want to resubmit this form"), then I get "this billing project already is deleted". I guess that tells me that the first request actually did make it through, but as a user, all that matters is that the BP doesn't exist anymore, right? If I mistype the name I'd get a 404, so I can't accidentally think I deleted a different BP.

Regardless, I think we should have _handle_api_error which lacks the first except and _handle_idempotence_and_api_error which calls _handle_api_error wrapped in a try-except.

try:
return await f(*args, **kwargs)
except BatchOperationAlreadyCompletedError as e:
log.info(e.message)
return
return None
except BatchUserError as e:
raise e.http_response()


async def _query_batch_jobs(request, batch_id: int, version: int, q: str, last_job_id: Optional[int]):
async def _query_batch_jobs(request: web.Request, batch_id: int, version: int, q: str, last_job_id: Optional[int]):
db: Database = request.app['db']
if version == 1:
sql, sql_args = parse_batch_jobs_query_v1(batch_id, q, last_job_id)
Expand Down Expand Up @@ -279,24 +292,22 @@ async def _get_jobs(request, batch_id: int, version: int, q: str, last_job_id: O
@routes.get('/api/v1alpha/batches/{batch_id}/jobs')
@rest_billing_project_users_only
@add_metadata_to_request
async def get_jobs_v1(request, userdata, batch_id): # pylint: disable=unused-argument
async def get_jobs_v1(request: web.Request, userdata: dict, batch_id: int): # pylint: disable=unused-argument
q = request.query.get('q', '')
last_job_id = request.query.get('last_job_id')
if last_job_id is not None:
last_job_id = int(last_job_id)
last_job_id = cast_query_param_to_int(request.query.get('last_job_id'))
resp = await _handle_api_error(_get_jobs, request, batch_id, 1, q, last_job_id)
assert resp is not None
return json_response(resp)


@routes.get('/api/v2alpha/batches/{batch_id}/jobs')
@rest_billing_project_users_only
@add_metadata_to_request
async def get_jobs_v2(request, userdata, batch_id): # pylint: disable=unused-argument
async def get_jobs_v2(request: web.Request, userdata: dict, batch_id: int): # pylint: disable=unused-argument
q = request.query.get('q', '')
last_job_id = request.query.get('last_job_id')
if last_job_id is not None:
last_job_id = int(last_job_id)
last_job_id = cast_query_param_to_int(request.query.get('last_job_id'))
resp = await _handle_api_error(_get_jobs, request, batch_id, 2, q, last_job_id)
assert resp is not None
return json_response(resp)


Expand Down Expand Up @@ -634,8 +645,10 @@ async def _query_batches(request, user: str, q: str, version: int, last_batch_id
async def get_batches_v1(request, userdata): # pylint: disable=unused-argument
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should fix the types so that we cannot make this mistake in the future. There's a couple things we need to do:

  1. add type annotations here: request: web.Request (userdata is a little harder to type, but maybe just Dict[str, str] for now?)
  2. Fix _handle_api_error to properly use types. It needs to use ParamSpec (see also the ParamSpec docs). This seems to type correctly and also cause mypy errors where it should:
T = TypeVar('T')
P = ParamSpec('P')

async def _handle_api_error(f: Callable[P, Awaitable[T]], *args: P.args, **kwargs: P.kwargs) -> Optional[T]:

user = userdata['username']
q = request.query.get('q', f'user:{user}')
last_batch_id = request.query.get('last_batch_id')
batches, last_batch_id = await _handle_api_error(_query_batches, request, user, q, 1, last_batch_id)
last_batch_id = cast_query_param_to_int(request.query.get('last_batch_id'))
result = await _handle_api_error(_query_batches, request, user, q, 1, last_batch_id)
assert result is not None
batches, last_batch_id = result
body = {'batches': batches}
if last_batch_id is not None:
body['last_batch_id'] = last_batch_id
Expand All @@ -648,8 +661,10 @@ async def get_batches_v1(request, userdata): # pylint: disable=unused-argument
async def get_batches_v2(request, userdata): # pylint: disable=unused-argument
user = userdata['username']
q = request.query.get('q', f'user = {user}')
last_batch_id = request.query.get('last_batch_id')
batches, last_batch_id = await _handle_api_error(_query_batches, request, user, q, 2, last_batch_id)
last_batch_id = cast_query_param_to_int(request.query.get('last_batch_id'))
result = await _handle_api_error(_query_batches, request, user, q, 2, last_batch_id)
assert result is not None
batches, last_batch_id = result
body = {'batches': batches}
if last_batch_id is not None:
body['last_batch_id'] = last_batch_id
Expand Down Expand Up @@ -1631,9 +1646,7 @@ async def ui_batch(request, userdata, batch_id):
batch = await _get_batch(app, batch_id)

q = request.query.get('q', '')
last_job_id = request.query.get('last_job_id')
if last_job_id is not None:
last_job_id = int(last_job_id)
last_job_id = cast_query_param_to_int(request.query.get('last_job_id'))

try:
jobs, last_job_id = await _query_batch_jobs(request, batch_id, CURRENT_QUERY_VERSION, q, last_job_id)
Expand Down Expand Up @@ -1697,17 +1710,15 @@ async def ui_delete_batch(request, userdata, batch_id): # pylint: disable=unuse
@routes.get('/batches', name='batches')
@auth.web_authenticated_users_only()
@catch_ui_error_in_dev
async def ui_batches(request, userdata):
async def ui_batches(request: web.Request, userdata: dict):
session = await aiohttp_session.get_session(request)
user = userdata['username']
q = request.query.get('q', f'user:{user}')
last_batch_id = request.query.get('last_batch_id')
if last_batch_id is not None:
last_batch_id = int(last_batch_id)
last_batch_id = cast_query_param_to_int(request.query.get('last_batch_id'))
try:
batches, last_batch_id = await _handle_ui_error(
session, _query_batches, request, user, q, CURRENT_QUERY_VERSION, last_batch_id
)
result = await _handle_ui_error(session, _query_batches, request, user, q, CURRENT_QUERY_VERSION, last_batch_id)
assert result is not None
batches, last_batch_id = result
except asyncio.CancelledError:
raise
except Exception:
Expand Down
38 changes: 32 additions & 6 deletions batch/test/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,24 +481,50 @@ def assert_batch_ids(expected: Set[int], q=None):

assert_batch_ids(
{b1.id, b2.id},
'''
f'''
start_time >= 2023-02-24T17:15:25Z
end_time < 3000-02-24T17:15:25Z
tag = {tag}
''',
)

assert_batch_ids(
set(),
'''
f'''
start_time >= 2023-02-24T17:15:25Z
end_time == 2023-02-24T17:15:25Z
tag = {tag}
''',
)

assert_batch_ids(set(), 'duration > 50000')
assert_batch_ids(set(), 'cost > 1000')
assert_batch_ids({b1.id}, f'batch_id = {b1.id}')
assert_batch_ids({b1.id}, f'batch_id == {b1.id}')
assert_batch_ids(
set(),
f'''
duration > 50000
tag = {tag}
''',
)
assert_batch_ids(
set(),
f'''
cost > 1000
tag = {tag}
''',
)
assert_batch_ids(
{b1.id},
f'''
batch_id = {b1.id}
tag = {tag}
''',
)
assert_batch_ids(
{b1.id},
f'''
batch_id == {b1.id}
tag = {tag}
''',
)

with pytest.raises(httpx.ClientResponseError, match='could not parse term'):
assert_batch_ids(batch_id_test_universe, 'batch_id >= 1 abcde')
Expand Down