Skip to content
This repository has been archived by the owner on Aug 2, 2023. It is now read-only.

Commit

Permalink
fix: excessive DB connection establishment delay and optimize GQL que…
Browse files Browse the repository at this point in the history
…ries

* refs MagicStack/asyncpg#530: apply "jit: off" option to DB connections
  - It is specified in `ai.backend.manager.models.base.pgsql_connect_opts`
* Reuse the same single connection for all GraphQL resolvers and mutation methods
  • Loading branch information
achimnol committed Mar 24, 2021
1 parent 8bfb9f7 commit 449d5d1
Show file tree
Hide file tree
Showing 10 changed files with 850 additions and 857 deletions.
62 changes: 32 additions & 30 deletions src/ai/backend/gateway/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,36 +64,38 @@ async def handle_gql(request: web.Request, params: Any) -> web.Response:
app_ctx: PrivateContext = request.app['admin.context']
manager_status = await root_ctx.shared_config.get_manager_status()
known_slot_types = await root_ctx.shared_config.get_resource_slots()
gql_ctx = GraphQueryContext(
dataloader_manager=DataLoaderManager(),
local_config=root_ctx.local_config,
shared_config=root_ctx.shared_config,
etcd=root_ctx.shared_config.etcd,
user=request['user'],
access_key=request['keypair']['access_key'],
dbpool=root_ctx.dbpool,
redis_stat=root_ctx.redis_stat,
redis_image=root_ctx.redis_image,
manager_status=manager_status,
known_slot_types=known_slot_types,
background_task_manager=root_ctx.background_task_manager,
storage_manager=root_ctx.storage_manager,
registry=root_ctx.registry,
)
result = app_ctx.gql_schema.execute(
params['query'],
app_ctx.gql_executor,
variable_values=params['variables'],
operation_name=params['operation_name'],
context_value=gql_ctx,
middleware=[
GQLLoggingMiddleware(),
GQLMutationUnfrozenRequiredMiddleware(),
GQLMutationPrivilegeCheckMiddleware(),
],
return_promise=True)
if inspect.isawaitable(result):
result = await result
async with root_ctx.dbpool.connect() as db_conn, db_conn.begin():
gql_ctx = GraphQueryContext(
dataloader_manager=DataLoaderManager(),
local_config=root_ctx.local_config,
shared_config=root_ctx.shared_config,
etcd=root_ctx.shared_config.etcd,
user=request['user'],
access_key=request['keypair']['access_key'],
dbpool=root_ctx.dbpool,
db_conn=db_conn,
redis_stat=root_ctx.redis_stat,
redis_image=root_ctx.redis_image,
manager_status=manager_status,
known_slot_types=known_slot_types,
background_task_manager=root_ctx.background_task_manager,
storage_manager=root_ctx.storage_manager,
registry=root_ctx.registry,
)
result = app_ctx.gql_schema.execute(
params['query'],
app_ctx.gql_executor,
variable_values=params['variables'],
operation_name=params['operation_name'],
context_value=gql_ctx,
middleware=[
GQLLoggingMiddleware(),
GQLMutationUnfrozenRequiredMiddleware(),
GQLMutationPrivilegeCheckMiddleware(),
],
return_promise=True)
if inspect.isawaitable(result):
result = await result
if result.errors:
errors = []
for e in result.errors:
Expand Down
42 changes: 21 additions & 21 deletions src/ai/backend/gateway/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,27 +429,27 @@ async def auth_middleware(request: web.Request, handler) -> web.StreamResponse:
.where(keypairs.c.access_key == access_key)
)
await conn.execute(query)
request['is_authorized'] = True
request['keypair'] = {
col.name: row[f'keypairs_{col.name}']
for col in keypairs.c
if col.name != 'secret_key'
}
request['keypair']['resource_policy'] = {
col.name: row[f'keypair_resource_policies_{col.name}']
for col in keypair_resource_policies.c
}
request['user'] = {
col.name: row[f'users_{col.name}']
for col in users.c
if col.name not in ('password', 'description', 'created_at')
}
request['user']['id'] = row['keypairs_user_id'] # legacy
# if request['role'] in ['admin', 'superadmin']:
if row['keypairs_is_admin']:
request['is_admin'] = True
if request['user']['role'] == 'superadmin':
request['is_superadmin'] = True
request['is_authorized'] = True
request['keypair'] = {
col.name: row[f'keypairs_{col.name}']
for col in keypairs.c
if col.name != 'secret_key'
}
request['keypair']['resource_policy'] = {
col.name: row[f'keypair_resource_policies_{col.name}']
for col in keypair_resource_policies.c
}
request['user'] = {
col.name: row[f'users_{col.name}']
for col in users.c
if col.name not in ('password', 'description', 'created_at')
}
request['user']['id'] = row['keypairs_user_id'] # legacy
# if request['role'] in ['admin', 'superadmin']:
if row['keypairs_is_admin']:
request['is_admin'] = True
if request['user']['role'] == 'superadmin':
request['is_superadmin'] = True

# No matter if authenticated or not, pass-through to the handler.
# (if it's required, auth_required decorator will handle the situation.)
Expand Down
8 changes: 4 additions & 4 deletions src/ai/backend/gateway/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
from ..manager.background import BackgroundTaskManager
from ..manager.exceptions import InvalidArgument
from ..manager.idle import create_idle_checkers
from ..manager.models.base import pgsql_connect_opts
from ..manager.models.storage import StorageSessionManager
from ..manager.plugin.webapp import WebappPluginContext
from ..manager.registry import AgentRegistry
Expand Down Expand Up @@ -324,11 +325,10 @@ async def database_ctx(root_ctx: RootContext) -> AsyncIterator[None]:
url = f"postgresql+asyncpg://{urlquote(username)}:{urlquote(password)}@{address}/{urlquote(dbname)}"
root_ctx.dbpool = create_async_engine(
url,
echo=False,
# echo=bool(root_ctx.local_config['logging']['level'] == 'DEBUG'),
echo=bool(root_ctx.local_config['logging']['level'] == 'DEBUG'),
connect_args=pgsql_connect_opts,
pool_size=8,
pool_recycle=120,
# timeout=60,
max_overflow=64,
json_serializer=functools.partial(json.dumps, cls=ExtendedJSONEncoder)
)
yield
Expand Down
133 changes: 65 additions & 68 deletions src/ai/backend/manager/models/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from typing import (
Any,
Mapping,
Optional,
Sequence,
TYPE_CHECKING,
)
Expand Down Expand Up @@ -184,100 +183,98 @@ async def resolve_hardware_metadata(

@staticmethod
async def load_count(
ctx: GraphQueryContext, *,
graph_ctx: GraphQueryContext, *,
scaling_group: str = None,
raw_status: str = None,
) -> int:
async with ctx.dbpool.connect() as conn:
query = (
sa.select([sa.func.count(agents.c.id)])
.select_from(agents)
)
if scaling_group is not None:
query = query.where(agents.c.scaling_group == scaling_group)
if raw_status is not None:
status = AgentStatus[raw_status]
query = query.where(agents.c.status == status)
result = await conn.execute(query)
return result.scalar()
query = (
sa.select([sa.func.count(agents.c.id)])
.select_from(agents)
)
if scaling_group is not None:
query = query.where(agents.c.scaling_group == scaling_group)
if raw_status is not None:
status = AgentStatus[raw_status]
query = query.where(agents.c.status == status)
result = await graph_ctx.db_conn.execute(query)
return result.scalar()

@classmethod
async def load_slice(
cls,
ctx: GraphQueryContext,
graph_ctx: GraphQueryContext,
limit: int, offset: int, *,
scaling_group: str = None,
raw_status: str = None,
order_key: str = None,
order_asc: bool = True,
) -> Sequence[Agent]:
async with ctx.dbpool.connect() as conn:
# TODO: optimization for pagination using subquery, join
if order_key is None:
_ordering = agents.c.id
else:
_order_func = sa.asc if order_asc else sa.desc
_ordering = _order_func(getattr(agents.c, order_key))
query = (
sa.select([agents])
.select_from(agents)
.order_by(_ordering)
.limit(limit)
.offset(offset)
)
if scaling_group is not None:
query = query.where(agents.c.scaling_group == scaling_group)
if raw_status is not None:
status = AgentStatus[raw_status]
query = query.where(agents.c.status == status)
return [
cls.from_row(ctx, row) async for row in (await conn.stream(query))
]
# TODO: optimization for pagination using subquery, join
if order_key is None:
_ordering = agents.c.id
else:
_order_func = sa.asc if order_asc else sa.desc
_ordering = _order_func(getattr(agents.c, order_key))
query = (
sa.select([agents])
.select_from(agents)
.order_by(_ordering)
.limit(limit)
.offset(offset)
)
if scaling_group is not None:
query = query.where(agents.c.scaling_group == scaling_group)
if raw_status is not None:
status = AgentStatus[raw_status]
query = query.where(agents.c.status == status)
return [
cls.from_row(graph_ctx, row)
async for row in (await graph_ctx.db_conn.stream(query))
]

@classmethod
async def load_all(
cls,
ctx: GraphQueryContext, *,
graph_ctx: GraphQueryContext, *,
scaling_group: str = None,
raw_status: str = None,
) -> Sequence[Agent]:
async with ctx.dbpool.connect() as conn:
query = (
sa.select([agents])
.select_from(agents)
)
if scaling_group is not None:
query = query.where(agents.c.scaling_group == scaling_group)
if raw_status is not None:
status = AgentStatus[raw_status]
query = query.where(agents.c.status == status)
return [
cls.from_row(ctx, row) async for row in (await conn.stream(query))
]
query = (
sa.select([agents])
.select_from(agents)
)
if scaling_group is not None:
query = query.where(agents.c.scaling_group == scaling_group)
if raw_status is not None:
status = AgentStatus[raw_status]
query = query.where(agents.c.status == status)
return [
cls.from_row(graph_ctx, row)
async for row in (await graph_ctx.db_conn.stream(query))
]

@classmethod
async def batch_load(
cls,
ctx: GraphQueryContext,
graph_ctx: GraphQueryContext,
agent_ids: Sequence[AgentId], *,
raw_status: str = None,
) -> Sequence[Optional[Agent]]:
async with ctx.dbpool.connect() as conn:
query = (
sa.select([agents])
.select_from(agents)
.where(agents.c.id.in_(agent_ids))
.order_by(
agents.c.id
)
)
if raw_status is not None:
status = AgentStatus[raw_status]
query = query.where(agents.c.status == status)
return await batch_result(
ctx, conn, query, cls,
agent_ids, lambda row: row['id'],
) -> Sequence[Agent | None]:
query = (
sa.select([agents])
.select_from(agents)
.where(agents.c.id.in_(agent_ids))
.order_by(
agents.c.id
)
)
if raw_status is not None:
status = AgentStatus[raw_status]
query = query.where(agents.c.status == status)
return await batch_result(
graph_ctx, graph_ctx.db_conn, query, cls,
agent_ids, lambda row: row['id'],
)


class AgentList(graphene.ObjectType):
Expand Down
Loading

0 comments on commit 449d5d1

Please sign in to comment.