Skip to content
This repository has been archived by the owner on Oct 11, 2024. It is now read-only.

Use random port for backend #390

Merged
merged 3 commits into from
Jul 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
OpenAIServingTokenization)
from vllm.logger import init_logger
from vllm.usage.usage_lib import UsageContext
from vllm.utils import FlexibleArgumentParser
from vllm.utils import FlexibleArgumentParser, get_open_port
from vllm.version import __version__ as VLLM_VERSION

TIMEOUT_KEEP_ALIVE = 5 # seconds
Expand Down Expand Up @@ -107,15 +107,18 @@ async def build_backend(args) -> AsyncIterator[VLLMBackend]:
else:
# remote backend
## First need to start the backend process
port = get_open_port(envs.VLLM_RPC_PORT)
rpc_server_process = Process(target=run_rpc_server,
args=(engine_args,
UsageContext.OPENAI_API_SERVER))
args=(engine_args,
UsageContext.OPENAI_API_SERVER,
port))
rpc_server_process.start()

## Then build the client for the backend process
# TODO: figure out a way around passing the tokenizer
backend = RPCClient(
tokenizer=AutoTokenizer.from_pretrained(args.model))
backend = RPCClient(tokenizer=AutoTokenizer.from_pretrained(
args.model),
port=port)
await backend.wait_for_server()

try:
Expand Down
1 change: 0 additions & 1 deletion vllm/entrypoints/openai/rpc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams

VLLM_RPC_PATH = "tcp://localhost:5570"
VLLM_RPC_SUCCESS_STR = "SUCCESS"


Expand Down
11 changes: 6 additions & 5 deletions vllm/entrypoints/openai/rpc/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import zmq.asyncio

from vllm.config import DecodingConfig, ModelConfig
from vllm.entrypoints.openai.rpc import (RPC_REQUEST_TYPE, VLLM_RPC_PATH,
from vllm.entrypoints.openai.rpc import (RPC_REQUEST_TYPE,
VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
RPCGenerateRequest, RPCUtilityRequest)
from vllm.inputs import PromptInputs
Expand All @@ -18,13 +18,14 @@
class RPCClient:

# TODO: check if opening all these sockets is an antipattern?
def __init__(self, tokenizer):
def __init__(self, tokenizer, port: int):
# ZMQ context.
self.context = zmq.asyncio.Context()

# TODO: do the tokenizer properly.
self.tokenizer = tokenizer
self.decoding_config = DecodingConfig()
self.path = f"tcp://localhost:{port}"

def close(self):
"""Destroy the ZeroMQ Context."""
Expand All @@ -36,7 +37,7 @@ async def _send_one_way_rpc_request(self, request: RPC_REQUEST_TYPE,

# Connect to socket.
socket = self.context.socket(zmq.constants.DEALER)
socket.connect(VLLM_RPC_PATH)
socket.connect(self.path)

# Ping RPC Server with request.
socket.send(pickle.dumps(request, pickle.HIGHEST_PROTOCOL))
Expand Down Expand Up @@ -76,7 +77,7 @@ async def get_model_config(self) -> ModelConfig:

# Connect to socket.
socket = self.context.socket(zmq.constants.DEALER)
socket.connect(VLLM_RPC_PATH)
socket.connect(self.path)

# Ping RPCServer with GET_MODEL_CONFIG request.
socket.send(pickle.dumps(RPCUtilityRequest.GET_MODEL_CONFIG))
Expand Down Expand Up @@ -122,7 +123,7 @@ async def generate(
# Note that we use DEALER to enable asynchronous communication
# to enable streaming.
socket = self.context.socket(zmq.constants.DEALER)
socket.connect(VLLM_RPC_PATH)
socket.connect(self.path)

# Send RPCGenerateRequest to the RPCServer.
socket.send_multipart([
Expand Down
20 changes: 9 additions & 11 deletions vllm/entrypoints/openai/rpc/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,8 @@
from typing_extensions import Never

from vllm import AsyncEngineArgs, AsyncLLMEngine
from vllm.entrypoints.openai.rpc import (VLLM_RPC_PATH, VLLM_RPC_SUCCESS_STR,
RPCAbortRequest, RPCGenerateRequest,
RPCUtilityRequest)
from vllm.entrypoints.openai.rpc import (VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
RPCGenerateRequest, RPCUtilityRequest)
from vllm.logger import init_logger
from vllm.usage.usage_lib import UsageContext

Expand All @@ -23,7 +22,7 @@ class RPCServer:
# Alternative, use a smaller number of sockets with conditioning on the
# data that is passed through the socket.
def __init__(self, async_engine_args: AsyncEngineArgs,
usage_context: UsageContext):
usage_context: UsageContext, port: int):
# Initialize engine first.
self.engine = AsyncLLMEngine.from_engine_args(async_engine_args,
usage_context)
Expand All @@ -33,7 +32,7 @@ def __init__(self, async_engine_args: AsyncEngineArgs,

# Init socket for readiness state.
self.socket = self.context.socket(zmq.constants.ROUTER)
self.socket.bind(VLLM_RPC_PATH)
self.socket.bind(f"tcp://localhost:{port}")

def cleanup(self):
"""Cleanup all resources."""
Expand All @@ -51,10 +50,9 @@ async def get_model_config(self, identity):
"""Send the ModelConfig """
model_config = await self.engine.get_model_config()

self.socket.send_multipart([
identity,
pickle.dumps(model_config, pickle.HIGHEST_PROTOCOL)
])
self.socket.send_multipart(
[identity,
pickle.dumps(model_config, pickle.HIGHEST_PROTOCOL)])

async def do_log_stats(self, identity):
await self.engine.do_log_stats()
Expand Down Expand Up @@ -166,6 +164,6 @@ def signal_handler() -> None:


def run_rpc_server(async_engine_args: AsyncEngineArgs,
usage_context: UsageContext):
server = RPCServer(async_engine_args, usage_context)
usage_context: UsageContext, port: int):
server = RPCServer(async_engine_args, usage_context, port)
asyncio.run(run_server(server))
6 changes: 6 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
if TYPE_CHECKING:
VLLM_HOST_IP: str = ""
VLLM_PORT: Optional[int] = None
VLLM_RPC_PORT: int = 5570
VLLM_USE_MODELSCOPE: bool = False
VLLM_RINGBUFFER_WARNING_INTERVAL: int = 60
VLLM_INSTANCE_ID: Optional[str] = None
Expand Down Expand Up @@ -142,6 +143,11 @@ def get_default_config_root():
lambda: int(os.getenv('VLLM_PORT', '0'))
if 'VLLM_PORT' in os.environ else None,

# used when the frontend api server is running in multi-processing mode,
# to communicate with the backend engine process over ZMQ.
'VLLM_RPC_PORT':
lambda: int(os.getenv('VLLM_PORT', '5570')),

# If true, will load models from ModelScope instead of Hugging Face Hub.
# note that the value is true or false, not numbers
"VLLM_USE_MODELSCOPE":
Expand Down
6 changes: 4 additions & 2 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,8 +384,10 @@ def get_distributed_init_method(ip: str, port: int) -> str:
return f"tcp://[{ip}]:{port}" if ":" in ip else f"tcp://{ip}:{port}"


def get_open_port() -> int:
port = envs.VLLM_PORT
def get_open_port(port: Optional[int] = None) -> int:
if port is None:
# Default behavior here is to return a port for multi-gpu communication
port = envs.VLLM_PORT
if port is not None:
while True:
try:
Expand Down