From 449d5d1ccb29fe6a45bffd2435a922769655a98b Mon Sep 17 00:00:00 2001 From: Joongi Kim Date: Wed, 24 Mar 2021 16:24:39 +0900 Subject: [PATCH] fix: excessive DB connection establishment delay and optimize GQL queries * refs MagicStack/asyncpg#530: apply "jit: off" option to DB connections - It is specified in `ai.backend.manager.models.base.pgsql_connect_opts` * Reuse the same single connection for all GraphQL resolvers and mutation methods --- src/ai/backend/gateway/admin.py | 62 +- src/ai/backend/gateway/auth.py | 42 +- src/ai/backend/gateway/server.py | 8 +- src/ai/backend/manager/models/agent.py | 133 ++-- src/ai/backend/manager/models/base.py | 76 +-- src/ai/backend/manager/models/gql.py | 6 +- src/ai/backend/manager/models/kernel.py | 312 +++++----- src/ai/backend/manager/models/keypair.py | 298 +++++---- src/ai/backend/manager/models/user.py | 744 +++++++++++------------ tests/conftest.py | 26 +- 10 files changed, 850 insertions(+), 857 deletions(-) diff --git a/src/ai/backend/gateway/admin.py b/src/ai/backend/gateway/admin.py index 2e658e6c2..5ac69a601 100644 --- a/src/ai/backend/gateway/admin.py +++ b/src/ai/backend/gateway/admin.py @@ -64,36 +64,38 @@ async def handle_gql(request: web.Request, params: Any) -> web.Response: app_ctx: PrivateContext = request.app['admin.context'] manager_status = await root_ctx.shared_config.get_manager_status() known_slot_types = await root_ctx.shared_config.get_resource_slots() - gql_ctx = GraphQueryContext( - dataloader_manager=DataLoaderManager(), - local_config=root_ctx.local_config, - shared_config=root_ctx.shared_config, - etcd=root_ctx.shared_config.etcd, - user=request['user'], - access_key=request['keypair']['access_key'], - dbpool=root_ctx.dbpool, - redis_stat=root_ctx.redis_stat, - redis_image=root_ctx.redis_image, - manager_status=manager_status, - known_slot_types=known_slot_types, - background_task_manager=root_ctx.background_task_manager, - storage_manager=root_ctx.storage_manager, - registry=root_ctx.registry, - ) - result = app_ctx.gql_schema.execute( - params['query'], - app_ctx.gql_executor, - variable_values=params['variables'], - operation_name=params['operation_name'], - context_value=gql_ctx, - middleware=[ - GQLLoggingMiddleware(), - GQLMutationUnfrozenRequiredMiddleware(), - GQLMutationPrivilegeCheckMiddleware(), - ], - return_promise=True) - if inspect.isawaitable(result): - result = await result + async with root_ctx.dbpool.connect() as db_conn, db_conn.begin(): + gql_ctx = GraphQueryContext( + dataloader_manager=DataLoaderManager(), + local_config=root_ctx.local_config, + shared_config=root_ctx.shared_config, + etcd=root_ctx.shared_config.etcd, + user=request['user'], + access_key=request['keypair']['access_key'], + dbpool=root_ctx.dbpool, + db_conn=db_conn, + redis_stat=root_ctx.redis_stat, + redis_image=root_ctx.redis_image, + manager_status=manager_status, + known_slot_types=known_slot_types, + background_task_manager=root_ctx.background_task_manager, + storage_manager=root_ctx.storage_manager, + registry=root_ctx.registry, + ) + result = app_ctx.gql_schema.execute( + params['query'], + app_ctx.gql_executor, + variable_values=params['variables'], + operation_name=params['operation_name'], + context_value=gql_ctx, + middleware=[ + GQLLoggingMiddleware(), + GQLMutationUnfrozenRequiredMiddleware(), + GQLMutationPrivilegeCheckMiddleware(), + ], + return_promise=True) + if inspect.isawaitable(result): + result = await result if result.errors: errors = [] for e in result.errors: diff --git a/src/ai/backend/gateway/auth.py b/src/ai/backend/gateway/auth.py index 470ef8f4b..e72e28f68 100644 --- a/src/ai/backend/gateway/auth.py +++ b/src/ai/backend/gateway/auth.py @@ -429,27 +429,27 @@ async def auth_middleware(request: web.Request, handler) -> web.StreamResponse: .where(keypairs.c.access_key == access_key) ) await conn.execute(query) - request['is_authorized'] = True - request['keypair'] = { - col.name: row[f'keypairs_{col.name}'] - for col in keypairs.c - if col.name != 'secret_key' - } - request['keypair']['resource_policy'] = { - col.name: row[f'keypair_resource_policies_{col.name}'] - for col in keypair_resource_policies.c - } - request['user'] = { - col.name: row[f'users_{col.name}'] - for col in users.c - if col.name not in ('password', 'description', 'created_at') - } - request['user']['id'] = row['keypairs_user_id'] # legacy - # if request['role'] in ['admin', 'superadmin']: - if row['keypairs_is_admin']: - request['is_admin'] = True - if request['user']['role'] == 'superadmin': - request['is_superadmin'] = True + request['is_authorized'] = True + request['keypair'] = { + col.name: row[f'keypairs_{col.name}'] + for col in keypairs.c + if col.name != 'secret_key' + } + request['keypair']['resource_policy'] = { + col.name: row[f'keypair_resource_policies_{col.name}'] + for col in keypair_resource_policies.c + } + request['user'] = { + col.name: row[f'users_{col.name}'] + for col in users.c + if col.name not in ('password', 'description', 'created_at') + } + request['user']['id'] = row['keypairs_user_id'] # legacy + # if request['role'] in ['admin', 'superadmin']: + if row['keypairs_is_admin']: + request['is_admin'] = True + if request['user']['role'] == 'superadmin': + request['is_superadmin'] = True # No matter if authenticated or not, pass-through to the handler. # (if it's required, auth_required decorator will handle the situation.) diff --git a/src/ai/backend/gateway/server.py b/src/ai/backend/gateway/server.py index a9a475290..f3f063ee7 100644 --- a/src/ai/backend/gateway/server.py +++ b/src/ai/backend/gateway/server.py @@ -50,6 +50,7 @@ from ..manager.background import BackgroundTaskManager from ..manager.exceptions import InvalidArgument from ..manager.idle import create_idle_checkers +from ..manager.models.base import pgsql_connect_opts from ..manager.models.storage import StorageSessionManager from ..manager.plugin.webapp import WebappPluginContext from ..manager.registry import AgentRegistry @@ -324,11 +325,10 @@ async def database_ctx(root_ctx: RootContext) -> AsyncIterator[None]: url = f"postgresql+asyncpg://{urlquote(username)}:{urlquote(password)}@{address}/{urlquote(dbname)}" root_ctx.dbpool = create_async_engine( url, - echo=False, - # echo=bool(root_ctx.local_config['logging']['level'] == 'DEBUG'), + echo=bool(root_ctx.local_config['logging']['level'] == 'DEBUG'), + connect_args=pgsql_connect_opts, pool_size=8, - pool_recycle=120, - # timeout=60, + max_overflow=64, json_serializer=functools.partial(json.dumps, cls=ExtendedJSONEncoder) ) yield diff --git a/src/ai/backend/manager/models/agent.py b/src/ai/backend/manager/models/agent.py index 07a6875c7..22f091f7f 100644 --- a/src/ai/backend/manager/models/agent.py +++ b/src/ai/backend/manager/models/agent.py @@ -4,7 +4,6 @@ from typing import ( Any, Mapping, - Optional, Sequence, TYPE_CHECKING, ) @@ -184,100 +183,98 @@ async def resolve_hardware_metadata( @staticmethod async def load_count( - ctx: GraphQueryContext, *, + graph_ctx: GraphQueryContext, *, scaling_group: str = None, raw_status: str = None, ) -> int: - async with ctx.dbpool.connect() as conn: - query = ( - sa.select([sa.func.count(agents.c.id)]) - .select_from(agents) - ) - if scaling_group is not None: - query = query.where(agents.c.scaling_group == scaling_group) - if raw_status is not None: - status = AgentStatus[raw_status] - query = query.where(agents.c.status == status) - result = await conn.execute(query) - return result.scalar() + query = ( + sa.select([sa.func.count(agents.c.id)]) + .select_from(agents) + ) + if scaling_group is not None: + query = query.where(agents.c.scaling_group == scaling_group) + if raw_status is not None: + status = AgentStatus[raw_status] + query = query.where(agents.c.status == status) + result = await graph_ctx.db_conn.execute(query) + return result.scalar() @classmethod async def load_slice( cls, - ctx: GraphQueryContext, + graph_ctx: GraphQueryContext, limit: int, offset: int, *, scaling_group: str = None, raw_status: str = None, order_key: str = None, order_asc: bool = True, ) -> Sequence[Agent]: - async with ctx.dbpool.connect() as conn: - # TODO: optimization for pagination using subquery, join - if order_key is None: - _ordering = agents.c.id - else: - _order_func = sa.asc if order_asc else sa.desc - _ordering = _order_func(getattr(agents.c, order_key)) - query = ( - sa.select([agents]) - .select_from(agents) - .order_by(_ordering) - .limit(limit) - .offset(offset) - ) - if scaling_group is not None: - query = query.where(agents.c.scaling_group == scaling_group) - if raw_status is not None: - status = AgentStatus[raw_status] - query = query.where(agents.c.status == status) - return [ - cls.from_row(ctx, row) async for row in (await conn.stream(query)) - ] + # TODO: optimization for pagination using subquery, join + if order_key is None: + _ordering = agents.c.id + else: + _order_func = sa.asc if order_asc else sa.desc + _ordering = _order_func(getattr(agents.c, order_key)) + query = ( + sa.select([agents]) + .select_from(agents) + .order_by(_ordering) + .limit(limit) + .offset(offset) + ) + if scaling_group is not None: + query = query.where(agents.c.scaling_group == scaling_group) + if raw_status is not None: + status = AgentStatus[raw_status] + query = query.where(agents.c.status == status) + return [ + cls.from_row(graph_ctx, row) + async for row in (await graph_ctx.db_conn.stream(query)) + ] @classmethod async def load_all( cls, - ctx: GraphQueryContext, *, + graph_ctx: GraphQueryContext, *, scaling_group: str = None, raw_status: str = None, ) -> Sequence[Agent]: - async with ctx.dbpool.connect() as conn: - query = ( - sa.select([agents]) - .select_from(agents) - ) - if scaling_group is not None: - query = query.where(agents.c.scaling_group == scaling_group) - if raw_status is not None: - status = AgentStatus[raw_status] - query = query.where(agents.c.status == status) - return [ - cls.from_row(ctx, row) async for row in (await conn.stream(query)) - ] + query = ( + sa.select([agents]) + .select_from(agents) + ) + if scaling_group is not None: + query = query.where(agents.c.scaling_group == scaling_group) + if raw_status is not None: + status = AgentStatus[raw_status] + query = query.where(agents.c.status == status) + return [ + cls.from_row(graph_ctx, row) + async for row in (await graph_ctx.db_conn.stream(query)) + ] @classmethod async def batch_load( cls, - ctx: GraphQueryContext, + graph_ctx: GraphQueryContext, agent_ids: Sequence[AgentId], *, raw_status: str = None, - ) -> Sequence[Optional[Agent]]: - async with ctx.dbpool.connect() as conn: - query = ( - sa.select([agents]) - .select_from(agents) - .where(agents.c.id.in_(agent_ids)) - .order_by( - agents.c.id - ) - ) - if raw_status is not None: - status = AgentStatus[raw_status] - query = query.where(agents.c.status == status) - return await batch_result( - ctx, conn, query, cls, - agent_ids, lambda row: row['id'], + ) -> Sequence[Agent | None]: + query = ( + sa.select([agents]) + .select_from(agents) + .where(agents.c.id.in_(agent_ids)) + .order_by( + agents.c.id ) + ) + if raw_status is not None: + status = AgentStatus[raw_status] + query = query.where(agents.c.status == status) + return await batch_result( + graph_ctx, graph_ctx.db_conn, query, cls, + agent_ids, lambda row: row['id'], + ) class AgentList(graphene.ObjectType): diff --git a/src/ai/backend/manager/models/base.py b/src/ai/backend/manager/models/base.py index c6d76a9d1..6e76a7efa 100644 --- a/src/ai/backend/manager/models/base.py +++ b/src/ai/backend/manager/models/base.py @@ -78,6 +78,8 @@ } metadata = sa.MetaData(naming_convention=convention) +pgsql_connect_opts = {'server_settings': {'jit': 'off'}} + # helper functions def zero_if_none(val): @@ -367,8 +369,8 @@ def from_row( async def batch_result( - ctx: GraphQueryContext, - conn: SAConnection, + graph_ctx: GraphQueryContext, + db_conn: SAConnection, query: sa.sql.Select, obj_type: Type[_GenericSQLBasedGQLObject], key_list: Iterable[_Key], @@ -381,14 +383,14 @@ async def batch_result( objs_per_key = collections.OrderedDict() for key in key_list: objs_per_key[key] = None - async for row in (await conn.stream(query)): - objs_per_key[key_getter(row)] = obj_type.from_row(ctx, row) + async for row in (await db_conn.stream(query)): + objs_per_key[key_getter(row)] = obj_type.from_row(graph_ctx, row) return [*objs_per_key.values()] async def batch_multiresult( - ctx: GraphQueryContext, - conn: SAConnection, + graph_ctx: GraphQueryContext, + db_conn: SAConnection, query: sa.sql.Select, obj_type: Type[_GenericSQLBasedGQLObject], key_list: Iterable[_Key], @@ -401,9 +403,9 @@ async def batch_multiresult( objs_per_key = collections.OrderedDict() for key in key_list: objs_per_key[key] = list() - async for row in (await conn.stream(query)): + async for row in (await db_conn.stream(query)): objs_per_key[key_getter(row)].append( - obj_type.from_row(ctx, row) + obj_type.from_row(graph_ctx, row) ) return [*objs_per_key.values()] @@ -567,47 +569,45 @@ async def wrapped(cls, root, info: graphene.ResolveInfo, *args, **kwargs) -> Any async def simple_db_mutate( result_cls: Type[ResultType], - ctx: GraphQueryContext, + graph_ctx: GraphQueryContext, mutation_query: sa.sql.Update | sa.sql.Insert, ) -> ResultType: - async with ctx.dbpool.connect() as conn, conn.begin(): - try: - result = await conn.execute(mutation_query) - if result.rowcount > 0: - return result_cls(True, 'success') - else: - return result_cls(False, 'no matching record') - except sa.exc.IntegrityError as e: - return result_cls(False, f'integrity error: {e}') - except (asyncio.CancelledError, asyncio.TimeoutError): - raise - except Exception as e: - return result_cls(False, f'unexpected error: {e}') + try: + result = await graph_ctx.db_conn.execute(mutation_query) + if result.rowcount > 0: + return result_cls(True, 'success') + else: + return result_cls(False, 'no matching record') + except sa.exc.IntegrityError as e: + return result_cls(False, f'integrity error: {e}') + except (asyncio.CancelledError, asyncio.TimeoutError): + raise + except Exception as e: + return result_cls(False, f'unexpected error: {e}') async def simple_db_mutate_returning_item( result_cls: Type[ResultType], - ctx: GraphQueryContext, + graph_ctx: GraphQueryContext, mutation_query: sa.sql.Update | sa.sql.Insert, *, item_query: sa.sql.Select, item_cls: Type[ItemType], ) -> ResultType: - async with ctx.dbpool.connect() as conn, conn.begin(): - try: - result = await conn.execute(mutation_query) - if result.rowcount > 0: - result = await conn.execute(item_query) - item = result.first() - return result_cls(True, 'success', item_cls.from_row(ctx, item)) - else: - return result_cls(False, 'no matching record', None) - except sa.exc.IntegrityError as e: - return result_cls(False, f'integrity error: {e}', None) - except (asyncio.CancelledError, asyncio.TimeoutError): - raise - except Exception as e: - return result_cls(False, f'unexpected error: {e}', None) + try: + result = await graph_ctx.db_conn.execute(mutation_query) + if result.rowcount > 0: + result = await graph_ctx.db_conn.execute(item_query) + item = result.first() + return result_cls(True, 'success', item_cls.from_row(graph_ctx, item)) + else: + return result_cls(False, 'no matching record', None) + except sa.exc.IntegrityError as e: + return result_cls(False, f'integrity error: {e}', None) + except (asyncio.CancelledError, asyncio.TimeoutError): + raise + except Exception as e: + return result_cls(False, f'unexpected error: {e}', None) def set_if_set(src: object, target: MutableMapping[str, Any], name: str, *, clean_func=None) -> None: diff --git a/src/ai/backend/manager/models/gql.py b/src/ai/backend/manager/models/gql.py index 1696ae2d3..4e2015fba 100644 --- a/src/ai/backend/manager/models/gql.py +++ b/src/ai/backend/manager/models/gql.py @@ -9,7 +9,10 @@ if TYPE_CHECKING: from aioredis import Redis from graphql.execution.executors.asyncio import AsyncioExecutor - from sqlalchemy.ext.asyncio import AsyncEngine as SAEngine + from sqlalchemy.ext.asyncio import ( + AsyncEngine as SAEngine, + AsyncConnection as SAConnection, + ) from ai.backend.common.etcd import AsyncEtcd from ai.backend.common.types import ( @@ -131,6 +134,7 @@ class GraphQueryContext: user: Mapping[str, Any] # TODO: express using typed dict access_key: str dbpool: SAEngine + db_conn: SAConnection redis_stat: Redis redis_image: Redis manager_status: ManagerStatus diff --git a/src/ai/backend/manager/models/kernel.py b/src/ai/backend/manager/models/kernel.py index 25610a2fc..791475546 100644 --- a/src/ai/backend/manager/models/kernel.py +++ b/src/ai/backend/manager/models/kernel.py @@ -354,9 +354,9 @@ async def match_session_ids( if for_update: match_sid_by_session_id = match_sid_by_session_id.with_for_update() for match_query in [ - match_sid_by_id, match_sid_by_session_id, match_sid_by_name, + match_sid_by_id, ]: result = await db_connection.execute(match_query) rows = result.fetchall() @@ -543,22 +543,21 @@ async def load_count( group_id: uuid.UUID = None, access_key: str = None, ) -> int: - async with ctx.dbpool.connect() as conn: - query = ( - sa.select([sa.func.count(kernels.c.id)]) - .select_from(kernels) - .where(kernels.c.session_id == session_id) - ) - if cluster_role is not None: - query = query.where(kernels.c.cluster_role == cluster_role) - if domain_name is not None: - query = query.where(kernels.c.domain_name == domain_name) - if group_id is not None: - query = query.where(kernels.c.group_id == group_id) - if access_key is not None: - query = query.where(kernels.c.access_key == access_key) - result = await conn.execute(query) - return result.scalar() + query = ( + sa.select([sa.func.count(kernels.c.id)]) + .select_from(kernels) + .where(kernels.c.session_id == session_id) + ) + if cluster_role is not None: + query = query.where(kernels.c.cluster_role == cluster_role) + if domain_name is not None: + query = query.where(kernels.c.domain_name == domain_name) + if group_id is not None: + query = query.where(kernels.c.group_id == group_id) + if access_key is not None: + query = query.where(kernels.c.access_key == access_key) + result = await ctx.db_conn.execute(query) + return result.scalar() @classmethod async def load_slice( @@ -575,29 +574,28 @@ async def load_slice( order_key: str = None, order_asc: bool = True, ) -> Sequence[Optional[ComputeContainer]]: - async with ctx.dbpool.connect() as conn: - if order_key is None: - _ordering = DEFAULT_SESSION_ORDERING - else: - _order_func = sa.asc if order_asc else sa.desc - _ordering = [_order_func(getattr(kernels.c, order_key))] - query = ( - sa.select([kernels]) - .select_from(kernels) - .where(kernels.c.session_id == session_id) - .order_by(*_ordering) - .limit(limit) - .offset(offset) - ) - if cluster_role is not None: - query = query.where(kernels.c.cluster_role == cluster_role) - if domain_name is not None: - query = query.where(kernels.c.domain_name == domain_name) - if group_id is not None: - query = query.where(kernels.c.group_id == group_id) - if access_key is not None: - query = query.where(kernels.c.access_key == access_key) - return [cls.from_row(ctx, r) async for r in (await conn.stream(query))] + if order_key is None: + _ordering = DEFAULT_SESSION_ORDERING + else: + _order_func = sa.asc if order_asc else sa.desc + _ordering = [_order_func(getattr(kernels.c, order_key))] + query = ( + sa.select([kernels]) + .select_from(kernels) + .where(kernels.c.session_id == session_id) + .order_by(*_ordering) + .limit(limit) + .offset(offset) + ) + if cluster_role is not None: + query = query.where(kernels.c.cluster_role == cluster_role) + if domain_name is not None: + query = query.where(kernels.c.domain_name == domain_name) + if group_id is not None: + query = query.where(kernels.c.group_id == group_id) + if access_key is not None: + query = query.where(kernels.c.access_key == access_key) + return [cls.from_row(ctx, r) async for r in (await ctx.db_conn.stream(query))] @classmethod async def batch_load_by_session( @@ -605,17 +603,16 @@ async def batch_load_by_session( ctx: GraphQueryContext, session_ids: Sequence[SessionId], ) -> Sequence[Sequence[ComputeContainer]]: - async with ctx.dbpool.connect() as conn: - query = ( - sa.select([kernels]) - .select_from(kernels) - # TODO: use "owner session ID" when we implement multi-container session - .where(kernels.c.session_id.in_(session_ids)) - ) - return await batch_multiresult( - ctx, conn, query, cls, - session_ids, lambda row: row['session_id'], - ) + query = ( + sa.select([kernels]) + .select_from(kernels) + # TODO: use "owner session ID" when we implement multi-container session + .where(kernels.c.session_id.in_(session_ids)) + ) + return await batch_multiresult( + ctx, ctx.db_conn, query, cls, + session_ids, lambda row: row['session_id'], + ) @classmethod async def batch_load_detail( @@ -626,26 +623,25 @@ async def batch_load_detail( domain_name: str = None, access_key: AccessKey = None, ) -> Sequence[Optional[ComputeContainer]]: - async with ctx.dbpool.connect() as conn: - j = ( - kernels - .join(groups, groups.c.id == kernels.c.group_id) - .join(users, users.c.uuid == kernels.c.user_uuid) - ) - query = ( - sa.select([kernels]) - .select_from(j) - .where( - (kernels.c.id.in_(container_ids)) - )) - if domain_name is not None: - query = query.where(kernels.c.domain_name == domain_name) - if access_key is not None: - query = query.where(kernels.c.access_key == access_key) - return await batch_result( - ctx, conn, query, cls, - container_ids, lambda row: row['id'], - ) + j = ( + kernels + .join(groups, groups.c.id == kernels.c.group_id) + .join(users, users.c.uuid == kernels.c.user_uuid) + ) + query = ( + sa.select([kernels]) + .select_from(j) + .where( + (kernels.c.id.in_(container_ids)) + )) + if domain_name is not None: + query = query.where(kernels.c.domain_name == domain_name) + if access_key is not None: + query = query.where(kernels.c.access_key == access_key) + return await batch_result( + ctx, ctx.db_conn, query, cls, + container_ids, lambda row: row['id'], + ) class ComputeSession(graphene.ObjectType): @@ -804,22 +800,21 @@ async def load_count( status_list = [KernelStatus[s] for s in status.split(',')] elif isinstance(status, KernelStatus): status_list = [status] - async with ctx.dbpool.connect() as conn: - query = ( - sa.select([sa.func.count(kernels.c.id)]) - .select_from(kernels) - .where(kernels.c.cluster_role == DEFAULT_ROLE) - ) - if domain_name is not None: - query = query.where(kernels.c.domain_name == domain_name) - if group_id is not None: - query = query.where(kernels.c.group_id == group_id) - if access_key is not None: - query = query.where(kernels.c.access_key == access_key) - if status is not None: - query = query.where(kernels.c.status.in_(status_list)) - result = await conn.execute(query) - return result.scalar() + query = ( + sa.select([sa.func.count(kernels.c.id)]) + .select_from(kernels) + .where(kernels.c.cluster_role == DEFAULT_ROLE) + ) + if domain_name is not None: + query = query.where(kernels.c.domain_name == domain_name) + if group_id is not None: + query = query.where(kernels.c.group_id == group_id) + if access_key is not None: + query = query.where(kernels.c.access_key == access_key) + if status is not None: + query = query.where(kernels.c.status.in_(status_list)) + result = await ctx.db_conn.execute(query) + return result.scalar() @classmethod async def load_slice( @@ -839,38 +834,37 @@ async def load_slice( status_list = [KernelStatus[s] for s in status.split(',')] elif isinstance(status, KernelStatus): status_list = [status] - async with ctx.dbpool.connect() as conn: - if order_key is None: - _ordering = DEFAULT_SESSION_ORDERING - else: - _order_func = sa.asc if order_asc else sa.desc - _ordering = [_order_func(getattr(kernels.c, order_key))] - j = ( - kernels - .join(groups, groups.c.id == kernels.c.group_id) - .join(users, users.c.uuid == kernels.c.user_uuid) - ) - query = ( - sa.select([ - kernels, - groups.c.name.label('group_name'), - users.c.email, - ]) - .select_from(j) - .where(kernels.c.cluster_role == DEFAULT_ROLE) - .order_by(*_ordering) - .limit(limit) - .offset(offset) - ) - if domain_name is not None: - query = query.where(kernels.c.domain_name == domain_name) - if group_id is not None: - query = query.where(kernels.c.group_id == group_id) - if access_key is not None: - query = query.where(kernels.c.access_key == access_key) - if status is not None: - query = query.where(kernels.c.status.in_(status_list)) - return [cls.from_row(ctx, r) async for r in (await conn.stream(query))] + if order_key is None: + _ordering = DEFAULT_SESSION_ORDERING + else: + _order_func = sa.asc if order_asc else sa.desc + _ordering = [_order_func(getattr(kernels.c, order_key))] + j = ( + kernels + .join(groups, groups.c.id == kernels.c.group_id) + .join(users, users.c.uuid == kernels.c.user_uuid) + ) + query = ( + sa.select([ + kernels, + groups.c.name.label('group_name'), + users.c.email, + ]) + .select_from(j) + .where(kernels.c.cluster_role == DEFAULT_ROLE) + .order_by(*_ordering) + .limit(limit) + .offset(offset) + ) + if domain_name is not None: + query = query.where(kernels.c.domain_name == domain_name) + if group_id is not None: + query = query.where(kernels.c.group_id == group_id) + if access_key is not None: + query = query.where(kernels.c.access_key == access_key) + if status is not None: + query = query.where(kernels.c.status.in_(status_list)) + return [cls.from_row(ctx, r) async for r in (await ctx.db_conn.stream(query))] @classmethod async def batch_load_by_dependency( @@ -878,23 +872,22 @@ async def batch_load_by_dependency( ctx: GraphQueryContext, session_ids: Sequence[SessionId], ) -> Sequence[Sequence[ComputeSession]]: - async with ctx.dbpool.connect() as conn: - j = sa.join( - kernels, session_dependencies, - kernels.c.session_id == session_dependencies.c.depends_on, - ) - query = ( - sa.select([kernels]) - .select_from(j) - .where( - (kernels.c.cluster_role == DEFAULT_ROLE) & - (session_dependencies.c.session_id.in_(session_ids)) - ) - ) - return await batch_multiresult( - ctx, conn, query, cls, - session_ids, lambda row: row['id'], + j = sa.join( + kernels, session_dependencies, + kernels.c.session_id == session_dependencies.c.depends_on, + ) + query = ( + sa.select([kernels]) + .select_from(j) + .where( + (kernels.c.cluster_role == DEFAULT_ROLE) & + (session_dependencies.c.session_id.in_(session_ids)) ) + ) + return await batch_multiresult( + ctx, ctx.db_conn, query, cls, + session_ids, lambda row: row['id'], + ) @classmethod async def batch_load_detail( @@ -905,31 +898,30 @@ async def batch_load_detail( domain_name: str = None, access_key: str = None, ) -> Sequence[ComputeSession | None]: - async with ctx.dbpool.connect() as conn: - j = ( - kernels - .join(groups, groups.c.id == kernels.c.group_id) - .join(users, users.c.uuid == kernels.c.user_uuid) - ) - query = ( - sa.select([ - kernels, - groups.c.name.label('group_name'), - users.c.email, - ]) - .select_from(j) - .where( - (kernels.c.cluster_role == DEFAULT_ROLE) & - (kernels.c.id.in_(session_ids)) - )) - if domain_name is not None: - query = query.where(kernels.c.domain_name == domain_name) - if access_key is not None: - query = query.where(kernels.c.access_key == access_key) - return await batch_result( - ctx, conn, query, cls, - session_ids, lambda row: row['id'], - ) + j = ( + kernels + .join(groups, groups.c.id == kernels.c.group_id) + .join(users, users.c.uuid == kernels.c.user_uuid) + ) + query = ( + sa.select([ + kernels, + groups.c.name.label('group_name'), + users.c.email, + ]) + .select_from(j) + .where( + (kernels.c.cluster_role == DEFAULT_ROLE) & + (kernels.c.id.in_(session_ids)) + )) + if domain_name is not None: + query = query.where(kernels.c.domain_name == domain_name) + if access_key is not None: + query = query.where(kernels.c.access_key == access_key) + return await batch_result( + ctx, ctx.db_conn, query, cls, + session_ids, lambda row: row['id'], + ) class ComputeContainerList(graphene.ObjectType): diff --git a/src/ai/backend/manager/models/keypair.py b/src/ai/backend/manager/models/keypair.py index a73075415..ef32feb98 100644 --- a/src/ai/backend/manager/models/keypair.py +++ b/src/ai/backend/manager/models/keypair.py @@ -207,61 +207,59 @@ async def resolve_compute_sessions(self, info: graphene.ResolveInfo, raw_status: @classmethod async def load_all( cls, - ctx: GraphQueryContext, + graph_ctx: GraphQueryContext, *, domain_name: str = None, is_active: bool = None, limit: int = None, ) -> Sequence[KeyPair]: from .user import users - async with ctx.dbpool.connect() as conn: - j = sa.join( - keypairs, users, - keypairs.c.user == users.c.uuid, - ) - query = ( - sa.select([keypairs]) - .select_from(j) - ) - if domain_name is not None: - query = query.where(users.c.domain_name == domain_name) - if is_active is not None: - query = query.where(keypairs.c.is_active == is_active) - if limit is not None: - query = query.limit(limit) - return [ - obj async for row in (await conn.stream(query)) - if (obj := cls.from_row(ctx, row)) is not None - ] + j = sa.join( + keypairs, users, + keypairs.c.user == users.c.uuid, + ) + query = ( + sa.select([keypairs]) + .select_from(j) + ) + if domain_name is not None: + query = query.where(users.c.domain_name == domain_name) + if is_active is not None: + query = query.where(keypairs.c.is_active == is_active) + if limit is not None: + query = query.limit(limit) + return [ + obj async for row in (await graph_ctx.db_conn.stream(query)) + if (obj := cls.from_row(graph_ctx, row)) is not None + ] @staticmethod async def load_count( - ctx: GraphQueryContext, + graph_ctx: GraphQueryContext, *, domain_name: str = None, email: str = None, is_active: bool = None, ) -> int: from .user import users - async with ctx.dbpool.connect() as conn: - j = sa.join(keypairs, users, keypairs.c.user == users.c.uuid) - query = ( - sa.select([sa.func.count(keypairs.c.access_key)]) - .select_from(j) - ) - if domain_name is not None: - query = query.where(users.c.domain_name == domain_name) - if email is not None: - query = query.where(keypairs.c.user_id == email) - if is_active is not None: - query = query.where(keypairs.c.is_active == is_active) - result = await conn.execute(query) - return result.scalar() + j = sa.join(keypairs, users, keypairs.c.user == users.c.uuid) + query = ( + sa.select([sa.func.count(keypairs.c.access_key)]) + .select_from(j) + ) + if domain_name is not None: + query = query.where(users.c.domain_name == domain_name) + if email is not None: + query = query.where(keypairs.c.user_id == email) + if is_active is not None: + query = query.where(keypairs.c.is_active == is_active) + result = await graph_ctx.db_conn.execute(query) + return result.scalar() @classmethod async def load_slice( cls, - ctx: GraphQueryContext, + graph_ctx: GraphQueryContext, limit: int, offset: int, *, @@ -272,85 +270,82 @@ async def load_slice( order_asc: bool = True, ) -> Sequence[KeyPair]: from .user import users - async with ctx.dbpool.connect() as conn: - if order_key is None: - _ordering = sa.desc(keypairs.c.created_at) - else: - _order_func = sa.asc if order_asc else sa.desc - _ordering = _order_func(getattr(keypairs.c, order_key)) - j = sa.join(keypairs, users, keypairs.c.user == users.c.uuid) - query = ( - sa.select([keypairs]) - .select_from(j) - .order_by(_ordering) - .limit(limit) - .offset(offset) - ) - if domain_name is not None: - query = query.where(users.c.domain_name == domain_name) - if email is not None: - query = query.where(keypairs.c.user_id == email) - if is_active is not None: - query = query.where(keypairs.c.is_active == is_active) - return [ - obj async for row in (await conn.stream(query)) - if (obj := cls.from_row(ctx, row)) is not None - ] + if order_key is None: + _ordering = sa.desc(keypairs.c.created_at) + else: + _order_func = sa.asc if order_asc else sa.desc + _ordering = _order_func(getattr(keypairs.c, order_key)) + j = sa.join(keypairs, users, keypairs.c.user == users.c.uuid) + query = ( + sa.select([keypairs]) + .select_from(j) + .order_by(_ordering) + .limit(limit) + .offset(offset) + ) + if domain_name is not None: + query = query.where(users.c.domain_name == domain_name) + if email is not None: + query = query.where(keypairs.c.user_id == email) + if is_active is not None: + query = query.where(keypairs.c.is_active == is_active) + return [ + obj async for row in (await graph_ctx.db_conn.stream(query)) + if (obj := cls.from_row(graph_ctx, row)) is not None + ] @classmethod async def batch_load_by_email( cls, - ctx: GraphQueryContext, + graph_ctx: GraphQueryContext, user_ids: Sequence[uuid.UUID], *, domain_name: str = None, is_active: bool = None, ) -> Sequence[Sequence[Optional[KeyPair]]]: from .user import users - async with ctx.dbpool.connect() as conn: - j = sa.join( - keypairs, users, - keypairs.c.user == users.c.uuid, - ) - query = ( - sa.select([keypairs]) - .select_from(j) - .where(keypairs.c.user_id.in_(user_ids)) - ) - if domain_name is not None: - query = query.where(users.c.domain_name == domain_name) - if is_active is not None: - query = query.where(keypairs.c.is_active == is_active) - return await batch_multiresult( - ctx, conn, query, cls, - user_ids, lambda row: row['user_id'], - ) + j = sa.join( + keypairs, users, + keypairs.c.user == users.c.uuid, + ) + query = ( + sa.select([keypairs]) + .select_from(j) + .where(keypairs.c.user_id.in_(user_ids)) + ) + if domain_name is not None: + query = query.where(users.c.domain_name == domain_name) + if is_active is not None: + query = query.where(keypairs.c.is_active == is_active) + return await batch_multiresult( + graph_ctx, graph_ctx.db_conn, query, cls, + user_ids, lambda row: row['user_id'], + ) @classmethod async def batch_load_by_ak( cls, - ctx: GraphQueryContext, + graph_ctx: GraphQueryContext, access_keys: Sequence[AccessKey], *, domain_name: str = None, ) -> Sequence[Optional[KeyPair]]: - async with ctx.dbpool.connect() as conn: - from .user import users - j = sa.join( - keypairs, users, - keypairs.c.user == users.c.uuid, - ) - query = ( - sa.select([keypairs]) - .select_from(j) - .where(keypairs.c.access_key.in_(access_keys)) - ) - if domain_name is not None: - query = query.where(users.c.domain_name == domain_name) - return await batch_result( - ctx, conn, query, cls, - access_keys, lambda row: row['access_key'], - ) + from .user import users + j = sa.join( + keypairs, users, + keypairs.c.user == users.c.uuid, + ) + query = ( + sa.select([keypairs]) + .select_from(j) + .where(keypairs.c.access_key.in_(access_keys)) + ) + if domain_name is not None: + query = query.where(users.c.domain_name == domain_name) + return await batch_result( + graph_ctx, graph_ctx.db_conn, query, cls, + access_keys, lambda row: row['access_key'], + ) class KeyPairList(graphene.ObjectType): @@ -399,61 +394,60 @@ async def mutate( user_id: uuid.UUID, props: KeyPairInput, ) -> CreateKeyPair: - ctx: GraphQueryContext = info.context - async with ctx.dbpool.connect() as conn, conn.begin(): - # Check if user exists with requested email (user_id for legacy). - from .user import users # noqa - query = ( - sa.select([users.c.uuid]) - .select_from(users) - .where(users.c.email == user_id) - ) - try: - result = await conn.execute(query) - user_uuid = result.scalar() - if user_uuid is None: - return cls(ok=False, msg=f'User not found: {user_id}', keypair=None) - except sa.exc.IntegrityError as e: - return cls(ok=False, msg=f'integrity error: {e}', keypair=None) - except (asyncio.CancelledError, asyncio.TimeoutError): - raise - except Exception as e: - return cls(ok=False, msg=f'unexpected error: {e}', keypair=None) - - # Create keypair. - ak, sk = generate_keypair() - pubkey, privkey = generate_ssh_keypair() - data = { - 'user_id': user_id, - 'access_key': ak, - 'secret_key': sk, - 'is_active': bool(props.is_active), - 'is_admin': bool(props.is_admin), - 'resource_policy': props.resource_policy, - 'concurrency_used': 0, - 'rate_limit': props.rate_limit, - 'num_queries': 0, - 'user': user_uuid, - 'ssh_public_key': pubkey, - 'ssh_private_key': privkey, - } - insert_query = (keypairs.insert().values(data)) - try: - result = await conn.execute(insert_query) - if result.rowcount > 0: - # Read the created key data from DB. - checkq = keypairs.select().where(keypairs.c.access_key == ak) - result = await conn.execute(checkq) - o = KeyPair.from_row(info.context, result.first()) - return cls(ok=True, msg='success', keypair=o) - else: - return cls(ok=False, msg='failed to create keypair', keypair=None) - except sa.exc.IntegrityError as e: - return cls(ok=False, msg=f'integrity error: {e}', keypair=None) - except (asyncio.CancelledError, asyncio.TimeoutError): - raise - except Exception as e: - return cls(ok=False, msg=f'unexpected error: {e}', keypair=None) + graph_ctx: GraphQueryContext = info.context + # Check if user exists with requested email (user_id for legacy). + from .user import users # noqa + query = ( + sa.select([users.c.uuid]) + .select_from(users) + .where(users.c.email == user_id) + ) + try: + result = await graph_ctx.db_conn.execute(query) + user_uuid = result.scalar() + if user_uuid is None: + return cls(ok=False, msg=f'User not found: {user_id}', keypair=None) + except sa.exc.IntegrityError as e: + return cls(ok=False, msg=f'integrity error: {e}', keypair=None) + except (asyncio.CancelledError, asyncio.TimeoutError): + raise + except Exception as e: + return cls(ok=False, msg=f'unexpected error: {e}', keypair=None) + + # Create keypair. + ak, sk = generate_keypair() + pubkey, privkey = generate_ssh_keypair() + data = { + 'user_id': user_id, + 'access_key': ak, + 'secret_key': sk, + 'is_active': bool(props.is_active), + 'is_admin': bool(props.is_admin), + 'resource_policy': props.resource_policy, + 'concurrency_used': 0, + 'rate_limit': props.rate_limit, + 'num_queries': 0, + 'user': user_uuid, + 'ssh_public_key': pubkey, + 'ssh_private_key': privkey, + } + insert_query = (keypairs.insert().values(data)) + try: + result = await graph_ctx.db_conn.execute(insert_query) + if result.rowcount > 0: + # Read the created key data from DB. + checkq = keypairs.select().where(keypairs.c.access_key == ak) + result = await graph_ctx.db_conn.execute(checkq) + o = KeyPair.from_row(info.context, result.first()) + return cls(ok=True, msg='success', keypair=o) + else: + return cls(ok=False, msg='failed to create keypair', keypair=None) + except sa.exc.IntegrityError as e: + return cls(ok=False, msg=f'integrity error: {e}', keypair=None) + except (asyncio.CancelledError, asyncio.TimeoutError): + raise + except Exception as e: + return cls(ok=False, msg=f'unexpected error: {e}', keypair=None) class ModifyKeyPair(graphene.Mutation): diff --git a/src/ai/backend/manager/models/user.py b/src/ai/backend/manager/models/user.py index d380accaf..05f869d80 100644 --- a/src/ai/backend/manager/models/user.py +++ b/src/ai/backend/manager/models/user.py @@ -215,32 +215,31 @@ async def load_all( """ Load user's information. Group names associated with the user are also returned. """ - async with ctx.dbpool.connect() as conn: - if group_id is not None: - from .group import association_groups_users as agus - j = (users.join(agus, agus.c.user_id == users.c.uuid)) - query = ( - sa.select([users]) - .select_from(j) - .where(agus.c.group_id == group_id) - ) - else: - query = ( - sa.select([users]) - .select_from(users) - ) - if ctx.user['role'] != UserRole.SUPERADMIN: - query = query.where(users.c.domain_name == ctx.user['domain_name']) - if domain_name is not None: - query = query.where(users.c.domain_name == domain_name) - if status is not None: - query = query.where(users.c.status == UserStatus(status)) - elif is_active is not None: # consider is_active field only if status is empty - _statuses = ACTIVE_USER_STATUSES if is_active else INACTIVE_USER_STATUSES - query = query.where(users.c.status.in_(_statuses)) - if limit is not None: - query = query.limit(limit) - return [cls.from_row(ctx, row) async for row in (await conn.stream(query))] + if group_id is not None: + from .group import association_groups_users as agus + j = (users.join(agus, agus.c.user_id == users.c.uuid)) + query = ( + sa.select([users]) + .select_from(j) + .where(agus.c.group_id == group_id) + ) + else: + query = ( + sa.select([users]) + .select_from(users) + ) + if ctx.user['role'] != UserRole.SUPERADMIN: + query = query.where(users.c.domain_name == ctx.user['domain_name']) + if domain_name is not None: + query = query.where(users.c.domain_name == domain_name) + if status is not None: + query = query.where(users.c.status == UserStatus(status)) + elif is_active is not None: # consider is_active field only if status is empty + _statuses = ACTIVE_USER_STATUSES if is_active else INACTIVE_USER_STATUSES + query = query.where(users.c.status.in_(_statuses)) + if limit is not None: + query = query.limit(limit) + return [cls.from_row(ctx, row) async for row in (await ctx.db_conn.stream(query))] @staticmethod async def load_count( @@ -251,29 +250,28 @@ async def load_count( is_active: bool = None, status: str = None, ) -> int: - async with ctx.dbpool.connect() as conn: - if group_id is not None: - from .group import association_groups_users as agus - j = (users.join(agus, agus.c.user_id == users.c.uuid)) - query = ( - sa.select([sa.func.count(users.c.uuid)]) - .select_from(j) - .where(agus.c.group_id == group_id) - ) - else: - query = ( - sa.select([sa.func.count(users.c.uuid)]) - .select_from(users) - ) - if domain_name is not None: - query = query.where(users.c.domain_name == domain_name) - if status is not None: - query = query.where(users.c.status == UserStatus(status)) - elif is_active is not None: # consider is_active field only if status is empty - _statuses = ACTIVE_USER_STATUSES if is_active else INACTIVE_USER_STATUSES - query = query.where(users.c.status.in_(_statuses)) - result = await conn.execute(query) - return result.scalar() + if group_id is not None: + from .group import association_groups_users as agus + j = (users.join(agus, agus.c.user_id == users.c.uuid)) + query = ( + sa.select([sa.func.count(users.c.uuid)]) + .select_from(j) + .where(agus.c.group_id == group_id) + ) + else: + query = ( + sa.select([sa.func.count(users.c.uuid)]) + .select_from(users) + ) + if domain_name is not None: + query = query.where(users.c.domain_name == domain_name) + if status is not None: + query = query.where(users.c.status == UserStatus(status)) + elif is_active is not None: # consider is_active field only if status is empty + _statuses = ACTIVE_USER_STATUSES if is_active else INACTIVE_USER_STATUSES + query = query.where(users.c.status.in_(_statuses)) + result = await ctx.db_conn.execute(query) + return result.scalar() @classmethod async def load_slice( @@ -289,41 +287,40 @@ async def load_slice( order_key: str = None, order_asc: bool = True, ) -> Sequence[User]: - async with ctx.dbpool.connect() as conn: - if order_key is None: - _ordering = sa.desc(users.c.created_at) - else: - _order_func = sa.asc if order_asc else sa.desc - _ordering = _order_func(getattr(users.c, order_key)) - if group_id is not None: - from .group import association_groups_users as agus - j = (users.join(agus, agus.c.user_id == users.c.uuid)) - query = ( - sa.select([users]) - .select_from(j) - .where(agus.c.group_id == group_id) - .order_by(_ordering) - .limit(limit) - .offset(offset) - ) - else: - query = ( - sa.select([users]) - .select_from(users) - .order_by(_ordering) - .limit(limit) - .offset(offset) - ) - if domain_name is not None: - query = query.where(users.c.domain_name == domain_name) - if status is not None: - query = query.where(users.c.status == UserStatus(status)) - elif is_active is not None: # consider is_active field only if status is empty - _statuses = ACTIVE_USER_STATUSES if is_active else INACTIVE_USER_STATUSES - query = query.where(users.c.status.in_(_statuses)) - return [ - cls.from_row(ctx, row) async for row in (await conn.stream(query)) - ] + if order_key is None: + _ordering = sa.desc(users.c.created_at) + else: + _order_func = sa.asc if order_asc else sa.desc + _ordering = _order_func(getattr(users.c, order_key)) + if group_id is not None: + from .group import association_groups_users as agus + j = (users.join(agus, agus.c.user_id == users.c.uuid)) + query = ( + sa.select([users]) + .select_from(j) + .where(agus.c.group_id == group_id) + .order_by(_ordering) + .limit(limit) + .offset(offset) + ) + else: + query = ( + sa.select([users]) + .select_from(users) + .order_by(_ordering) + .limit(limit) + .offset(offset) + ) + if domain_name is not None: + query = query.where(users.c.domain_name == domain_name) + if status is not None: + query = query.where(users.c.status == UserStatus(status)) + elif is_active is not None: # consider is_active field only if status is empty + _statuses = ACTIVE_USER_STATUSES if is_active else INACTIVE_USER_STATUSES + query = query.where(users.c.status.in_(_statuses)) + return [ + cls.from_row(ctx, row) async for row in (await ctx.db_conn.stream(query)) + ] @classmethod async def batch_load_by_email( @@ -337,23 +334,22 @@ async def batch_load_by_email( ) -> Sequence[Optional[User]]: if not emails: return [] - async with ctx.dbpool.connect() as conn: - query = ( - sa.select([users]) - .select_from(users) - .where(users.c.email.in_(emails)) - ) - if domain_name is not None: - query = query.where(users.c.domain_name == domain_name) - if status is not None: - query = query.where(users.c.status == UserStatus(status)) - elif is_active is not None: # consider is_active field only if status is empty - _statuses = ACTIVE_USER_STATUSES if is_active else INACTIVE_USER_STATUSES - query = query.where(users.c.status.in_(_statuses)) - return await batch_result( - ctx, conn, query, cls, - emails, lambda row: row['email'], - ) + query = ( + sa.select([users]) + .select_from(users) + .where(users.c.email.in_(emails)) + ) + if domain_name is not None: + query = query.where(users.c.domain_name == domain_name) + if status is not None: + query = query.where(users.c.status == UserStatus(status)) + elif is_active is not None: # consider is_active field only if status is empty + _statuses = ACTIVE_USER_STATUSES if is_active else INACTIVE_USER_STATUSES + query = query.where(users.c.status.in_(_statuses)) + return await batch_result( + ctx, ctx.db_conn, query, cls, + emails, lambda row: row['email'], + ) @classmethod async def batch_load_by_uuid( @@ -367,23 +363,22 @@ async def batch_load_by_uuid( ) -> Sequence[Optional[User]]: if not user_ids: return [] - async with ctx.dbpool.connect() as conn: - query = ( - sa.select([users]) - .select_from(users) - .where(users.c.uuid.in_(user_ids)) - ) - if domain_name is not None: - query = query.where(users.c.domain_name == domain_name) - if status is not None: - query = query.where(users.c.status == UserStatus(status)) - elif is_active is not None: # consider is_active field only if status is empty - _statuses = ACTIVE_USER_STATUSES if is_active else INACTIVE_USER_STATUSES - query = query.where(users.c.status.in_(_statuses)) - return await batch_result( - ctx, conn, query, cls, - user_ids, lambda row: row['uuid'], - ) + query = ( + sa.select([users]) + .select_from(users) + .where(users.c.uuid.in_(user_ids)) + ) + if domain_name is not None: + query = query.where(users.c.domain_name == domain_name) + if status is not None: + query = query.where(users.c.status == UserStatus(status)) + elif is_active is not None: # consider is_active field only if status is empty + _statuses = ACTIVE_USER_STATUSES if is_active else INACTIVE_USER_STATUSES + query = query.where(users.c.status.in_(_statuses)) + return await batch_result( + ctx, ctx.db_conn, query, cls, + user_ids, lambda row: row['uuid'], + ) class UserList(graphene.ObjectType): @@ -447,77 +442,76 @@ async def mutate( props: UserInput, ) -> CreateUser: graph_ctx: GraphQueryContext = info.context - async with graph_ctx.dbpool.connect() as conn, conn.begin(): - username = props.username if props.username else email - if props.status is None and props.is_active is not None: - _status = UserStatus.ACTIVE if props.is_active else UserStatus.INACTIVE + username = props.username if props.username else email + if props.status is None and props.is_active is not None: + _status = UserStatus.ACTIVE if props.is_active else UserStatus.INACTIVE + else: + _status = UserStatus(props.status) + data = { + 'username': username, + 'email': email, + 'password': props.password, + 'need_password_change': props.need_password_change, + 'full_name': props.full_name, + 'description': props.description, + 'status': _status, + 'status_info': 'admin-requested', # user mutation is only for admin + 'domain_name': props.domain_name, + 'role': UserRole(props.role), + } + try: + query = (users.insert().values(data)) + result = await graph_ctx.db_conn.execute(query) + if result.rowcount > 0: + # Read the created user data from DB. + checkq = users.select().where(users.c.email == email) + result = await graph_ctx.db_conn.execute(checkq) + o = User.from_row(info.context, result.first()) + + # Create user's first access_key and secret_key. + from .keypair import generate_keypair, generate_ssh_keypair, keypairs + ak, sk = generate_keypair() + pubkey, privkey = generate_ssh_keypair() + is_admin = True if data['role'] in [UserRole.SUPERADMIN, UserRole.ADMIN] else False + kp_data = { + 'user_id': email, + 'access_key': ak, + 'secret_key': sk, + 'is_active': True if _status == UserStatus.ACTIVE else False, + 'is_admin': is_admin, + 'resource_policy': 'default', + 'concurrency_used': 0, + 'rate_limit': 10000, + 'num_queries': 0, + 'user': o.uuid, + 'ssh_public_key': pubkey, + 'ssh_private_key': privkey, + } + query = (keypairs.insert().values(kp_data)) + await graph_ctx.db_conn.execute(query) + + # Add user to groups if group_ids parameter is provided. + from .group import association_groups_users, groups + if props.group_ids: + query = (sa.select([groups.c.id]) + .select_from(groups) + .where(groups.c.domain_name == props.domain_name) + .where(groups.c.id.in_(props.group_ids))) + result = await graph_ctx.db_conn.execute(query) + grps = result.fetchall() + if grps: + values = [{'user_id': o.uuid, 'group_id': grp.id} for grp in grps] + query = association_groups_users.insert().values(values) + await graph_ctx.db_conn.execute(query) + return cls(ok=True, msg='success', user=o) else: - _status = UserStatus(props.status) - data = { - 'username': username, - 'email': email, - 'password': props.password, - 'need_password_change': props.need_password_change, - 'full_name': props.full_name, - 'description': props.description, - 'status': _status, - 'status_info': 'admin-requested', # user mutation is only for admin - 'domain_name': props.domain_name, - 'role': UserRole(props.role), - } - try: - query = (users.insert().values(data)) - result = await conn.execute(query) - if result.rowcount > 0: - # Read the created user data from DB. - checkq = users.select().where(users.c.email == email) - result = await conn.execute(checkq) - o = User.from_row(info.context, result.first()) - - # Create user's first access_key and secret_key. - from .keypair import generate_keypair, generate_ssh_keypair, keypairs - ak, sk = generate_keypair() - pubkey, privkey = generate_ssh_keypair() - is_admin = True if data['role'] in [UserRole.SUPERADMIN, UserRole.ADMIN] else False - kp_data = { - 'user_id': email, - 'access_key': ak, - 'secret_key': sk, - 'is_active': True if _status == UserStatus.ACTIVE else False, - 'is_admin': is_admin, - 'resource_policy': 'default', - 'concurrency_used': 0, - 'rate_limit': 10000, - 'num_queries': 0, - 'user': o.uuid, - 'ssh_public_key': pubkey, - 'ssh_private_key': privkey, - } - query = (keypairs.insert().values(kp_data)) - await conn.execute(query) - - # Add user to groups if group_ids parameter is provided. - from .group import association_groups_users, groups - if props.group_ids: - query = (sa.select([groups.c.id]) - .select_from(groups) - .where(groups.c.domain_name == props.domain_name) - .where(groups.c.id.in_(props.group_ids))) - result = await conn.execute(query) - grps = result.fetchall() - if grps: - values = [{'user_id': o.uuid, 'group_id': grp.id} for grp in grps] - query = association_groups_users.insert().values(values) - await conn.execute(query) - return cls(ok=True, msg='success', user=o) - else: - return cls(ok=False, msg='failed to create user', user=None) - except sa.exc.IntegrityError as e: - return cls(ok=False, msg=f'integrity error: {e}', user=None) - except (asyncio.CancelledError, asyncio.TimeoutError): - raise - except Exception as e: - return cls(ok=False, msg=f'unexpected error: {e}', user=None) + return cls(ok=False, msg='failed to create user', user=None) + except sa.exc.IntegrityError as e: + return cls(ok=False, msg=f'integrity error: {e}', user=None) + except (asyncio.CancelledError, asyncio.TimeoutError): + raise + except Exception as e: + return cls(ok=False, msg=f'unexpected error: {e}', user=None) class ModifyUser(graphene.Mutation): @@ -540,131 +534,129 @@ async def mutate( email: str, props: ModifyUserInput, ) -> ModifyUser: - ctx: GraphQueryContext = info.context - async with ctx.dbpool.connect() as conn, conn.begin(): - - data: Dict[str, Any] = {} - set_if_set(props, data, 'username') - set_if_set(props, data, 'password') - set_if_set(props, data, 'need_password_change') - set_if_set(props, data, 'full_name') - set_if_set(props, data, 'description') - set_if_set(props, data, 'status') - set_if_set(props, data, 'domain_name') - set_if_set(props, data, 'role') - - if 'role' in data: - data['role'] = UserRole(data['role']) - - if data.get('status') is None and props.is_active is not None: - _status = 'active' if props.is_active else 'inactive' - data['status'] = _status - if 'status' in data and data['status'] is not None: - data['status'] = UserStatus(data['status']) - - if not data and not props.group_ids: - return cls(ok=False, msg='nothing to update', user=None) - - try: - # Get previous domain name of the user. - query = (sa.select([users.c.domain_name, users.c.role, users.c.status]) - .select_from(users) - .where(users.c.email == email)) - result = await conn.execute(query) - row = result.first() - prev_domain_name = row.domain_name - prev_role = row.role + graph_ctx: GraphQueryContext = info.context + data: Dict[str, Any] = {} + set_if_set(props, data, 'username') + set_if_set(props, data, 'password') + set_if_set(props, data, 'need_password_change') + set_if_set(props, data, 'full_name') + set_if_set(props, data, 'description') + set_if_set(props, data, 'status') + set_if_set(props, data, 'domain_name') + set_if_set(props, data, 'role') + + if 'role' in data: + data['role'] = UserRole(data['role']) + + if data.get('status') is None and props.is_active is not None: + _status = 'active' if props.is_active else 'inactive' + data['status'] = _status + if 'status' in data and data['status'] is not None: + data['status'] = UserStatus(data['status']) + + if not data and not props.group_ids: + return cls(ok=False, msg='nothing to update', user=None) - if 'status' in data and row.status != data['status']: - data['status_info'] = 'admin-requested' # user mutation is only for admin + try: + # Get previous domain name of the user. + query = (sa.select([users.c.domain_name, users.c.role, users.c.status]) + .select_from(users) + .where(users.c.email == email)) + result = await graph_ctx.db_conn.execute(query) + row = result.first() + prev_domain_name = row.domain_name + prev_role = row.role + + if 'status' in data and row.status != data['status']: + data['status_info'] = 'admin-requested' # user mutation is only for admin + + # Update user. + query = (users.update().values(data).where(users.c.email == email)) + result = await graph_ctx.db_conn.execute(query) + if result.rowcount > 0: + checkq = users.select().where(users.c.email == email) + result = await graph_ctx.db_conn.execute(checkq) + o = User.from_row(graph_ctx, result.first()) + else: + return cls(ok=False, msg='no such user', user=None) - # Update user. - query = (users.update().values(data).where(users.c.email == email)) - result = await conn.execute(query) - if result.rowcount > 0: - checkq = users.select().where(users.c.email == email) - result = await conn.execute(checkq) - o = User.from_row(ctx, result.first()) + # Update keypair if user's role is updated. + # NOTE: This assumes that user have only one keypair. + if 'role' in data and data['role'] != prev_role: + from ai.backend.manager.models import keypairs + query = (sa.select([keypairs.c.user, + keypairs.c.is_active, + keypairs.c.is_admin]) + .select_from(keypairs) + .where(keypairs.c.user == o.uuid) + .order_by(sa.desc(keypairs.c.is_admin)) + .order_by(sa.desc(keypairs.c.is_active))) + result = await graph_ctx.db_conn.execute(query) + if data['role'] in [UserRole.SUPERADMIN, UserRole.ADMIN]: + # User's becomes admin. Set the keypair as active admin. + kp = result.first() + kp_data = dict() + if not kp.is_admin: + kp_data['is_admin'] = True + if not kp.is_active: + kp_data['is_active'] = True + if len(kp_data.keys()) > 0: + query = (keypairs.update() + .values(kp_data) + .where(keypairs.c.user == o.uuid)) + await graph_ctx.db_conn.execute(query) else: - return cls(ok=False, msg='no such user', user=None) - - # Update keypair if user's role is updated. - # NOTE: This assumes that user have only one keypair. - if 'role' in data and data['role'] != prev_role: - from ai.backend.manager.models import keypairs - query = (sa.select([keypairs.c.user, - keypairs.c.is_active, - keypairs.c.is_admin]) - .select_from(keypairs) - .where(keypairs.c.user == o.uuid) - .order_by(sa.desc(keypairs.c.is_admin)) - .order_by(sa.desc(keypairs.c.is_active))) - result = await conn.execute(query) - if data['role'] in [UserRole.SUPERADMIN, UserRole.ADMIN]: - # User's becomes admin. Set the keypair as active admin. - kp = result.first() + # User becomes non-admin. Make the keypair non-admin as well. + # If there are multiple admin keypairs, inactivate them. + rows = result.fetchall() + cnt = 0 + for row in rows: kp_data = dict() - if not kp.is_admin: - kp_data['is_admin'] = True - if not kp.is_active: - kp_data['is_active'] = True + if cnt == 0: + kp_data['is_admin'] = False + elif row.is_admin and row.is_active: + kp_data['is_active'] = False if len(kp_data.keys()) > 0: query = (keypairs.update() - .values(kp_data) - .where(keypairs.c.user == o.uuid)) - await conn.execute(query) - else: - # User becomes non-admin. Make the keypair non-admin as well. - # If there are multiple admin keypairs, inactivate them. - rows = result.fetchall() - cnt = 0 - for row in rows: - kp_data = dict() - if cnt == 0: - kp_data['is_admin'] = False - elif row.is_admin and row.is_active: - kp_data['is_active'] = False - if len(kp_data.keys()) > 0: - query = (keypairs.update() - .values(kp_data) - .where(keypairs.c.user == row.user)) - await conn.execute(query) - cnt += 1 - - # If domain is changed and no group is associated, clear previous domain's group. - if prev_domain_name != o.domain_name and not props.group_ids: - from .group import association_groups_users, groups - query = (association_groups_users - .delete() - .where(association_groups_users.c.user_id == o.uuid)) - await conn.execute(query) - - # Update user's group if group_ids parameter is provided. - if props.group_ids and o is not None: - from .group import association_groups_users, groups # noqa - # Clear previous groups associated with the user. - query = (association_groups_users - .delete() - .where(association_groups_users.c.user_id == o.uuid)) - await conn.execute(query) - # Add user to new groups. - query = (sa.select([groups.c.id]) - .select_from(groups) - .where(groups.c.domain_name == o.domain_name) - .where(groups.c.id.in_(props.group_ids))) - result = await conn.execute(query) - grps = result.fetchall() - if grps: - values = [{'user_id': o.uuid, 'group_id': grp.id} for grp in grps] - query = association_groups_users.insert().values(values) - await conn.execute(query) - return cls(ok=True, msg='success', user=o) - except sa.exc.IntegrityError as e: - return cls(ok=False, msg=f'integrity error: {e}', user=None) - except (asyncio.CancelledError, asyncio.TimeoutError): - raise - except Exception as e: - return cls(ok=False, msg=f'unexpected error: {e}', user=None) + .values(kp_data) + .where(keypairs.c.user == row.user)) + await graph_ctx.db_conn.execute(query) + cnt += 1 + + # If domain is changed and no group is associated, clear previous domain's group. + if prev_domain_name != o.domain_name and not props.group_ids: + from .group import association_groups_users, groups + query = (association_groups_users + .delete() + .where(association_groups_users.c.user_id == o.uuid)) + await graph_ctx.db_conn.execute(query) + + # Update user's group if group_ids parameter is provided. + if props.group_ids and o is not None: + from .group import association_groups_users, groups # noqa + # Clear previous groups associated with the user. + query = (association_groups_users + .delete() + .where(association_groups_users.c.user_id == o.uuid)) + await graph_ctx.db_conn.execute(query) + # Add user to new groups. + query = (sa.select([groups.c.id]) + .select_from(groups) + .where(groups.c.domain_name == o.domain_name) + .where(groups.c.id.in_(props.group_ids))) + result = await graph_ctx.db_conn.execute(query) + grps = result.fetchall() + if grps: + values = [{'user_id': o.uuid, 'group_id': grp.id} for grp in grps] + query = association_groups_users.insert().values(values) + await graph_ctx.db_conn.execute(query) + return cls(ok=True, msg='success', user=o) + except sa.exc.IntegrityError as e: + return cls(ok=False, msg=f'integrity error: {e}', user=None) + except (asyncio.CancelledError, asyncio.TimeoutError): + raise + except Exception as e: + return cls(ok=False, msg=f'unexpected error: {e}', user=None) class DeleteUser(graphene.Mutation): @@ -690,34 +682,33 @@ async def mutate( email: str, ) -> DeleteUser: graph_ctx: GraphQueryContext = info.context - async with graph_ctx.dbpool.connect() as conn, conn.begin(): - try: - # Make all user keypairs inactive. - from ai.backend.manager.models import keypairs - query = ( - keypairs.update() - .values(is_active=False) - .where(keypairs.c.user_id == email) - ) - await conn.execute(query) - # Mark user as deleted. - query = ( - users.update() - .values(status=UserStatus.DELETED, - status_info='admin-requested') - .where(users.c.email == email) - ) - result = await conn.execute(query) - if result.rowcount > 0: - return cls(ok=True, msg='success') - else: - return cls(ok=False, msg='no such user') - except sa.exc.IntegrityError as e: - return cls(ok=False, msg=f'integrity error: {e}') - except (asyncio.CancelledError, asyncio.TimeoutError): - raise - except Exception as e: - return cls(ok=False, msg=f'unexpected error: {e}') + try: + # Make all user keypairs inactive. + from ai.backend.manager.models import keypairs + query = ( + keypairs.update() + .values(is_active=False) + .where(keypairs.c.user_id == email) + ) + await graph_ctx.db_conn.execute(query) + # Mark user as deleted. + query = ( + users.update() + .values(status=UserStatus.DELETED, + status_info='admin-requested') + .where(users.c.email == email) + ) + result = await graph_ctx.db_conn.execute(query) + if result.rowcount > 0: + return cls(ok=True, msg='success') + else: + return cls(ok=False, msg='no such user') + except sa.exc.IntegrityError as e: + return cls(ok=False, msg=f'integrity error: {e}') + except (asyncio.CancelledError, asyncio.TimeoutError): + raise + except Exception as e: + return cls(ok=False, msg=f'unexpected error: {e}') class PurgeUser(graphene.Mutation): @@ -752,50 +743,49 @@ async def mutate( email: str, props: PurgeUserInput, ) -> PurgeUser: - ctx: GraphQueryContext = info.context - async with ctx.dbpool.connect() as conn, conn.begin(): - try: - query = ( - sa.select([users.c.uuid]) - .select_from(users) - .where(users.c.email == email) + graph_ctx: GraphQueryContext = info.context + try: + query = ( + sa.select([users.c.uuid]) + .select_from(users) + .where(users.c.email == email) + ) + user_uuid = await graph_ctx.db_conn.scalar(query) + log.info('completly deleting user {0}...', email) + + if await cls.user_vfolder_mounted_to_active_kernels(graph_ctx.db_conn, user_uuid): + raise RuntimeError('Some of user\'s virtual folders are mounted to active kernels. ' + 'Terminate those kernels first.') + if await cls.user_has_active_kernels(graph_ctx.db_conn, user_uuid): + raise RuntimeError('User has some active kernels. Terminate them first.') + + if not props.purge_shared_vfolders: + await cls.migrate_shared_vfolders( + graph_ctx.db_conn, + deleted_user_uuid=user_uuid, + target_user_uuid=graph_ctx.user['uuid'], + target_user_email=graph_ctx.user['email'], ) - user_uuid = await conn.scalar(query) - log.info('completly deleting user {0}...', email) - - if await cls.user_vfolder_mounted_to_active_kernels(conn, user_uuid): - raise RuntimeError('Some of user\'s virtual folders are mounted to active kernels. ' - 'Terminate those kernels first.') - if await cls.user_has_active_kernels(conn, user_uuid): - raise RuntimeError('User has some active kernels. Terminate them first.') - - if not props.purge_shared_vfolders: - await cls.migrate_shared_vfolders( - conn, - deleted_user_uuid=user_uuid, - target_user_uuid=ctx.user['uuid'], - target_user_email=ctx.user['email'], - ) - await cls.delete_vfolders(conn, user_uuid, ctx.storage_manager) - await cls.delete_kernels(conn, user_uuid) - await cls.delete_keypairs(conn, user_uuid) + await cls.delete_vfolders(graph_ctx.db_conn, user_uuid, graph_ctx.storage_manager) + await cls.delete_kernels(graph_ctx.db_conn, user_uuid) + await cls.delete_keypairs(graph_ctx.db_conn, user_uuid) - query = ( - users.delete() - .where(users.c.email == email) - ) - result = await conn.execute(query) - if result.rowcount > 0: - log.info('user is deleted: {0}', email) - return cls(ok=True, msg='success') - else: - return cls(ok=False, msg='no such user') - except sa.exc.IntegrityError as e: - return cls(ok=False, msg=f'integrity error: {e}') - except (asyncio.CancelledError, asyncio.TimeoutError): - raise - except Exception as e: - return cls(ok=False, msg=f'unexpected error: {e}') + query = ( + users.delete() + .where(users.c.email == email) + ) + result = await graph_ctx.db_conn.execute(query) + if result.rowcount > 0: + log.info('user is deleted: {0}', email) + return cls(ok=True, msg='success') + else: + return cls(ok=False, msg='no such user') + except sa.exc.IntegrityError as e: + return cls(ok=False, msg=f'integrity error: {e}') + except (asyncio.CancelledError, asyncio.TimeoutError): + raise + except Exception as e: + return cls(ok=False, msg=f'unexpected error: {e}') @classmethod async def migrate_shared_vfolders( diff --git a/tests/conftest.py b/tests/conftest.py index 43a260a0c..19968b762 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -38,7 +38,7 @@ from ai.backend.gateway.types import ( CleanupContext, ) -from ai.backend.manager.models.base import populate_fixture +from ai.backend.manager.models.base import populate_fixture, pgsql_connect_opts from ai.backend.manager.models import ( scaling_groups, agents, @@ -184,7 +184,11 @@ def database(request, local_config, test_db): db_url = f'postgresql+asyncpg://{db_user}@{db_addr}/testing' async def init_db(): - engine = sa.ext.asyncio.create_async_engine(db_url, isolation_level="AUTOCOMMIT") + engine = sa.ext.asyncio.create_async_engine( + db_url, + connect_args=pgsql_connect_opts, + isolation_level="AUTOCOMMIT", + ) async with engine.connect() as conn: await conn.execute(sa.text(f'CREATE DATABASE "{test_db}";')) await engine.dispose() @@ -192,7 +196,11 @@ async def init_db(): asyncio.run(init_db()) async def finalize_db(): - engine = sa.ext.asyncio.create_async_engine(db_url, isolation_level="AUTOCOMMIT") + engine = sa.ext.asyncio.create_async_engine( + db_url, + connect_args=pgsql_connect_opts, + isolation_level="AUTOCOMMIT", + ) async with engine.connect() as conn: await conn.execute(sa.text(f'REVOKE CONNECT ON DATABASE "{test_db}" FROM public;')) await conn.execute(sa.text('SELECT pg_terminate_backend(pid) FROM pg_stat_activity ' @@ -244,9 +252,12 @@ def database_fixture(local_config, test_db, database): )) async def init_fixture(): - engine: SAEngine = sa.ext.asyncio.create_async_engine(db_url) + engine: SAEngine = sa.ext.asyncio.create_async_engine( + db_url, + connect_args=pgsql_connect_opts, + ) try: - await populate_fixture(engine, fixtures, ignore_unique_violation=True) + await populate_fixture(engine, fixtures) finally: await engine.dispose() @@ -255,7 +266,10 @@ async def init_fixture(): yield async def clean_fixture(): - engine: SAEngine = sa.ext.asyncio.create_async_engine(db_url) + engine: SAEngine = sa.ext.asyncio.create_async_engine( + db_url, + connect_args=pgsql_connect_opts, + ) try: async with engine.begin() as conn: await conn.execute((vfolders.delete()))