Skip to content

Commit

Permalink
[batch] Check whether user exists before adding to billing project (#…
Browse files Browse the repository at this point in the history
…13945)

Fixes #13858.

<img width="589" alt="Screenshot 2023-10-30 at 12 23 06 PM"
src="https://github.com/hail-is/hail/assets/1693348/5ad26813-5534-488c-8029-f2607ba72033">
  • Loading branch information
jigold authored Oct 31, 2023
1 parent 701bce0 commit 40a3467
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 7 deletions.
5 changes: 5 additions & 0 deletions batch/batch/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,11 @@ def __init__(self, batch_id):
super().__init__(f'Batch {batch_id} does not exist.', 'error')


class NonExistentUserError(BatchUserError):
def __init__(self, user):
super().__init__(f'User {user} does not exist.', 'error')


class OpenBatchError(BatchUserError):
def __init__(self, batch_id):
super().__init__(f'Batch {batch_id} is open.', 'error')
Expand Down
21 changes: 17 additions & 4 deletions batch/batch/front_end/front_end.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
setup_aiohttp_session,
transaction,
)
from gear.auth import get_session_id, impersonate_user
from gear.clients import get_cloud_async_fs
from gear.database import CallError
from gear.profiling import install_profiler_if_requested
Expand Down Expand Up @@ -81,6 +82,7 @@
ClosedBillingProjectError,
InvalidBillingLimitError,
NonExistentBillingProjectError,
NonExistentUserError,
QueryError,
)
from ..file_store import FileStore
Expand Down Expand Up @@ -2556,7 +2558,17 @@ async def api_get_billing_projects_remove_user(request: web.Request) -> web.Resp
return json_response({'billing_project': billing_project, 'user': user})


async def _add_user_to_billing_project(db, billing_project, user):
async def _add_user_to_billing_project(request: web.Request, db: Database, billing_project: str, user: str):
try:
session_id = await get_session_id(request)
assert session_id is not None
url = deploy_config.url('auth', f'/api/v1alpha/users/{user}')
await impersonate_user(session_id, request.app['client_session'], url)
except aiohttp.ClientResponseError as e:
if e.status == 404:
raise NonExistentUserError(user) from e
raise

@transaction(db)
async def insert(tx):
# we want to be case-insensitive here to avoid duplicates with existing records
Expand Down Expand Up @@ -2588,6 +2600,7 @@ async def insert(tx):
raise BatchOperationAlreadyCompletedError(
f'User {user} is already member of billing project {billing_project}.', 'info'
)

await tx.execute_insertone(
'''
INSERT INTO billing_project_users(billing_project, user, user_cs)
Expand All @@ -2605,13 +2618,13 @@ async def insert(tx):
async def post_billing_projects_add_user(request: web.Request, _) -> NoReturn:
db: Database = request.app['db']
post = await request.post()
user = post['user']
user = str(post['user'])
billing_project = request.match_info['billing_project']

session = await aiohttp_session.get_session(request)

try:
await _handle_ui_error(session, _add_user_to_billing_project, db, billing_project, user)
await _handle_ui_error(session, _add_user_to_billing_project, request, db, billing_project, user)
set_message(session, f'Added user {user} to billing project {billing_project}.', 'info') # type: ignore
finally:
raise web.HTTPFound(deploy_config.external_url('batch', '/billing_projects')) # pylint: disable=lost-exception
Expand All @@ -2624,7 +2637,7 @@ async def api_billing_projects_add_user(request: web.Request) -> web.Response:
user = request.match_info['user']
billing_project = request.match_info['billing_project']

await _handle_api_error(_add_user_to_billing_project, db, billing_project, user)
await _handle_api_error(_add_user_to_billing_project, request, db, billing_project, user)
return json_response({'billing_project': billing_project, 'user': user})


Expand Down
11 changes: 10 additions & 1 deletion batch/test/test_accounts.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import pytest

from hailtop import httpx
from hailtop.auth import session_id_encode_to_str
from hailtop.auth import async_get_user, session_id_encode_to_str
from hailtop.batch_client.aioclient import Batch, BatchClient
from hailtop.utils import secret_alnum_string
from hailtop.utils.rich_progress_bar import BatchProgressBar
Expand Down Expand Up @@ -279,6 +279,15 @@ async def test_add_and_delete_user(dev_client: BatchClient, new_billing_project:
assert r['user'] not in bp['users']


async def test_error_adding_nonexistent_user(dev_client: BatchClient, new_billing_project: str):
with pytest.raises(httpx.ClientResponseError) as e_info:
with pytest.raises(httpx.ClientResponseError) as e_user:
await async_get_user('foobar')
assert e_user.value.status == 401
await dev_client.add_user('foobar', new_billing_project)
assert e_info.value.status == 403


async def test_edit_billing_limit_dev(dev_client: BatchClient, new_billing_project: str):
project = new_billing_project
r = await dev_client.add_user('test', project)
Expand Down
8 changes: 6 additions & 2 deletions gear/gear/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,10 +120,14 @@ async def _fetch_userdata_from_auth_service(session_id_and_session: Tuple[str, h


async def impersonate_user_and_get_info(session_id: str, client_session: httpx.ClientSession):
headers = {'Authorization': f'Bearer {session_id}'}
userinfo_url = deploy_config.url('auth', '/api/v1alpha/userinfo')
return await impersonate_user(session_id, client_session, userinfo_url)


async def impersonate_user(session_id: str, client_session: httpx.ClientSession, url: str):
headers = {'Authorization': f'Bearer {session_id}'}
try:
return await retry_transient_errors(client_session.get_read_json, userinfo_url, headers=headers)
return await retry_transient_errors(client_session.get_read_json, url, headers=headers)
except aiohttp.ClientResponseError as err:
if err.status == 401:
return None
Expand Down

0 comments on commit 40a3467

Please sign in to comment.