From ab199a4445900936fc47f94ba3003f39ef5ee9fd Mon Sep 17 00:00:00 2001 From: "zhizhi.zxy" Date: Thu, 5 Sep 2024 20:09:32 +0800 Subject: [PATCH] fix --- llumnix/arg_utils.py | 1 + llumnix/backends/backend_interface.py | 6 ++++-- llumnix/backends/utils.py | 4 ++-- llumnix/backends/vllm/llm_engine.py | 5 +++-- llumnix/backends/vllm/scheduler.py | 4 ++-- llumnix/llm_engine_manager.py | 2 -- llumnix/llumlet/llumlet.py | 1 + llumnix/llumlet/local_migration_scheduler.py | 1 + llumnix/llumlet/request.py | 2 ++ tests/backends/vllm/utils.py | 2 +- tests/global_scheduler/test_dispatch_scheduler.py | 6 ++++-- 11 files changed, 21 insertions(+), 13 deletions(-) diff --git a/llumnix/arg_utils.py b/llumnix/arg_utils.py index 3d2569f..dc5575c 100644 --- a/llumnix/arg_utils.py +++ b/llumnix/arg_utils.py @@ -67,6 +67,7 @@ def create_global_scheduler_configs( self, ) -> Tuple[GlobalSchedulerConfig]: + # Provide default configuration. config_data = get_cfg() if self.config_file: config_data.merge_from_file(self.config_file) diff --git a/llumnix/backends/backend_interface.py b/llumnix/backends/backend_interface.py index bf33b6f..16b2669 100644 --- a/llumnix/backends/backend_interface.py +++ b/llumnix/backends/backend_interface.py @@ -42,9 +42,11 @@ def add_request(self, request_id: str, server_info: ServerInfo, request_expected Args: request_id: Request ID. server_info: The information of the api server where the request come. - request_expected_steps: The expected number of steps for the request to run.The number of steps + request_expected_steps: The expected number of steps for the request to run. The number of steps represents the sum of the times 'engine.step()' has been called by the - backend instances for the request. + backend instances for the request. Currently, `request_expected_steps` + is used to implement prefill-decoding disaggregation. For prefill requests, + `request_expected_steps` is set to 1. *args: Positional arguments that represent request-specific data. **kwargs: Keyword arguments that contain metadata of the backend request (request_id, arrival_time, etc.). diff --git a/llumnix/backends/utils.py b/llumnix/backends/utils.py index 1e110f4..fa9ce2b 100644 --- a/llumnix/backends/utils.py +++ b/llumnix/backends/utils.py @@ -19,11 +19,11 @@ from llumnix.backends.backend_interface import BackendInterface, BackendType -def init_backend_engine(instance_id: str, backend_type: BackendType, *args, **kwargs) -> BackendInterface: +def init_backend_engine(instance_id: str, backend_type: BackendType, strict_pre_migration: bool, *args, **kwargs) -> BackendInterface: if backend_type == BackendType.VLLM: # pylint: disable=import-outside-toplevel from llumnix.backends.vllm.llm_engine import BackendVLLM - backend_engine = BackendVLLM(instance_id, *args, **kwargs) + backend_engine = BackendVLLM(instance_id, strict_pre_migration, *args, **kwargs) elif backend_type == BackendType.SIM_VLLM: # pylint: disable=import-outside-toplevel from llumnix.backends.vllm.simulator import BackendSimVLLM diff --git a/llumnix/backends/vllm/llm_engine.py b/llumnix/backends/vllm/llm_engine.py index 8029bc3..a5e11e7 100644 --- a/llumnix/backends/vllm/llm_engine.py +++ b/llumnix/backends/vllm/llm_engine.py @@ -104,7 +104,6 @@ def _process_model_outputs( for scheduled_seq_group, seq_group_meta, seq_group_output in zip(scheduled_seq_groups, seq_group_metadata_list, output[0].outputs): seq_group = scheduled_seq_group.seq_group if seq_group.get_seqs(SequenceStatus.RUNNING): - # print(seq_group) new_scheduled_seq_groups.append(scheduled_seq_group) new_seq_group_metadata_list.append(seq_group_meta) new_output.append(seq_group_output) @@ -178,6 +177,7 @@ class BackendVLLM(BackendInterface): def __init__( self, instance_id: str, + strict_pre_migration: bool, migration_config: MigrationConfig, engine_args: EngineArgs, placement_group: "PlacementGroup" = None, @@ -189,7 +189,8 @@ def __init__( placement_group=placement_group, node_id=node_id) # multi-instance args - self.engine.scheduler = SchedulerLlumnix(self.engine.scheduler_config, self.engine.cache_config, self.engine.lora_config) + self.engine.scheduler = SchedulerLlumnix(strict_pre_migration, self.engine.scheduler_config, + self.engine.cache_config, self.engine.lora_config) self.engine.scheduler.add_update_instance_info_callback(self.engine.update_instance_info) self.engine.output_processor.scheduler = self.engine.scheduler self.instance_id = instance_id diff --git a/llumnix/backends/vllm/scheduler.py b/llumnix/backends/vllm/scheduler.py index ba90c09..71f236e 100644 --- a/llumnix/backends/vllm/scheduler.py +++ b/llumnix/backends/vllm/scheduler.py @@ -45,7 +45,7 @@ def add_block_table(self, block_table: BlockTable, seq_id: int) -> None: self.block_tables[seq_id] = block_table.copy() class SchedulerLlumnix(Scheduler): - def __init__(self, *args, **kwargs) -> None: + def __init__(self, strict_pre_migration: bool, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.block_manager = BlockManagerLlumnix( block_size=self.cache_config.block_size, @@ -56,7 +56,7 @@ def __init__(self, *args, **kwargs) -> None: self.pre_alloc_cache_dict: Dict[str, BlockTable] = {} self.scheduler_lock = threading.Lock() self.migrating_out_request_last_stage: List[LlumnixRequest] = [] - self.strict_pre_migration = True + self.strict_pre_migration = strict_pre_migration def add_update_instance_info_callback(self, update_instance_info_callback): self.update_instance_info_callback = update_instance_info_callback diff --git a/llumnix/llm_engine_manager.py b/llumnix/llm_engine_manager.py index c56731f..68521b9 100644 --- a/llumnix/llm_engine_manager.py +++ b/llumnix/llm_engine_manager.py @@ -247,8 +247,6 @@ async def _migrate_control(self) -> None: async def _migrate(self, migration_target:str, migrate_in_num_requests:int) -> None: migrate_instance_pairs = self.global_scheduler.pair_migration(migration_target) - # if len(migrate_instance_pairs)>0: - # logger.info("[_migrate] migrate_instance_pairs {} {}".format(migration_target, migrate_instance_pairs)) try: migration_tasks = [] call_migrate_instance_pairs: List[Tuple[str, str]] = [] diff --git a/llumnix/llumlet/llumlet.py b/llumnix/llumlet/llumlet.py index 74d73d0..6ca339d 100644 --- a/llumnix/llumlet/llumlet.py +++ b/llumnix/llumlet/llumlet.py @@ -42,6 +42,7 @@ def __init__(self, self.backend_engine: BackendInterface = init_backend_engine(self.instance_id, backend_type, + self.strict_pre_migration, migration_config, *args, **kwargs) diff --git a/llumnix/llumlet/local_migration_scheduler.py b/llumnix/llumlet/local_migration_scheduler.py index ac65d16..ada8347 100644 --- a/llumnix/llumlet/local_migration_scheduler.py +++ b/llumnix/llumlet/local_migration_scheduler.py @@ -22,6 +22,7 @@ def __init__(self, request_migration_policy: str, backend_engine: BackendInterfa self.request_migration_policy = request_migration_policy self.backend_engine = backend_engine self.strict_pre_migration = strict_pre_migration + def get_migrate_out_request(self, min_request_len=0, max_request_len=np.inf) -> Optional[LlumnixRequest]: # Requests meet the strict pre-migration always have higher prioirity than other migration policy. migrate_out_request = self.get_ready_migration_request(min_request_len, max_request_len) diff --git a/llumnix/llumlet/request.py b/llumnix/llumlet/request.py index 03e326f..6e5eea0 100644 --- a/llumnix/llumlet/request.py +++ b/llumnix/llumlet/request.py @@ -55,6 +55,8 @@ def prompt_len(self) -> int: def output_len(self) -> int: raise NotImplementedError + # Whether the migration of request is divided into multiple stages. For requests that have already reached + # the expected steps, the migration will completed within one stage. @property def blocking_migration(self) -> bool: return self.expected_steps < 0 or (self.expected_steps > 0 and self.expected_steps < self.output_len) diff --git a/tests/backends/vllm/utils.py b/tests/backends/vllm/utils.py index 5aa2af4..9f8d306 100644 --- a/tests/backends/vllm/utils.py +++ b/tests/backends/vllm/utils.py @@ -43,7 +43,7 @@ def initialize_scheduler(*, cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 8 cache_config.num_gpu_blocks = 8 - scheduler = SchedulerLlumnixTest(scheduler_config, cache_config, lora_config) + scheduler = SchedulerLlumnixTest(True, scheduler_config, cache_config, lora_config) scheduler.update_instance_info_callback = MagicMock() return scheduler diff --git a/tests/global_scheduler/test_dispatch_scheduler.py b/tests/global_scheduler/test_dispatch_scheduler.py index 87b7975..d5c0d56 100644 --- a/tests/global_scheduler/test_dispatch_scheduler.py +++ b/tests/global_scheduler/test_dispatch_scheduler.py @@ -76,7 +76,8 @@ def test_dispatch_load(): instance_num_requests[instance_id] = 0 dispatch_scheduler.instance_num_requests = instance_num_requests dispatch_scheduler.instance_info = instance_info_dict - available_instance_dict = {key: instance_info_dict[key] for key in instance_info_dict if key in dispatch_scheduler.available_dispatch_instance_set} + available_instance_dict = {key: instance_info_dict[key] for key in instance_info_dict + if key in dispatch_scheduler.available_dispatch_instance_set} min_instance_id = next(key for key, value in sorted(available_instance_dict.items(), key=lambda item: item[1].instance_load_dispatch_scale)) instance_id = dispatch_scheduler.dispatch() @@ -99,7 +100,8 @@ def test_dispatch_queue(): instance_num_requests[instance_id] = 0 dispatch_scheduler.instance_num_requests = instance_num_requests dispatch_scheduler.instance_info = instance_info_dict - available_instance_dict = {key: instance_info_dict[key] for key in instance_info_dict if key in dispatch_scheduler.available_dispatch_instance_set} + available_instance_dict = {key: instance_info_dict[key] for key in instance_info_dict + if key in dispatch_scheduler.available_dispatch_instance_set} min_instance_id = next(key for key, value in sorted(available_instance_dict.items(), key=lambda item: item[1].num_waiting_requests)) instance_id = dispatch_scheduler.dispatch()