Skip to content

Commit

Permalink
fix: add basic metrics and access logs for websocket (#5109)
Browse files Browse the repository at this point in the history
* feat: rename mount_asgi_app to asgi_app

Signed-off-by: Frost Ming <me@frostming.com>

* fix: add basic metrics for websocket

Signed-off-by: Frost Ming <me@frostming.com>

---------

Signed-off-by: Frost Ming <me@frostming.com>
  • Loading branch information
frostming authored Dec 5, 2024
1 parent d04dd47 commit e481037
Show file tree
Hide file tree
Showing 7 changed files with 137 additions and 28 deletions.
19 changes: 18 additions & 1 deletion src/_bentoml_sdk/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
from __future__ import annotations
from bentoml._internal.utils.pkg import pkg_version_info
from typing_extensions import deprecated
from typing import TypeVar, Callable, TYPE_CHECKING

if TYPE_CHECKING:
from bentoml._internal.external_typing import ASGIApp

if (ver := pkg_version_info("pydantic")) < (2,):
raise ImportError(
Expand All @@ -8,19 +14,30 @@

# ruff: noqa

from .decorators import api, on_shutdown, mount_asgi_app, on_deployment, task
from .decorators import api, on_shutdown, asgi_app, on_deployment, task
from .service import get_current_service
from .service import depends
from .service import Service, ServiceConfig
from .service import service
from .service import runner_service
from .io_models import IODescriptor

T = TypeVar("T")


@deprecated("Deprecated in favor of `bentoml.asgi_app`")
def mount_asgi_app(
app: ASGIApp, *, path: str = "/", name: str | None = None
) -> Callable[[T], T]:
return asgi_app(app, path=path, name=name)


__all__ = [
"api",
"task",
"on_shutdown",
"on_deployment",
"asgi_app",
"mount_asgi_app",
"depends",
"Service",
Expand Down
2 changes: 1 addition & 1 deletion src/_bentoml_sdk/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def wrapper(func: t.Callable[t.Concatenate[t.Any, P], R]) -> APIMethod[P, R]:
return wrapper


def mount_asgi_app(
def asgi_app(
app: ASGIApp, *, path: str = "/", name: str | None = None
) -> t.Callable[[R], R]:
"""Mount an ASGI app to the service.
Expand Down
4 changes: 2 additions & 2 deletions src/_bentoml_sdk/gradio.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from bentoml.exceptions import MissingDependencyException

from .decorators import mount_asgi_app
from .decorators import asgi_app

R = t.TypeVar("R")

Expand Down Expand Up @@ -55,7 +55,7 @@ def decorator(obj: R) -> R:
blocks.root_path = path
blocks.favicon_path = favicon_path
gradio_app = gr.routes.App.create_app(blocks, app_kwargs={"root_path": path})
mount_asgi_app(gradio_app, path=path, name=name)(obj)
asgi_app(gradio_app, path=path, name=name)(obj)

# @bentoml.service() decorator returns a wrapper instead of the original class
# Check if the object is an instance of Service
Expand Down
3 changes: 3 additions & 0 deletions src/bentoml/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@
from _bentoml_impl.loader import importing
from _bentoml_sdk import IODescriptor
from _bentoml_sdk import api
from _bentoml_sdk import asgi_app
from _bentoml_sdk import depends
from _bentoml_sdk import get_current_service
from _bentoml_sdk import images
Expand Down Expand Up @@ -252,6 +253,7 @@
"depends",
"on_shutdown",
"on_deployment",
"asgi_app",
"mount_asgi_app",
"get_current_service",
"IODescriptor",
Expand Down Expand Up @@ -363,6 +365,7 @@ def __getattr__(name: str) -> Any:
"validators",
"Field",
"get_current_service",
"asgi_app",
"mount_asgi_app",
# new implementation
"SyncHTTPClient",
Expand Down
44 changes: 35 additions & 9 deletions src/bentoml/_internal/server/http/access.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,27 +65,48 @@ async def __call__(
receive: ext.ASGIReceive,
send: ext.ASGISend,
) -> None:
if not scope["type"].startswith("http"):
if not scope["type"].startswith(("http", "websocket")):
await self.app(scope, receive, send)
return

start = default_timer()
client = scope.get("client")
scheme = scope["scheme"]
method = scope["method"]
method = scope.get("method", "")
path = scope["path"]
if path.startswith(tuple(self.skip_paths)):
await self.app(scope, receive, send)
return

if client:
address = f"{client[0]}:{client[1]}"
else:
address = "_"

if self.has_request_content_length or self.has_request_content_type:
for key, value in scope["headers"]:
if key == CONTENT_LENGTH:
request_content_length.set(value)
elif key == CONTENT_TYPE:
request_content_type.set(value)

async def wrapped_receive() -> "ext.ASGIMessage":
message = await receive()
if message["type"] == "websocket.connect":
self.logger.info(
"%s (scheme=%s,path=%s) - Client connected", address, scheme, path
)
elif message["type"] == "websocket.disconnect":
self.logger.info(
"%s (scheme=%s,path=%s) - Client disconnected",
address,
scheme,
path,
)
return message

async def wrapped_send(message: "ext.ASGIMessage") -> None:
latency = max(default_timer() - start, 0) * 1000
if message["type"] == "http.response.start":
status.set(message["status"])
if self.has_response_content_length or self.has_response_content_type:
Expand All @@ -94,17 +115,23 @@ async def wrapped_send(message: "ext.ASGIMessage") -> None:
response_content_length.set(value)
elif key == CONTENT_TYPE:
response_content_type.set(value)
elif message["type"] == "websocket.close":
self.logger.info(
"%s (scheme=%s,path=%s) - Connection closed", address, scheme, path
)
elif message["type"] == "websocket.accept":
self.logger.info(
"%s (scheme=%s,path=%s) - Connection established",
address,
scheme,
path,
)

elif message["type"] == "http.response.body":
if "more_body" in message and message["more_body"]:
await send(message)
return

if client:
address = f"{client[0]}:{client[1]}"
else:
address = "_"

request = [f"scheme={scheme}", f"method={method}", f"path={path}"]
if self.has_request_content_type:
request.append(f"type={request_content_type.get().decode()}")
Expand All @@ -117,7 +144,6 @@ async def wrapped_send(message: "ext.ASGIMessage") -> None:
if self.has_response_content_length:
response.append(f"length={response_content_length.get().decode()}")

latency = max(default_timer() - start, 0) * 1000
await send(message)
self.logger.info(
"%s (%s) (%s) %.3fms",
Expand All @@ -130,4 +156,4 @@ async def wrapped_send(message: "ext.ASGIMessage") -> None:

await send(message)

await self.app(scope, receive, wrapped_send)
await self.app(scope, wrapped_receive, wrapped_send)
89 changes: 76 additions & 13 deletions src/bentoml/_internal/server/http/instruments.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,25 @@ def _setup(
labelnames=["endpoint", "service_name", "service_version", "runner_name"],
multiprocess_mode="livesum",
)
self.metrics_websocket_connections = metrics_client.Gauge(
namespace=self.namespace,
name="websocket_connections",
documentation="Total number of websocket connections",
labelnames=["endpoint", "service_name", "service_version", "runner_name"],
multiprocess_mode="livesum",
)
self.metrics_websocket_data_received = metrics_client.Summary(
namespace=self.namespace,
name="websocket_data_received",
documentation="Total number of bytes received from websocket",
labelnames=["endpoint", "service_name", "service_version", "runner_name"],
)
self.metrics_websocket_data_sent = metrics_client.Summary(
namespace=self.namespace,
name="websocket_data_sent",
documentation="Total number of bytes sent to websocket",
labelnames=["endpoint", "service_name", "service_version", "runner_name"],
)
self._is_setup = True

async def __call__(
Expand All @@ -198,11 +217,11 @@ async def __call__(
) -> None:
if not self._is_setup:
self._setup()
if not scope["type"].startswith("http"):
if not scope["type"].startswith(("http", "websocket")):
await self.app(scope, receive, send)
return

if scope["path"] == "/metrics":
if scope["type"].startswith("http") and scope["path"] == "/metrics":
from starlette.responses import Response

response = Response(
Expand All @@ -214,37 +233,81 @@ async def __call__(
return

endpoint = scope["path"]
START_TIME_VAR.set(default_timer())
start_time = default_timer()
status_code = 0

async def wrapped_receive() -> "ext.ASGIMessage":
message = await receive()
if message["type"] == "websocket.disconnect":
self.metrics_websocket_connections.labels(
endpoint=endpoint,
service_name=server_context.bento_name,
service_version=server_context.bento_version,
runner_name=server_context.service_name,
).dec()
elif message["type"] == "websocket.receive":
if message.get("bytes") is not None:
data_len = len(message["bytes"])
else:
data_len = len(message["text"])
self.metrics_websocket_data_received.labels(
endpoint=endpoint,
service_name=server_context.bento_name,
service_version=server_context.bento_version,
runner_name=server_context.service_name,
).observe(data_len)
return message

async def wrapped_send(message: "ext.ASGIMessage") -> None:
nonlocal status_code
if message["type"] == "http.response.start":
STATUS_VAR.set(message["status"])
status_code = message["status"]
elif message["type"] == "http.response.body":
if ("more_body" not in message) or not message["more_body"]:
assert START_TIME_VAR.get() != 0
assert STATUS_VAR.get() != 0

# instrument request total count
self.metrics_request_total.labels(
endpoint=endpoint,
service_name=server_context.bento_name,
service_version=server_context.bento_version,
http_response_code=STATUS_VAR.get(),
http_response_code=status_code,
runner_name=server_context.service_name,
).inc()

# instrument request duration
total_time = max(default_timer() - START_TIME_VAR.get(), 0)
total_time = max(default_timer() - start_time, 0)
self.metrics_request_duration.labels( # type: ignore
endpoint=endpoint,
service_name=server_context.bento_name,
service_version=server_context.bento_version,
http_response_code=STATUS_VAR.get(),
http_response_code=status_code,
runner_name=server_context.service_name,
).observe(total_time)
elif message["type"] == "websocket.send":
if message.get("bytes") is not None:
data_len = len(message["bytes"])
else:
data_len = len(message["text"])
self.metrics_websocket_data_sent.labels(
endpoint=endpoint,
service_name=server_context.bento_name,
service_version=server_context.bento_version,
runner_name=server_context.service_name,
).observe(data_len)
elif message["type"] == "websocket.accept":
self.metrics_websocket_connections.labels(
endpoint=endpoint,
service_name=server_context.bento_name,
service_version=server_context.bento_version,
runner_name=server_context.service_name,
).inc()
elif message["type"] == "websocket.close":
self.metrics_websocket_connections.labels(
endpoint=endpoint,
service_name=server_context.bento_name,
service_version=server_context.bento_version,
runner_name=server_context.service_name,
).dec()

START_TIME_VAR.set(0)
STATUS_VAR.set(0)
await send(message)

with self.metrics_request_in_progress.labels(
Expand All @@ -253,5 +316,5 @@ async def wrapped_send(message: "ext.ASGIMessage") -> None:
service_version=server_context.bento_version,
runner_name=server_context.service_name,
).track_inprogress():
await self.app(scope, receive, wrapped_send)
await self.app(scope, wrapped_receive, wrapped_send)
return
4 changes: 2 additions & 2 deletions tests/unit/bentoml_io/test_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def test_mount_asgi_app():

app = FastAPI()

@bentoml.mount_asgi_app(app, path="/test")
@bentoml.asgi_app(app, path="/test")
@bentoml.service(metrics={"enabled": False})
class TestService:
@app.get("/hello")
Expand All @@ -37,7 +37,7 @@ def test_mount_asgi_app_later():
app = FastAPI()

@bentoml.service(metrics={"enabled": False})
@bentoml.mount_asgi_app(app, path="/test")
@bentoml.asgi_app(app, path="/test")
class TestService:
@app.get("/hello")
def hello(self):
Expand Down

0 comments on commit e481037

Please sign in to comment.