Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sept 23 upstream 3 (1c28203) #304

Merged
merged 10 commits into from
Sep 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 45 additions & 4 deletions auth/auth/driver/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import logging
import os
import random
from typing import Any, Awaitable, Callable, Dict, List
from typing import Any, Awaitable, Callable, Dict, List, Optional

import aiohttp
import kubernetes_asyncio.client
Expand All @@ -16,6 +16,8 @@
from gear.cloud_config import get_gcp_config, get_global_config
from hailtop import aiotools, httpx
from hailtop import batch_client as bc
from hailtop.aiocloud.aioazure import AzureGraphClient
from hailtop.aiocloud.aiogoogle import GoogleIAmClient
from hailtop.utils import secret_alnum_string, time_msecs

log = logging.getLogger('auth.driver')
Expand Down Expand Up @@ -140,10 +142,14 @@ async def delete(self):


class GSAResource:
def __init__(self, iam_client, gsa_email=None):
def __init__(self, iam_client: GoogleIAmClient, gsa_email: Optional[str] = None):
self.iam_client = iam_client
self.gsa_email = gsa_email

async def get_unique_id(self) -> str:
service_account = await self.iam_client.get(f'/serviceAccounts/{self.gsa_email}')
return service_account['uniqueId']

async def create(self, username):
assert self.gsa_email is None

Expand All @@ -164,7 +170,7 @@ async def create(self, username):

async def _delete(self, gsa_email):
try:
await self.iam_client.delete(f'/serviceAccounts/{gsa_email}/keys')
await self.iam_client.delete(f'/serviceAccounts/{gsa_email}')
except aiohttp.ClientResponseError as e:
if e.status == 404:
pass
Expand All @@ -179,10 +185,17 @@ async def delete(self):


class AzureServicePrincipalResource:
def __init__(self, graph_client, app_obj_id=None):
def __init__(self, graph_client: AzureGraphClient, app_obj_id: Optional[str] = None):
self.graph_client = graph_client
self.app_obj_id = app_obj_id

async def get_service_principal_object_id(self) -> str:
assert self.app_obj_id
app = await self.graph_client.get(f'/applications/{self.app_obj_id}')
app_id = app['appId']
service_principal = await self.graph_client.get(f"/servicePrincipals(appId='{app_id}')")
return service_principal['id']

async def create(self, username):
assert self.app_obj_id is None

Expand Down Expand Up @@ -496,6 +509,28 @@ async def delete_user(app, user):
)


async def resolve_identity_uid(app, hail_identity):
id_client = app['identity_client']
db = app['db']

if CLOUD == 'gcp':
gsa = GSAResource(id_client, hail_identity)
hail_identity_uid = await gsa.get_unique_id()
else:
assert CLOUD == 'azure'
sp = AzureServicePrincipalResource(id_client, hail_identity)
hail_identity_uid = await sp.get_service_principal_object_id()

await db.just_execute(
'''
UPDATE users
SET hail_identity_uid = %s
WHERE hail_identity = %s
''',
(hail_identity_uid, hail_identity),
)


async def update_users(app):
log.info('in update_users')

Expand All @@ -511,6 +546,12 @@ async def update_users(app):
for user in deleting_users:
await delete_user(app, user)

users_without_hail_identity_uid = [
x async for x in db.execute_and_fetchall('SELECT * FROM users WHERE hail_identity_uid IS NULL')
]
for user in users_without_hail_identity_uid:
await resolve_identity_uid(app, user['hail_identity'])

return True


Expand Down
1 change: 1 addition & 0 deletions auth/sql/add-hail-identity-uid.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
ALTER TABLE `users` ADD COLUMN `hail_identity_uid` VARCHAR(300) DEFAULT NULL;
1 change: 1 addition & 0 deletions auth/sql/estimated-current.sql
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ CREATE TABLE `users` (
`tokens_secret_name` varchar(255) DEFAULT NULL,
-- identity
`hail_identity` varchar(255) DEFAULT NULL,
`hail_identity_uid` VARCHAR(255) DEFAULT NULL,
`hail_credentials_secret_name` varchar(255) DEFAULT NULL,
-- namespace, for developers
`namespace_name` varchar(255) DEFAULT NULL,
Expand Down
9 changes: 7 additions & 2 deletions batch/batch/driver/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,12 @@
from ..file_store import FileStore
from ..globals import HTTP_CLIENT_MAX_SIZE
from ..inst_coll_config import InstanceCollectionConfigs, PoolConfig
from ..utils import authorization_token, batch_only, json_to_value, query_billing_projects
from ..utils import (
authorization_token,
batch_only,
json_to_value,
query_billing_projects_with_cost,
)
from .canceller import Canceller
from .driver import CloudDriver
from .instance_collection import InstanceCollectionManager, JobPrivateInstanceManager, Pool
Expand Down Expand Up @@ -1185,7 +1190,7 @@ async def _cancel_batch(app, batch_id):
async def monitor_billing_limits(app):
db: Database = app['db']

records = await query_billing_projects(db)
records = await query_billing_projects_with_cost(db)
for record in records:
limit = record['limit']
accrued_cost = record['accrued_cost']
Expand Down
24 changes: 18 additions & 6 deletions batch/batch/front_end/front_end.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,12 @@
from ..inst_coll_config import InstanceCollectionConfigs
from ..resource_usage import ResourceUsageMonitor
from ..spec_writer import SpecWriter
from ..utils import query_billing_projects, regions_to_bits_rep, unavailable_if_frozen
from ..utils import (
query_billing_projects_with_cost,
query_billing_projects_without_cost,
regions_to_bits_rep,
unavailable_if_frozen,
)
from .query import CURRENT_QUERY_VERSION, build_batch_jobs_query
from .validate import ValidationError, validate_and_clean_jobs, validate_batch, validate_batch_update

Expand Down Expand Up @@ -2416,9 +2421,16 @@ async def ui_get_billing_limits(request, userdata):
else:
user = None

billing_projects = await query_billing_projects(db, user=user)
billing_projects = await query_billing_projects_with_cost(db, user=user)

open_billing_projects = [bp for bp in billing_projects if bp['status'] == 'open']
closed_billing_projects = [bp for bp in billing_projects if bp['status'] == 'closed']

page_context = {'billing_projects': billing_projects, 'is_developer': userdata['is_developer']}
page_context = {
'open_billing_projects': open_billing_projects,
'closed_billing_projects': closed_billing_projects,
'is_developer': userdata['is_developer'],
}
return await render_template('batch', request, userdata, 'billing_limits.html', page_context)


Expand Down Expand Up @@ -2615,7 +2627,7 @@ async def ui_get_billing(request, userdata):
@catch_ui_error_in_dev
async def ui_get_billing_projects(request, userdata):
db: Database = request.app['db']
billing_projects = await query_billing_projects(db)
billing_projects = await query_billing_projects_without_cost(db)
page_context = {
'billing_projects': [{**p, 'size': len(p['users'])} for p in billing_projects if p['status'] == 'open'],
'closed_projects': [p for p in billing_projects if p['status'] == 'closed'],
Expand All @@ -2633,7 +2645,7 @@ async def get_billing_projects(request, userdata):
else:
user = None

billing_projects = await query_billing_projects(db, user=user)
billing_projects = await query_billing_projects_with_cost(db, user=user)
return json_response(billing_projects)


Expand All @@ -2648,7 +2660,7 @@ async def get_billing_project(request, userdata):
else:
user = None

billing_projects = await query_billing_projects(db, user=user, billing_project=billing_project)
billing_projects = await query_billing_projects_with_cost(db, user=user, billing_project=billing_project)

if not billing_projects:
raise web.HTTPForbidden(reason=f'Unknown Hail Batch billing project {billing_project}.')
Expand Down
42 changes: 40 additions & 2 deletions batch/batch/front_end/templates/billing_limits.html
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
{% block title %}Billing Limits{% endblock %}
{% block content %}
<h1>Billing Project Limits</h1>
{% if open_billing_projects %}
<div class='flex-col' style="overflow: auto;">
<table class="data-table" id="billing_limits">
<h2>Open Projects</h2>
<table class="data-table" id="open-billing-limits">
<thead>
<tr>
<th>Billing Project</th>
Expand All @@ -12,7 +14,7 @@ <h1>Billing Project Limits</h1>
</tr>
</thead>
<tbody>
{% for row in billing_projects %}
{% for row in open_billing_projects %}
<tr>
<td>{{ row['billing_project'] }}</td>
<td>{{ row['accrued_cost'] }}</td>
Expand All @@ -34,4 +36,40 @@ <h1>Billing Project Limits</h1>
</tbody>
</table>
</div>
{% endif %}
{% if closed_billing_projects %}
<div class='flex-col' style="overflow: auto;">
<h2>Closed Projects</h2>
<table class="data-table" id="closed-billing-limits">
<thead>
<tr>
<th>Billing Project</th>
<th>Accrued Cost</th>
<th>Limit</th>
</tr>
</thead>
<tbody>
{% for row in closed_billing_projects %}
<tr>
<td>{{ row['billing_project'] }}</td>
<td>{{ row['accrued_cost'] }}</td>
{% if is_developer %}
<td>
<form action="{{ base_path }}/billing_limits/{{ row['billing_project'] }}/edit" method="POST">
<input type="hidden" name="_csrf" value="{{ csrf_token }}">
<input type="text" required name="limit" value="{{ row['limit'] }}">
<button>
Edit
</button>
</form>
</td>
{% else %}
<td>{{ row['limit'] }}</td>
{% endif %}
</tr>
{% endfor %}
</tbody>
</table>
</div>
{% endif %}
{% endblock %}
56 changes: 46 additions & 10 deletions batch/batch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,10 +115,9 @@ def __repr__(self):
return f'global {self._global_counter}'


async def query_billing_projects(db, user=None, billing_project=None):
args = []

async def query_billing_projects_with_cost(db, user=None, billing_project=None):
where_conditions = ["billing_projects.`status` != 'deleted'"]
args = []

if user:
where_conditions.append("JSON_CONTAINS(users, JSON_QUOTE(%s))")
Expand Down Expand Up @@ -161,14 +160,51 @@ async def query_billing_projects(db, user=None, billing_project=None):
LOCK IN SHARE MODE;
'''

def record_to_dict(record):
if record['users'] is None:
record['users'] = []
else:
record['users'] = json.loads(record['users'])
return record
billing_projects = []
async for record in db.execute_and_fetchall(sql, tuple(args)):
record['users'] = json.loads(record['users']) if record['users'] is not None else []
billing_projects.append(record)

return billing_projects


async def query_billing_projects_without_cost(db, user=None, billing_project=None):
where_conditions = ["billing_projects.`status` != 'deleted'"]
args = []

if user:
where_conditions.append("JSON_CONTAINS(users, JSON_QUOTE(%s))")
args.append(user)

if billing_project:
where_conditions.append('billing_projects.name_cs = %s')
args.append(billing_project)

if where_conditions:
where_condition = f'WHERE {" AND ".join(where_conditions)}'
else:
where_condition = ''

sql = f'''
SELECT billing_projects.name as billing_project,
billing_projects.`status` as `status`,
users, `limit`
FROM billing_projects
LEFT JOIN LATERAL (
SELECT billing_project, JSON_ARRAYAGG(`user_cs`) as users
FROM billing_project_users
WHERE billing_project_users.billing_project = billing_projects.name
GROUP BY billing_project_users.billing_project
LOCK IN SHARE MODE
) AS t ON TRUE
{where_condition}
LOCK IN SHARE MODE;
'''

billing_projects = [record_to_dict(record) async for record in db.execute_and_fetchall(sql, tuple(args))]
billing_projects = []
async for record in db.execute_and_fetchall(sql, tuple(args)):
record['users'] = json.loads(record['users']) if record['users'] is not None else []
billing_projects.append(record)

return billing_projects

Expand Down
Loading