Skip to content

Commit

Permalink
add news fragment
Browse files Browse the repository at this point in the history
  • Loading branch information
kyujin-cho committed Dec 20, 2024
1 parent d044f4b commit 9bc0661
Show file tree
Hide file tree
Showing 7 changed files with 182 additions and 45 deletions.
1 change: 1 addition & 0 deletions changes/3277.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Support model service auto scaling
28 changes: 28 additions & 0 deletions src/ai/backend/manager/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
MutableMapping,
Sequence,
)
from decimal import Decimal
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -1656,3 +1657,30 @@ def generate_sql_info_for_gql_connection(
"Set 'first' or 'last' to a smaller integer."
)
return ret


class DecimalType(TypeDecorator, Decimal):
"""
Database type adaptor for Decimal
"""

impl = sa.VARCHAR
cache_ok = True

def process_bind_param(
self,
value: Optional[Decimal],
dialect: Dialect,
) -> Optional[str]:
return str(value) if value else None

def process_result_value(
self,
value: str,
dialect: Dialect,
) -> Optional[Decimal]:
return Decimal(value) if value else None

@property
def python_type(self) -> type[Decimal]:
return Decimal
24 changes: 18 additions & 6 deletions src/ai/backend/manager/models/endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
from .base import (
GUID,
Base,
DecimalType,
EndpointIDColumn,
EnumValueType,
ForeignKeyIDColumn,
Expand Down Expand Up @@ -119,8 +120,8 @@ class AutoScalingMetricSource(StrEnum):
class AutoScalingMetricComparator(StrEnum):
LESS_THAN = "lt"
LESS_THAN_OR_EQUAL = "le"
GREATHER_THAN = "gt"
GREATHER_THAN_OR_EQUAL = "ge"
GREATER_THAN = "gt"
GREATER_THAN_OR_EQUAL = "ge"


class EndpointRow(Base):
Expand Down Expand Up @@ -584,9 +585,7 @@ class EndpointAutoScalingRuleRow(Base):
id = IDColumn()
metric_source = sa.Column("metric_source", StrEnumType(AutoScalingMetricSource), nullable=False)
metric_name = sa.Column("metric_name", sa.Text(), nullable=False)
threshold = sa.Column(
"threshold", sa.Text(), nullable=False
) # FIXME: How can I put Decimal here?
threshold = sa.Column("threshold", DecimalType(), nullable=False)
comparator = sa.Column("comparator", StrEnumType(AutoScalingMetricComparator), nullable=False)
step_size = sa.Column("step_size", sa.Integer(), nullable=False)
cooldown_seconds = sa.Column("cooldown_seconds", sa.Integer(), nullable=False, default=300)
Expand Down Expand Up @@ -618,10 +617,23 @@ async def list(
) -> list["EndpointAutoScalingRuleRow"]:
query = sa.select(EndpointAutoScalingRuleRow)
if load_endpoint:
query = query.options(selectinload(EndpointAutoScalingRuleRow.tokens))
query = query.options(selectinload(EndpointAutoScalingRuleRow.endpoint_row))
result = await session.execute(query)
return result.scalars().all()

@classmethod
async def get(
cls, session: AsyncSession, id: uuid.UUID, load_endpoint=False
) -> "EndpointAutoScalingRuleRow":
query = sa.select(EndpointAutoScalingRuleRow).filter(EndpointAutoScalingRuleRow.id == id)
if load_endpoint:
query = query.options(selectinload(EndpointAutoScalingRuleRow.endpoint_row))
result = await session.execute(query)
row = result.scalar()
if not row:
raise ObjectNotFound("endpoint_auto_scaling_rule")
return row

def __init__(
self,
id: uuid.UUID,
Expand Down
6 changes: 5 additions & 1 deletion src/ai/backend/manager/models/gql.py
Original file line number Diff line number Diff line change
Expand Up @@ -923,7 +923,9 @@ class Queries(graphene.ObjectType):
)

endpoint_auto_scaling_rule_nodes = PaginatedConnectionField(
EndpointAutoScalingRuleConnection, description="Added in 24.12.0."
EndpointAutoScalingRuleConnection,
endpoint=graphene.String(required=True),
description="Added in 24.12.0.",
)

@staticmethod
Expand Down Expand Up @@ -2657,6 +2659,7 @@ async def resolve_endpoint_auto_scaling_rule_node(
async def resolve_endpoint_auto_scaling_rule_nodes(
root: Any,
info: graphene.ResolveInfo,
endpoint: str,
*,
filter: str | None = None,
order: str | None = None,
Expand All @@ -2668,6 +2671,7 @@ async def resolve_endpoint_auto_scaling_rule_nodes(
) -> ConnectionResolverResult:
return await EndpointAutoScalingRuleNode.get_connection(
info,
endpoint,
filter_expr=filter,
order_expr=order,
offset=offset,
Expand Down
91 changes: 72 additions & 19 deletions src/ai/backend/manager/models/gql_models/endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from typing import TYPE_CHECKING, Mapping, Self

import graphene
import sqlalchemy as sa
from dateutil.parser import parse as dtparse
from graphene.types.datetime import DateTime as GQLDateTime
from graphql import Undefined
Expand Down Expand Up @@ -85,7 +84,7 @@ class Meta:

@classmethod
def from_row(
cls, graph_ctx: GraphQueryContext, row: EndpointAutoScalingRuleRow
cls, graph_ctx: "GraphQueryContext", row: EndpointAutoScalingRuleRow
) -> "EndpointAutoScalingRuleNode":
return EndpointAutoScalingRuleNode(
id=row.id,
Expand All @@ -102,21 +101,40 @@ def from_row(
)

@classmethod
async def get_node(cls, info: graphene.ResolveInfo, id: str) -> "EndpointAutoScalingRuleNode":
async def get_node(
cls, info: graphene.ResolveInfo, rule_id: str
) -> "EndpointAutoScalingRuleNode":
graph_ctx: GraphQueryContext = info.context

_, rule = AsyncNode.resolve_global_id(info, id)
query = sa.select(EndpointAutoScalingRuleRow).where(EndpointAutoScalingRuleRow.id == rule)
_, raw_rule_id = AsyncNode.resolve_global_id(info, rule_id)
if not raw_rule_id:
raw_rule_id = rule_id
try:
_rule_id = uuid.UUID(raw_rule_id)
except ValueError:
raise ObjectNotFound("endpoint_auto_scaling_rule")

async with graph_ctx.db.begin_readonly_session() as db_session:
rule_row = await db_session.scalar(query)
if rule_row is None:
raise ValueError(f"Rule not found (id: {rule})")
rule_row = await EndpointAutoScalingRuleRow.get(
db_session, _rule_id, load_endpoint=True
)
match graph_ctx.user["role"]:
case UserRole.SUPERADMIN:
pass
case UserRole.ADMIN:
if rule_row.endpoint_row.domain != graph_ctx.user["domain_name"]:
raise GenericForbidden
case UserRole.USER:
if rule_row.endpoint_row.created_user != graph_ctx.user["uuid"]:
raise GenericForbidden

return cls.from_row(graph_ctx, rule_row)

@classmethod
async def get_connection(
cls,
info: graphene.ResolveInfo,
endpoint: str,
*,
filter_expr: str | None = None,
order_expr: str | None = None,
Expand Down Expand Up @@ -156,7 +174,31 @@ async def get_connection(
before=before,
last=last,
)

async with graph_ctx.db.begin_readonly_session() as db_session:
_, raw_endpoint_id = AsyncNode.resolve_global_id(info, endpoint)
if not raw_endpoint_id:
raw_endpoint_id = endpoint
try:
_endpoint_id = uuid.UUID(raw_endpoint_id)
except ValueError:
raise ObjectNotFound("endpoint")
try:
row = await EndpointRow.get(db_session, _endpoint_id)
except NoResultFound:
raise ObjectNotFound(object_name="endpoint")

match graph_ctx.user["role"]:
case UserRole.SUPERADMIN:
pass
case UserRole.ADMIN:
if row.endpoint_row.domain != graph_ctx.user["domain_name"]:
raise GenericForbidden
case UserRole.USER:
if row.endpoint_row.created_user != graph_ctx.user["uuid"]:
raise GenericForbidden

query = query.filter(EndpointAutoScalingRuleRow.endpoint == _endpoint_id)
group_rows = (await db_session.scalars(query)).all()
result = [cls.from_row(graph_ctx, row) for row in group_rows]
total_cnt = await db_session.scalar(cnt_query)
Expand Down Expand Up @@ -209,20 +251,24 @@ class CreateEndpointAutoScalingRuleNode(graphene.Mutation):
allowed_roles = (UserRole.USER, UserRole.ADMIN, UserRole.SUPERADMIN)

class Arguments:
endpoint_id = graphene.String(required=True)
endpoint = graphene.String(required=True)
props = EndpointAutoScalingRuleInput(required=True)

ok = graphene.Boolean()
msg = graphene.String()
rule = graphene.Field(lambda: EndpointAutoScalingRuleNode, required=False)

@classmethod
async def mutate(
cls,
root,
info: graphene.ResolveInfo,
endpoint_id: str,
endpoint: str,
props: EndpointAutoScalingRuleInput,
) -> "CreateEndpointAutoScalingRuleNode":
_, raw_endpoint_id = AsyncNode.resolve_global_id(info, endpoint_id)
_, raw_endpoint_id = AsyncNode.resolve_global_id(info, endpoint)
if not raw_endpoint_id:
raw_endpoint_id = endpoint_id
raw_endpoint_id = endpoint

try:
_endpoint_id = uuid.UUID(raw_endpoint_id)
Expand All @@ -248,13 +294,13 @@ async def mutate(

try:
_source = AutoScalingMetricSource[props.metric_source]
except ValueError:
except (KeyError, ValueError):
raise InvalidAPIParameters(
f"Unsupported AutoScalingMetricSource {props.metric_source}"
)
try:
_comparator = AutoScalingMetricComparator[props.comparator]
except ValueError:
except (KeyError, ValueError):
raise InvalidAPIParameters(
f"Unsupported AutoScalingMetricComparator {props.comparator}"
)
Expand All @@ -267,7 +313,7 @@ async def _do_mutate() -> CreateEndpointAutoScalingRuleNode:
created_rule = await row.create_auto_scaling_rule(
db_session,
_source,
props.name,
props.metric_name,
_threshold,
_comparator,
props.step_size,
Expand All @@ -276,7 +322,7 @@ async def _do_mutate() -> CreateEndpointAutoScalingRuleNode:
return CreateEndpointAutoScalingRuleNode(
ok=True,
msg="Auto scaling rule created",
network=EndpointAutoScalingRuleNode.from_row(info.context, created_rule),
rule=EndpointAutoScalingRuleNode.from_row(info.context, created_rule),
)

return await gql_mutation_wrapper(CreateEndpointAutoScalingRuleNode, _do_mutate)
Expand All @@ -289,6 +335,10 @@ class Arguments:
id = graphene.String(required=True)
props = ModifyEndpointAutoScalingRuleInput(required=True)

ok = graphene.Boolean()
msg = graphene.String()
rule = graphene.Field(lambda: EndpointAutoScalingRuleNode, required=False)

@classmethod
async def mutate(
cls,
Expand Down Expand Up @@ -327,12 +377,12 @@ async def _do_mutate() -> CreateEndpointAutoScalingRuleNode:
if (_newval := props.metric_source) and _newval is not Undefined:
try:
row.metric_source = AutoScalingMetricSource[_newval]
except ValueError:
except (KeyError, ValueError):
raise InvalidAPIParameters(f"Unsupported AutoScalingMetricSource {_newval}")
if (_newval := props.comparator) and _newval is not Undefined:
try:
row.comparator = AutoScalingMetricComparator[_newval]
except ValueError:
except (KeyError, ValueError):
raise InvalidAPIParameters(
f"Unsupported AutoScalingMetricComparator {_newval}"
)
Expand All @@ -349,7 +399,7 @@ async def _do_mutate() -> CreateEndpointAutoScalingRuleNode:
return ModifyEndpointAutoScalingRuleNode(
ok=True,
msg="Auto scaling rule updated",
network=EndpointAutoScalingRuleNode.from_row(info.context, row),
rule=EndpointAutoScalingRuleNode.from_row(info.context, row),
)

return await gql_mutation_wrapper(ModifyEndpointAutoScalingRuleNode, _do_mutate)
Expand All @@ -361,6 +411,9 @@ class DeleteEndpointAutoScalingRuleNode(graphene.Mutation):
class Arguments:
id = graphene.String(required=True)

ok = graphene.Boolean()
msg = graphene.String()

@classmethod
async def mutate(
cls,
Expand Down
2 changes: 1 addition & 1 deletion src/ai/backend/manager/models/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,7 +591,7 @@ def get_used_days(self, local_tz: tzfile) -> Optional[int]:
async def bulk_load_by_session_id(
session: SASession, session_ids: list[uuid.UUID]
) -> list["KernelRow"]:
query = sa.select(KernelRow).where(KernelRow.session.in_(session_ids))
query = sa.select(KernelRow).where(KernelRow.session_id.in_(session_ids))
return (await session.execute(query)).scalars().all()

@staticmethod
Expand Down
Loading

0 comments on commit 9bc0661

Please sign in to comment.