Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Xinyi-ECNU committed Sep 5, 2024
1 parent 26f4d6a commit ab199a4
Show file tree
Hide file tree
Showing 11 changed files with 21 additions and 13 deletions.
1 change: 1 addition & 0 deletions llumnix/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def create_global_scheduler_configs(
self,
) -> Tuple[GlobalSchedulerConfig]:

# Provide default configuration.
config_data = get_cfg()
if self.config_file:
config_data.merge_from_file(self.config_file)
Expand Down
6 changes: 4 additions & 2 deletions llumnix/backends/backend_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,11 @@ def add_request(self, request_id: str, server_info: ServerInfo, request_expected
Args:
request_id: Request ID.
server_info: The information of the api server where the request come.
request_expected_steps: The expected number of steps for the request to run.The number of steps
request_expected_steps: The expected number of steps for the request to run. The number of steps
represents the sum of the times 'engine.step()' has been called by the
backend instances for the request.
backend instances for the request. Currently, `request_expected_steps`
is used to implement prefill-decoding disaggregation. For prefill requests,
`request_expected_steps` is set to 1.
*args: Positional arguments that represent request-specific data.
**kwargs: Keyword arguments that contain metadata of the backend request
(request_id, arrival_time, etc.).
Expand Down
4 changes: 2 additions & 2 deletions llumnix/backends/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@
from llumnix.backends.backend_interface import BackendInterface, BackendType


def init_backend_engine(instance_id: str, backend_type: BackendType, *args, **kwargs) -> BackendInterface:
def init_backend_engine(instance_id: str, backend_type: BackendType, strict_pre_migration: bool, *args, **kwargs) -> BackendInterface:
if backend_type == BackendType.VLLM:
# pylint: disable=import-outside-toplevel
from llumnix.backends.vllm.llm_engine import BackendVLLM
backend_engine = BackendVLLM(instance_id, *args, **kwargs)
backend_engine = BackendVLLM(instance_id, strict_pre_migration, *args, **kwargs)
elif backend_type == BackendType.SIM_VLLM:
# pylint: disable=import-outside-toplevel
from llumnix.backends.vllm.simulator import BackendSimVLLM
Expand Down
5 changes: 3 additions & 2 deletions llumnix/backends/vllm/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,6 @@ def _process_model_outputs(
for scheduled_seq_group, seq_group_meta, seq_group_output in zip(scheduled_seq_groups, seq_group_metadata_list, output[0].outputs):
seq_group = scheduled_seq_group.seq_group
if seq_group.get_seqs(SequenceStatus.RUNNING):
# print(seq_group)
new_scheduled_seq_groups.append(scheduled_seq_group)
new_seq_group_metadata_list.append(seq_group_meta)
new_output.append(seq_group_output)
Expand Down Expand Up @@ -178,6 +177,7 @@ class BackendVLLM(BackendInterface):
def __init__(
self,
instance_id: str,
strict_pre_migration: bool,
migration_config: MigrationConfig,
engine_args: EngineArgs,
placement_group: "PlacementGroup" = None,
Expand All @@ -189,7 +189,8 @@ def __init__(
placement_group=placement_group,
node_id=node_id)
# multi-instance args
self.engine.scheduler = SchedulerLlumnix(self.engine.scheduler_config, self.engine.cache_config, self.engine.lora_config)
self.engine.scheduler = SchedulerLlumnix(strict_pre_migration, self.engine.scheduler_config,
self.engine.cache_config, self.engine.lora_config)
self.engine.scheduler.add_update_instance_info_callback(self.engine.update_instance_info)
self.engine.output_processor.scheduler = self.engine.scheduler
self.instance_id = instance_id
Expand Down
4 changes: 2 additions & 2 deletions llumnix/backends/vllm/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def add_block_table(self, block_table: BlockTable, seq_id: int) -> None:
self.block_tables[seq_id] = block_table.copy()

class SchedulerLlumnix(Scheduler):
def __init__(self, *args, **kwargs) -> None:
def __init__(self, strict_pre_migration: bool, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.block_manager = BlockManagerLlumnix(
block_size=self.cache_config.block_size,
Expand All @@ -56,7 +56,7 @@ def __init__(self, *args, **kwargs) -> None:
self.pre_alloc_cache_dict: Dict[str, BlockTable] = {}
self.scheduler_lock = threading.Lock()
self.migrating_out_request_last_stage: List[LlumnixRequest] = []
self.strict_pre_migration = True
self.strict_pre_migration = strict_pre_migration

def add_update_instance_info_callback(self, update_instance_info_callback):
self.update_instance_info_callback = update_instance_info_callback
Expand Down
2 changes: 0 additions & 2 deletions llumnix/llm_engine_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,8 +247,6 @@ async def _migrate_control(self) -> None:

async def _migrate(self, migration_target:str, migrate_in_num_requests:int) -> None:
migrate_instance_pairs = self.global_scheduler.pair_migration(migration_target)
# if len(migrate_instance_pairs)>0:
# logger.info("[_migrate] migrate_instance_pairs {} {}".format(migration_target, migrate_instance_pairs))
try:
migration_tasks = []
call_migrate_instance_pairs: List[Tuple[str, str]] = []
Expand Down
1 change: 1 addition & 0 deletions llumnix/llumlet/llumlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def __init__(self,

self.backend_engine: BackendInterface = init_backend_engine(self.instance_id,
backend_type,
self.strict_pre_migration,
migration_config,
*args,
**kwargs)
Expand Down
1 change: 1 addition & 0 deletions llumnix/llumlet/local_migration_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def __init__(self, request_migration_policy: str, backend_engine: BackendInterfa
self.request_migration_policy = request_migration_policy
self.backend_engine = backend_engine
self.strict_pre_migration = strict_pre_migration

def get_migrate_out_request(self, min_request_len=0, max_request_len=np.inf) -> Optional[LlumnixRequest]:
# Requests meet the strict pre-migration always have higher prioirity than other migration policy.
migrate_out_request = self.get_ready_migration_request(min_request_len, max_request_len)
Expand Down
2 changes: 2 additions & 0 deletions llumnix/llumlet/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ def prompt_len(self) -> int:
def output_len(self) -> int:
raise NotImplementedError

# Whether the migration of request is divided into multiple stages. For requests that have already reached
# the expected steps, the migration will completed within one stage.
@property
def blocking_migration(self) -> bool:
return self.expected_steps < 0 or (self.expected_steps > 0 and self.expected_steps < self.output_len)
Expand Down
2 changes: 1 addition & 1 deletion tests/backends/vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def initialize_scheduler(*,
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
cache_config.num_cpu_blocks = 8
cache_config.num_gpu_blocks = 8
scheduler = SchedulerLlumnixTest(scheduler_config, cache_config, lora_config)
scheduler = SchedulerLlumnixTest(True, scheduler_config, cache_config, lora_config)
scheduler.update_instance_info_callback = MagicMock()
return scheduler

Expand Down
6 changes: 4 additions & 2 deletions tests/global_scheduler/test_dispatch_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ def test_dispatch_load():
instance_num_requests[instance_id] = 0
dispatch_scheduler.instance_num_requests = instance_num_requests
dispatch_scheduler.instance_info = instance_info_dict
available_instance_dict = {key: instance_info_dict[key] for key in instance_info_dict if key in dispatch_scheduler.available_dispatch_instance_set}
available_instance_dict = {key: instance_info_dict[key] for key in instance_info_dict
if key in dispatch_scheduler.available_dispatch_instance_set}
min_instance_id = next(key for key, value in sorted(available_instance_dict.items(),
key=lambda item: item[1].instance_load_dispatch_scale))
instance_id = dispatch_scheduler.dispatch()
Expand All @@ -99,7 +100,8 @@ def test_dispatch_queue():
instance_num_requests[instance_id] = 0
dispatch_scheduler.instance_num_requests = instance_num_requests
dispatch_scheduler.instance_info = instance_info_dict
available_instance_dict = {key: instance_info_dict[key] for key in instance_info_dict if key in dispatch_scheduler.available_dispatch_instance_set}
available_instance_dict = {key: instance_info_dict[key] for key in instance_info_dict
if key in dispatch_scheduler.available_dispatch_instance_set}
min_instance_id = next(key for key, value in sorted(available_instance_dict.items(),
key=lambda item: item[1].num_waiting_requests))
instance_id = dispatch_scheduler.dispatch()
Expand Down

0 comments on commit ab199a4

Please sign in to comment.