diff --git a/changes/3277.feature.md b/changes/3277.feature.md new file mode 100644 index 0000000000..c3e45b0911 --- /dev/null +++ b/changes/3277.feature.md @@ -0,0 +1 @@ +Support model service auto scaling diff --git a/src/ai/backend/manager/models/base.py b/src/ai/backend/manager/models/base.py index b9e8100887..36730e00eb 100644 --- a/src/ai/backend/manager/models/base.py +++ b/src/ai/backend/manager/models/base.py @@ -14,6 +14,7 @@ MutableMapping, Sequence, ) +from decimal import Decimal from typing import ( TYPE_CHECKING, Any, @@ -1656,3 +1657,30 @@ def generate_sql_info_for_gql_connection( "Set 'first' or 'last' to a smaller integer." ) return ret + + +class DecimalType(TypeDecorator, Decimal): + """ + Database type adaptor for Decimal + """ + + impl = sa.VARCHAR + cache_ok = True + + def process_bind_param( + self, + value: Optional[Decimal], + dialect: Dialect, + ) -> Optional[str]: + return str(value) if value else None + + def process_result_value( + self, + value: str, + dialect: Dialect, + ) -> Optional[Decimal]: + return Decimal(value) if value else None + + @property + def python_type(self) -> type[Decimal]: + return Decimal diff --git a/src/ai/backend/manager/models/endpoint.py b/src/ai/backend/manager/models/endpoint.py index 89075ab17a..c9b4154b57 100644 --- a/src/ai/backend/manager/models/endpoint.py +++ b/src/ai/backend/manager/models/endpoint.py @@ -54,6 +54,7 @@ from .base import ( GUID, Base, + DecimalType, EndpointIDColumn, EnumValueType, ForeignKeyIDColumn, @@ -119,8 +120,8 @@ class AutoScalingMetricSource(StrEnum): class AutoScalingMetricComparator(StrEnum): LESS_THAN = "lt" LESS_THAN_OR_EQUAL = "le" - GREATHER_THAN = "gt" - GREATHER_THAN_OR_EQUAL = "ge" + GREATER_THAN = "gt" + GREATER_THAN_OR_EQUAL = "ge" class EndpointRow(Base): @@ -584,9 +585,7 @@ class EndpointAutoScalingRuleRow(Base): id = IDColumn() metric_source = sa.Column("metric_source", StrEnumType(AutoScalingMetricSource), nullable=False) metric_name = sa.Column("metric_name", sa.Text(), nullable=False) - threshold = sa.Column( - "threshold", sa.Text(), nullable=False - ) # FIXME: How can I put Decimal here? + threshold = sa.Column("threshold", DecimalType(), nullable=False) comparator = sa.Column("comparator", StrEnumType(AutoScalingMetricComparator), nullable=False) step_size = sa.Column("step_size", sa.Integer(), nullable=False) cooldown_seconds = sa.Column("cooldown_seconds", sa.Integer(), nullable=False, default=300) @@ -618,10 +617,23 @@ async def list( ) -> list["EndpointAutoScalingRuleRow"]: query = sa.select(EndpointAutoScalingRuleRow) if load_endpoint: - query = query.options(selectinload(EndpointAutoScalingRuleRow.tokens)) + query = query.options(selectinload(EndpointAutoScalingRuleRow.endpoint_row)) result = await session.execute(query) return result.scalars().all() + @classmethod + async def get( + cls, session: AsyncSession, id: uuid.UUID, load_endpoint=False + ) -> "EndpointAutoScalingRuleRow": + query = sa.select(EndpointAutoScalingRuleRow).filter(EndpointAutoScalingRuleRow.id == id) + if load_endpoint: + query = query.options(selectinload(EndpointAutoScalingRuleRow.endpoint_row)) + result = await session.execute(query) + row = result.scalar() + if not row: + raise ObjectNotFound("endpoint_auto_scaling_rule") + return row + def __init__( self, id: uuid.UUID, diff --git a/src/ai/backend/manager/models/gql.py b/src/ai/backend/manager/models/gql.py index 7027e74530..2e5f5c2447 100644 --- a/src/ai/backend/manager/models/gql.py +++ b/src/ai/backend/manager/models/gql.py @@ -923,7 +923,9 @@ class Queries(graphene.ObjectType): ) endpoint_auto_scaling_rule_nodes = PaginatedConnectionField( - EndpointAutoScalingRuleConnection, description="Added in 24.12.0." + EndpointAutoScalingRuleConnection, + endpoint=graphene.String(required=True), + description="Added in 24.12.0.", ) @staticmethod @@ -2657,6 +2659,7 @@ async def resolve_endpoint_auto_scaling_rule_node( async def resolve_endpoint_auto_scaling_rule_nodes( root: Any, info: graphene.ResolveInfo, + endpoint: str, *, filter: str | None = None, order: str | None = None, @@ -2668,6 +2671,7 @@ async def resolve_endpoint_auto_scaling_rule_nodes( ) -> ConnectionResolverResult: return await EndpointAutoScalingRuleNode.get_connection( info, + endpoint, filter_expr=filter, order_expr=order, offset=offset, diff --git a/src/ai/backend/manager/models/gql_models/endpoint.py b/src/ai/backend/manager/models/gql_models/endpoint.py index 51e0eeade6..8ccebc2af8 100644 --- a/src/ai/backend/manager/models/gql_models/endpoint.py +++ b/src/ai/backend/manager/models/gql_models/endpoint.py @@ -3,7 +3,6 @@ from typing import TYPE_CHECKING, Mapping, Self import graphene -import sqlalchemy as sa from dateutil.parser import parse as dtparse from graphene.types.datetime import DateTime as GQLDateTime from graphql import Undefined @@ -85,7 +84,7 @@ class Meta: @classmethod def from_row( - cls, graph_ctx: GraphQueryContext, row: EndpointAutoScalingRuleRow + cls, graph_ctx: "GraphQueryContext", row: EndpointAutoScalingRuleRow ) -> "EndpointAutoScalingRuleNode": return EndpointAutoScalingRuleNode( id=row.id, @@ -102,21 +101,40 @@ def from_row( ) @classmethod - async def get_node(cls, info: graphene.ResolveInfo, id: str) -> "EndpointAutoScalingRuleNode": + async def get_node( + cls, info: graphene.ResolveInfo, rule_id: str + ) -> "EndpointAutoScalingRuleNode": graph_ctx: GraphQueryContext = info.context - _, rule = AsyncNode.resolve_global_id(info, id) - query = sa.select(EndpointAutoScalingRuleRow).where(EndpointAutoScalingRuleRow.id == rule) + _, raw_rule_id = AsyncNode.resolve_global_id(info, rule_id) + if not raw_rule_id: + raw_rule_id = rule_id + try: + _rule_id = uuid.UUID(raw_rule_id) + except ValueError: + raise ObjectNotFound("endpoint_auto_scaling_rule") + async with graph_ctx.db.begin_readonly_session() as db_session: - rule_row = await db_session.scalar(query) - if rule_row is None: - raise ValueError(f"Rule not found (id: {rule})") + rule_row = await EndpointAutoScalingRuleRow.get( + db_session, _rule_id, load_endpoint=True + ) + match graph_ctx.user["role"]: + case UserRole.SUPERADMIN: + pass + case UserRole.ADMIN: + if rule_row.endpoint_row.domain != graph_ctx.user["domain_name"]: + raise GenericForbidden + case UserRole.USER: + if rule_row.endpoint_row.created_user != graph_ctx.user["uuid"]: + raise GenericForbidden + return cls.from_row(graph_ctx, rule_row) @classmethod async def get_connection( cls, info: graphene.ResolveInfo, + endpoint: str, *, filter_expr: str | None = None, order_expr: str | None = None, @@ -156,7 +174,31 @@ async def get_connection( before=before, last=last, ) + async with graph_ctx.db.begin_readonly_session() as db_session: + _, raw_endpoint_id = AsyncNode.resolve_global_id(info, endpoint) + if not raw_endpoint_id: + raw_endpoint_id = endpoint + try: + _endpoint_id = uuid.UUID(raw_endpoint_id) + except ValueError: + raise ObjectNotFound("endpoint") + try: + row = await EndpointRow.get(db_session, _endpoint_id) + except NoResultFound: + raise ObjectNotFound(object_name="endpoint") + + match graph_ctx.user["role"]: + case UserRole.SUPERADMIN: + pass + case UserRole.ADMIN: + if row.endpoint_row.domain != graph_ctx.user["domain_name"]: + raise GenericForbidden + case UserRole.USER: + if row.endpoint_row.created_user != graph_ctx.user["uuid"]: + raise GenericForbidden + + query = query.filter(EndpointAutoScalingRuleRow.endpoint == _endpoint_id) group_rows = (await db_session.scalars(query)).all() result = [cls.from_row(graph_ctx, row) for row in group_rows] total_cnt = await db_session.scalar(cnt_query) @@ -209,20 +251,24 @@ class CreateEndpointAutoScalingRuleNode(graphene.Mutation): allowed_roles = (UserRole.USER, UserRole.ADMIN, UserRole.SUPERADMIN) class Arguments: - endpoint_id = graphene.String(required=True) + endpoint = graphene.String(required=True) props = EndpointAutoScalingRuleInput(required=True) + ok = graphene.Boolean() + msg = graphene.String() + rule = graphene.Field(lambda: EndpointAutoScalingRuleNode, required=False) + @classmethod async def mutate( cls, root, info: graphene.ResolveInfo, - endpoint_id: str, + endpoint: str, props: EndpointAutoScalingRuleInput, ) -> "CreateEndpointAutoScalingRuleNode": - _, raw_endpoint_id = AsyncNode.resolve_global_id(info, endpoint_id) + _, raw_endpoint_id = AsyncNode.resolve_global_id(info, endpoint) if not raw_endpoint_id: - raw_endpoint_id = endpoint_id + raw_endpoint_id = endpoint try: _endpoint_id = uuid.UUID(raw_endpoint_id) @@ -248,13 +294,13 @@ async def mutate( try: _source = AutoScalingMetricSource[props.metric_source] - except ValueError: + except (KeyError, ValueError): raise InvalidAPIParameters( f"Unsupported AutoScalingMetricSource {props.metric_source}" ) try: _comparator = AutoScalingMetricComparator[props.comparator] - except ValueError: + except (KeyError, ValueError): raise InvalidAPIParameters( f"Unsupported AutoScalingMetricComparator {props.comparator}" ) @@ -267,7 +313,7 @@ async def _do_mutate() -> CreateEndpointAutoScalingRuleNode: created_rule = await row.create_auto_scaling_rule( db_session, _source, - props.name, + props.metric_name, _threshold, _comparator, props.step_size, @@ -276,7 +322,7 @@ async def _do_mutate() -> CreateEndpointAutoScalingRuleNode: return CreateEndpointAutoScalingRuleNode( ok=True, msg="Auto scaling rule created", - network=EndpointAutoScalingRuleNode.from_row(info.context, created_rule), + rule=EndpointAutoScalingRuleNode.from_row(info.context, created_rule), ) return await gql_mutation_wrapper(CreateEndpointAutoScalingRuleNode, _do_mutate) @@ -289,6 +335,10 @@ class Arguments: id = graphene.String(required=True) props = ModifyEndpointAutoScalingRuleInput(required=True) + ok = graphene.Boolean() + msg = graphene.String() + rule = graphene.Field(lambda: EndpointAutoScalingRuleNode, required=False) + @classmethod async def mutate( cls, @@ -327,12 +377,12 @@ async def _do_mutate() -> CreateEndpointAutoScalingRuleNode: if (_newval := props.metric_source) and _newval is not Undefined: try: row.metric_source = AutoScalingMetricSource[_newval] - except ValueError: + except (KeyError, ValueError): raise InvalidAPIParameters(f"Unsupported AutoScalingMetricSource {_newval}") if (_newval := props.comparator) and _newval is not Undefined: try: row.comparator = AutoScalingMetricComparator[_newval] - except ValueError: + except (KeyError, ValueError): raise InvalidAPIParameters( f"Unsupported AutoScalingMetricComparator {_newval}" ) @@ -349,7 +399,7 @@ async def _do_mutate() -> CreateEndpointAutoScalingRuleNode: return ModifyEndpointAutoScalingRuleNode( ok=True, msg="Auto scaling rule updated", - network=EndpointAutoScalingRuleNode.from_row(info.context, row), + rule=EndpointAutoScalingRuleNode.from_row(info.context, row), ) return await gql_mutation_wrapper(ModifyEndpointAutoScalingRuleNode, _do_mutate) @@ -361,6 +411,9 @@ class DeleteEndpointAutoScalingRuleNode(graphene.Mutation): class Arguments: id = graphene.String(required=True) + ok = graphene.Boolean() + msg = graphene.String() + @classmethod async def mutate( cls, diff --git a/src/ai/backend/manager/models/kernel.py b/src/ai/backend/manager/models/kernel.py index 1565a4f9e6..3e34be10bc 100644 --- a/src/ai/backend/manager/models/kernel.py +++ b/src/ai/backend/manager/models/kernel.py @@ -591,7 +591,7 @@ def get_used_days(self, local_tz: tzfile) -> Optional[int]: async def bulk_load_by_session_id( session: SASession, session_ids: list[uuid.UUID] ) -> list["KernelRow"]: - query = sa.select(KernelRow).where(KernelRow.session.in_(session_ids)) + query = sa.select(KernelRow).where(KernelRow.session_id.in_(session_ids)) return (await session.execute(query)).scalars().all() @staticmethod diff --git a/src/ai/backend/manager/scheduler/dispatcher.py b/src/ai/backend/manager/scheduler/dispatcher.py index 61891bea3f..37c9fa296e 100644 --- a/src/ai/backend/manager/scheduler/dispatcher.py +++ b/src/ai/backend/manager/scheduler/dispatcher.py @@ -1442,6 +1442,13 @@ async def _autoscale_endpoints( for rule in rules: should_trigger = False + if len(endpoint_by_id[rule.endpoint].routings) == 0: + log.debug( + "_autoscale_endpoints(e: {}, r: {}): endpoint does not have any replicas, skipping", + rule.endpoint, + rule.id, + ) + continue match rule.metric_source: # kernel metrics should be evaluated by the average of the metric across every kernels @@ -1452,7 +1459,7 @@ async def _autoscale_endpoints( for kernel in kernels_by_session_id[route.session]: if not kernel_statistics_by_id[kernel.id]: continue - live_stat = json.loads(kernel_statistics_by_id[kernel.id]) + live_stat = kernel_statistics_by_id[kernel.id] if rule.metric_name not in live_stat: continue metric_found_kernel_count += 1 @@ -1465,30 +1472,62 @@ async def _autoscale_endpoints( case AutoScalingMetricSource.INFERENCE_FRAMEWORK: if not endpoint_statistics_by_id[rule.endpoint]: continue - live_stat = json.loads(endpoint_statistics_by_id[rule.endpoint]) + live_stat = endpoint_statistics_by_id[rule.endpoint] if rule.metric_name not in live_stat: + log.debug( + "_autoscale_endpoints(e: {}, r: {}): metric {} does not exist, skipping", + rule.endpoint, + rule.id, + rule.metric_name, + ) continue - current_value = Decimal(live_stat[rule.metric_name]["current"]) + current_value = Decimal(live_stat[rule.metric_name]["current"]) / len( + endpoint_by_id[rule.endpoint].routings + ) case _: raise AssertionError("Should not reach here") # FIXME: Replace with named error match rule.comparator: case AutoScalingMetricComparator.LESS_THAN: - should_trigger = current_value < Decimal(rule.threshold) + should_trigger = current_value < rule.threshold case AutoScalingMetricComparator.LESS_THAN_OR_EQUAL: - should_trigger = current_value <= Decimal(rule.threshold) - case AutoScalingMetricComparator.GREATHER_THAN: - should_trigger = current_value > Decimal(rule.threshold) - case AutoScalingMetricComparator.GREATHER_THAN_OR_EQUAL: - should_trigger = current_value >= Decimal(rule.threshold) - - # changes applied here will be reflected at consequent queries (at `scale_services()`) - # so we do not have to propagate the changes on the function level - if should_trigger and rule.last_triggered_at < ( - current_datetime - timedelta(seconds=rule.cooldown_seconds) - ): - rule.endpoint_row.replicas += rule.step - rule.last_triggered_at = current_datetime + should_trigger = current_value <= rule.threshold + case AutoScalingMetricComparator.GREATER_THAN: + should_trigger = current_value > rule.threshold + case AutoScalingMetricComparator.GREATER_THAN_OR_EQUAL: + should_trigger = current_value >= rule.threshold + + log.debug( + "_autoscale_endpoints(e: {}, r: {}): {} {} {}: {}", + rule.endpoint, + rule.id, + current_value, + rule.comparator.value, + rule.threshold, + should_trigger, + ) + if should_trigger: + if rule.last_triggered_at is None or rule.last_triggered_at.replace(tzinfo=None) < ( + current_datetime - timedelta(seconds=rule.cooldown_seconds) + ): + # changes applied here will be reflected at consequent queries (at `scale_services()`) + # so we do not have to propagate the changes on the function level + rule.endpoint_row.replicas += rule.step_size + if rule.endpoint_row.replicas < 0: + rule.endpoint_row.replicas = 0 + rule.last_triggered_at = current_datetime + log.debug( + "_autoscale_endpoints(e: {}, r: {}): added {} to replica count", + rule.endpoint, + rule.id, + rule.step_size, + ) + else: + log.debug( + "_autoscale_endpoints(e: {}, r: {}): rule on cooldown period; deferring execution", + rule.endpoint, + rule.id, + ) async def scale_services( self, @@ -1514,7 +1553,7 @@ def _pipeline(r: Redis) -> RedisPipeline: return pipe async def _autoscale_txn() -> None: - async with self.db.begin_sssion(commit_on_end=True) as session: + async with self.db.begin_session(commit_on_end=True) as session: await self._autoscale_endpoints(session) await execute_with_retry(_autoscale_txn)