diff --git a/batch/batch/exceptions.py b/batch/batch/exceptions.py index 43a52451655..4ce2ff78aab 100644 --- a/batch/batch/exceptions.py +++ b/batch/batch/exceptions.py @@ -37,6 +37,11 @@ def __init__(self, batch_id): super().__init__(f'Batch {batch_id} does not exist.', 'error') +class NonExistentUserError(BatchUserError): + def __init__(self, user): + super().__init__(f'User {user} does not exist.', 'error') + + class OpenBatchError(BatchUserError): def __init__(self, batch_id): super().__init__(f'Batch {batch_id} is open.', 'error') diff --git a/batch/batch/front_end/front_end.py b/batch/batch/front_end/front_end.py index 2f6b0a66ad9..a9740788ef4 100644 --- a/batch/batch/front_end/front_end.py +++ b/batch/batch/front_end/front_end.py @@ -41,6 +41,7 @@ setup_aiohttp_session, transaction, ) +from gear.auth import get_session_id, impersonate_user from gear.clients import get_cloud_async_fs from gear.database import CallError from gear.profiling import install_profiler_if_requested @@ -81,6 +82,7 @@ ClosedBillingProjectError, InvalidBillingLimitError, NonExistentBillingProjectError, + NonExistentUserError, QueryError, ) from ..file_store import FileStore @@ -2556,7 +2558,17 @@ async def api_get_billing_projects_remove_user(request: web.Request) -> web.Resp return json_response({'billing_project': billing_project, 'user': user}) -async def _add_user_to_billing_project(db, billing_project, user): +async def _add_user_to_billing_project(request: web.Request, db: Database, billing_project: str, user: str): + try: + session_id = await get_session_id(request) + assert session_id is not None + url = deploy_config.url('auth', f'/api/v1alpha/users/{user}') + await impersonate_user(session_id, request.app['client_session'], url) + except aiohttp.ClientResponseError as e: + if e.status == 404: + raise NonExistentUserError(user) from e + raise + @transaction(db) async def insert(tx): # we want to be case-insensitive here to avoid duplicates with existing records @@ -2588,6 +2600,7 @@ async def insert(tx): raise BatchOperationAlreadyCompletedError( f'User {user} is already member of billing project {billing_project}.', 'info' ) + await tx.execute_insertone( ''' INSERT INTO billing_project_users(billing_project, user, user_cs) @@ -2605,13 +2618,13 @@ async def insert(tx): async def post_billing_projects_add_user(request: web.Request, _) -> NoReturn: db: Database = request.app['db'] post = await request.post() - user = post['user'] + user = str(post['user']) billing_project = request.match_info['billing_project'] session = await aiohttp_session.get_session(request) try: - await _handle_ui_error(session, _add_user_to_billing_project, db, billing_project, user) + await _handle_ui_error(session, _add_user_to_billing_project, request, db, billing_project, user) set_message(session, f'Added user {user} to billing project {billing_project}.', 'info') # type: ignore finally: raise web.HTTPFound(deploy_config.external_url('batch', '/billing_projects')) # pylint: disable=lost-exception @@ -2624,7 +2637,7 @@ async def api_billing_projects_add_user(request: web.Request) -> web.Response: user = request.match_info['user'] billing_project = request.match_info['billing_project'] - await _handle_api_error(_add_user_to_billing_project, db, billing_project, user) + await _handle_api_error(_add_user_to_billing_project, request, db, billing_project, user) return json_response({'billing_project': billing_project, 'user': user}) diff --git a/batch/test/test_accounts.py b/batch/test/test_accounts.py index 941f2c64b93..d08fd5702f7 100644 --- a/batch/test/test_accounts.py +++ b/batch/test/test_accounts.py @@ -8,7 +8,7 @@ import pytest from hailtop import httpx -from hailtop.auth import session_id_encode_to_str +from hailtop.auth import async_get_user, session_id_encode_to_str from hailtop.batch_client.aioclient import Batch, BatchClient from hailtop.utils import secret_alnum_string from hailtop.utils.rich_progress_bar import BatchProgressBar @@ -279,6 +279,15 @@ async def test_add_and_delete_user(dev_client: BatchClient, new_billing_project: assert r['user'] not in bp['users'] +async def test_error_adding_nonexistent_user(dev_client: BatchClient, new_billing_project: str): + with pytest.raises(httpx.ClientResponseError) as e_info: + with pytest.raises(httpx.ClientResponseError) as e_user: + await async_get_user('foobar') + assert e_user.value.status == 401 + await dev_client.add_user('foobar', new_billing_project) + assert e_info.value.status == 403 + + async def test_edit_billing_limit_dev(dev_client: BatchClient, new_billing_project: str): project = new_billing_project r = await dev_client.add_user('test', project) diff --git a/gear/gear/auth.py b/gear/gear/auth.py index cae28c43751..0d91e51b6b0 100644 --- a/gear/gear/auth.py +++ b/gear/gear/auth.py @@ -120,10 +120,14 @@ async def _fetch_userdata_from_auth_service(session_id_and_session: Tuple[str, h async def impersonate_user_and_get_info(session_id: str, client_session: httpx.ClientSession): - headers = {'Authorization': f'Bearer {session_id}'} userinfo_url = deploy_config.url('auth', '/api/v1alpha/userinfo') + return await impersonate_user(session_id, client_session, userinfo_url) + + +async def impersonate_user(session_id: str, client_session: httpx.ClientSession, url: str): + headers = {'Authorization': f'Bearer {session_id}'} try: - return await retry_transient_errors(client_session.get_read_json, userinfo_url, headers=headers) + return await retry_transient_errors(client_session.get_read_json, url, headers=headers) except aiohttp.ClientResponseError as err: if err.status == 401: return None