diff --git a/aphrodite/common/config.py b/aphrodite/common/config.py index df7434240..b8e8e5589 100644 --- a/aphrodite/common/config.py +++ b/aphrodite/common/config.py @@ -923,7 +923,8 @@ def __init__(self, delay_factor: float = 0.0, enable_chunked_prefill: bool = False, embedding_mode: Optional[bool] = False, - preemption_mode: Optional[str] = None) -> None: + preemption_mode: Optional[str] = None, + num_scheduler_steps: int = 1) -> None: if max_num_batched_tokens is not None: self.max_num_batched_tokens = max_num_batched_tokens else: @@ -952,6 +953,7 @@ def __init__(self, self.chunked_prefill_enabled = enable_chunked_prefill self.embedding_mode = embedding_mode self.preemption_mode = preemption_mode + self.num_scheduler_steps = num_scheduler_steps self._verify_args() @@ -978,6 +980,16 @@ def _verify_args(self) -> None: f"({self.num_lookahead_slots}) must be greater than or " "equal to 0.") + if self.num_scheduler_steps < 1: + raise ValueError( + "num_scheduler_steps " + f"({self.num_scheduler_steps}) must be greater than or " + "equal to 1.") + + @property + def is_multi_step(self) -> bool: + return self.num_scheduler_steps > 1 + class DeviceConfig: diff --git a/aphrodite/common/sequence.py b/aphrodite/common/sequence.py index 67671180d..bf7c4438c 100644 --- a/aphrodite/common/sequence.py +++ b/aphrodite/common/sequence.py @@ -7,6 +7,7 @@ from dataclasses import dataclass, field from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Union, cast +import numpy import torch from aphrodite.common.pooling_params import PoolingParams @@ -474,6 +475,19 @@ def __repr__(self) -> str: f"num_blocks={self.n_blocks}, ") +@dataclass +class SequenceGroupState: + """Mutable state tied to a specific sequence group""" + + # for multi-step decoding + num_steps: int = 1 + current_step: int = 0 + + @property + def remaining_steps(self) -> int: + return self.num_steps - self.current_step + + class SequenceGroup: """A group of sequences that are generated from the same prompt. @@ -516,6 +530,7 @@ def __init__( time_in_queue=None) self.lora_request = lora_request self.prompt_logprobs: Optional[PromptLogprobs] = None + self.state = SequenceGroupState() self.embeddings = embeddings self.pooling_params = pooling_params self.prompt_adapter_request = prompt_adapter_request @@ -569,6 +584,10 @@ def prompt_adapter_num_virtual_tokens(self) -> int: return self.prompt_adapter_request.prompt_adapter_num_virtual_tokens\ if self.prompt_adapter_request else 0 + def init_multi_step(self, num_scheduler_steps: int) -> None: + self.state.num_steps = num_scheduler_steps + self.state.current_step = 0 + def get_last_latency(self, now: float) -> Optional[float]: """Sets the last token time for Request level timings.""" # If still in prefill phase, raise Error. @@ -735,6 +754,7 @@ class SequenceGroupMetadata: token_chunk_size: The number of tokens to be processed (per sequence). None if chunking is not required. lora_request: LoRA request. + state: Internal state tied to this sequence group. computed_block_nums: The block numbers that are already computed, used in prefix caching. multi_modal_data: Multi modal data. @@ -762,6 +782,7 @@ def __init__( token_chunk_size: Optional[int] = None, lora_request: Optional[LoRARequest] = None, computed_block_nums: Optional[List[int]] = None, + state: Optional[SequenceGroupState] = None, multi_modal_data: Optional["MultiModalDataDict"] = None, encoder_seq_data: Optional[SequenceData] = None, cross_block_table: Optional[List[int]] = None, @@ -777,6 +798,7 @@ def __init__( self.prompt_adapter_request = prompt_adapter_request self.computed_block_nums = computed_block_nums self.multi_modal_data = multi_modal_data + self.state = SequenceGroupState() if state is None else state self.encoder_seq_data = encoder_seq_data self.cross_block_table = cross_block_table self._token_chunk_size = token_chunk_size @@ -815,6 +837,10 @@ def token_chunk_size(self) -> int: assert self._token_chunk_size is not None return self._token_chunk_size + def finish_step(self) -> None: + assert self.state.current_step < self.state.num_steps + self.state.current_step += 1 + class SequenceOutput: """The model output associated with a sequence. @@ -952,6 +978,7 @@ class SamplerOutput: # On-device tensor containing the sampled token ids. sampled_token_ids: Optional[torch.Tensor] = None + sampled_token_ids_numpy: Optional[numpy.ndarray] = None # Spec decode metrics populated by workers. spec_decode_worker_metrics: Optional["SpecDecodeWorkerMetrics"] = None @@ -1086,6 +1113,33 @@ class ExecuteModelRequest: num_steps: int = 1 # Finished request ids since last step. finished_requests_ids: List[str] = field(default_factory=list) + # The last sampled token ids for multi step decoding. + last_sampled_token_ids: Optional[torch.Tensor] = None + + @property + def is_first_multi_step(self) -> bool: + # TODO: make this be able to handle batches with variable number of + # steps + assert len(self.seq_group_metadata_list) > 0 + first_seq_group = self.seq_group_metadata_list[0] + return first_seq_group.state.current_step == 0 + + @property + def is_last_step(self) -> bool: + # TODO: make this be able to handle batches with variable number of + # steps + assert len(self.seq_group_metadata_list) > 0 + first_seq_group = self.seq_group_metadata_list[0] + num_steps = first_seq_group.state.num_steps + current_step = first_seq_group.state.current_step + return num_steps - current_step == 1 + + @property + def current_step(self) -> int: + # TODO: make this be able to handle batches with variable number of + # steps + assert len(self.seq_group_metadata_list) > 0 + return self.seq_group_metadata_list[0].state.current_step def clone( self, seq_group_metadata_list: List[SequenceGroupMetadata] @@ -1102,4 +1156,5 @@ def clone( previous_hidden_states=self.previous_hidden_states, num_steps=self.num_steps, finished_requests_ids=self.finished_requests_ids, - ) + last_sampled_token_ids=self.last_sampled_token_ids.clone() + if self.last_sampled_token_ids is not None else None) diff --git a/aphrodite/engine/args_tools.py b/aphrodite/engine/args_tools.py index e8a192b80..f996cee05 100644 --- a/aphrodite/engine/args_tools.py +++ b/aphrodite/engine/args_tools.py @@ -111,6 +111,7 @@ class EngineArgs: guided_decoding_backend: str = 'outlines' max_num_batched_tokens: Optional[int] = None max_num_seqs: int = 256 + num_scheduler_steps: int = 1 # Speculative Decoding Options num_lookahead_slots: int = 0 speculative_model: Optional[str] = None @@ -617,6 +618,11 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: help="Category: API Options\n" "maximum number of sequences per iteration", ) + parser.add_argument('--num-scheduler-steps', + type=int, + default=1, + help=('Maximum number of forward steps per ' + 'scheduler call.')) # Speculative Decoding Options parser.add_argument("--num-lookahead-slots", type=int, @@ -970,19 +976,35 @@ def create_engine_config(self, ) -> EngineConfig: disable_logprobs=self.disable_logprobs_during_spec_decoding, ) + if self.num_scheduler_steps > 1: + raise NotImplementedError("Multi-step is not yet supported.") + if speculative_config is not None: + raise ValueError("Speculative decoding is not supported with " + "multi-step (--num-scheduler-steps > 1)") + if self.enable_chunked_prefill: + raise ValueError("Chunked prefill is not supported with " + "multi-step (--num-scheduler-steps > 1)") + + # make sure num_lookahead_slots is set the higher value depending on + # if we are using speculative decoding or multi-step + num_lookahead_slots = max(self.num_lookahead_slots, + self.num_scheduler_steps - 1) + num_lookahead_slots = num_lookahead_slots \ + if speculative_config is None \ + else speculative_config.num_lookahead_slots + scheduler_config = SchedulerConfig( max_num_batched_tokens=self.max_num_batched_tokens, max_num_seqs=self.max_num_seqs, max_model_len=model_config.max_model_len, is_attention_free=model_config.is_attention_free(), use_v2_block_manager=self.use_v2_block_manager, - num_lookahead_slots=(self.num_lookahead_slots - if speculative_config is None else - speculative_config.num_lookahead_slots), + num_lookahead_slots=num_lookahead_slots, delay_factor=self.scheduler_delay_factor, enable_chunked_prefill=self.enable_chunked_prefill, embedding_mode=model_config.embedding_mode, preemption_mode=self.preemption_mode, + num_scheduler_steps=self.num_scheduler_steps, ) lora_config = LoRAConfig( diff --git a/aphrodite/processing/scheduler.py b/aphrodite/processing/scheduler.py index e44da1701..393d7e93a 100644 --- a/aphrodite/processing/scheduler.py +++ b/aphrodite/processing/scheduler.py @@ -803,6 +803,9 @@ def _schedule_prefills( curr_loras.add(lora_int_id) waiting_queue.popleft() self._allocate_and_set_running(seq_group) + seq_group.init_multi_step( + num_scheduler_steps=self._get_num_lookahead_slots( + is_prefill=True) + 1) seq_groups.append( ScheduledSequenceGroup(seq_group=seq_group, token_chunk_size=num_new_tokens)) @@ -1105,6 +1108,7 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: computed_block_nums=common_computed_block_nums, encoder_seq_data=encoder_seq_data, cross_block_table=cross_block_table, + state=seq_group.state, # `multi_modal_data` will only be present for the 1st comm # between engine and worker. # the subsequent comms can still use delta, but @@ -1170,6 +1174,7 @@ def _append_slots( slots. """ num_lookahead_slots = self._get_num_lookahead_slots(is_prefill=False) + seq_group.init_multi_step(num_scheduler_steps=num_lookahead_slots + 1) for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): cows = self.block_manager.append_slots(seq, num_lookahead_slots)