Skip to content
This repository has been archived by the owner on Oct 11, 2024. It is now read-only.

Socket context #393

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,7 @@ async def run_server(args, **uvicorn_kwargs) -> None:
logger.info("vLLM API server version %s", VLLM_VERSION)
logger.info("args: %s", args)

shutdown_task = None
async with build_backend(args) as backend:

server = await build_server(
Expand All @@ -383,7 +384,11 @@ def signal_handler() -> None:
await server_task
except asyncio.CancelledError:
logger.info("Gracefully stopping http server")
await server.shutdown()
shutdown_task = server.shutdown()

if shutdown_task:
# NB: Await server shutdown only after the backend context is exited
await shutdown_task


if __name__ == "__main__":
Expand Down
125 changes: 60 additions & 65 deletions vllm/entrypoints/openai/rpc/client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pickle
from contextlib import contextmanager
from typing import Any, AsyncIterator, Mapping, Optional

import zmq
Expand Down Expand Up @@ -47,52 +48,55 @@ def close(self):
"""Destroy the ZeroMQ Context."""
self.context.destroy()

@contextmanager
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this need to be asynccontextmanager?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so because it doesn't use any async methods or do any blocking

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on this I think it's just poll / send / recv that need to be handled with async:

https://pyzmq.readthedocs.io/en/latest/api/zmq.asyncio.html#socket

def socket(self):
# Ensure client sockets are always closed after use

# Connect to RPC socket for Request-Reply pattern,
# Note that we use DEALER to enable asynchronous communication
# to enable streaming.
socket = self.context.socket(zmq.constants.DEALER)
try:
socket.connect(self.path)
yield socket
finally:
socket.close()

async def _send_get_data_rpc_request(self, request: RPCUtilityRequest,
expected_type: Any,
error_message: str) -> Any:
"""Send an RPC request that is expecting data back."""

# Connect to socket.
socket = self.context.socket(zmq.constants.DEALER)
socket.connect(self.path)
with self.socket() as socket:

# Ping RPCServer with a request.
await socket.send(pickle.dumps(request))

# Ping RPCServer with a request.
await socket.send(pickle.dumps(request))
# Await the data from the Server.
data = pickle.loads(await socket.recv())

# Await the data from the Server.
data = pickle.loads(await socket.recv())
if not isinstance(data, expected_type):
# LoRAConfig can be None.
if expected_type == LoRAConfig and data is None:
pass
else:
socket.close()
raise ValueError(error_message)

socket.close()

return data

async def _send_one_way_rpc_request(self, request: RPC_REQUEST_TYPE,
error_message: str):
"""Send one-way RPC request to trigger an action."""
with self.socket() as socket:
# Ping RPC Server with request.
await socket.send(pickle.dumps(request, pickle.HIGHEST_PROTOCOL))

# Connect to socket.
socket = self.context.socket(zmq.constants.DEALER)
socket.connect(self.path)

# Ping RPC Server with request.
await socket.send(pickle.dumps(request, pickle.HIGHEST_PROTOCOL))

# Await acknowledgement from RPCServer.
response = pickle.loads(await socket.recv())
# Await acknowledgement from RPCServer.
response = pickle.loads(await socket.recv())

if not isinstance(response, str) or response != VLLM_RPC_SUCCESS_STR:
socket.close()
raise ValueError(error_message)

socket.close()

return response

async def get_tokenizer(self, lora_request: LoRARequest):
Expand Down Expand Up @@ -180,56 +184,47 @@ async def generate(
) -> AsyncIterator[RequestOutput]:
"""Send an RPCGenerateRequest to the RPCServer and stream responses."""

# Connect to RPC socket for Request-Reply pattern,
# Note that we use DEALER to enable asynchronous communication
# to enable streaming.
socket = self.context.socket(zmq.constants.DEALER)
socket.connect(self.path)

# Send RPCGenerateRequest to the RPCServer.
await socket.send_multipart([
pickle.dumps(
RPCGenerateRequest(
inputs=inputs,
sampling_params=sampling_params,
request_id=request_id,
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request),
pickle.HIGHEST_PROTOCOL)
])

# Stream back the results from the RPC Server.
while True:
message = await socket.recv()
request_output = pickle.loads(message)

if isinstance(request_output, Exception):
socket.close()
raise request_output

if request_output.finished:
break
yield request_output
with self.socket() as socket:

# Send RPCGenerateRequest to the RPCServer.
await socket.send_multipart([
pickle.dumps(
RPCGenerateRequest(
inputs=inputs,
sampling_params=sampling_params,
request_id=request_id,
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request),
pickle.HIGHEST_PROTOCOL)
])

# Stream back the results from the RPC Server.
while True:
message = await socket.recv()
request_output = pickle.loads(message)

if isinstance(request_output, Exception):
raise request_output

if request_output.finished:
break
yield request_output

yield request_output
socket.close()
yield request_output

async def check_health(self) -> None:
"""Raise if unhealthy"""

# Connect to socket.
socket = self.context.socket(zmq.constants.DEALER)
socket.connect(self.path)
with self.socket() as socket:

# Ping RPCServer with CHECK_HEALTH request.
await socket.send(pickle.dumps(RPCUtilityRequest.CHECK_HEALTH))
# Ping RPCServer with CHECK_HEALTH request.
await socket.send(pickle.dumps(RPCUtilityRequest.CHECK_HEALTH))

# Await the reply from the server.
# TODO: do we need an internal timeout here?
# Or do we expect the external probe to timeout and let this chill?
health_message = pickle.loads(await socket.recv())
socket.close()
# Await the reply from the server.
# TODO: do we need an internal timeout here?
# Or do we expect the external probe to timeout and let this chill?
health_message = pickle.loads(await socket.recv())

if isinstance(health_message, Exception):
raise health_message
Expand Down
6 changes: 3 additions & 3 deletions vllm/transformers_utils/tokenizer_group/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@


def init_tokenizer_from_configs(model_config: ModelConfig,
scheduler_config: SchedulerConfig,
parallel_config: ParallelConfig,
enable_lora: bool):
scheduler_config: SchedulerConfig,
parallel_config: ParallelConfig,
enable_lora: bool):
init_kwargs = dict(tokenizer_id=model_config.tokenizer,
enable_lora=enable_lora,
max_num_seqs=scheduler_config.max_num_seqs,
Expand Down