diff --git a/changes/3277.feature.md b/changes/3277.feature.md new file mode 100644 index 0000000000..c3e45b0911 --- /dev/null +++ b/changes/3277.feature.md @@ -0,0 +1 @@ +Support model service auto scaling diff --git a/src/ai/backend/client/cli/__init__.py b/src/ai/backend/client/cli/__init__.py index f6453d6ea1..655d8e1946 100644 --- a/src/ai/backend/client/cli/__init__.py +++ b/src/ai/backend/client/cli/__init__.py @@ -6,6 +6,7 @@ from . import model # noqa # type: ignore from . import server_log # noqa # type: ignore from . import service # noqa # type: ignore +from . import service_auto_scaling_rule # noqa # type: ignore from . import session # noqa # type: ignore from . import session_template # noqa # type: ignore from . import vfolder # noqa # type: ignore diff --git a/src/ai/backend/client/cli/service_auto_scaling_rule.py b/src/ai/backend/client/cli/service_auto_scaling_rule.py new file mode 100644 index 0000000000..43da1d0657 --- /dev/null +++ b/src/ai/backend/client/cli/service_auto_scaling_rule.py @@ -0,0 +1,243 @@ +import decimal +import sys +import uuid +from typing import Any, Iterable, Optional + +import click + +from ai.backend.cli.params import OptionalType +from ai.backend.cli.types import ExitCode, Undefined, undefined +from ai.backend.client.cli.extensions import pass_ctx_obj +from ai.backend.client.cli.service import get_service_id +from ai.backend.client.cli.types import CLIContext +from ai.backend.client.exceptions import BackendAPIError +from ai.backend.client.session import Session +from ai.backend.common.types import AutoScalingMetricComparator, AutoScalingMetricSource + +from ..func.service_auto_scaling_rule import _default_fields as _default_get_fields +from ..output.fields import service_auto_scaling_rule_fields +from .pretty import print_done +from .service import service + +_default_list_fields = ( + service_auto_scaling_rule_fields["id"], + service_auto_scaling_rule_fields["metric_source"], + service_auto_scaling_rule_fields["metric_name"], + service_auto_scaling_rule_fields["comparator"], + service_auto_scaling_rule_fields["threshold"], +) + + +@service.group() +def auto_scaling_rule(): + """Set of model service auto scaling rule operations""" + + +@auto_scaling_rule.command() +@pass_ctx_obj +@click.argument("service", type=str, metavar="SERVICE_NAME_OR_ID") +@click.option("--metric-source", type=click.Choice([*AutoScalingMetricSource]), required=True) +@click.option("--metric-name", type=str, required=True) +@click.option("--threshold", type=str, required=True) +@click.option("--comparator", type=click.Choice([*AutoScalingMetricComparator]), required=True) +@click.option("--step-size", type=int, required=True) +@click.option("--cooldown-seconds", type=int, required=True) +@click.option("--min-replicas", type=int) +@click.option("--max-replicas", type=int) +def create( + ctx: CLIContext, + service: str, + *, + metric_source: AutoScalingMetricSource, + metric_name: str, + threshold: str, + comparator: AutoScalingMetricComparator, + step_size: int, + cooldown_seconds: int, + min_replicas: Optional[int] = None, + max_replicas: Optional[int] = None, +) -> None: + """Create a new auto scaling rule.""" + + with Session() as session: + try: + _threshold = decimal.Decimal(threshold) + except decimal.InvalidOperation: + ctx.output.print_fail(f"{threshold} is not a valid Decimal") + sys.exit(ExitCode.FAILURE) + + try: + service_id = uuid.UUID(get_service_id(session, service)) + rule = session.ServiceAutoScalingRule.create( + service_id, + metric_source, + metric_name, + _threshold, + comparator, + step_size, + cooldown_seconds, + min_replicas=min_replicas, + max_replicas=max_replicas, + ) + print_done(f"Auto Scaling Rule (ID {rule.rule_id}) created.") + except Exception as e: + ctx.output.print_error(e) + sys.exit(ExitCode.FAILURE) + + +@auto_scaling_rule.command() +@pass_ctx_obj +@click.argument("service", type=str, metavar="SERVICE_NAME_OR_ID") +@click.option( + "-f", + "--format", + default=None, + help="Display only specified fields. When specifying multiple fields separate them with comma (,).", +) +@click.option("--filter", "filter_", default=None, help="Set the query filter expression.") +@click.option("--order", default=None, help="Set the query ordering expression.") +@click.option("--offset", default=0, help="The index of the current page start for pagination.") +@click.option("--limit", type=int, default=None, help="The page size for pagination.") +def list(ctx: CLIContext, service: str, format, filter_, order, offset, limit): + """List all set auto scaling rules for given model service.""" + + if format: + try: + fields = [service_auto_scaling_rule_fields[f.strip()] for f in format.split(",")] + except KeyError as e: + ctx.output.print_fail(f"Field {str(e)} not found") + sys.exit(ExitCode.FAILURE) + else: + fields = None + with Session() as session: + service_id = uuid.UUID(get_service_id(session, service)) + + try: + fetch_func = lambda pg_offset, pg_size: session.ServiceAutoScalingRule.paginated_list( + service_id, + page_offset=pg_offset, + page_size=pg_size, + filter=filter_, + order=order, + fields=fields, + ) + ctx.output.print_paginated_list( + fetch_func, + initial_page_offset=offset, + page_size=limit, + ) + except Exception as e: + ctx.output.print_error(e) + sys.exit(ExitCode.FAILURE) + + +@auto_scaling_rule.command() +@pass_ctx_obj +@click.argument("rule", type=str, metavar="RULE_ID") +@click.option( + "-f", + "--format", + default=None, + help="Display only specified fields. When specifying multiple fields separate them with comma (,).", +) +def get(ctx: CLIContext, rule, format): + """Prints attributes of given auto scaling rule.""" + fields: Iterable[Any] + if format: + try: + fields = [service_auto_scaling_rule_fields[f.strip()] for f in format.split(",")] + except KeyError as e: + ctx.output.print_fail(f"Field {str(e)} not found") + sys.exit(ExitCode.FAILURE) + else: + fields = _default_get_fields + + with Session() as session: + try: + rule_info = session.ServiceAutoScalingRule(uuid.UUID(rule)).get(fields=fields) + except (ValueError, BackendAPIError): + ctx.output.print_fail(f"Network {rule} not found.") + sys.exit(ExitCode.FAILURE) + + ctx.output.print_item(rule_info, fields) + + +@auto_scaling_rule.command() +@pass_ctx_obj +@click.argument("rule", type=str, metavar="RULE_ID") +@click.option("--metric-source", type=OptionalType(AutoScalingMetricSource), default=undefined) +@click.option("--metric-name", type=OptionalType(str), default=undefined) +@click.option("--threshold", type=OptionalType(str), default=undefined) +@click.option("--comparator", type=OptionalType(AutoScalingMetricComparator), default=undefined) +@click.option("--step-size", type=OptionalType(int), default=undefined) +@click.option("--cooldown-seconds", type=OptionalType(int), default=undefined) +@click.option( + "--min-replicas", + type=OptionalType(int), + help="Set as -1 to remove min_replicas restriction.", + default=undefined, +) +@click.option( + "--max-replicas", + type=OptionalType(int), + help="Set as -1 to remove max_replicas restriction.", + default=undefined, +) +def update( + ctx: CLIContext, + rule: str, + *, + metric_source: str | Undefined, + metric_name: str | Undefined, + threshold: str | Undefined, + comparator: str | Undefined, + step_size: int | Undefined, + cooldown_seconds: int | Undefined, + min_replicas: Optional[int] | Undefined, + max_replicas: Optional[int] | Undefined, +): + with Session() as session: + try: + _threshold = decimal.Decimal(threshold) if threshold != undefined else undefined + except decimal.InvalidOperation: + ctx.output.print_fail(f"{threshold} is not a valid Decimal") + sys.exit(ExitCode.FAILURE) + + if min_replicas == -1: + min_replicas = None + if max_replicas == -1: + max_replicas = None + + try: + _rule = session.ServiceAutoScalingRule(uuid.UUID(rule)) + _rule.get() + _rule.update( + metric_source=metric_source, + metric_name=metric_name, + threshold=_threshold, + comparator=comparator, + step_size=step_size, + cooldown_seconds=cooldown_seconds, + min_replicas=min_replicas, + max_replicas=max_replicas, + ) + print_done(f"Auto Scaling Rule (ID {_rule.rule_id}) updated.") + except BackendAPIError as e: + ctx.output.print_fail(e.data["title"]) + sys.exit(ExitCode.FAILURE) + + +@auto_scaling_rule.command() +@pass_ctx_obj +@click.argument("rule", type=str, metavar="NETWORK_ID_OR_NAME") +def delete(ctx: CLIContext, rule): + with Session() as session: + rule = session.ServiceAutoScalingRule(uuid.UUID(rule)) + try: + rule.get(fields=[service_auto_scaling_rule_fields["id"]]) + rule.delete() + print_done(f"Auto scaling rule {rule.rule_id} has been deleted.") + except BackendAPIError as e: + ctx.output.print_fail(f"Failed to delete rule {rule.rule_id}:") + ctx.output.print_error(e) + sys.exit(ExitCode.FAILURE) diff --git a/src/ai/backend/client/func/service_auto_scaling_rule.py b/src/ai/backend/client/func/service_auto_scaling_rule.py new file mode 100644 index 0000000000..7bf05d40f8 --- /dev/null +++ b/src/ai/backend/client/func/service_auto_scaling_rule.py @@ -0,0 +1,207 @@ +import textwrap +from decimal import Decimal +from typing import Any, Optional, Sequence +from uuid import UUID + +from ai.backend.client.func.base import BaseFunction, api_function +from ai.backend.client.output.types import FieldSpec, RelayPaginatedResult +from ai.backend.client.pagination import execute_paginated_relay_query +from ai.backend.client.session import api_session +from ai.backend.client.types import set_if_set +from ai.backend.common.types import AutoScalingMetricComparator, AutoScalingMetricSource + +from ...cli.types import Undefined, undefined +from ..output.fields import service_auto_scaling_rule_fields + +_default_fields: Sequence[FieldSpec] = ( + service_auto_scaling_rule_fields["id"], + service_auto_scaling_rule_fields["metric_source"], + service_auto_scaling_rule_fields["metric_name"], + service_auto_scaling_rule_fields["comparator"], + service_auto_scaling_rule_fields["threshold"], + service_auto_scaling_rule_fields["endpoint"], + service_auto_scaling_rule_fields["step_size"], + service_auto_scaling_rule_fields["cooldown_seconds"], + service_auto_scaling_rule_fields["min_replicas"], + service_auto_scaling_rule_fields["max_replicas"], + service_auto_scaling_rule_fields["created_at"], + service_auto_scaling_rule_fields["last_triggered_at"], +) + + +class ServiceAutoScalingRule(BaseFunction): + rule_id: UUID + + @api_function + @classmethod + async def paginated_list( + cls, + endpoint_id: UUID, + *, + fields: Sequence[FieldSpec] | None = None, + page_offset: int = 0, + page_size: int = 20, + filter: Optional[str] = None, + order: Optional[str] = None, + ) -> RelayPaginatedResult[dict]: + return await execute_paginated_relay_query( + "endpoint_auto_scaling_rule_nodes", + { + "endpoint": (str(endpoint_id), "String!"), + "filter": (filter, "String"), + "order": (order, "String"), + }, + fields or _default_fields, + limit=page_size, + offset=page_offset, + ) + + @api_function + @classmethod + async def create( + cls, + service: UUID, + metric_source: AutoScalingMetricSource, + metric_name: str, + threshold: Decimal, + comparator: AutoScalingMetricComparator, + step_size: int, + cooldown_seconds: int, + *, + min_replicas: Optional[int] = None, + max_replicas: Optional[int] = None, + ) -> "ServiceAutoScalingRule": + q = textwrap.dedent( + """ + mutation( + $endpoint: String!, + $metric_source: AutoScalingMetricSource!, + $metric_name: String!, + $threshold: String!, + $comparator: AutoScalingMetricComparator!, + $step_size: Int!, + $cooldown_seconds: Int!, + $min_replicas: Int, + $max_replicas: Int + ) { + create_endpoint_auto_scaling_rule_node( + endpoint: $endpoint, + props: { + metric_source: $metric_source, + metric_name: $metric_name, + threshold: $threshold, + comparator: $comparator, + step_size: $step_size, + cooldown_seconds: $cooldown_seconds, + min_replicas: $min_replicas, + max_replicas: $max_replicas + } + ) { + rule { + row_id + } + } + } + """ + ) + data = await api_session.get().Admin._query( + q, + { + "endpoint": str(service), + "metric_source": metric_source, + "metric_name": metric_name, + "threshold": threshold, + "comparator": comparator, + "step_size": step_size, + "cooldown_seconds": cooldown_seconds, + "min_replicas": min_replicas, + "max_replicas": max_replicas, + }, + ) + + return cls(rule_id=UUID(data["create_endpoint_auto_scaling_rule_node"]["rule"]["row_id"])) + + def __init__(self, rule_id: UUID) -> None: + super().__init__() + self.rule_id = rule_id + + @api_function + async def get( + self, + fields: Sequence[FieldSpec] | None = None, + ) -> Sequence[dict]: + query = textwrap.dedent( + """\ + query($rule_id: String!) { + endpoint_auto_scaling_rule_node(id: $rule_id) {$fields} + } + """ + ) + query = query.replace("$fields", " ".join(f.field_ref for f in (fields or _default_fields))) + variables = {"rule_id": str(self.rule_id)} + data = await api_session.get().Admin._query(query, variables) + return data["endpoint_auto_scaling_rule_node"] + + @api_function + async def update( + self, + *, + metric_source: AutoScalingMetricSource | Undefined = undefined, + metric_name: str | Undefined = undefined, + threshold: Decimal | Undefined = undefined, + comparator: AutoScalingMetricComparator | Undefined = undefined, + step_size: int | Undefined = undefined, + cooldown_seconds: int | Undefined = undefined, + min_replicas: Optional[int] | Undefined = undefined, + max_replicas: Optional[int] | Undefined = undefined, + ) -> "ServiceAutoScalingRule": + q = textwrap.dedent( + """ + mutation( + $rule_id: String!, + $input: ModifyEndpointAutoScalingRuleInput!, + ) { + modify_endpoint_auto_scaling_rule_node( + id: $rule_id, + props: $input + ) { + ok + msg + } + } + """ + ) + inputs: dict[str, Any] = {} + set_if_set(inputs, "metric_source", metric_source) + set_if_set(inputs, "metric_name", metric_name) + set_if_set(inputs, "threshold", threshold) + set_if_set(inputs, "comparator", comparator) + set_if_set(inputs, "step_size", step_size) + set_if_set(inputs, "cooldown_seconds", cooldown_seconds) + set_if_set(inputs, "min_replicas", min_replicas) + set_if_set(inputs, "max_replicas", max_replicas) + data = await api_session.get().Admin._query( + q, + {"rule_id": str(self.rule_id), "input": inputs}, + ) + + return data["modify_endpoint_auto_scaling_rule_node"] + + @api_function + async def delete(self) -> None: + q = textwrap.dedent( + """ + mutation($rule_id: String!) { + delete_endpoint_auto_scaling_rule_node(id: $rule_id) { + ok + msg + } + } + """ + ) + + variables = { + "rule_id": str(self.rule_id), + } + data = await api_session.get().Admin._query(q, variables) + return data["delete_endpoint_auto_scaling_rule_node"] diff --git a/src/ai/backend/client/output/fields.py b/src/ai/backend/client/output/fields.py index ab45df52e2..bf67b54719 100644 --- a/src/ai/backend/client/output/fields.py +++ b/src/ai/backend/client/output/fields.py @@ -356,3 +356,19 @@ FieldSpec("created_at"), FieldSpec("updated_at", "Last Updated"), ]) + + +service_auto_scaling_rule_fields = FieldSet([ + FieldSpec(field_ref="row_id", field_name="id", alt_name="id"), + FieldSpec("endpoint"), + FieldSpec("metric_source"), + FieldSpec("metric_name"), + FieldSpec("threshold"), + FieldSpec("comparator"), + FieldSpec("step_size"), + FieldSpec("cooldown_seconds"), + FieldSpec("min_replicas"), + FieldSpec("max_replicas"), + FieldSpec("created_at"), + FieldSpec("last_triggered_at", "Last Triggered"), +]) diff --git a/src/ai/backend/client/session.py b/src/ai/backend/client/session.py index 7723a62650..a28c788dfb 100644 --- a/src/ai/backend/client/session.py +++ b/src/ai/backend/client/session.py @@ -270,6 +270,7 @@ class BaseSession(metaclass=abc.ABCMeta): "ServerLog", "Permission", "Service", + "ServiceAutoScalingRule", "Model", "QuotaScope", "Network", @@ -313,6 +314,7 @@ def __init__( from .func.scaling_group import ScalingGroup from .func.server_log import ServerLog from .func.service import Service + from .func.service_auto_scaling_rule import ServiceAutoScalingRule from .func.session import ComputeSession from .func.session_template import SessionTemplate from .func.storage import Storage @@ -344,6 +346,7 @@ def __init__( self.ServerLog = ServerLog self.Permission = Permission self.Service = Service + self.ServiceAutoScalingRule = ServiceAutoScalingRule self.Model = Model self.QuotaScope = QuotaScope self.Network = Network diff --git a/src/ai/backend/common/types.py b/src/ai/backend/common/types.py index 71e388f24a..48c87bc650 100644 --- a/src/ai/backend/common/types.py +++ b/src/ai/backend/common/types.py @@ -38,6 +38,7 @@ Union, cast, overload, + override, ) import attrs @@ -94,6 +95,8 @@ "EtcdRedisConfig", "RedisConnectionInfo", "RuntimeVariant", + "AutoScalingMetricSource", + "AutoScalingMetricComparator", "MODEL_SERVICE_RUNTIME_PROFILES", ) @@ -1316,3 +1319,101 @@ def metric_string(self) -> str: val = metric.metric_value_string(self.metric_name, self.metric_primitive) result += f"{val}\n" return result + + +class CIStrEnum(enum.StrEnum): + """ + An StrEnum variant to allow case-insenstive matching of the members while the values are + lowercased. + """ + + @override + @classmethod + def _missing_(cls, value: Any) -> Self | None: + assert isinstance(value, str) # since this is an StrEnum + value = value.lower() + # To prevent infinite recursion, we don't rely on "cls(value)" but manually search the + # members as the official stdlib example suggests. + for member in cls: + if member.value == value: + return member + return None + + # The defualt behavior of `enum.auto()` is to set the value to the lowercased member name. + + @classmethod + def as_trafaret(cls) -> t.Trafaret: + return CIStrEnumTrafaret(cls) + + +class CIUpperStrEnum(CIStrEnum): + """ + An StrEnum variant to allow case-insenstive matching of the members while the values are + UPPERCASED. + """ + + @override + @classmethod + def _missing_(cls, value: Any) -> Self | None: + assert isinstance(value, str) # since this is an StrEnum + value = value.upper() + for member in cls: + if member.value == value: + return member + return None + + @override + @staticmethod + def _generate_next_value_(name, start, count, last_values) -> str: + return name.upper() + + @classmethod + def as_trafaret(cls) -> t.Trafaret: + return CIUpperStrEnumTrafaret(cls) + + +T_enum = TypeVar("T_enum", bound=enum.Enum) + + +class CIStrEnumTrafaret(t.Trafaret, Generic[T_enum]): + """ + A case-insensitive version of trafaret to parse StrEnum values. + """ + + def __init__(self, enum_cls: type[T_enum]) -> None: + self.enum_cls = enum_cls + + def check_and_return(self, value: str) -> T_enum: + try: + # Assume that the enum values are lowercases. + return self.enum_cls(value.lower()) + except (KeyError, ValueError): + self._failure(f"value is not a valid member of {self.enum_cls.__name__}", value=value) + + +class CIUpperStrEnumTrafaret(t.Trafaret, Generic[T_enum]): + """ + A case-insensitive version of trafaret to parse StrEnum values. + """ + + def __init__(self, enum_cls: type[T_enum]) -> None: + self.enum_cls = enum_cls + + def check_and_return(self, value: str) -> T_enum: + try: + # Assume that the enum values are lowercases. + return self.enum_cls(value.upper()) + except (KeyError, ValueError): + self._failure(f"value is not a valid member of {self.enum_cls.__name__}", value=value) + + +class AutoScalingMetricSource(CIUpperStrEnum): + KERNEL = enum.auto() + INFERENCE_FRAMEWORK = enum.auto() + + +class AutoScalingMetricComparator(CIUpperStrEnum): + LESS_THAN = enum.auto() + LESS_THAN_OR_EQUAL = enum.auto() + GREATER_THAN = enum.auto() + GREATER_THAN_OR_EQUAL = enum.auto() diff --git a/src/ai/backend/manager/api/schema.graphql b/src/ai/backend/manager/api/schema.graphql index cfdf9a29ea..31042a1208 100644 --- a/src/ai/backend/manager/api/schema.graphql +++ b/src/ai/backend/manager/api/schema.graphql @@ -220,6 +220,12 @@ type Queries { """Added in 24.12.0.""" networks(filter: String, order: String, offset: Int, before: String, after: String, first: Int, last: Int): NetworkConnection + + """Added in 24.12.0.""" + endpoint_auto_scaling_rule_node(id: String!): EndpointAutoScalingRuleNode + + """Added in 24.12.0.""" + endpoint_auto_scaling_rule_nodes(endpoint: String!, filter: String, order: String, offset: Int, before: String, after: String, first: Int, last: Int): EndpointAutoScalingRuleConnection } """ @@ -1651,6 +1657,61 @@ type NetworkEdge { cursor: String! } +"""Added in 25.01.0.""" +type EndpointAutoScalingRuleNode implements Node { + """The ID of the object""" + id: ID! + row_id: UUID! + metric_source: AutoScalingMetricSource! + metric_name: String! + threshold: String! + comparator: AutoScalingMetricComparator! + step_size: Int! + cooldown_seconds: Int! + min_replicas: Int + max_replicas: Int + created_at: DateTime! + last_triggered_at: DateTime + endpoint: UUID! +} + +"""Added in 25.01.0.""" +enum AutoScalingMetricSource { + KERNEL + INFERENCE_FRAMEWORK +} + +"""Added in 25.01.0.""" +enum AutoScalingMetricComparator { + LESS_THAN + LESS_THAN_OR_EQUAL + GREATER_THAN + GREATER_THAN_OR_EQUAL +} + +"""Added in 25.01.0.""" +type EndpointAutoScalingRuleConnection { + """Pagination data for this connection.""" + pageInfo: PageInfo! + + """Contains the nodes in this connection.""" + edges: [EndpointAutoScalingRuleEdge]! + + """Total count of the GQL nodes of the query.""" + count: Int +} + +""" +Added in 25.01.0. A Relay edge containing a `EndpointAutoScalingRule` and its cursor. +""" +type EndpointAutoScalingRuleEdge { + """The item at the end of the edge""" + node: EndpointAutoScalingRuleNode + + """A cursor for use in pagination""" + cursor: String! +} + """All available GraphQL mutations.""" type Mutations { modify_agent(id: String!, props: ModifyAgentInput!): ModifyAgent @@ -1848,6 +1909,15 @@ type Mutations { """Object id. Can be either global id or object id. Added in 24.09.0.""" id: String! ): DeleteContainerRegistryNode + + """Added in 24.12.0.""" + create_endpoint_auto_scaling_rule_node(endpoint: String!, props: EndpointAutoScalingRuleInput!): CreateEndpointAutoScalingRuleNode + + """Added in 24.12.0.""" + modify_endpoint_auto_scaling_rule_node(id: String!, props: ModifyEndpointAutoScalingRuleInput!): ModifyEndpointAutoScalingRuleNode + + """Added in 24.12.0.""" + delete_endpoint_auto_scaling_rule_node(id: String!): DeleteEndpointAutoScalingRuleNode create_container_registry(hostname: String!, props: CreateContainerRegistryInput!): CreateContainerRegistry modify_container_registry(hostname: String!, props: ModifyContainerRegistryInput!): ModifyContainerRegistry delete_container_registry(hostname: String!): DeleteContainerRegistry @@ -2581,6 +2651,50 @@ type DeleteContainerRegistryNode { container_registry: ContainerRegistryNode } +"""Added in 25.01.0.""" +type CreateEndpointAutoScalingRuleNode { + ok: Boolean + msg: String + rule: EndpointAutoScalingRuleNode +} + +"""Added in 25.01.0.""" +input EndpointAutoScalingRuleInput { + metric_source: AutoScalingMetricSource! + metric_name: String! + threshold: String! + comparator: AutoScalingMetricComparator! + step_size: Int! + cooldown_seconds: Int! + min_replicas: Int + max_replicas: Int +} + +"""Added in 25.01.0.""" +type ModifyEndpointAutoScalingRuleNode { + ok: Boolean + msg: String + rule: EndpointAutoScalingRuleNode +} + +"""Added in 25.01.0.""" +input ModifyEndpointAutoScalingRuleInput { + metric_source: AutoScalingMetricSource + metric_name: String + threshold: String + comparator: AutoScalingMetricComparator + step_size: Int + cooldown_seconds: Int + min_replicas: Int + max_replicas: Int +} + +"""Added in 25.01.0.""" +type DeleteEndpointAutoScalingRuleNode { + ok: Boolean + msg: String +} + type CreateContainerRegistry { container_registry: ContainerRegistry } diff --git a/src/ai/backend/manager/models/alembic/versions/fb89f5d7817b_create_endpoint_auto_scaling_rules_table.py b/src/ai/backend/manager/models/alembic/versions/fb89f5d7817b_create_endpoint_auto_scaling_rules_table.py new file mode 100644 index 0000000000..245d7dbdb2 --- /dev/null +++ b/src/ai/backend/manager/models/alembic/versions/fb89f5d7817b_create_endpoint_auto_scaling_rules_table.py @@ -0,0 +1,53 @@ +"""create endpoint_auto_scaling_rules table + +Revision ID: fb89f5d7817b +Revises: 0bb88d5a46bf +Create Date: 2024-12-20 01:48:21.009056 + +""" + +import sqlalchemy as sa +from alembic import op + +from ai.backend.manager.models.base import GUID, IDColumn + +# revision identifiers, used by Alembic. +revision = "fb89f5d7817b" +down_revision = "0bb88d5a46bf" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "endpoint_auto_scaling_rules", + IDColumn(), + sa.Column("metric_source", sa.VARCHAR(64), nullable=False), + sa.Column("metric_name", sa.Text(), nullable=False), + sa.Column("threshold", sa.Text(), nullable=False), + sa.Column("comparator", sa.VARCHAR(64), nullable=False), + sa.Column("step_size", sa.Integer(), nullable=False), + sa.Column("cooldown_seconds", sa.Integer(), nullable=False), + sa.Column("min_replicas", sa.Integer(), nullable=True), + sa.Column("max_replicas", sa.Integer(), nullable=True), + sa.Column( + "created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True + ), + sa.Column("last_triggered_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("endpoint", GUID(), nullable=False), + sa.ForeignKeyConstraint( + ["endpoint"], + ["endpoints.id"], + name=op.f("fk_endpoint_auto_scaling_rules_endpoint_endpoints"), + ondelete="CASCADE", + ), + sa.PrimaryKeyConstraint("id", name=op.f("pk_endpoint_auto_scaling_rules")), + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("endpoint_auto_scaling_rules") + # ### end Alembic commands ### diff --git a/src/ai/backend/manager/models/base.py b/src/ai/backend/manager/models/base.py index b9e8100887..510e9df2d6 100644 --- a/src/ai/backend/manager/models/base.py +++ b/src/ai/backend/manager/models/base.py @@ -14,6 +14,7 @@ MutableMapping, Sequence, ) +from decimal import Decimal from typing import ( TYPE_CHECKING, Any, @@ -1656,3 +1657,30 @@ def generate_sql_info_for_gql_connection( "Set 'first' or 'last' to a smaller integer." ) return ret + + +class DecimalType(TypeDecorator, Decimal): + """ + Database type adaptor for Decimal + """ + + impl = sa.VARCHAR + cache_ok = True + + def process_bind_param( + self, + value: Optional[Decimal], + dialect: Dialect, + ) -> Optional[str]: + return f"{value:f}" if value else None + + def process_result_value( + self, + value: str, + dialect: Dialect, + ) -> Optional[Decimal]: + return Decimal(value) if value else None + + @property + def python_type(self) -> type[Decimal]: + return Decimal diff --git a/src/ai/backend/manager/models/endpoint.py b/src/ai/backend/manager/models/endpoint.py index 63694a6ef0..388cded220 100644 --- a/src/ai/backend/manager/models/endpoint.py +++ b/src/ai/backend/manager/models/endpoint.py @@ -1,9 +1,22 @@ +from __future__ import annotations + import datetime import logging import uuid +from collections.abc import ( + Mapping, + Sequence, +) +from decimal import Decimal from enum import Enum from pathlib import Path -from typing import TYPE_CHECKING, Any, Iterable, List, Mapping, Optional, Sequence, cast +from typing import ( + TYPE_CHECKING, + Any, + Optional, + Self, + cast, +) import graphene import jwt @@ -25,10 +38,13 @@ from ai.backend.common.types import ( MODEL_SERVICE_RUNTIME_PROFILES, AccessKey, + AutoScalingMetricComparator, + AutoScalingMetricSource, ClusterMode, ImageAlias, MountPermission, MountTypes, + RedisConnectionInfo, ResourceSlot, RuntimeVariant, SessionTypes, @@ -52,6 +68,7 @@ from .base import ( GUID, Base, + DecimalType, EndpointIDColumn, EnumValueType, ForeignKeyIDColumn, @@ -94,6 +111,7 @@ "EndpointTokenRow", "EndpointToken", "EndpointTokenList", + "EndpointAutoScalingRuleRow", ) @@ -211,6 +229,9 @@ class EndpointRow(Base): routings = relationship("RoutingRow", back_populates="endpoint_row") tokens = relationship("EndpointTokenRow", back_populates="endpoint_row") + endpoint_auto_scaling_rules = relationship( + "EndpointAutoScalingRuleRow", back_populates="endpoint_row" + ) image_row = relationship("ImageRow", back_populates="endpoints") model_row = relationship("VFolderRow", back_populates="endpoints") created_user_row = relationship( @@ -286,7 +307,7 @@ async def get( load_created_user=False, load_session_owner=False, load_model=False, - ) -> "EndpointRow": + ) -> Self: """ :raises: sqlalchemy.orm.exc.NoResultFound """ @@ -330,7 +351,7 @@ async def list( load_created_user=False, load_session_owner=False, status_filter=[EndpointLifecycle.CREATED], - ) -> List["EndpointRow"]: + ) -> list[Self]: query = ( sa.select(EndpointRow) .order_by(sa.desc(EndpointRow.created_at)) @@ -355,6 +376,47 @@ async def list( result = await session.execute(query) return result.scalars().all() + @classmethod + async def batch_load( + cls, + session: AsyncSession, + endpoint_ids: Sequence[uuid.UUID], + domain: Optional[str] = None, + project: Optional[uuid.UUID] = None, + user_uuid: Optional[uuid.UUID] = None, + load_routes=False, + load_image=False, + load_tokens=False, + load_created_user=False, + load_session_owner=False, + status_filter=[EndpointLifecycle.CREATED], + ) -> Sequence[Self]: + query = ( + sa.select(EndpointRow) + .order_by(sa.desc(EndpointRow.created_at)) + .filter( + EndpointRow.lifecycle_stage.in_(status_filter) & EndpointRow.id.in_(endpoint_ids) + ) + ) + if load_routes: + query = query.options(selectinload(EndpointRow.routings)) + if load_tokens: + query = query.options(selectinload(EndpointRow.tokens)) + if load_image: + query = query.options(selectinload(EndpointRow.image_row)) + if load_created_user: + query = query.options(selectinload(EndpointRow.created_user_row)) + if load_session_owner: + query = query.options(selectinload(EndpointRow.session_owner_row)) + if project: + query = query.filter(EndpointRow.project == project) + if domain: + query = query.filter(EndpointRow.domain == domain) + if user_uuid: + query = query.filter(EndpointRow.session_owner == user_uuid) + result = await session.execute(query) + return result.scalars().all() + @classmethod async def list_by_model( cls, @@ -369,7 +431,7 @@ async def list_by_model( load_created_user=False, load_session_owner=False, status_filter=[EndpointLifecycle.CREATED], - ) -> List["EndpointRow"]: + ) -> Sequence[Self]: query = ( sa.select(EndpointRow) .order_by(sa.desc(EndpointRow.created_at)) @@ -396,6 +458,33 @@ async def list_by_model( result = await session.execute(query) return result.scalars().all() + async def create_auto_scaling_rule( + self, + session: AsyncSession, + metric_source: AutoScalingMetricSource, + metric_name: str, + threshold: Decimal, + comparator: AutoScalingMetricComparator, + step_size: int, + cooldown_seconds: int = 300, + min_replicas: int | None = None, + max_replicas: int | None = None, + ) -> EndpointAutoScalingRuleRow: + row = EndpointAutoScalingRuleRow( + id=uuid.uuid4(), + endpoint=self.id, + metric_source=metric_source, + metric_name=metric_name, + threshold=threshold, + comparator=comparator, + step_size=step_size, + cooldown_seconds=cooldown_seconds, + min_replicas=min_replicas, + max_replicas=max_replicas, + ) + session.add(row) + return row + class EndpointTokenRow(Base): __tablename__ = "endpoint_tokens" @@ -450,7 +539,7 @@ async def list( project: Optional[uuid.UUID] = None, user_uuid: Optional[uuid.UUID] = None, load_endpoint=False, - ) -> Iterable["EndpointTokenRow"]: + ) -> Sequence[Self]: query = ( sa.select(EndpointTokenRow) .filter(EndpointTokenRow.endpoint == endpoint_id) @@ -477,7 +566,7 @@ async def get( project: Optional[uuid.UUID] = None, user_uuid: Optional[uuid.UUID] = None, load_endpoint=False, - ) -> "EndpointTokenRow": + ) -> Self: query = sa.select(EndpointTokenRow).filter(EndpointTokenRow.token == token) if load_endpoint: query = query.options(selectinload(EndpointTokenRow.tokens)) @@ -494,6 +583,73 @@ async def get( return row +class EndpointAutoScalingRuleRow(Base): + __tablename__ = "endpoint_auto_scaling_rules" + + id = IDColumn() + metric_source = sa.Column( + "metric_source", StrEnumType(AutoScalingMetricSource, use_name=False), nullable=False + ) + metric_name = sa.Column("metric_name", sa.Text(), nullable=False) + threshold = sa.Column("threshold", DecimalType(), nullable=False) + comparator = sa.Column( + "comparator", StrEnumType(AutoScalingMetricComparator, use_name=False), nullable=False + ) + step_size = sa.Column("step_size", sa.Integer(), nullable=False) + cooldown_seconds = sa.Column("cooldown_seconds", sa.Integer(), nullable=False, default=300) + + min_replicas = sa.Column("min_replicas", sa.Integer(), nullable=True) + max_replicas = sa.Column("max_replicas", sa.Integer(), nullable=True) + + created_at = sa.Column( + "created_at", + sa.DateTime(timezone=True), + server_default=sa.text("now()"), + nullable=True, + ) + last_triggered_at = sa.Column( + "last_triggered_at", + sa.DateTime(timezone=True), + nullable=True, + ) + + endpoint = sa.Column( + "endpoint", + GUID, + sa.ForeignKey("endpoints.id", ondelete="CASCADE"), + nullable=False, + ) + + endpoint_row = relationship("EndpointRow", back_populates="endpoint_auto_scaling_rules") + + @classmethod + async def list(cls, session: AsyncSession, load_endpoint=False) -> Sequence[Self]: + query = sa.select(EndpointAutoScalingRuleRow) + if load_endpoint: + query = query.options(selectinload(EndpointAutoScalingRuleRow.endpoint_row)) + result = await session.execute(query) + return result.scalars().all() + + @classmethod + async def get( + cls, session: AsyncSession, id: uuid.UUID, load_endpoint=False + ) -> "EndpointAutoScalingRuleRow": + query = sa.select(EndpointAutoScalingRuleRow).filter(EndpointAutoScalingRuleRow.id == id) + if load_endpoint: + query = query.options(selectinload(EndpointAutoScalingRuleRow.endpoint_row)) + result = await session.execute(query) + row = result.scalar() + if not row: + raise ObjectNotFound("endpoint_auto_scaling_rule") + return row + + async def remove_rule( + self, + session: AsyncSession, + ) -> None: + await session.delete(self) + + class ModelServicePredicateChecker: @staticmethod async def check_scaling_group( @@ -709,15 +865,15 @@ class RuntimeVariantInfo(graphene.ObjectType): human_readable_name = graphene.String() @classmethod - def from_enum(cls, enum: RuntimeVariant) -> "RuntimeVariantInfo": + def from_enum(cls, enum: RuntimeVariant) -> Self: return cls(name=enum.value, human_readable_name=MODEL_SERVICE_RUNTIME_PROFILES[enum].name) class EndpointStatistics: @classmethod - async def batch_load_by_endpoint( + async def batch_load_by_endpoint_impl( cls, - ctx: "GraphQueryContext", + redis_stat: RedisConnectionInfo, endpoint_ids: Sequence[uuid.UUID], ) -> Sequence[Optional[Mapping[str, Any]]]: async def _build_pipeline(redis: Redis) -> Pipeline: @@ -727,7 +883,7 @@ async def _build_pipeline(redis: Redis) -> Pipeline: return pipe stats = [] - results = await redis_helper.execute(ctx.redis_stat, _build_pipeline) + results = await redis_helper.execute(redis_stat, _build_pipeline) for result in results: if result is not None: stats.append(msgpack.unpackb(result)) @@ -736,9 +892,17 @@ async def _build_pipeline(redis: Redis) -> Pipeline: return stats @classmethod - async def batch_load_by_replica( + async def batch_load_by_endpoint( cls, ctx: "GraphQueryContext", + endpoint_ids: Sequence[uuid.UUID], + ) -> Sequence[Optional[Mapping[str, Any]]]: + return await cls.batch_load_by_endpoint_impl(ctx.redis_stat, endpoint_ids) + + @classmethod + async def batch_load_by_replica( + cls, + ctx: GraphQueryContext, endpoint_replica_ids: Sequence[tuple[uuid.UUID, uuid.UUID]], ) -> Sequence[Optional[Mapping[str, Any]]]: async def _build_pipeline(redis: Redis) -> Pipeline: @@ -842,7 +1006,7 @@ async def from_row( cls, ctx, # ctx: GraphQueryContext, row: EndpointRow, - ) -> "Endpoint": + ) -> Self: return cls( endpoint_id=row.id, # image="", # deprecated, row.image_object.name, @@ -926,7 +1090,7 @@ async def load_slice( project: Optional[uuid.UUID] = None, filter: Optional[str] = None, order: Optional[str] = None, - ) -> Sequence["Endpoint"]: + ) -> Sequence[Self]: query = ( sa.select(EndpointRow) .select_from( @@ -993,7 +1157,7 @@ async def load_item( domain_name: Optional[str] = None, user_uuid: Optional[uuid.UUID] = None, project: uuid.UUID | None = None, - ) -> "Endpoint": + ) -> Self: """ :raises: ai.backend.manager.api.exceptions.EndpointNotFound """ @@ -1160,10 +1324,10 @@ async def mutate( info: graphene.ResolveInfo, endpoint_id: uuid.UUID, props: ModifyEndpointInput, - ) -> "ModifyEndpoint": + ) -> Self: graph_ctx: GraphQueryContext = info.context - async def _do_mutate() -> ModifyEndpoint: + async def _do_mutate() -> Self: async with graph_ctx.db.begin_session() as db_session: try: endpoint_row = await EndpointRow.get( @@ -1376,8 +1540,10 @@ def _get_vfolder_id(id_input: str) -> uuid.UUID: await db_session.commit() - return ModifyEndpoint( - True, "success", await Endpoint.from_row(graph_ctx, endpoint_row) + return cls( + True, + "success", + await Endpoint.from_row(graph_ctx, endpoint_row), ) return await gql_mutation_wrapper( @@ -1404,7 +1570,7 @@ async def from_row( cls, ctx, # ctx: GraphQueryContext, row: EndpointTokenRow, - ) -> "EndpointToken": + ) -> Self: return cls( token=row.token, endpoint_id=row.endpoint, @@ -1450,7 +1616,7 @@ async def load_slice( project: Optional[uuid.UUID] = None, domain_name: Optional[str] = None, user_uuid: Optional[uuid.UUID] = None, - ) -> Sequence["EndpointToken"]: + ) -> Sequence[Self]: query = ( sa.select(EndpointTokenRow) .limit(limit) @@ -1480,13 +1646,13 @@ async def load_slice( @classmethod async def load_all( cls, - ctx, # ctx: GraphQueryContext + ctx: GraphQueryContext, endpoint_id: uuid.UUID, *, project: Optional[uuid.UUID] = None, domain_name: Optional[str] = None, user_uuid: Optional[uuid.UUID] = None, - ) -> Sequence["EndpointToken"]: + ) -> Sequence[Self]: async with ctx.db.begin_readonly_session() as session: rows = await EndpointTokenRow.list( session, @@ -1495,7 +1661,7 @@ async def load_all( domain=domain_name, user_uuid=user_uuid, ) - return [await EndpointToken.from_row(ctx, row) for row in rows] + return [await cls.from_row(ctx, row) for row in rows] @classmethod async def load_item( @@ -1506,7 +1672,7 @@ async def load_item( project: Optional[uuid.UUID] = None, domain_name: Optional[str] = None, user_uuid: Optional[uuid.UUID] = None, - ) -> "EndpointToken": + ) -> Self: try: async with ctx.db.begin_readonly_session() as session: row = await EndpointTokenRow.get( @@ -1514,7 +1680,7 @@ async def load_item( ) except NoResultFound: raise EndpointTokenNotFound - return await EndpointToken.from_row(ctx, row) + return await cls.from_row(ctx, row) async def resolve_valid_until( self, diff --git a/src/ai/backend/manager/models/gql.py b/src/ai/backend/manager/models/gql.py index 81669804f1..2e5f5c2447 100644 --- a/src/ai/backend/manager/models/gql.py +++ b/src/ai/backend/manager/models/gql.py @@ -81,6 +81,13 @@ DomainPermissionValueField, ModifyDomainNode, ) +from .gql_models.endpoint import ( + CreateEndpointAutoScalingRuleNode, + DeleteEndpointAutoScalingRuleNode, + EndpointAutoScalingRuleConnection, + EndpointAutoScalingRuleNode, + ModifyEndpointAutoScalingRuleNode, +) from .gql_models.fields import AgentPermissionField, ScopeField from .gql_models.group import GroupConnection, GroupNode from .gql_models.image import ( @@ -335,6 +342,16 @@ class Mutations(graphene.ObjectType): description="Added in 24.09.0." ) + create_endpoint_auto_scaling_rule_node = CreateEndpointAutoScalingRuleNode.Field( + description="Added in 24.12.0." + ) + modify_endpoint_auto_scaling_rule_node = ModifyEndpointAutoScalingRuleNode.Field( + description="Added in 24.12.0." + ) + delete_endpoint_auto_scaling_rule_node = DeleteEndpointAutoScalingRuleNode.Field( + description="Added in 24.12.0." + ) + # Legacy mutations create_container_registry = CreateContainerRegistry.Field() modify_container_registry = ModifyContainerRegistry.Field() @@ -899,6 +916,18 @@ class Queries(graphene.ObjectType): ) networks = PaginatedConnectionField(NetworkConnection, description="Added in 24.12.0.") + endpoint_auto_scaling_rule_node = graphene.Field( + EndpointAutoScalingRuleNode, + id=graphene.String(required=True), + description="Added in 24.12.0.", + ) + + endpoint_auto_scaling_rule_nodes = PaginatedConnectionField( + EndpointAutoScalingRuleConnection, + endpoint=graphene.String(required=True), + description="Added in 24.12.0.", + ) + @staticmethod @privileged_query(UserRole.SUPERADMIN) async def resolve_agent( @@ -2618,6 +2647,40 @@ async def resolve_networks( last, ) + @staticmethod + async def resolve_endpoint_auto_scaling_rule_node( + root: Any, + info: graphene.ResolveInfo, + id: str, + ) -> EndpointAutoScalingRuleNode: + return await EndpointAutoScalingRuleNode.get_node(info, id) + + @staticmethod + async def resolve_endpoint_auto_scaling_rule_nodes( + root: Any, + info: graphene.ResolveInfo, + endpoint: str, + *, + filter: str | None = None, + order: str | None = None, + offset: int | None = None, + after: str | None = None, + first: int | None = None, + before: str | None = None, + last: int | None = None, + ) -> ConnectionResolverResult: + return await EndpointAutoScalingRuleNode.get_connection( + info, + endpoint, + filter_expr=filter, + order_expr=order, + offset=offset, + after=after, + first=first, + before=before, + last=last, + ) + class GQLMutationPrivilegeCheckMiddleware: def resolve(self, next, root, info: graphene.ResolveInfo, **args) -> Any: diff --git a/src/ai/backend/manager/models/gql_models/endpoint.py b/src/ai/backend/manager/models/gql_models/endpoint.py new file mode 100644 index 0000000000..03f8f05e34 --- /dev/null +++ b/src/ai/backend/manager/models/gql_models/endpoint.py @@ -0,0 +1,488 @@ +from __future__ import annotations + +import decimal +import uuid +from typing import TYPE_CHECKING, Mapping, Self + +import graphene +from dateutil.parser import parse as dtparse +from graphene.types.datetime import DateTime as GQLDateTime +from graphql import Undefined +from sqlalchemy.orm.exc import NoResultFound + +from ai.backend.common.types import ( + AutoScalingMetricComparator, + AutoScalingMetricSource, +) + +from ...api.exceptions import ( + GenericForbidden, + InvalidAPIParameters, + ObjectNotFound, +) +from ..base import ( + FilterExprArg, + OrderExprArg, + generate_sql_info_for_gql_connection, + gql_mutation_wrapper, + orm_set_if_set, +) +from ..endpoint import ( + EndpointAutoScalingRuleRow, + EndpointRow, +) +from ..gql_relay import AsyncNode, Connection, ConnectionResolverResult +from ..minilang.ordering import OrderSpecItem, QueryOrderParser +from ..minilang.queryfilter import FieldSpecItem, QueryFilterParser +from ..user import UserRole +from ..utils import generate_desc_for_enum_kvlist + +if TYPE_CHECKING: + from ..gql import GraphQueryContext + + +_queryfilter_fieldspec: Mapping[str, FieldSpecItem] = { + "id": ("id", None), + "metric_source": ("metric_source", None), + "metric_name": ("metric_name", None), + "threshold": ("threshold", None), + "comparator": ("comparator", None), + "step_size": ("step_size", None), + "cooldown_seconds": ("cooldown_seconds", None), + "created_at": ("created_at", dtparse), + "last_triggered_at": ("last_triggered_at", dtparse), + "endpoint": ("endpoint", None), +} + +_queryorder_colmap: Mapping[str, OrderSpecItem] = { + "id": ("id", None), + "metric_source": ("metric_source", None), + "metric_name": ("metric_name", None), + "threshold": ("threshold", None), + "comparator": ("comparator", None), + "step_size": ("step_size", None), + "cooldown_seconds": ("cooldown_seconds", None), + "created_at": ("created_at", None), + "last_triggered_at": ("last_triggered_at", None), + "endpoint": ("endpoint", None), +} + + +class EndpointAutoScalingRuleNode(graphene.ObjectType): + class Meta: + interfaces = (AsyncNode,) + description = "Added in 25.01.0." + + row_id = graphene.UUID(required=True) + + metric_source = graphene.Field( + graphene.Enum.from_enum(AutoScalingMetricSource, description="Added in 25.01.0."), + required=True, + ) + metric_name = graphene.String(required=True) + threshold = graphene.String(required=True) + comparator = graphene.Field( + graphene.Enum.from_enum(AutoScalingMetricComparator, description="Added in 25.01.0."), + required=True, + ) + step_size = graphene.Int(required=True) + cooldown_seconds = graphene.Int(required=True) + + min_replicas = graphene.Int() + max_replicas = graphene.Int() + + created_at = GQLDateTime(required=True) + last_triggered_at = GQLDateTime() + + endpoint = graphene.UUID(required=True) + + @classmethod + def from_row(cls, graph_ctx: GraphQueryContext, row: EndpointAutoScalingRuleRow) -> Self: + return cls( + id=row.id, + row_id=row.id, + metric_source=row.metric_source, + metric_name=row.metric_name, + threshold=row.threshold, + comparator=row.comparator, + step_size=row.step_size, + cooldown_seconds=row.cooldown_seconds, + min_replicas=row.min_replicas, + max_replicas=row.max_replicas, + created_at=row.created_at, + last_triggered_at=row.last_triggered_at, + endpoint=row.endpoint, + ) + + @classmethod + async def get_node(cls, info: graphene.ResolveInfo, rule_id: str) -> Self: + graph_ctx: GraphQueryContext = info.context + + _, raw_rule_id = AsyncNode.resolve_global_id(info, rule_id) + if not raw_rule_id: + raw_rule_id = rule_id + try: + _rule_id = uuid.UUID(raw_rule_id) + except ValueError: + raise ObjectNotFound("endpoint_auto_scaling_rule") + + async with graph_ctx.db.begin_readonly_session() as db_session: + rule_row = await EndpointAutoScalingRuleRow.get( + db_session, _rule_id, load_endpoint=True + ) + match graph_ctx.user["role"]: + case UserRole.SUPERADMIN: + pass + case UserRole.ADMIN: + if rule_row.endpoint_row.domain != graph_ctx.user["domain_name"]: + raise GenericForbidden + case UserRole.USER: + if rule_row.endpoint_row.created_user != graph_ctx.user["uuid"]: + raise GenericForbidden + + return cls.from_row(graph_ctx, rule_row) + + @classmethod + async def get_connection( + cls, + info: graphene.ResolveInfo, + endpoint: str, + *, + filter_expr: str | None = None, + order_expr: str | None = None, + offset: int | None = None, + after: str | None = None, + first: int | None = None, + before: str | None = None, + last: int | None = None, + ) -> ConnectionResolverResult[Self]: + graph_ctx: GraphQueryContext = info.context + _filter_arg = ( + FilterExprArg(filter_expr, QueryFilterParser(_queryfilter_fieldspec)) + if filter_expr is not None + else None + ) + _order_expr = ( + OrderExprArg(order_expr, QueryOrderParser(_queryorder_colmap)) + if order_expr is not None + else None + ) + ( + query, + cnt_query, + _, + cursor, + pagination_order, + page_size, + ) = generate_sql_info_for_gql_connection( + info, + EndpointAutoScalingRuleRow, + EndpointAutoScalingRuleRow.id, + _filter_arg, + _order_expr, + offset, + after=after, + first=first, + before=before, + last=last, + ) + + async with graph_ctx.db.begin_readonly_session() as db_session: + _, raw_endpoint_id = AsyncNode.resolve_global_id(info, endpoint) + if not raw_endpoint_id: + raw_endpoint_id = endpoint + try: + _endpoint_id = uuid.UUID(raw_endpoint_id) + except ValueError: + raise ObjectNotFound("endpoint") + try: + row = await EndpointRow.get(db_session, _endpoint_id) + except NoResultFound: + raise ObjectNotFound(object_name="endpoint") + + match graph_ctx.user["role"]: + case UserRole.SUPERADMIN: + pass + case UserRole.ADMIN: + if row.endpoint_row.domain != graph_ctx.user["domain_name"]: + raise GenericForbidden + case UserRole.USER: + if row.endpoint_row.created_user != graph_ctx.user["uuid"]: + raise GenericForbidden + + query = query.filter(EndpointAutoScalingRuleRow.endpoint == _endpoint_id) + group_rows = (await db_session.scalars(query)).all() + result = [cls.from_row(graph_ctx, row) for row in group_rows] + total_cnt = await db_session.scalar(cnt_query) + return ConnectionResolverResult(result, cursor, pagination_order, page_size, total_cnt) + + +class EndpointAutoScalingRuleConnection(Connection): + class Meta: + node = EndpointAutoScalingRuleNode + description = "Added in 25.01.0." + + +class EndpointAutoScalingRuleInput(graphene.InputObjectType): + class Meta: + description = "Added in 25.01.0." + + metric_source = graphene.Field( + graphene.Enum.from_enum( + AutoScalingMetricSource, + description=( + f"Available values: {generate_desc_for_enum_kvlist(AutoScalingMetricSource)}" + ), + ), + required=True, + ) + metric_name = graphene.String(required=True) + threshold = graphene.String(required=True) + comparator = graphene.Field( + graphene.Enum.from_enum( + AutoScalingMetricComparator, + description=( + f"Available values: {generate_desc_for_enum_kvlist(AutoScalingMetricComparator)}" + ), + ), + required=True, + ) + step_size = graphene.Int(required=True) + cooldown_seconds = graphene.Int(required=True) + min_replicas = graphene.Int() + max_replicas = graphene.Int() + + +class ModifyEndpointAutoScalingRuleInput(graphene.InputObjectType): + class Meta: + description = "Added in 25.01.0." + + metric_source = graphene.Field( + graphene.Enum.from_enum( + AutoScalingMetricSource, + description=( + f"Available values: {", ".join([p.name for p in AutoScalingMetricSource])}" + ), + ), + default_value=Undefined, + ) + metric_name = graphene.String() + threshold = graphene.String() + comparator = graphene.Field( + graphene.Enum.from_enum( + AutoScalingMetricComparator, + description=( + f"Available values: {", ".join([p.name for p in AutoScalingMetricComparator])}" + ), + ), + default_value=Undefined, + ) + step_size = graphene.Int() + cooldown_seconds = graphene.Int() + min_replicas = graphene.Int() + max_replicas = graphene.Int() + + +class CreateEndpointAutoScalingRuleNode(graphene.Mutation): + allowed_roles = (UserRole.USER, UserRole.ADMIN, UserRole.SUPERADMIN) + + class Arguments: + endpoint = graphene.String(required=True) + props = EndpointAutoScalingRuleInput(required=True) + + class Meta: + description = "Added in 25.01.0." + + ok = graphene.Boolean() + msg = graphene.String() + rule = graphene.Field(lambda: EndpointAutoScalingRuleNode, required=False) + + @classmethod + async def mutate( + cls, + root, + info: graphene.ResolveInfo, + endpoint: str, + props: EndpointAutoScalingRuleInput, + ) -> Self: + _, raw_endpoint_id = AsyncNode.resolve_global_id(info, endpoint) + if not raw_endpoint_id: + raw_endpoint_id = endpoint + if not props.metric_source: + raise InvalidAPIParameters("metric_source is a required field") + if not props.comparator: + raise InvalidAPIParameters("comparator is a required field") + + try: + _endpoint_id = uuid.UUID(raw_endpoint_id) + except ValueError: + raise ObjectNotFound("endpoint") + + graph_ctx: GraphQueryContext = info.context + async with graph_ctx.db.begin_session(commit_on_end=True) as db_session: + try: + row = await EndpointRow.get(db_session, _endpoint_id) + except NoResultFound: + raise ObjectNotFound(object_name="endpoint") + + match graph_ctx.user["role"]: + case UserRole.SUPERADMIN: + pass + case UserRole.ADMIN: + if row.domain != graph_ctx.user["domain_name"]: + raise GenericForbidden + case UserRole.USER: + if row.created_user != graph_ctx.user["uuid"]: + raise GenericForbidden + + try: + _threshold = decimal.Decimal(props.threshold) + except decimal.InvalidOperation: + raise InvalidAPIParameters(f"Cannot convert {props.threshold} to Decimal") + + async def _do_mutate() -> Self: + created_rule = await row.create_auto_scaling_rule( + db_session, + props.metric_source, + props.metric_name, + _threshold, + props.comparator, + props.step_size, + cooldown_seconds=props.cooldown_seconds, + min_replicas=props.min_replicas, + max_replicas=props.max_replicas, + ) + return cls( + ok=True, + msg="Auto scaling rule created", + rule=EndpointAutoScalingRuleNode.from_row(info.context, created_rule), + ) + + return await gql_mutation_wrapper(cls, _do_mutate) + + +class ModifyEndpointAutoScalingRuleNode(graphene.Mutation): + allowed_roles = (UserRole.USER, UserRole.ADMIN, UserRole.SUPERADMIN) + + class Arguments: + id = graphene.String(required=True) + props = ModifyEndpointAutoScalingRuleInput(required=True) + + class Meta: + description = "Added in 25.01.0." + + ok = graphene.Boolean() + msg = graphene.String() + rule = graphene.Field(lambda: EndpointAutoScalingRuleNode, required=False) + + @classmethod + async def mutate( + cls, + root, + info: graphene.ResolveInfo, + id: str, + props: ModifyEndpointAutoScalingRuleInput, + ) -> Self: + _, rule_id = AsyncNode.resolve_global_id(info, id) + if not rule_id: + rule_id = id + + try: + _rule_id = uuid.UUID(rule_id) + except ValueError: + raise ObjectNotFound("auto_scaling_rule") + + graph_ctx: GraphQueryContext = info.context + async with graph_ctx.db.begin_session(commit_on_end=True) as db_session: + try: + row = await EndpointAutoScalingRuleRow.get(db_session, _rule_id, load_endpoint=True) + except NoResultFound: + raise ObjectNotFound(object_name="auto_scaling_rule") + + match graph_ctx.user["role"]: + case UserRole.SUPERADMIN: + pass + case UserRole.ADMIN: + if row.endpoint_row.domain != graph_ctx.user["domain_name"]: + raise GenericForbidden + case UserRole.USER: + if row.endpoint_row.created_user != graph_ctx.user["uuid"]: + raise GenericForbidden + + async def _do_mutate() -> Self: + if (_newval := props.threshold) and _newval is not Undefined: + try: + row.threshold = decimal.Decimal(_newval) + except decimal.InvalidOperation: + raise InvalidAPIParameters(f"Cannot convert {_newval} to Decimal") + + orm_set_if_set(props, row, "metric_source") + orm_set_if_set(props, row, "metric_name") + orm_set_if_set(props, row, "comparator") + orm_set_if_set(props, row, "step_size") + orm_set_if_set(props, row, "cooldown_seconds") + orm_set_if_set(props, row, "min_replicas") + orm_set_if_set(props, row, "max_replicas") + + return cls( + ok=True, + msg="Auto scaling rule updated", + rule=EndpointAutoScalingRuleNode.from_row(info.context, row), + ) + + return await gql_mutation_wrapper(cls, _do_mutate) + + +class DeleteEndpointAutoScalingRuleNode(graphene.Mutation): + allowed_roles = (UserRole.USER, UserRole.ADMIN, UserRole.SUPERADMIN) + + class Arguments: + id = graphene.String(required=True) + + class Meta: + description = "Added in 25.01.0." + + ok = graphene.Boolean() + msg = graphene.String() + + @classmethod + async def mutate( + cls, + root, + info: graphene.ResolveInfo, + id: str, + ) -> Self: + _, rule_id = AsyncNode.resolve_global_id(info, id) + if not rule_id: + rule_id = id + + try: + _rule_id = uuid.UUID(rule_id) + except ValueError: + raise ObjectNotFound("auto_scaling_rule") + + graph_ctx: GraphQueryContext = info.context + async with graph_ctx.db.begin_session(commit_on_end=True) as db_session: + try: + row = await EndpointAutoScalingRuleRow.get(db_session, _rule_id, load_endpoint=True) + except NoResultFound: + raise ObjectNotFound(object_name="auto_scaling_rule") + + match graph_ctx.user["role"]: + case UserRole.SUPERADMIN: + pass + case UserRole.ADMIN: + if row.endpoint_row.domain != graph_ctx.user["domain_name"]: + raise GenericForbidden + case UserRole.USER: + if row.endpoint_row.created_user != graph_ctx.user["uuid"]: + raise GenericForbidden + + async def _do_mutate() -> Self: + await db_session.delete(row) + return cls( + ok=True, + msg="Auto scaling rule removed", + ) + + return await gql_mutation_wrapper(cls, _do_mutate) diff --git a/src/ai/backend/manager/models/kernel.py b/src/ai/backend/manager/models/kernel.py index e54f4827fd..c0dac4aee4 100644 --- a/src/ai/backend/manager/models/kernel.py +++ b/src/ai/backend/manager/models/kernel.py @@ -587,6 +587,13 @@ def get_used_days(self, local_tz: tzfile) -> Optional[int]: ) return None + @staticmethod + async def batch_load_by_session_id( + session: SASession, session_ids: list[uuid.UUID] + ) -> list["KernelRow"]: + query = sa.select(KernelRow).where(KernelRow.session_id.in_(session_ids)) + return (await session.execute(query)).scalars().all() + @staticmethod async def get_kernel( db: ExtendedAsyncSAEngine, kern_id: uuid.UUID, allow_stale: bool = False @@ -803,11 +810,13 @@ class SessionInfo(TypedDict): class KernelStatistics: @classmethod - async def batch_load_by_kernel( + async def batch_load_by_kernel_impl( cls, - ctx: GraphQueryContext, + redis_stat: RedisConnectionInfo, session_ids: Sequence[SessionId], ) -> Sequence[Optional[Mapping[str, Any]]]: + """For cases where required to collect kernel metrics in bulk internally""" + async def _build_pipeline(redis: Redis) -> Pipeline: pipe = redis.pipeline() for sess_id in session_ids: @@ -815,7 +824,7 @@ async def _build_pipeline(redis: Redis) -> Pipeline: return pipe stats = [] - results = await redis_helper.execute(ctx.redis_stat, _build_pipeline) + results = await redis_helper.execute(redis_stat, _build_pipeline) for result in results: if result is not None: stats.append(msgpack.unpackb(result)) @@ -823,6 +832,15 @@ async def _build_pipeline(redis: Redis) -> Pipeline: stats.append(None) return stats + @classmethod + async def batch_load_by_kernel( + cls, + ctx: GraphQueryContext, + session_ids: Sequence[SessionId], + ) -> Sequence[Optional[Mapping[str, Any]]]: + """wrapper of `KernelStatistics.batch_load_by_kernel_impl()` for aiodataloader""" + return await cls.batch_load_by_kernel_impl(ctx.redis_stat, session_ids) + @classmethod async def batch_load_inference_metrics_by_kernel( cls, diff --git a/src/ai/backend/manager/models/utils.py b/src/ai/backend/manager/models/utils.py index 2731b9fce4..3d66f1a998 100644 --- a/src/ai/backend/manager/models/utils.py +++ b/src/ai/backend/manager/models/utils.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import enum import functools import json import logging @@ -550,3 +551,10 @@ async def vacuum_db( vacuum_sql = "VACUUM FULL" if vacuum_full else "VACUUM" log.info(f"Perfoming {vacuum_sql} operation...") await conn.exec_driver_sql(vacuum_sql) + + +def generate_desc_for_enum_kvlist(e: type[enum.StrEnum]) -> str: + items = [] + for name, value in e.__members__.items(): + items.append(f"{str(value)!r} ({name})") + return ", ".join(items) diff --git a/src/ai/backend/manager/scheduler/dispatcher.py b/src/ai/backend/manager/scheduler/dispatcher.py index 833c617670..71debf09cb 100644 --- a/src/ai/backend/manager/scheduler/dispatcher.py +++ b/src/ai/backend/manager/scheduler/dispatcher.py @@ -5,6 +5,7 @@ import json import logging import uuid +from collections import defaultdict from collections.abc import ( Awaitable, Callable, @@ -34,7 +35,7 @@ from sqlalchemy.orm import noload, selectinload from ai.backend.common import redis_helper -from ai.backend.common.defs import REDIS_LIVE_DB +from ai.backend.common.defs import REDIS_LIVE_DB, REDIS_STAT_DB from ai.backend.common.distributed import GlobalTimer from ai.backend.common.events import ( AgentStartedEvent, @@ -59,6 +60,8 @@ from ai.backend.common.types import ( AgentId, AgentSelectionStrategy, + AutoScalingMetricComparator, + AutoScalingMetricSource, ClusterMode, RedisConnectionInfo, ResourceSlot, @@ -81,9 +84,12 @@ from ..models import ( AgentRow, AgentStatus, + EndpointAutoScalingRuleRow, EndpointLifecycle, EndpointRow, + EndpointStatistics, KernelRow, + KernelStatistics, KernelStatus, RouteStatus, RoutingRow, @@ -200,6 +206,7 @@ class SchedulerDispatcher(aobject): update_session_status_timer: GlobalTimer redis_live: RedisConnectionInfo + redis_stat: RedisConnectionInfo def __init__( self, @@ -222,6 +229,11 @@ def __init__( name="scheduler.live", db=REDIS_LIVE_DB, ) + self.redis_stat = redis_helper.get_redis_object( + self.shared_config.data["redis"], + name="stat", + db=REDIS_STAT_DB, + ) async def __ainit__(self) -> None: coalescing_opts: CoalescingOptions = { @@ -1367,6 +1379,169 @@ async def _mark_session_and_kernel_creating( except asyncio.TimeoutError: log.warning("start(): timeout while executing start_session()") + async def _autoscale_endpoints( + self, + session: SASession, + ) -> None: + current_datetime = datetime.now() + rules = await EndpointAutoScalingRuleRow.list(session, load_endpoint=True) + + # currently auto scaling supports two types of stat as source: kernel and endpoint + # to fetch aggregated kernel metrics among every kernels managed by a single endpoint + # we first need to collect every routings, and then the sessions tied to each routing, + # and finally the child kernels of each session + endpoints = await EndpointRow.batch_load( + session, [rule.endpoint for rule in rules], load_routes=True + ) + endpoint_by_id: dict[uuid.UUID, EndpointRow] = { + endpoint.id: endpoint for endpoint in endpoints + } + metric_requested_sessions: list[uuid.UUID] = list() + metric_requested_kernels: list[uuid.UUID] = list() + metric_requested_endpoints: list[uuid.UUID] = list() + + kernel_statistics_by_id: dict[uuid.UUID, Any] = {} + endpoint_statistics_by_id: dict[uuid.UUID, Any] = {} + kernels_by_session_id: dict[uuid.UUID, list[KernelRow]] = defaultdict(lambda: []) + + for rule in rules: + match rule.metric_source: + case AutoScalingMetricSource.KERNEL: + metric_requested_sessions += [ + route.session for route in endpoint_by_id[rule.endpoint].routings + ] + case AutoScalingMetricSource.INFERENCE_FRAMEWORK: + metric_requested_endpoints.append(rule.endpoint) + + kernel_rows = await KernelRow.batch_load_by_session_id( + session, list(metric_requested_sessions) + ) + for kernel in kernel_rows: + kernels_by_session_id[kernel.session_id].append(kernel) + metric_requested_kernels.append(kernel) + + # to speed up and lower the pressure to the redis we must load every metrics + # in bulk, not querying each key at once + kernel_live_stats = await KernelStatistics.batch_load_by_kernel_impl( + self.redis_stat, + cast(list[SessionId], list(metric_requested_kernels)), + ) + endpoint_live_stats = await EndpointStatistics.batch_load_by_endpoint_impl( + self.redis_stat, + cast(list[SessionId], list(metric_requested_endpoints)), + ) + + kernel_statistics_by_id = { + kernel_id: metric + for kernel_id, metric in zip(metric_requested_kernels, kernel_live_stats) + } + endpoint_statistics_by_id = { + endpoint_id: metric + for endpoint_id, metric in zip(metric_requested_endpoints, endpoint_live_stats) + } + + for rule in rules: + should_trigger = False + if len(endpoint_by_id[rule.endpoint].routings) == 0: + log.debug( + "_autoscale_endpoints(e: {}, r: {}): endpoint does not have any replicas, skipping", + rule.endpoint, + rule.id, + ) + continue + + match rule.metric_source: + # kernel metrics should be evaluated by the average of the metric across every kernels + case AutoScalingMetricSource.KERNEL: + metric_aggregated_value = Decimal("0") + metric_found_kernel_count = 0 + for route in endpoint_by_id[rule.endpoint].routings: + for kernel in kernels_by_session_id[route.session]: + if not kernel_statistics_by_id[kernel.id]: + continue + live_stat = kernel_statistics_by_id[kernel.id] + if rule.metric_name not in live_stat: + continue + metric_found_kernel_count += 1 + metric_aggregated_value += Decimal( + live_stat[rule.metric_name]["current"] + ) + if metric_found_kernel_count == 0: + continue + current_value = metric_aggregated_value / Decimal(metric_found_kernel_count) + case AutoScalingMetricSource.INFERENCE_FRAMEWORK: + if not endpoint_statistics_by_id[rule.endpoint]: + continue + live_stat = endpoint_statistics_by_id[rule.endpoint] + if rule.metric_name not in live_stat: + log.debug( + "_autoscale_endpoints(e: {}, r: {}): metric {} does not exist, skipping", + rule.endpoint, + rule.id, + rule.metric_name, + ) + continue + current_value = Decimal(live_stat[rule.metric_name]["current"]) / len( + endpoint_by_id[rule.endpoint].routings + ) + case _: + raise AssertionError("Should not reach here") # FIXME: Replace with named error + + match rule.comparator: + case AutoScalingMetricComparator.LESS_THAN: + should_trigger = current_value < rule.threshold + case AutoScalingMetricComparator.LESS_THAN_OR_EQUAL: + should_trigger = current_value <= rule.threshold + case AutoScalingMetricComparator.GREATER_THAN: + should_trigger = current_value > rule.threshold + case AutoScalingMetricComparator.GREATER_THAN_OR_EQUAL: + should_trigger = current_value >= rule.threshold + + log.debug( + "_autoscale_endpoints(e: {}, r: {}): {} {} {}: {}", + rule.endpoint, + rule.id, + current_value, + rule.comparator.value, + rule.threshold, + should_trigger, + ) + if should_trigger: + new_replicas = rule.endpoint_row.replicas + rule.step_size + if (rule.min_replicas is not None and new_replicas < rule.min_replicas) or ( + rule.max_replicas is not None and new_replicas > rule.max_replicas + ): + log.debug( + "_autoscale_endpoints(e: {}, r: {}): new replica count {} violates min ({}) / max ({}) replica restriction; skipping", + rule.endpoint, + rule.id, + new_replicas, + rule.min_replicas, + rule.max_replicas, + ) + continue + if rule.last_triggered_at is None or rule.last_triggered_at.replace(tzinfo=None) < ( + current_datetime - timedelta(seconds=rule.cooldown_seconds) + ): + # changes applied here will be reflected at consequent queries (at `scale_services()`) + # so we do not have to propagate the changes on the function level + rule.endpoint_row.replicas += rule.step_size + if rule.endpoint_row.replicas < 0: + rule.endpoint_row.replicas = 0 + rule.last_triggered_at = current_datetime + log.debug( + "_autoscale_endpoints(e: {}, r: {}): added {} to replica count", + rule.endpoint, + rule.id, + rule.step_size, + ) + else: + log.debug( + "_autoscale_endpoints(e: {}, r: {}): rule on cooldown period; deferring execution", + rule.endpoint, + rule.id, + ) + async def scale_services( self, context: None, @@ -1390,6 +1565,12 @@ def _pipeline(r: Redis) -> RedisPipeline: ) return pipe + async def _autoscale_txn() -> None: + async with self.db.begin_session(commit_on_end=True) as session: + await self._autoscale_endpoints(session) + + await execute_with_retry(_autoscale_txn) + await redis_helper.execute( self.redis_live, _pipeline,