diff --git a/src/_bentoml_sdk/__init__.py b/src/_bentoml_sdk/__init__.py index 52d3e686be8..f965d58c789 100644 --- a/src/_bentoml_sdk/__init__.py +++ b/src/_bentoml_sdk/__init__.py @@ -1,4 +1,10 @@ +from __future__ import annotations from bentoml._internal.utils.pkg import pkg_version_info +from typing_extensions import deprecated +from typing import TypeVar, Callable, TYPE_CHECKING + +if TYPE_CHECKING: + from bentoml._internal.external_typing import ASGIApp if (ver := pkg_version_info("pydantic")) < (2,): raise ImportError( @@ -8,7 +14,7 @@ # ruff: noqa -from .decorators import api, on_shutdown, mount_asgi_app, on_deployment, task +from .decorators import api, on_shutdown, asgi_app, on_deployment, task from .service import get_current_service from .service import depends from .service import Service, ServiceConfig @@ -16,11 +22,22 @@ from .service import runner_service from .io_models import IODescriptor +T = TypeVar("T") + + +@deprecated("Deprecated in favor of `bentoml.asgi_app`") +def mount_asgi_app( + app: ASGIApp, *, path: str = "/", name: str | None = None +) -> Callable[[T], T]: + return asgi_app(app, path=path, name=name) + + __all__ = [ "api", "task", "on_shutdown", "on_deployment", + "asgi_app", "mount_asgi_app", "depends", "Service", diff --git a/src/_bentoml_sdk/decorators.py b/src/_bentoml_sdk/decorators.py index c10ae391b06..1cedbdfc12e 100644 --- a/src/_bentoml_sdk/decorators.py +++ b/src/_bentoml_sdk/decorators.py @@ -103,7 +103,7 @@ def wrapper(func: t.Callable[t.Concatenate[t.Any, P], R]) -> APIMethod[P, R]: return wrapper -def mount_asgi_app( +def asgi_app( app: ASGIApp, *, path: str = "/", name: str | None = None ) -> t.Callable[[R], R]: """Mount an ASGI app to the service. diff --git a/src/_bentoml_sdk/gradio.py b/src/_bentoml_sdk/gradio.py index ef7bae7d9c8..a1c915a8dec 100644 --- a/src/_bentoml_sdk/gradio.py +++ b/src/_bentoml_sdk/gradio.py @@ -5,7 +5,7 @@ from bentoml.exceptions import MissingDependencyException -from .decorators import mount_asgi_app +from .decorators import asgi_app R = t.TypeVar("R") @@ -55,7 +55,7 @@ def decorator(obj: R) -> R: blocks.root_path = path blocks.favicon_path = favicon_path gradio_app = gr.routes.App.create_app(blocks, app_kwargs={"root_path": path}) - mount_asgi_app(gradio_app, path=path, name=name)(obj) + asgi_app(gradio_app, path=path, name=name)(obj) # @bentoml.service() decorator returns a wrapper instead of the original class # Check if the object is an instance of Service diff --git a/src/bentoml/__init__.py b/src/bentoml/__init__.py index 6fbcdf35dd6..54618fe0e3b 100644 --- a/src/bentoml/__init__.py +++ b/src/bentoml/__init__.py @@ -104,6 +104,7 @@ from _bentoml_impl.loader import importing from _bentoml_sdk import IODescriptor from _bentoml_sdk import api + from _bentoml_sdk import asgi_app from _bentoml_sdk import depends from _bentoml_sdk import get_current_service from _bentoml_sdk import images @@ -252,6 +253,7 @@ "depends", "on_shutdown", "on_deployment", + "asgi_app", "mount_asgi_app", "get_current_service", "IODescriptor", @@ -363,6 +365,7 @@ def __getattr__(name: str) -> Any: "validators", "Field", "get_current_service", + "asgi_app", "mount_asgi_app", # new implementation "SyncHTTPClient", diff --git a/src/bentoml/_internal/server/http/access.py b/src/bentoml/_internal/server/http/access.py index 9a7e122283d..ce240e01210 100644 --- a/src/bentoml/_internal/server/http/access.py +++ b/src/bentoml/_internal/server/http/access.py @@ -65,19 +65,24 @@ async def __call__( receive: ext.ASGIReceive, send: ext.ASGISend, ) -> None: - if not scope["type"].startswith("http"): + if not scope["type"].startswith(("http", "websocket")): await self.app(scope, receive, send) return start = default_timer() client = scope.get("client") scheme = scope["scheme"] - method = scope["method"] + method = scope.get("method", "") path = scope["path"] if path.startswith(tuple(self.skip_paths)): await self.app(scope, receive, send) return + if client: + address = f"{client[0]}:{client[1]}" + else: + address = "_" + if self.has_request_content_length or self.has_request_content_type: for key, value in scope["headers"]: if key == CONTENT_LENGTH: @@ -85,7 +90,23 @@ async def __call__( elif key == CONTENT_TYPE: request_content_type.set(value) + async def wrapped_receive() -> "ext.ASGIMessage": + message = await receive() + if message["type"] == "websocket.connect": + self.logger.info( + "%s (scheme=%s,path=%s) - Client connected", address, scheme, path + ) + elif message["type"] == "websocket.disconnect": + self.logger.info( + "%s (scheme=%s,path=%s) - Client disconnected", + address, + scheme, + path, + ) + return message + async def wrapped_send(message: "ext.ASGIMessage") -> None: + latency = max(default_timer() - start, 0) * 1000 if message["type"] == "http.response.start": status.set(message["status"]) if self.has_response_content_length or self.has_response_content_type: @@ -94,17 +115,23 @@ async def wrapped_send(message: "ext.ASGIMessage") -> None: response_content_length.set(value) elif key == CONTENT_TYPE: response_content_type.set(value) + elif message["type"] == "websocket.close": + self.logger.info( + "%s (scheme=%s,path=%s) - Connection closed", address, scheme, path + ) + elif message["type"] == "websocket.accept": + self.logger.info( + "%s (scheme=%s,path=%s) - Connection established", + address, + scheme, + path, + ) elif message["type"] == "http.response.body": if "more_body" in message and message["more_body"]: await send(message) return - if client: - address = f"{client[0]}:{client[1]}" - else: - address = "_" - request = [f"scheme={scheme}", f"method={method}", f"path={path}"] if self.has_request_content_type: request.append(f"type={request_content_type.get().decode()}") @@ -117,7 +144,6 @@ async def wrapped_send(message: "ext.ASGIMessage") -> None: if self.has_response_content_length: response.append(f"length={response_content_length.get().decode()}") - latency = max(default_timer() - start, 0) * 1000 await send(message) self.logger.info( "%s (%s) (%s) %.3fms", @@ -130,4 +156,4 @@ async def wrapped_send(message: "ext.ASGIMessage") -> None: await send(message) - await self.app(scope, receive, wrapped_send) + await self.app(scope, wrapped_receive, wrapped_send) diff --git a/src/bentoml/_internal/server/http/instruments.py b/src/bentoml/_internal/server/http/instruments.py index e5b717dd954..15bf7e6d26d 100644 --- a/src/bentoml/_internal/server/http/instruments.py +++ b/src/bentoml/_internal/server/http/instruments.py @@ -188,6 +188,25 @@ def _setup( labelnames=["endpoint", "service_name", "service_version", "runner_name"], multiprocess_mode="livesum", ) + self.metrics_websocket_connections = metrics_client.Gauge( + namespace=self.namespace, + name="websocket_connections", + documentation="Total number of websocket connections", + labelnames=["endpoint", "service_name", "service_version", "runner_name"], + multiprocess_mode="livesum", + ) + self.metrics_websocket_data_received = metrics_client.Summary( + namespace=self.namespace, + name="websocket_data_received", + documentation="Total number of bytes received from websocket", + labelnames=["endpoint", "service_name", "service_version", "runner_name"], + ) + self.metrics_websocket_data_sent = metrics_client.Summary( + namespace=self.namespace, + name="websocket_data_sent", + documentation="Total number of bytes sent to websocket", + labelnames=["endpoint", "service_name", "service_version", "runner_name"], + ) self._is_setup = True async def __call__( @@ -198,11 +217,11 @@ async def __call__( ) -> None: if not self._is_setup: self._setup() - if not scope["type"].startswith("http"): + if not scope["type"].startswith(("http", "websocket")): await self.app(scope, receive, send) return - if scope["path"] == "/metrics": + if scope["type"].startswith("http") and scope["path"] == "/metrics": from starlette.responses import Response response = Response( @@ -214,37 +233,81 @@ async def __call__( return endpoint = scope["path"] - START_TIME_VAR.set(default_timer()) + start_time = default_timer() + status_code = 0 + + async def wrapped_receive() -> "ext.ASGIMessage": + message = await receive() + if message["type"] == "websocket.disconnect": + self.metrics_websocket_connections.labels( + endpoint=endpoint, + service_name=server_context.bento_name, + service_version=server_context.bento_version, + runner_name=server_context.service_name, + ).dec() + elif message["type"] == "websocket.receive": + if message.get("bytes") is not None: + data_len = len(message["bytes"]) + else: + data_len = len(message["text"]) + self.metrics_websocket_data_received.labels( + endpoint=endpoint, + service_name=server_context.bento_name, + service_version=server_context.bento_version, + runner_name=server_context.service_name, + ).observe(data_len) + return message async def wrapped_send(message: "ext.ASGIMessage") -> None: + nonlocal status_code if message["type"] == "http.response.start": - STATUS_VAR.set(message["status"]) + status_code = message["status"] elif message["type"] == "http.response.body": if ("more_body" not in message) or not message["more_body"]: - assert START_TIME_VAR.get() != 0 - assert STATUS_VAR.get() != 0 - # instrument request total count self.metrics_request_total.labels( endpoint=endpoint, service_name=server_context.bento_name, service_version=server_context.bento_version, - http_response_code=STATUS_VAR.get(), + http_response_code=status_code, runner_name=server_context.service_name, ).inc() # instrument request duration - total_time = max(default_timer() - START_TIME_VAR.get(), 0) + total_time = max(default_timer() - start_time, 0) self.metrics_request_duration.labels( # type: ignore endpoint=endpoint, service_name=server_context.bento_name, service_version=server_context.bento_version, - http_response_code=STATUS_VAR.get(), + http_response_code=status_code, runner_name=server_context.service_name, ).observe(total_time) + elif message["type"] == "websocket.send": + if message.get("bytes") is not None: + data_len = len(message["bytes"]) + else: + data_len = len(message["text"]) + self.metrics_websocket_data_sent.labels( + endpoint=endpoint, + service_name=server_context.bento_name, + service_version=server_context.bento_version, + runner_name=server_context.service_name, + ).observe(data_len) + elif message["type"] == "websocket.accept": + self.metrics_websocket_connections.labels( + endpoint=endpoint, + service_name=server_context.bento_name, + service_version=server_context.bento_version, + runner_name=server_context.service_name, + ).inc() + elif message["type"] == "websocket.close": + self.metrics_websocket_connections.labels( + endpoint=endpoint, + service_name=server_context.bento_name, + service_version=server_context.bento_version, + runner_name=server_context.service_name, + ).dec() - START_TIME_VAR.set(0) - STATUS_VAR.set(0) await send(message) with self.metrics_request_in_progress.labels( @@ -253,5 +316,5 @@ async def wrapped_send(message: "ext.ASGIMessage") -> None: service_version=server_context.bento_version, runner_name=server_context.service_name, ).track_inprogress(): - await self.app(scope, receive, wrapped_send) + await self.app(scope, wrapped_receive, wrapped_send) return diff --git a/tests/unit/bentoml_io/test_decorators.py b/tests/unit/bentoml_io/test_decorators.py index 9f8a9b3a178..f3c3b44f9fd 100644 --- a/tests/unit/bentoml_io/test_decorators.py +++ b/tests/unit/bentoml_io/test_decorators.py @@ -18,7 +18,7 @@ def test_mount_asgi_app(): app = FastAPI() - @bentoml.mount_asgi_app(app, path="/test") + @bentoml.asgi_app(app, path="/test") @bentoml.service(metrics={"enabled": False}) class TestService: @app.get("/hello") @@ -37,7 +37,7 @@ def test_mount_asgi_app_later(): app = FastAPI() @bentoml.service(metrics={"enabled": False}) - @bentoml.mount_asgi_app(app, path="/test") + @bentoml.asgi_app(app, path="/test") class TestService: @app.get("/hello") def hello(self):