Skip to content
This repository has been archived by the owner on Aug 2, 2023. It is now read-only.

Commit

Permalink
feat: Upgrade SQLAlchemy v1.4 for native asyncio support (#406)
Browse files Browse the repository at this point in the history
* fix: a long-standing transaction error
  - It is now reproducible reliably with the new SQLAlchemy + asyncpg combination!
  - Also fixed a hidden type conversion bug in AgentRegistry.set_kernel_status() due to
    a comma typo....
* fix/test: Update codes for population of DB fixtures
  - Eliminate manual primary key checks for duplicate entries by utilizing
    PostgreSQL's "on conflict" (upsert) support.
  - Fix up using special characters in database passwords by correctly escaping them
    using `urllib.parse.quote_plus()`.
* fix: Do not rely on `rowcount` for SELECT queries
  - rowcount in SQLAlchemy does NOT represent the number of fetched
    rows for SELECT queries, in contrast to the Python's standard DB API.
* fix: excessive DB connection establishment delay and optimize GQL queries
  - refs MagicStack/asyncpg#530: apply "jit: off" option to DB connections
  - It is specified in `ai.backend.manager.models.base.pgsql_connect_opts`
  - Reuse the same single connection for all GraphQL resolvers and mutation methods
* fix: consistently use urlquote for db passwords
* fix: all DB connections are now transactions
* refactor: Finally, rename "dbpool" to "db"
  • Loading branch information
achimnol authored Mar 24, 2021
1 parent b929ba5 commit cd7479a
Show file tree
Hide file tree
Showing 52 changed files with 1,880 additions and 1,787 deletions.
1 change: 1 addition & 0 deletions changes/406.fix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Upgrade SQLAlchemy to v1.4 for native asyncio support and better transaction/concurrency handling
2 changes: 1 addition & 1 deletion config/ci.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,10 @@ drivers = ["console"]
[logging.pkg-ns]
"" = "WARNING"
"aiotools" = "INFO"
"aiopg" = "WARNING"
"aiohttp" = "INFO"
"ai.backend" = "INFO"
"alembic" = "INFO"
"sqlalchemy" = "WARNING"

[logging.console]
colored = true
Expand Down
2 changes: 1 addition & 1 deletion config/halfstack.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,10 @@ drivers = ["console"]
[logging.pkg-ns]
"" = "WARNING"
"aiotools" = "INFO"
"aiopg" = "WARNING"
"aiohttp" = "INFO"
"ai.backend" = "INFO"
"alembic" = "INFO"
"sqlalchemy" = "WARNING"

[logging.console]
colored = true
Expand Down
2 changes: 1 addition & 1 deletion config/sample.toml
Original file line number Diff line number Diff line change
Expand Up @@ -142,10 +142,10 @@ ssl-verify = true
[logging.pkg-ns]
"" = "WARNING"
"aiotools" = "INFO"
"aiopg" = "WARNING"
"aiohttp" = "INFO"
"ai.backend" = "INFO"
"alembic" = "INFO"
"sqlalchemy" = "WARNING"


[debug]
Expand Down
4 changes: 2 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,13 @@ install_requires =
aiohttp_cors~=0.7
aiohttp_sse~=2.0
aiojobs~=0.3.0
aiopg~=1.1.0
aioredis~=1.3.1
aioredlock~=0.7.0
aiotools~=1.2.1
alembic~=1.4.3
async_timeout~=3.0
asyncache>=0.1.1
asyncpg>=0.22.0
attrs>=20.3
boltons~=20.2.1
callosum~=0.9.7
Expand All @@ -62,7 +62,7 @@ install_requires =
python-snappy~=0.6.0
PyYAML~=5.4.1
pyzmq~=22.0.3
SQLAlchemy~=1.3.20
SQLAlchemy~=1.4.2
uvloop~=0.15.1
setproctitle~=1.2.2
tabulate~=0.8.6
Expand Down
62 changes: 32 additions & 30 deletions src/ai/backend/gateway/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,36 +64,38 @@ async def handle_gql(request: web.Request, params: Any) -> web.Response:
app_ctx: PrivateContext = request.app['admin.context']
manager_status = await root_ctx.shared_config.get_manager_status()
known_slot_types = await root_ctx.shared_config.get_resource_slots()
gql_ctx = GraphQueryContext(
dataloader_manager=DataLoaderManager(),
local_config=root_ctx.local_config,
shared_config=root_ctx.shared_config,
etcd=root_ctx.shared_config.etcd,
user=request['user'],
access_key=request['keypair']['access_key'],
dbpool=root_ctx.dbpool,
redis_stat=root_ctx.redis_stat,
redis_image=root_ctx.redis_image,
manager_status=manager_status,
known_slot_types=known_slot_types,
background_task_manager=root_ctx.background_task_manager,
storage_manager=root_ctx.storage_manager,
registry=root_ctx.registry,
)
result = app_ctx.gql_schema.execute(
params['query'],
app_ctx.gql_executor,
variable_values=params['variables'],
operation_name=params['operation_name'],
context_value=gql_ctx,
middleware=[
GQLLoggingMiddleware(),
GQLMutationUnfrozenRequiredMiddleware(),
GQLMutationPrivilegeCheckMiddleware(),
],
return_promise=True)
if inspect.isawaitable(result):
result = await result
async with root_ctx.db.begin() as db_conn:
gql_ctx = GraphQueryContext(
dataloader_manager=DataLoaderManager(),
local_config=root_ctx.local_config,
shared_config=root_ctx.shared_config,
etcd=root_ctx.shared_config.etcd,
user=request['user'],
access_key=request['keypair']['access_key'],
db=root_ctx.db,
db_conn=db_conn,
redis_stat=root_ctx.redis_stat,
redis_image=root_ctx.redis_image,
manager_status=manager_status,
known_slot_types=known_slot_types,
background_task_manager=root_ctx.background_task_manager,
storage_manager=root_ctx.storage_manager,
registry=root_ctx.registry,
)
result = app_ctx.gql_schema.execute(
params['query'],
app_ctx.gql_executor,
variable_values=params['variables'],
operation_name=params['operation_name'],
context_value=gql_ctx,
middleware=[
GQLLoggingMiddleware(),
GQLMutationUnfrozenRequiredMiddleware(),
GQLMutationPrivilegeCheckMiddleware(),
],
return_promise=True)
if inspect.isawaitable(result):
result = await result
if result.errors:
errors = []
for e in result.errors:
Expand Down
80 changes: 40 additions & 40 deletions src/ai/backend/gateway/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,7 @@ async def auth_middleware(request: web.Request, handler) -> web.StreamResponse:
params = _extract_auth_params(request)
if params:
sign_method, access_key, signature = params
async with root_ctx.dbpool.acquire() as conn, conn.begin():
async with root_ctx.db.begin() as conn:
j = (
keypairs
.join(users, keypairs.c.user == users.c.uuid)
Expand All @@ -414,7 +414,7 @@ async def auth_middleware(request: web.Request, handler) -> web.StreamResponse:
)
)
result = await conn.execute(query)
row = await result.first()
row = result.first()
if row is None:
raise AuthorizationFailed('Access key not found')
my_signature = \
Expand All @@ -429,27 +429,27 @@ async def auth_middleware(request: web.Request, handler) -> web.StreamResponse:
.where(keypairs.c.access_key == access_key)
)
await conn.execute(query)
request['is_authorized'] = True
request['keypair'] = {
col.name: row[f'keypairs_{col.name}']
for col in keypairs.c
if col.name != 'secret_key'
}
request['keypair']['resource_policy'] = {
col.name: row[f'keypair_resource_policies_{col.name}']
for col in keypair_resource_policies.c
}
request['user'] = {
col.name: row[f'users_{col.name}']
for col in users.c
if col.name not in ('password', 'description', 'created_at')
}
request['user']['id'] = row['keypairs_user_id'] # legacy
# if request['role'] in ['admin', 'superadmin']:
if row['keypairs_is_admin']:
request['is_admin'] = True
if request['user']['role'] == 'superadmin':
request['is_superadmin'] = True
request['is_authorized'] = True
request['keypair'] = {
col.name: row[f'keypairs_{col.name}']
for col in keypairs.c
if col.name != 'secret_key'
}
request['keypair']['resource_policy'] = {
col.name: row[f'keypair_resource_policies_{col.name}']
for col in keypair_resource_policies.c
}
request['user'] = {
col.name: row[f'users_{col.name}']
for col in users.c
if col.name not in ('password', 'description', 'created_at')
}
request['user']['id'] = row['keypairs_user_id'] # legacy
# if request['role'] in ['admin', 'superadmin']:
if row['keypairs_is_admin']:
request['is_admin'] = True
if request['user']['role'] == 'superadmin':
request['is_superadmin'] = True

# No matter if authenticated or not, pass-through to the handler.
# (if it's required, auth_required decorator will handle the situation.)
Expand Down Expand Up @@ -531,9 +531,9 @@ async def get_role(request: web.Request, params: Any) -> web.Response:
(association_groups_users.c.user_id == request['user']['uuid'])
)
)
async with root_ctx.dbpool.acquire() as conn:
async with root_ctx.db.begin() as conn:
result = await conn.execute(query)
row = await result.first()
row = result.first()
if row is None:
raise GenericNotFound('No such user group or '
'you are not the member of the group.')
Expand Down Expand Up @@ -563,12 +563,12 @@ async def authorize(request: web.Request, params: Any) -> web.Response:

# [Hooking point for AUTHORIZE with the FIRST_COMPLETED requirement]
# The hook handlers should accept the whole ``params`` dict, and optional
# ``dbpool`` parameter (if the hook needs to query to database).
# ``db`` parameter (if the hook needs to query to database).
# They should return a corresponding Backend.AI user object after performing
# their own authentication steps, like LDAP authentication, etc.
hook_result = await root_ctx.hook_plugin_ctx.dispatch(
'AUTHORIZE',
(params, root_ctx.dbpool),
(params, root_ctx.db),
return_when=FIRST_COMPLETED,
)
if hook_result.status != PASSED:
Expand All @@ -579,7 +579,7 @@ async def authorize(request: web.Request, params: Any) -> web.Response:
else:
# No AUTHORIZE hook is defined (proceed with normal login)
user = await check_credential(
root_ctx.dbpool,
root_ctx.db,
params['domain'], params['username'], params['password']
)
if user is None:
Expand All @@ -588,7 +588,7 @@ async def authorize(request: web.Request, params: Any) -> web.Response:
raise AuthorizationFailed('This account needs email verification.')
if user.get('status') in INACTIVE_USER_STATUSES:
raise AuthorizationFailed('User credential mismatch.')
async with root_ctx.dbpool.acquire() as conn:
async with root_ctx.db.begin() as conn:
query = (sa.select([keypairs.c.access_key, keypairs.c.secret_key])
.select_from(keypairs)
.where(
Expand All @@ -597,7 +597,7 @@ async def authorize(request: web.Request, params: Any) -> web.Response:
)
.order_by(sa.desc(keypairs.c.is_admin)))
result = await conn.execute(query)
keypair = await result.first()
keypair = result.first()
if keypair is None:
raise AuthorizationFailed('No API keypairs found.')
return web.json_response({
Expand Down Expand Up @@ -640,13 +640,13 @@ async def signup(request: web.Request, params: Any) -> web.Response:
# Merge the hook results as a single map.
user_data_overriden = ChainMap(*cast(Mapping, hook_result.result))

async with root_ctx.dbpool.acquire() as conn:
async with root_ctx.db.begin() as conn:
# Check if email already exists.
query = (sa.select([users])
.select_from(users)
.where((users.c.email == params['email'])))
result = await conn.execute(query)
row = await result.first()
row = result.first()
if row is not None:
raise GenericBadRequest('Email already exists')

Expand All @@ -673,7 +673,7 @@ async def signup(request: web.Request, params: Any) -> web.Response:
if result.rowcount > 0:
checkq = users.select().where(users.c.email == params['email'])
result = await conn.execute(checkq)
user = await result.first()
user = result.first()
# Create user's first access_key and secret_key.
ak, sk = _gen_keypair()
resource_policy = (
Expand Down Expand Up @@ -701,7 +701,7 @@ async def signup(request: web.Request, params: Any) -> web.Response:
.where(groups.c.domain_name == params['domain'])
.where(groups.c.name == group_name))
result = await conn.execute(query)
grp = await result.first()
grp = result.first()
if grp is not None:
values = [{'user_id': user.uuid, 'group_id': grp.id}]
query = association_groups_users.insert().values(values)
Expand Down Expand Up @@ -741,11 +741,11 @@ async def signout(request: web.Request, params: Any) -> web.Response:
if request['user']['email'] != params['email']:
raise GenericForbidden('Not the account owner')
result = await check_credential(
root_ctx.dbpool,
root_ctx.db,
domain_name, params['email'], params['password'])
if result is None:
raise GenericBadRequest('Invalid email and/or password')
async with root_ctx.dbpool.acquire() as conn, conn.begin():
async with root_ctx.db.begin() as conn:
# Inactivate the user.
query = (
users.update()
Expand Down Expand Up @@ -779,7 +779,7 @@ async def update_password(request: web.Request, params: Any) -> web.Response:
log_args = (domain_name, email)
log.info(log_fmt, *log_args)

user = await check_credential(root_ctx.dbpool, domain_name, email, params['old_password'])
user = await check_credential(root_ctx.db, domain_name, email, params['old_password'])
if user is None:
log.info(log_fmt + ': old password mismtach', *log_args)
raise AuthorizationFailed('Old password mismatch')
Expand All @@ -801,7 +801,7 @@ async def update_password(request: web.Request, params: Any) -> web.Response:
hook_result.reason = hook_result.reason or 'invalid password format'
raise RejectedByHook.from_hook_result(hook_result)

async with root_ctx.dbpool.acquire() as conn:
async with root_ctx.db.begin() as conn:
# Update user password.
data = {
'password': params['new_password'],
Expand All @@ -821,7 +821,7 @@ async def get_ssh_keypair(request: web.Request) -> web.Response:
log_fmt = 'AUTH.GET_SSH_KEYPAIR(d:{}, ak:{})'
log_args = (domain_name, access_key)
log.info(log_fmt, *log_args)
async with root_ctx.dbpool.acquire() as conn:
async with root_ctx.db.begin() as conn:
# Get SSH public key. Return partial string from the public key just for checking.
query = (
sa.select([keypairs.c.ssh_public_key])
Expand All @@ -840,7 +840,7 @@ async def refresh_ssh_keypair(request: web.Request) -> web.Response:
log_args = (domain_name, access_key)
log.info(log_fmt, *log_args)
root_ctx: RootContext = request.app['_root.context']
async with root_ctx.dbpool.acquire() as conn:
async with root_ctx.db.begin() as conn:
pubkey, privkey = generate_ssh_keypair()
data = {
'ssh_public_key': pubkey,
Expand Down
Loading

0 comments on commit cd7479a

Please sign in to comment.