Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Use ray queue to put request output tokens back to the api server #41

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions .github/workflows/bench_test.yml
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
name: bench_test

on:
push:
branches:
- main
pull_request:
branches:
- main
Expand Down
3 changes: 0 additions & 3 deletions .github/workflows/e2e_test.yml
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
name: e2e_test

on:
push:
branches:
- main
pull_request:
branches:
- main
Expand Down
3 changes: 0 additions & 3 deletions .github/workflows/migration_test.yml
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
name: migration_test

on:
push:
branches:
- main
pull_request:
branches:
- main
Expand Down
3 changes: 0 additions & 3 deletions .github/workflows/offline_inference.yml
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
name: offline_inference

on:
push:
branches:
- main
pull_request:
branches:
- main
Expand Down
3 changes: 0 additions & 3 deletions .github/workflows/pylint.yml
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
name: pylint

on:
push:
branches:
- main
pull_request:
branches:
- main
Expand Down
3 changes: 0 additions & 3 deletions .github/workflows/unit_test.yml
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
name: unit_test

on:
push:
branches:
- main
pull_request:
branches:
- main
Expand Down
3 changes: 0 additions & 3 deletions .github/workflows/whl.yml → .github/workflows/whl_build.yml
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
name: whl_build

on:
push:
branches:
- main
pull_request:
branches:
- main
Expand Down
5 changes: 3 additions & 2 deletions configs/base.yml
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
SERVER:
HOST: '127.0.0.1'
PORT: 37000
QUEUE_TYPE: "rayqueue"

RAY:
RAY_CLUSTER_PORT: 30037
LAUNCH_RAY_CLUSTER: True

MANAGER:
DISABLE_FIXED_NODE_INIT_INSTANCE: False
DISABLE_INIT_INSTANCE_BY_MANAGER: False
DISABLE_FIXED_NODE_INIT_INSTANCE: True
DISABLE_INIT_INSTANCE_BY_MANAGER: True

INITIAL_INSTANCES: 1

Expand Down
20 changes: 5 additions & 15 deletions examlpes/offline_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,11 @@

from llumnix import launch_ray_cluster, connect_to_ray_cluster, init_manager, init_llumlets
from llumnix import (SamplingParams, ServerInfo, EngineManagerArgs, LLMEngineManager, Llumlet,
EngineArgs, RequestOutput)
EngineArgs)
from llumnix.utils import random_uuid
from llumnix.rpc.queue_server import QueueServer
from llumnix.rpc.queue_client import QueueClient
from llumnix.rpc.utils import get_open_zmq_ipc_path
from llumnix.output_queue.ray_queue_server import RayQueueServer
from llumnix.entrypoints.llumnix_utils import get_ip_address


# Sample prompts.
prompts = [
"Hello, my name is",
Expand Down Expand Up @@ -45,8 +42,7 @@
# Create llumlets.
llumlet_ids: List[str] = None
llumlets: List[Llumlet] = None
llumlet_ids, llumlets = init_llumlets(manager_args, engine_args,
node_id=ray.get_runtime_context().get_node_id())
llumlet_ids, llumlets = init_llumlets(manager_args, engine_args, ray.get_runtime_context().get_node_id(), "rayqueue")


# Create a manager. If the manager is created first, and then the llumlets are created, manager.scale_up
Expand All @@ -55,11 +51,8 @@

# The requests‘ outputs will be put to the request_output_queue no matter which instance it's running in.
server_id = random_uuid()
ip = get_ip_address()
port = 1234
server_info = ServerInfo(server_id, ip, port)
rpc_path = get_open_zmq_ipc_path(server_info.request_output_queue_ip, server_info.request_output_queue_port)
request_output_queue = QueueServer(rpc_path)
request_output_queue = RayQueueServer()
server_info = ServerInfo(server_id, "rayqueue", request_output_queue, None, None)

# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
Expand Down Expand Up @@ -94,9 +87,6 @@ async def main():
for actor in named_actors:
try:
actor_handle = ray.get_actor(actor['name'], namespace=actor['namespace'])
except:
continue
try:
ray.kill(actor_handle)
except:
continue
Expand Down
7 changes: 4 additions & 3 deletions llumnix/backends/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,16 @@
from llumnix.backends.backend_interface import BackendInterface, BackendType


def init_backend_engine(instance_id: str, backend_type: BackendType, *args, **kwargs) -> BackendInterface:
def init_backend_engine(instance_id: str, output_queue_type: str,
backend_type: BackendType, *args, **kwargs) -> BackendInterface:
if backend_type == BackendType.VLLM:
# pylint: disable=import-outside-toplevel
from llumnix.backends.vllm.llm_engine import BackendVLLM
backend_engine = BackendVLLM(instance_id, *args, **kwargs)
backend_engine = BackendVLLM(instance_id, output_queue_type, *args, **kwargs)
elif backend_type == BackendType.SIM_VLLM:
# pylint: disable=import-outside-toplevel
from llumnix.backends.vllm.simulator import BackendSimVLLM
backend_engine = BackendSimVLLM(instance_id, *args, **kwargs)
backend_engine = BackendSimVLLM(instance_id, output_queue_type, *args, **kwargs)
else:
raise ValueError(f'Unsupported backend: {backend_type}')
return backend_engine
Expand Down
17 changes: 12 additions & 5 deletions llumnix/backends/vllm/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,19 @@
from llumnix.backends.profiling import LatencyMemData
from llumnix.server_info import ServerInfo
from llumnix.internal_config import MigrationConfig
from llumnix.rpc.queue_client import QueueClient
from llumnix.output_queue.output_queue_client_base import OutputQueueClientBase
from llumnix.output_queue.utils import get_output_queue_client

logger = init_logger(__name__)


class AsyncPutQueueThread(threading.Thread):
def __init__(self, instance_id):
def __init__(self, instance_id, output_queue_type):
super().__init__()
self.instance_id = instance_id
self.request_output_queue_client = QueueClient()

self.request_output_queue_client: OutputQueueClientBase \
= get_output_queue_client(output_queue_type)
self.engine_actor_handle = None
self.loop = asyncio.new_event_loop()
self.daemon = True
Expand Down Expand Up @@ -82,20 +85,21 @@ def put_nowait_batch_to_servers(self,


class LLMEngineLlumnix(LLMEngine):
def __init__(self, instance_id: str, *arg, **kwargs) -> None:
def __init__(self, instance_id: str, output_queue_type: str, *arg, **kwargs) -> None:
super().__init__(*arg, **kwargs)
self.instance_id = instance_id
self.step_counter = Counter()
self.instance_info = None
# TODO(s5u13b): Reduce the overhead.
self.async_put_queue_thread = AsyncPutQueueThread(instance_id)
self.async_put_queue_thread = AsyncPutQueueThread(instance_id, output_queue_type)
self.async_put_queue_thread.start()

# pylint: disable=W0221
@classmethod
def from_engine_args(
cls,
engine_args: EngineArgs,
output_queue_type: str,
migration_config: MigrationConfig,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
instance_id: str = None,
Expand Down Expand Up @@ -124,6 +128,7 @@ def from_engine_args(
# Create the LLM engine.
engine = cls(
instance_id=instance_id,
output_queue_type=output_queue_type,
**engine_config.to_dict(),
executor_class=executor_class,
log_stats=not engine_args.disable_log_stats,
Expand Down Expand Up @@ -215,12 +220,14 @@ class BackendVLLM(BackendInterface):
def __init__(
self,
instance_id: str,
output_queue_type: str,
migration_config: MigrationConfig,
engine_args: EngineArgs,
placement_group: PlacementGroup = None,
node_id: str = None
) -> None:
self.engine: LLMEngineLlumnix = LLMEngineLlumnix.from_engine_args(engine_args=engine_args,
output_queue_type=output_queue_type,
migration_config=migration_config,
instance_id=instance_id,
placement_group=placement_group,
Expand Down
9 changes: 7 additions & 2 deletions llumnix/backends/vllm/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import os
from typing import List
import ray.actor

from vllm.utils import Counter
from vllm.engine.arg_utils import EngineArgs
Expand All @@ -31,6 +32,7 @@ class BackendSimVLLM(BackendVLLM):
def __init__(
self,
instance_id: int,
output_queue_type: str,
migration_config: MigrationConfig,
profiling_result_file_path: str,
gpu_type: str,
Expand All @@ -54,12 +56,15 @@ def __init__(
latency_mem: LatencyMemData = profiling_result.para_dict[sim_parallel_config]

self.engine: LLMEngineLlumnix = LLMEngineLlumnix.from_engine_args(migration_config=migration_config,
latency_mem=latency_mem, engine_args=engine_args)
output_queue_type=output_queue_type,
latency_mem=latency_mem,
engine_args=engine_args)
self.engine.scheduler = SchedulerLlumnix(self.engine.scheduler_config, self.engine.cache_config, self.engine.lora_config)
self.engine.output_processor.scheduler = self.engine.scheduler
self.migration_config = migration_config
self.instance_id = instance_id
self.step_counter = Counter()

def send_blocks(self, dst_ray_actor: "ray.actor.ActorHandle", src_blocks: List[int], dst_blocks: List[int]) -> None:
# pylint: disable=unused-argument
def send_blocks(self, dst_ray_actor: ray.actor.ActorHandle, src_blocks: List[int], dst_blocks: List[int]) -> None:
self.engine.model_executor.send_blocks(len(src_blocks))
2 changes: 2 additions & 0 deletions llumnix/config/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
_C.SERVER.HOST = "localhost"
# Port number for the server
_C.SERVER.PORT = 8000
# Queue type for request output queue
_C.SERVER.QUEUE_TYPE = "rayqueue"
# Port number for the request output queue
_C.SERVER.REQUEST_OUTPUT_QUEUE_PORT = 1234
# Path to SSL key file for secure connections
Expand Down
27 changes: 9 additions & 18 deletions llumnix/entrypoints/llumnix_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,6 @@
from llumnix.logger import init_logger
from llumnix.utils import random_uuid
from llumnix.arg_utils import EngineManagerArgs
from llumnix.rpc.utils import get_open_zmq_ipc_path
from llumnix.server_info import ServerInfo
from llumnix.rpc.queue_server import QueueServer


logger = init_logger(__name__)

Expand Down Expand Up @@ -131,9 +127,8 @@ def init_manager(engine_manager_args: EngineManagerArgs) -> LLMEngineManager:
logger.info("Get existing LLMEngineManager")
return engine_manager

def init_llumlets(engine_manager_args: EngineManagerArgs,
engine_args,
node_id: str) -> Tuple[List[str], List[Llumlet]]:
def init_llumlets(engine_manager_args: EngineManagerArgs, engine_args, node_id: str,
output_queue_type: str) -> Tuple[List[str], List[Llumlet]]:
engine_config = engine_args.create_engine_config()
parallel_config = engine_config.parallel_config
instance_ids: List[str] = []
Expand All @@ -146,6 +141,7 @@ def init_llumlets(engine_manager_args: EngineManagerArgs,
instance_id = instance_ids[idx]
if not engine_manager_args.profiling_result_file_path:
llumlet = Llumlet.from_args(
output_queue_type,
engine_manager_args.disable_fixed_node_init_instance,
False,
node_id,
Expand All @@ -157,6 +153,7 @@ def init_llumlets(engine_manager_args: EngineManagerArgs,
)
else:
llumlet = Llumlet.from_args(
output_queue_type,
engine_manager_args.disable_fixed_node_init_instance,
False,
node_id,
Expand All @@ -171,22 +168,16 @@ def init_llumlets(engine_manager_args: EngineManagerArgs,
llumlets.append(llumlet)
return instance_ids, llumlets

def init_request_output_queue(server_info: ServerInfo) -> QueueServer:
rpc_path = get_open_zmq_ipc_path(server_info.request_output_queue_ip, server_info.request_output_queue_port)
request_output_queue = QueueServer(rpc_path)
return request_output_queue

def init_llumnix_components(engine_manager_args: EngineManagerArgs,
engine_args,
node_id: str,
server_info: ServerInfo) -> Tuple[LLMEngineManager, List[Llumlet], QueueServer]:
request_output_queue = init_request_output_queue(server_info)

output_queue_type: str):
engine_manager = init_manager(engine_manager_args)
if engine_manager_args.disable_init_instance_by_manager:
instance_ids, llumlets = init_llumlets(engine_manager_args, engine_args, node_id)
instance_ids, llumlets = init_llumlets(engine_manager_args, engine_args, node_id, output_queue_type)
else:
instance_ids, llumlets = retry_manager_method_sync(engine_manager.init_llumlets.remote, 'init_llumlets', engine_args, node_id)
instance_ids, llumlets = retry_manager_method_sync(
engine_manager.init_llumlets.remote, 'init_llumlets', engine_args, node_id, output_queue_type)

available_instance_ids = []
dead_instance_ids = []
Expand All @@ -211,4 +202,4 @@ def init_llumnix_components(engine_manager_args: EngineManagerArgs,
logger.info("Init Llumnix components done, {} instances are ready, instance_ids: {}."
.format(len(available_instance_ids), available_instance_ids))

return engine_manager, available_instance_ids, available_llumlets, request_output_queue
return engine_manager, available_instance_ids, available_llumlets
16 changes: 10 additions & 6 deletions llumnix/entrypoints/vllm/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@
from llumnix.logger import init_logger
from llumnix.utils import random_uuid
from llumnix.backends.vllm.utils import check_engine_args
from llumnix.rpc.queue_server import QueueServer
from llumnix.output_queue.output_queue_server_base import OutputQueueServerBase
from llumnix.output_queue.utils import get_output_queue_server
from llumnix.config import get_llumnix_config, LlumnixConfig

logger = init_logger("llumnix.api_server")
Expand All @@ -43,7 +44,7 @@
instances = {}
instance_num_requests: Dict[str, int] = {}
# request_output_queue could be None if initialzed in lifespan.
request_output_queue: QueueServer = None
request_output_queue: OutputQueueServerBase = None
server_info = None
TIMEOUT_KEEP_ALIVE = 5 # seconds.
request_streams: Dict[str, AsyncStream] = {}
Expand Down Expand Up @@ -250,7 +251,8 @@ def add_argument(self, *args, **kwargs):
parser.add_argument('--disable-log-requests-server', action='store_true', help='disable logging requests in server')
parser.add_argument("--ray-cluster-port", type=int)
parser.add_argument('--launch-ray-cluster', action='store_true', help='if launch ray cluster in api server')
parser.add_argument("--request-output-queue-port", type=int)
parser.add_argument("--queue-type", type=str, choices=['rayqueue', 'zmq'], help='queue type for request output queue')
parser.add_argument("--request-output-queue-port", type=int, help='port for zeromq')
parser.add_argument("--config-file", help="path to config file")
parser = EngineManagerArgs.add_cli_args(parser)

Expand Down Expand Up @@ -278,10 +280,12 @@ def add_argument(self, *args, **kwargs):
# Launch the Llumnix componets on current node.
server_id = random_uuid()
ip = get_ip_address()
server_info = ServerInfo(server_id, ip, cfg.SERVER.REQUEST_OUTPUT_QUEUE_PORT)
node_id = ray.get_runtime_context().get_node_id()
engine_manager, instance_ids, llumlets, request_output_queue = \
init_llumnix_components(engine_manager_args, engine_args, node_id, server_info)
engine_manager, instance_ids, llumlets = \
init_llumnix_components(engine_manager_args, engine_args, node_id, cfg.SERVER.QUEUE_TYPE)
request_output_queue = get_output_queue_server(ip, cfg.SERVER.REQUEST_OUTPUT_QUEUE_PORT, cfg.SERVER.QUEUE_TYPE)
server_info = ServerInfo(server_id, cfg.SERVER.QUEUE_TYPE, request_output_queue, ip,
cfg.SERVER.REQUEST_OUTPUT_QUEUE_PORT)

for idx, ins_id in enumerate(instance_ids):
instances[ins_id] = llumlets[idx]
Expand Down
Loading