From 92b3b9bfd11d154b985249e0e43cf52f975ef5c7 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 24 Oct 2024 00:16:44 -0700 Subject: [PATCH] [core] simplify seq group code (#9569) Co-authored-by: Zhuohan Li Signed-off-by: Maxime Fournioux <55544262+mfournioux@users.noreply.github.com> --- tests/core/test_chunked_prefill_scheduler.py | 153 -------------- tests/core/test_scheduler.py | 204 +------------------ vllm/core/scheduler.py | 2 +- vllm/engine/llm_engine.py | 40 ++-- vllm/engine/output_processor/single_step.py | 127 ++---------- vllm/sequence.py | 102 ++-------- 6 files changed, 62 insertions(+), 566 deletions(-) diff --git a/tests/core/test_chunked_prefill_scheduler.py b/tests/core/test_chunked_prefill_scheduler.py index 308dad1850c9a..acd82065ae457 100644 --- a/tests/core/test_chunked_prefill_scheduler.py +++ b/tests/core/test_chunked_prefill_scheduler.py @@ -4,7 +4,6 @@ import pytest # noqa from vllm.config import CacheConfig, SchedulerConfig -from vllm.core.interfaces import AllocStatus from vllm.core.scheduler import Scheduler from vllm.sequence import Logprob, SequenceGroup @@ -347,158 +346,6 @@ def test_prompt_limit_exceed(): assert out.ignored_seq_groups[0] == seq_group -def test_swap(): - """Verify swapping works with chunked prefill requests""" - block_size = 4 - max_seqs = 30 - max_model_len = 200 - max_num_batched_tokens = 30 - scheduler_config = SchedulerConfig( - "generate", - max_num_batched_tokens, - max_seqs, - max_model_len, - enable_chunked_prefill=True, - ) - cache_config = CacheConfig(block_size, 1.0, 1, "auto") - cache_config.num_cpu_blocks = 16 - cache_config.num_gpu_blocks = 16 - scheduler = Scheduler(scheduler_config, cache_config, None) - - _, seq_group = create_dummy_prompt("1", - prompt_length=60, - best_of=2, - block_size=block_size) - scheduler.add_seq_group(seq_group) - _, out = schedule_and_update_computed_tokens(scheduler) - # The request is chunked. - # prefill scheduled now. - assert len(out.scheduled_seq_groups) == 1 - assert out.num_prefill_groups == 1 - assert seq_group.is_prefill() - assert out.num_batched_tokens == max_num_batched_tokens - - # The last request should be swapped out. - scheduler.block_manager.can_append_slots = MagicMock() - - def cannot_append_second_group(seq_group, num_lookahead_slots): - return seq_group.request_id != "1" - - scheduler.block_manager.can_append_slots.side_effect = ( - cannot_append_second_group) - - # The running prefill is now swapped. - _, out = schedule_and_update_computed_tokens(scheduler) - assert len(out.scheduled_seq_groups) == 0 - assert out.num_batched_tokens == 0 - assert out.blocks_to_swap_out != [] - assert out.blocks_to_swap_in == [] - - # Add 1 more task. Swap should be prioritized over new prefill. - _, seq_group = create_dummy_prompt("2", prompt_length=60) - scheduler.add_seq_group(seq_group) - _, out = schedule_and_update_computed_tokens(scheduler) - assert len(out.scheduled_seq_groups) == 1 - # 3 decodes. It is swapped in. - assert out.num_batched_tokens == 30 - assert out.blocks_to_swap_in != [] - assert out.blocks_to_swap_out == [] - - -def test_running_prefill_prioritized_over_swap(): - block_size = 4 - max_seqs = 30 - max_model_len = 200 - max_num_batched_tokens = 30 - scheduler_config = SchedulerConfig( - "generate", - max_num_batched_tokens, - max_seqs, - max_model_len, - enable_chunked_prefill=True, - ) - cache_config = CacheConfig(block_size, 1.0, 1, "auto") - cache_config.num_cpu_blocks = 32 - cache_config.num_gpu_blocks = 32 - scheduler = Scheduler(scheduler_config, cache_config, None) - - _, seq_group = create_dummy_prompt("1", - prompt_length=60, - best_of=2, - block_size=block_size) - scheduler.add_seq_group(seq_group) - _, out = schedule_and_update_computed_tokens(scheduler) - # The request is chunked. - # prefill scheduled now. - assert len(out.scheduled_seq_groups) == 1 - assert out.num_prefill_groups == 1 - assert seq_group.is_prefill() - assert out.num_batched_tokens == max_num_batched_tokens - - # The request should be swapped out. - scheduler.block_manager.can_append_slots = MagicMock() - - def cannot_append_second_group(seq_group, num_lookahead_slots): - return seq_group.request_id != "1" - - scheduler.block_manager.can_append_slots.side_effect = ( - cannot_append_second_group) - - # The running prefill is now swapped. - _, out = schedule_and_update_computed_tokens(scheduler) - assert len(out.scheduled_seq_groups) == 0 - assert out.num_batched_tokens == 0 - assert out.blocks_to_swap_out != [] - assert out.blocks_to_swap_in == [] - - # Add 1 more task. Swap is not possible, so prefill is running. - scheduler.block_manager.can_swap_in = MagicMock() - scheduler.block_manager.can_swap_in.return_value = AllocStatus.LATER - - _, seq_group2 = create_dummy_prompt("2", - prompt_length=60, - block_size=block_size) - scheduler.add_seq_group(seq_group2) - _, out = schedule_and_update_computed_tokens(scheduler) - assert len(out.scheduled_seq_groups) == 1 - # 3 decodes. It is swapped in. - assert out.num_batched_tokens == 30 - assert out.blocks_to_swap_in == [] - assert out.blocks_to_swap_out == [] - assert out.scheduled_seq_groups[0].seq_group == seq_group2 - - # Now although swap is possible, running prefill is prioritized. - scheduler.block_manager.can_swap_in.return_value = AllocStatus.OK - _, out = schedule_and_update_computed_tokens(scheduler) - assert len(out.scheduled_seq_groups) == 1 - # 3 decodes. It is swapped in. - assert out.num_batched_tokens == 30 - assert out.blocks_to_swap_in == [] - assert out.blocks_to_swap_out == [] - assert not seq_group2.is_prefill() - assert out.scheduled_seq_groups[0].seq_group == seq_group2 - append_new_token(seq_group2, 1) - - # Decoding is prioritized. - _, out = schedule_and_update_computed_tokens(scheduler) - assert len(out.scheduled_seq_groups) == 1 - # 3 decodes. It is swapped in. - assert out.num_batched_tokens == 1 - assert out.blocks_to_swap_in == [] - assert out.blocks_to_swap_out == [] - assert not seq_group2.is_prefill() - assert out.scheduled_seq_groups[0].seq_group == seq_group2 - append_new_token(seq_group2, 1) - - # Since we abort the sequence group, we can finally swap. - scheduler.abort_seq_group(seq_group2.request_id) - _, out = schedule_and_update_computed_tokens(scheduler) - assert len(out.scheduled_seq_groups) == 1 - assert out.num_batched_tokens == 30 - assert out.blocks_to_swap_in != [] - assert out.blocks_to_swap_out == [] - - def test_chunked_prefill_preempt(): """Verify preempt works with chunked prefill requests""" block_size = 4 diff --git a/tests/core/test_scheduler.py b/tests/core/test_scheduler.py index 00b6349b9f8c5..5ff32be611592 100644 --- a/tests/core/test_scheduler.py +++ b/tests/core/test_scheduler.py @@ -10,7 +10,7 @@ from vllm.core.interfaces import AllocStatus from vllm.core.scheduler import Scheduler, SchedulingBudget from vllm.lora.request import LoRARequest -from vllm.sequence import SequenceGroup, SequenceStatus +from vllm.sequence import SequenceGroup from .utils import (append_new_token, append_new_token_seq_group, create_dummy_prompt, get_sequence_groups, @@ -296,55 +296,6 @@ def test_scheduler_delay_factor(): append_new_token(out, 1) -def test_swapped_out_prioritized(): - block_size = 4 - scheduler = initialize_scheduler(max_num_seqs=6, - block_size=block_size, - num_cpu_blocks=64, - num_gpu_blocks=64) - # best_of=2 * 3 == 6 sequences. - for i in range(3): - _, seq_group = create_dummy_prompt(str(i), - prompt_length=60, - best_of=2, - block_size=block_size) - scheduler.add_seq_group(seq_group) - seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) - # prefill scheduled now. - assert len(out.scheduled_seq_groups) == 3 - append_new_token(out, 1) - - # The last request should be swapped out. - scheduler.block_manager.can_append_slots = MagicMock() - - def cannot_append_second_group(seq_group, num_lookahead_slots): - return seq_group.request_id != "2" - - scheduler.block_manager.can_append_slots.side_effect = ( - cannot_append_second_group) - - seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) - assert len(out.scheduled_seq_groups) == 2 - assert out.num_batched_tokens == 2 - assert out.blocks_to_swap_out != [] - assert out.blocks_to_swap_in == [] - append_new_token(out, 1) - - # Add 1 more task. Swap should be prioritized over prefill. - _, seq_group = create_dummy_prompt(str(i), - prompt_length=60, - best_of=2, - block_size=block_size) - scheduler.add_seq_group(seq_group) - seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) - append_new_token(out, 1) - assert len(out.scheduled_seq_groups) == 3 - # 3 decodes. It is swapped in. - assert out.num_batched_tokens == 3 - assert out.blocks_to_swap_in != [] - assert out.blocks_to_swap_out == [] - - def initialize_scheduler( *, max_num_seqs=1000, @@ -646,60 +597,6 @@ def cannot_append_second_group(seq_group, num_lookahead_slots): assert output.blocks_to_copy == [] -def test_decode_swap_beam_search(): - """ - Test best_of > 1 swap out blocks - """ - block_size = 4 - scheduler = initialize_scheduler(block_size=block_size, - num_gpu_blocks=64, - num_cpu_blocks=64) - curr_loras = None - budget = create_token_budget() - for i in range(3): - _, seq_group = create_dummy_prompt(str(i), - prompt_length=60, - best_of=2, - block_size=block_size) - scheduler._allocate_and_set_running(seq_group) - scheduler._add_seq_group_to_running(seq_group) - append_new_token_seq_group(60, seq_group, 1) - budget.add_num_seqs(seq_group.request_id, - seq_group.get_max_num_running_seqs()) - budget.add_num_batched_tokens( - seq_group.request_id, seq_group.num_seqs(SequenceStatus.RUNNING)) - - # The last request should be swapped out. - scheduler.block_manager.can_append_slots = MagicMock() - - def cannot_append_second_group(seq_group, num_lookahead_slots): - return seq_group.request_id != "2" - - scheduler.block_manager.can_append_slots.side_effect = ( - cannot_append_second_group) - scheduler.block_manager.swap_out = MagicMock() - expected_swap_mapping = [("5", "7")] - scheduler.block_manager.swap_out.return_value = expected_swap_mapping - - output = scheduler._schedule_running(budget, curr_loras) - remainig_running = scheduler.running - assert len(remainig_running) == 0 - assert len(output.decode_seq_groups) == 2 - assert len(output.prefill_seq_groups) == 0 - assert output.decode_seq_groups[0].seq_group.request_id == "0" - assert output.decode_seq_groups[1].seq_group.request_id == "1" - assert len(output.preempted) == 0 - assert len(output.swapped_out) == 1 - # Budget should refledct preempted requests. - assert budget.num_batched_tokens == 2 - # since there are 2 sequences, 2 should be subtracted. - assert budget.num_curr_seqs == 4 - # Both should be preempted, not swapped. - assert output.blocks_to_swap_out == expected_swap_mapping - # Nothing is copied. - assert output.blocks_to_copy == [] - - def test_schedule_decode_blocks_to_copy_update(): """ Verify blocks_to_copy is updated. @@ -736,105 +633,6 @@ def test_schedule_decode_blocks_to_copy_update(): assert output.blocks_to_copy == [(2, 3)] -def test_schedule_swapped_simple(): - block_size = 4 - scheduler = initialize_scheduler(block_size=block_size) - curr_loras = None - blocks_to_swap_out: List[Tuple[int, int]] = [] - _, seq_group = create_dummy_prompt("1", - prompt_length=4, - best_of=2, - block_size=block_size) - scheduler._allocate_and_set_running(seq_group) - append_new_token_seq_group(4, seq_group, 1) - scheduler._swap_out(seq_group, blocks_to_swap_out) - scheduler._add_seq_group_to_swapped(seq_group) - - budget = create_token_budget() - output = scheduler._schedule_swapped(budget, curr_loras) - remaining_swapped = scheduler.swapped - assert len(remaining_swapped) == 0 - assert budget.num_batched_tokens == 1 - assert budget.num_curr_seqs == 2 - assert len(output.decode_seq_groups) == 1 - assert len(output.prefill_seq_groups) == 0 - # swap in is the reverse of swap out - blocks_to_swap_in_reverse = [] - for swapin, swapout in output.blocks_to_swap_in: - blocks_to_swap_in_reverse.append((swapout, swapin)) - assert blocks_to_swap_out == blocks_to_swap_in_reverse - - -def test_schedule_swapped_max_token_budget(): - block_size = 4 - scheduler = initialize_scheduler(block_size=block_size, - num_cpu_blocks=32, - num_gpu_blocks=32) - curr_loras = None - blocks_to_swap_out: List[Tuple[int, int]] = [] - for i in range(2): - _, seq_group = create_dummy_prompt(str(i), prompt_length=60, best_of=2) - scheduler._allocate_and_set_running(seq_group) - append_new_token_seq_group(60, seq_group, 1) - scheduler._swap_out(seq_group, blocks_to_swap_out) - scheduler._add_seq_group_to_swapped(seq_group) - - budget = create_token_budget(token_budget=1) - output = scheduler._schedule_swapped(budget, curr_loras) - remaining_swapped = scheduler.swapped - assert len(remaining_swapped) == 1 - assert budget.num_batched_tokens == 1 - assert budget.num_curr_seqs == 2 - assert len(output.decode_seq_groups) == 1 - assert len(output.prefill_seq_groups) == 0 - - # Verify num_batched_tokens are respected. - budget = create_token_budget(token_budget=1) - add_token_budget(budget, 1, 0) - output = scheduler._schedule_swapped(budget, curr_loras) - remaining_swapped = scheduler.swapped - assert len(remaining_swapped) == 1 - assert budget.num_batched_tokens == 1 - assert budget.num_curr_seqs == 0 - assert len(output.decode_seq_groups) == 0 - assert len(output.prefill_seq_groups) == 0 - - -def test_schedule_swapped_max_seqs(): - block_size = 4 - scheduler = initialize_scheduler(block_size=block_size, - num_cpu_blocks=64, - num_gpu_blocks=64) - curr_loras = None - blocks_to_swap_out: List[Tuple[int, int]] = [] - for i in range(4): - _, seq_group = create_dummy_prompt(str(i), - prompt_length=60, - block_size=4) - scheduler._allocate_and_set_running(seq_group) - append_new_token_seq_group(60, seq_group, 1) - scheduler._swap_out(seq_group, blocks_to_swap_out) - scheduler._add_seq_group_to_swapped(seq_group) - - budget = create_token_budget(max_num_seqs=2) - output = scheduler._schedule_swapped(budget, curr_loras) - remaining_swapped = scheduler.swapped - assert len(remaining_swapped) == 2 - assert budget.num_batched_tokens == 2 - assert budget.num_curr_seqs == 2 - assert len(output.decode_seq_groups) == 2 - assert len(output.prefill_seq_groups) == 0 - - # Verify num_curr_seqs are respected. - output = scheduler._schedule_swapped(budget, curr_loras) - remaining_swapped = scheduler.swapped - assert len(remaining_swapped) == 2 - assert budget.num_batched_tokens == 2 - assert budget.num_curr_seqs == 2 - assert len(output.decode_seq_groups) == 0 - assert len(output.prefill_seq_groups) == 0 - - def test_schedule_swapped_max_loras(): block_size = 4 lora_config = LoRAConfig(max_lora_rank=8, max_loras=1) diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 8d3fce106dd2c..88733b8f53b86 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -290,7 +290,7 @@ def scheduler_running_outputs_builder(): def scheduled_seq_group_builder(): - return ScheduledSequenceGroup(SequenceGroup("", [], -1), + return ScheduledSequenceGroup(SequenceGroup.__new__(SequenceGroup), token_chunk_size=0) # return ScheduledSequenceGroup(seq_group=None, token_chunk_size=0) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 0d73ed7c8e7ab..1dd0f097c74ff 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -647,10 +647,24 @@ def _add_processed_request( prompt_adapter_request: Optional[PromptAdapterRequest], trace_headers: Optional[Mapping[str, str]] = None, priority: int = 0, - ) -> SequenceGroup: + ) -> Optional[SequenceGroup]: """Add a processed request to the engine's request pool. return the created sequence group. """ + if isinstance(params, SamplingParams) and params.n > 1: + ParallelSampleSequenceGroup.add_request( + request_id, + self, + params, + processed_inputs=processed_inputs, + arrival_time=arrival_time, + lora_request=lora_request, + trace_headers=trace_headers, + prompt_adapter_request=prompt_adapter_request, + priority=priority, + ) + return None + self._validate_model_inputs(processed_inputs) # Create the sequences. block_size = self.cache_config.block_size @@ -721,7 +735,7 @@ def add_request( trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, - ) -> Optional[SequenceGroup]: + ) -> None: ... @overload @@ -735,7 +749,7 @@ def add_request( trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, - ) -> Optional[SequenceGroup]: + ) -> None: ... @deprecate_kwargs( @@ -754,7 +768,7 @@ def add_request( priority: int = 0, *, inputs: Optional[PromptType] = None, # DEPRECATED - ) -> Optional[SequenceGroup]: + ) -> None: """Add a request to the engine's request pool. The request is added to the request pool and will be processed by the @@ -798,22 +812,6 @@ def add_request( >>> # continue the request processing >>> ... """ - - if isinstance(params, SamplingParams) and params.n > 1: - ParallelSampleSequenceGroup.add_request( - request_id, - self, - params, - prompt=prompt, - arrival_time=arrival_time, - lora_request=lora_request, - trace_headers=trace_headers, - prompt_adapter_request=prompt_adapter_request, - priority=priority, - inputs=inputs, - ) - return None - if inputs is not None: prompt = inputs assert prompt is not None and params is not None @@ -844,7 +842,7 @@ def add_request( processed_inputs["mm_processor_kwargs"] = preprocessed_inputs.get( "mm_processor_kwargs") - return self._add_processed_request( + self._add_processed_request( request_id=request_id, processed_inputs=processed_inputs, params=params, diff --git a/vllm/engine/output_processor/single_step.py b/vllm/engine/output_processor/single_step.py index 9f8ebaf1f4d8c..da3185f33dbe9 100644 --- a/vllm/engine/output_processor/single_step.py +++ b/vllm/engine/output_processor/single_step.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Tuple +from typing import List from vllm.config import SchedulerConfig from vllm.core.scheduler import Scheduler @@ -6,9 +6,8 @@ SequenceGroupOutputProcessor) from vllm.engine.output_processor.stop_checker import StopChecker from vllm.logger import init_logger -from vllm.sequence import (CompletionSequenceGroupOutput, Sequence, - SequenceGroup, SequenceGroupOutput, SequenceOutput, - SequenceStatus) +from vllm.sequence import (CompletionSequenceGroupOutput, SequenceGroup, + SequenceGroupOutput) from vllm.transformers_utils.detokenizer import Detokenizer from vllm.utils import Counter @@ -114,104 +113,22 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, outputs: SequenceGroupOutput, is_async: bool) -> None: sampling_params = seq_group.sampling_params - if sampling_params.n == 1: - # only have one output sample - sample = outputs.samples[0] - # only have one sequence - seq = seq_group.seqs[0] - if not is_async: - seq.append_token_id(sample.output_token, sample.logprobs) - if sampling_params.detokenize and self.detokenizer: - new_char_count = self.detokenizer.decode_sequence_inplace( - seq, sampling_params) - else: - new_char_count = 0 - self.stop_checker.maybe_stop_sequence( - seq, - new_char_count, - sampling_params, - lora_req=seq_group.lora_request, - ) - if seq.is_finished(): - for scheduler in self.scheduler: - scheduler.free_seq(seq) - return - - # TODO: Add support for async for beam search - assert not is_async - - # Process samples - samples = outputs.samples - parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING) - parent_child_dict: Dict[int, List[SequenceOutput]] = { - parent_seq.seq_id: [] - for parent_seq in parent_seqs - } - for sample in samples: - # Guard against a KeyError which can occur if the request was - # aborted while the output was generated - if (child_list := - parent_child_dict.get(sample.parent_seq_id)) is not None: - child_list.append(sample) - # List of (child, parent) - child_seqs: List[Tuple[Sequence, Sequence]] = [] - - # Process the child samples for each parent sequence - for parent in parent_seqs: - child_samples: List[SequenceOutput] = parent_child_dict[ - parent.seq_id] - if len(child_samples) == 0: - # This parent sequence has no children samples. Remove - # the parent sequence from the sequence group since it will - # not be used in the future iterations. - parent.status = SequenceStatus.FINISHED_ABORTED - seq_group.remove(parent.seq_id) - for scheduler in self.scheduler: - scheduler.free_seq(parent) - continue - # Fork the parent sequence if there are multiple child samples. - for child_sample in child_samples[:-1]: - new_child_seq_id: int = next(self.seq_counter) - child = parent.fork(new_child_seq_id) - child.append_token_id(child_sample.output_token, - child_sample.logprobs) - child_seqs.append((child, parent)) - # Continue the parent sequence for the last child sample. - # We reuse the parent sequence here to reduce redundant memory - # copies, especially when using non-beam search sampling methods. - last_child_sample = child_samples[-1] - parent.append_token_id(last_child_sample.output_token, - last_child_sample.logprobs) - child_seqs.append((parent, parent)) - - for seq, _ in child_seqs: - if sampling_params.detokenize and self.detokenizer: - new_char_count = self.detokenizer.decode_sequence_inplace( - seq, sampling_params) - else: - new_char_count = 0 - self.stop_checker.maybe_stop_sequence( - seq, - new_char_count, - sampling_params, - lora_req=seq_group.lora_request, - ) - - # For newly created child sequences, add them to the sequence group - # and fork them in block manager if they are not finished. - for seq, parent in child_seqs: - if seq is not parent: - seq_group.add(seq) - if not seq.is_finished(): - for scheduler in self.scheduler: - scheduler.fork_seq(parent, seq) - - # Free the finished and selected parent sequences' memory in block - # manager. Keep them in the sequence group as candidate output. - # NOTE: we need to fork the new sequences before freeing the - # old sequences. - for seq, parent in child_seqs: - if seq is parent and seq.is_finished(): - for scheduler in self.scheduler: - scheduler.free_seq(seq) - return + + sample = outputs.samples[0] + seq = seq_group.first_seq + if not is_async: + seq.append_token_id(sample.output_token, sample.logprobs) + if sampling_params.detokenize and self.detokenizer: + new_char_count = self.detokenizer.decode_sequence_inplace( + seq, sampling_params) + else: + new_char_count = 0 + self.stop_checker.maybe_stop_sequence( + seq, + new_char_count, + sampling_params, + lora_req=seq_group.lora_request, + ) + if seq.is_finished(): + for scheduler in self.scheduler: + scheduler.free_seq(seq) diff --git a/vllm/sequence.py b/vllm/sequence.py index 93f58f00ef77b..fc936fbab0ea7 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -681,6 +681,7 @@ def __init__( ) -> None: self.request_id = request_id self.seqs = seqs + self.first_seq = seqs[0] self.arrival_time = arrival_time self.is_single_seq = len(seqs) == 1 self.seqs_dict = {seq.seq_id: seq for seq in seqs} @@ -705,15 +706,11 @@ def __init__( @property def prompt(self) -> Optional[str]: - # All sequences in the group should have the same prompt. - # We use the prompt of an arbitrary sequence. - return self.seqs[0].prompt + return self.first_seq.prompt @property def prompt_token_ids(self) -> List[int]: - # All sequences in the group should have the same prompt. - # We use the prompt of an arbitrary sequence. - return self.seqs[0].prompt_token_ids + return self.first_seq.prompt_token_ids @property def encoder_prompt(self) -> Optional[str]: @@ -733,17 +730,11 @@ def encoder_prompt_token_ids(self) -> Optional[List[int]]: @property def multi_modal_data(self) -> "MultiModalDataDict": - # All sequences in the group should have the same multi-modal data. - # We use the multi-modal data of an arbitrary sequence. - return self.seqs[0].multi_modal_data + return self.first_seq.multi_modal_data @property def mm_processor_kwargs(self) -> Dict[str, Any]: - # As with multi-modal data, all sequences in the group should have the - # same processor kwargs (i.e., mm_processor_kwargs are optionally - # provided per request; note that are independent of whether the model - # decoder-only or an encoder-decoder). - return self.seqs[0].mm_processor_kwargs + return self.first_seq.mm_processor_kwargs @property def lora_int_id(self) -> int: @@ -808,7 +799,7 @@ def maybe_set_first_token_time(self, time: float) -> None: # in TPOT, rather than recalculating TTFT (since from the ) # POV of the user, there is simply a long generation delay. if (self.metrics.first_token_time is None - and self.seqs[0].get_output_len() == 1): + and self.first_seq.get_output_len() == 1): self.metrics.first_token_time = time def maybe_set_first_scheduled_time(self, time: float) -> None: @@ -825,18 +816,7 @@ def set_finished_time(self, time: Optional[float]) -> None: def get_max_num_running_seqs(self) -> int: """The maximum number of sequences running in parallel in the remaining lifetime of the request.""" - if self.sampling_params: - n = self.sampling_params.n - assert isinstance(n, int) - if n > self.num_seqs(): - # At prompt stage, the sequence group is not yet filled up - # and only have one sequence running. However, in the - # generation stage, we will have `n` sequences - # running. - return n - # At sampling stages, return the number of actual sequences - # that are not finished yet. - return self.num_unfinished_seqs() + return 0 if self.first_seq.is_finished() else 1 def get_seqs( self, @@ -845,10 +825,7 @@ def get_seqs( if status is None: return self.seqs - if self.is_single_seq: - return self.seqs if self.seqs[0].status == status else [] - - return [seq for seq in self.seqs if seq.status == status] + return self.seqs if self.first_seq.status == status else [] def is_encoder_decoder(self) -> bool: return self.encoder_seq is not None @@ -856,29 +833,20 @@ def is_encoder_decoder(self) -> bool: def get_encoder_seq(self) -> Optional[Sequence]: return self.encoder_seq - def get_unfinished_seqs(self) -> List[Sequence]: - if self.is_single_seq: - return self.seqs if not self.seqs[0].is_finished() else [] - - return [seq for seq in self.seqs if not seq.is_finished()] - def get_finished_seqs(self) -> List[Sequence]: - if self.is_single_seq: - return self.seqs if self.seqs[0].is_finished() else [] - - return [seq for seq in self.seqs if seq.is_finished()] + return self.seqs if self.first_seq.is_finished() else [] def update_num_computed_tokens(self, num_new_computed_tokens: int): """Update number of tokens computed so far.""" - for seq in self.seqs: - if not seq.is_finished(): - seq.data.update_num_computed_tokens(num_new_computed_tokens) + seq = self.first_seq + if not seq.is_finished(): + seq.data.update_num_computed_tokens(num_new_computed_tokens) def get_num_uncomputed_tokens(self) -> int: num_uncomputed_tokens = 0 - for seq in self.seqs: - if not seq.is_finished(): - num_uncomputed_tokens += seq.data.get_num_uncomputed_tokens() + seq = self.first_seq + if not seq.is_finished(): + num_uncomputed_tokens += seq.data.get_num_uncomputed_tokens() return num_uncomputed_tokens def num_seqs(self, status: Optional[SequenceStatus] = None) -> int: @@ -892,46 +860,14 @@ def num_seqs(self, status: Optional[SequenceStatus] = None) -> int: return len(self.get_seqs(status)) - def num_unfinished_seqs(self) -> int: - if self.is_single_seq: - return 1 if not self.seqs[0].is_finished() else 0 - - return len(self.get_unfinished_seqs()) - def num_finished_seqs(self) -> int: - if self.is_single_seq: - return 1 if self.seqs[0].is_finished() else 0 - - return len(self.get_finished_seqs()) - - def find(self, seq_id: int) -> Sequence: - if seq_id not in self.seqs_dict: - raise ValueError(f"Sequence {seq_id} not found.") - return self.seqs_dict[seq_id] - - def add(self, seq: Sequence) -> None: - if seq.seq_id in self.seqs_dict: - raise ValueError(f"Sequence {seq.seq_id} already exists.") - self.seqs_dict[seq.seq_id] = seq - self.seqs.append(seq) - self.is_single_seq = len(self.seqs) == 1 - - def remove(self, seq_id: int) -> None: - seq = self.seqs_dict.pop(seq_id, None) - if seq is None: - raise ValueError(f"Sequence {seq_id} not found.") - self.seqs.remove(seq) - self.is_single_seq = len(self.seqs) == 1 + return 1 if self.first_seq.is_finished() else 0 def is_finished(self) -> bool: - if self.is_single_seq: - return self.seqs[0].is_finished() - - return all(seq.is_finished() for seq in self.seqs) + return self.first_seq.is_finished() def is_prefill(self) -> bool: - # Every sequence should be in the same stage. - return self.seqs[0].is_prefill() + return self.first_seq.is_prefill() def __repr__(self) -> str: return (f"SequenceGroup(request_id={self.request_id}, " @@ -1455,7 +1391,7 @@ def add_request(request_id: str, engine, params, **kwargs): for i in range(original_params.n): request_id_i = f"{request_id}_parallel_sample_{i}" group.seq_id_to_index[request_id_i] = i - seq_group = engine.add_request( + seq_group = engine._add_processed_request( request_id_i, params=params, **kwargs,