Skip to content
This repository has been archived by the owner on Oct 11, 2024. It is now read-only.

Commit

Permalink
✨ pipe tracing flag (#400)
Browse files Browse the repository at this point in the history
(plus rounding out the protocol with an error on `.encode`)

---------

Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>
  • Loading branch information
joerunde authored Aug 1, 2024
1 parent 1bdbfcb commit f3c0f1c
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 7 deletions.
1 change: 1 addition & 0 deletions vllm/entrypoints/openai/rpc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class RPCUtilityRequest(Enum):
GET_LORA_CONFIG = 6
DO_LOG_STATS = 7
CHECK_HEALTH = 8
IS_TRACING_ENABLED = 9


RPC_REQUEST_TYPE = Union[RPCGenerateRequest, RPCAbortRequest,
Expand Down
28 changes: 21 additions & 7 deletions vllm/entrypoints/openai/rpc/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
RPCGenerateRequest, RPCUtilityRequest)
from vllm.inputs import PromptInputs
from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
Expand All @@ -34,6 +34,7 @@ async def setup(self):
# Get the configs.
self.model_config = await self._get_model_config_rpc()
self.decoding_config = await self._get_decoding_config_rpc()
self.tracing_flag = await self._is_tracing_enabled_rpc()

# Create the tokenizer group.
# TODO: refactor OAI server to avoid needing this info.
Expand Down Expand Up @@ -102,15 +103,14 @@ async def _send_one_way_rpc_request(self, request: RPC_REQUEST_TYPE,
async def get_tokenizer(self, lora_request: LoRARequest):
return await self.tokenizer.get_lora_tokenizer_async(lora_request)

async def get_decoding_config(self):
async def get_decoding_config(self) -> DecodingConfig:
return self.decoding_config

async def get_model_config(self):
async def get_model_config(self) -> ModelConfig:
return self.model_config

async def is_tracing_enabled(self):
# TODO: what is this?
return False
async def is_tracing_enabled(self) -> bool:
return self.tracing_flag

async def wait_for_server(self):
"""Wait for the RPCServer to start up."""
Expand Down Expand Up @@ -141,7 +141,7 @@ async def _get_parallel_config_rpc(self) -> ParallelConfig:
return await self._send_get_data_rpc_request(
RPCUtilityRequest.GET_PARALLEL_CONFIG,
expected_type=ParallelConfig,
error_message="Could not get ModelConfig from RPC Server")
error_message="Could not get ParallelConfig from RPC Server")

async def _get_scheduler_config_rpc(self) -> SchedulerConfig:
"""Get SchedulerConfig from the RPCServer"""
Expand All @@ -159,6 +159,15 @@ async def _get_lora_config_rpc(self):
expected_type=LoRAConfig,
error_message="Could not get LoRAConfig from RPC Server")

async def _is_tracing_enabled_rpc(self) -> ParallelConfig:
"""Get is_tracing_enabled flag from the RPCServer"""

return await self._send_get_data_rpc_request(
RPCUtilityRequest.IS_TRACING_ENABLED,
expected_type=bool,
error_message="Could not get is_tracing_enabled flag from RPC "
"Server")

async def abort(self, request_id: str):
"""Send an ABORT_REQUEST signal to the RPC Server"""

Expand Down Expand Up @@ -232,3 +241,8 @@ async def check_health(self) -> None:
if health_message != VLLM_RPC_HEALTHY_STR:
raise ValueError("Expected healthy response from backend but got "
"f{health_message}")

async def encode(self, *args,
**kwargs) -> AsyncIterator[EmbeddingRequestOutput]:
raise NotImplementedError(
"Embeddings not supported with multiprocessing backend")
9 changes: 9 additions & 0 deletions vllm/entrypoints/openai/rpc/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,13 @@ async def get_parallel_config(self, identity):
await self.socket.send_multipart(
[identity, cloudpickle.dumps(parallel_config)])

async def is_tracing_enabled(self, identity):
"""Send the is_tracing_enabled flag"""
tracing_flag = await self.engine.is_tracing_enabled()

await self.socket.send_multipart(
[identity, cloudpickle.dumps(tracing_flag)])

async def do_log_stats(self, identity):
"""Log stats and confirm success."""
await self.engine.do_log_stats()
Expand Down Expand Up @@ -153,6 +160,8 @@ def _make_handler_coro(self, identity,
return self.is_server_ready(identity)
elif request == RPCUtilityRequest.CHECK_HEALTH:
return self.check_health(identity)
elif request == RPCUtilityRequest.IS_TRACING_ENABLED:
return self.is_tracing_enabled(identity)
else:
raise ValueError(f"Unknown RPCUtilityRequest type: {request}")

Expand Down

0 comments on commit f3c0f1c

Please sign in to comment.