Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🥅 Kill servers on engine death #63

Merged
merged 5 commits into from
Aug 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 9 additions & 20 deletions src/vllm_tgis_adapter/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@

import asyncio
import contextlib
import signal
from concurrent.futures import FIRST_EXCEPTION
from concurrent.futures import FIRST_COMPLETED
from typing import TYPE_CHECKING

import uvloop
Expand Down Expand Up @@ -35,6 +34,7 @@ async def start_servers(args: argparse.Namespace) -> None:
run_http_server(args, engine),
name="http_server",
)
# The http server task will catch interrupt signals for us
tasks.append(http_server_task)

grpc_server_task = loop.create_task(
Expand All @@ -43,28 +43,17 @@ async def start_servers(args: argparse.Namespace) -> None:
)
tasks.append(grpc_server_task)

def signal_handler() -> None:
# prevents the uvicorn signal handler to exit early
for task in tasks:
task.cancel()

async def override_signal_handler() -> None:
loop = asyncio.get_running_loop()

for sig in (signal.SIGINT, signal.SIGTERM):
loop.add_signal_handler(sig, signal_handler)

tasks.append(loop.create_task(override_signal_handler()))

with contextlib.suppress(asyncio.CancelledError):
await asyncio.wait(
tasks,
return_when=FIRST_EXCEPTION,
)

# Both server tasks will exit normally on shutdown, so we await
# FIRST_COMPLETED to catch either one shutting down.
joerunde marked this conversation as resolved.
Show resolved Hide resolved
await asyncio.wait(tasks, return_when=FIRST_COMPLETED)
# Once either server shuts down, cancel the other
for task in tasks:
task.cancel()

# Final wait for both servers to finish
await asyncio.wait(tasks)

check_for_failed_tasks(tasks)


Expand Down
36 changes: 29 additions & 7 deletions src/vllm_tgis_adapter/grpc/grpc_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,19 @@ async def _handle_exception(
context: ServicerContext = kwargs.get("context", None) or args[-1]
is_generate_fn = "generate" in func.__name__.lower()

# self.engine on the servicer
engine = args[0].engine
# If the engine has died, then the server cannot process any further
# requests. We want to immediately stop the process in this case to avoid
# any downtime while waiting for probes to fail.
if engine.errored and not engine.is_running:
# awaiting a server.stop() in here won't work because we're in
# the context of a running request.
# Instead we set an event to signal another coroutine to stop the
# server.
stop_event = args[0].stop_event
stop_event.set()

# First just try to replicate the TGIS-style log messages
# for generate_* rpcs
if is_generate_fn:
Expand Down Expand Up @@ -167,8 +180,10 @@ def __init__(
engine: AsyncEngineClient | AsyncLLMEngine,
args: argparse.Namespace,
health_servicer: health.HealthServicer,
stop_event: asyncio.Event,
):
self.engine: AsyncEngineClient = engine
self.stop_event = stop_event

# This is set in post_init()
self.config: ModelConfig | None = None
Expand Down Expand Up @@ -870,13 +885,14 @@ async def ModelInfo(
async def start_grpc_server(
args: argparse.Namespace,
engine: AsyncLLMEngine | AsyncEngineClient,
stop_event: asyncio.Event,
) -> aio.Server:
server = aio.server()

health_servicer = health.HealthServicer()
health_pb2_grpc.add_HealthServicer_to_server(health_servicer, server)

generation = TextGenerationService(engine, args, health_servicer)
generation = TextGenerationService(engine, args, health_servicer, stop_event)
await generation.post_init()
generation_pb2_grpc.add_GenerationServiceServicer_to_server(generation, server)

Expand Down Expand Up @@ -936,14 +952,20 @@ async def run_grpc_server(
args: argparse.Namespace,
engine: AsyncEngineClient | AsyncLLMEngine,
) -> None:
server = await start_grpc_server(
args,
engine,
)
stop_event = asyncio.Event()
server = await start_grpc_server(args, engine, stop_event)

# Add a task to watch for the stop event, so that the server can kill
# itself from within its own handlers
async def wait_for_server_shutdown() -> None:
await stop_event.wait()
# Kill with no grace period because the engine is dead
await server.stop(0)

try:
while True:
await asyncio.sleep(10)
# Either the server stops itself,
# Or the task running this coroutine gets cancelled
await wait_for_server_shutdown()
except asyncio.CancelledError:
print("Gracefully stopping gRPC server") # noqa: T201
await server.stop(30) # TODO configurable grace
Expand Down
62 changes: 23 additions & 39 deletions src/vllm_tgis_adapter/http.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from __future__ import annotations

import asyncio
from typing import TYPE_CHECKING

import uvicorn
from vllm.entrypoints.launcher import serve_http
from vllm.entrypoints.openai.api_server import (
init_app,
)
Expand All @@ -12,7 +11,6 @@
if TYPE_CHECKING:
import argparse

from fastapi import FastAPI
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.engine.protocol import AsyncEngineClient

Expand All @@ -21,30 +19,6 @@
logger = init_logger(__name__)


async def serve_http(
app: FastAPI,
**uvicorn_kwargs, # noqa: ANN003
) -> None:
logger.info("Available routes are:")
for route in app.routes:
methods = getattr(route, "methods", None)
path = getattr(route, "path", None)

if methods is None or path is None:
continue

logger.info("Route: %s, Methods: %s", path, ", ".join(methods))

config = uvicorn.Config(app, **uvicorn_kwargs)
server = uvicorn.Server(config)

try:
await server.serve()
except asyncio.CancelledError:
logger.info("Gracefully stopping http server")
await server.shutdown()


async def run_http_server(
args: argparse.Namespace,
engine: AsyncLLMEngine | AsyncEngineClient,
Expand All @@ -55,15 +29,25 @@ async def run_http_server(

app = await init_app(engine, args) # type: ignore[arg-type]

await serve_http(
app,
host=args.host,
port=args.port,
log_level=args.uvicorn_log_level,
timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
ssl_keyfile=args.ssl_keyfile,
ssl_certfile=args.ssl_certfile,
ssl_ca_certs=args.ssl_ca_certs,
ssl_cert_reqs=args.ssl_cert_reqs,
**uvicorn_kwargs,
)
serve_kwargs = {
"host": args.host,
"port": args.port,
"log_level": args.uvicorn_log_level,
"timeout_keep_alive": TIMEOUT_KEEP_ALIVE,
"ssl_keyfile": args.ssl_keyfile,
"ssl_certfile": args.ssl_certfile,
"ssl_ca_certs": args.ssl_ca_certs,
"ssl_cert_reqs": args.ssl_cert_reqs,
}
serve_kwargs.update(uvicorn_kwargs)

try:
shutdown_coro = await serve_http(app, engine, **serve_kwargs)
except TypeError:
# vllm 0.5.4 backwards compatibility
# HTTP server will not shut itself down when the engine dies
shutdown_coro = await serve_http(app, **serve_kwargs)

# launcher.serve_http returns a shutdown coroutine to await
# (The double await is intentional)
await shutdown_coro