From 4e1b011e82ea2fd88e32a1f553f3875a8af6a08f Mon Sep 17 00:00:00 2001 From: Joshua Rosenkranz Date: Thu, 20 Jun 2024 20:23:12 -0400 Subject: [PATCH] [Model] MLPSpeculator speculative decoding support (#4947) Signed-off-by: Thomas Parnell Co-authored-by: Thomas Parnell Co-authored-by: Nick Hill Co-authored-by: Davis Wertheimer --- examples/offline_inference_mlpspeculator.py | 59 ++++++++ tests/spec_decode/test_spec_decode_worker.py | 8 +- tests/spec_decode/test_utils.py | 4 +- vllm/config.py | 54 +++++-- vllm/model_executor/models/__init__.py | 1 + vllm/model_executor/models/mlp_speculator.py | 143 ++++++++++++++++++ vllm/sequence.py | 46 ++++++ vllm/spec_decode/batch_expansion.py | 6 +- vllm/spec_decode/interfaces.py | 4 + vllm/spec_decode/mlp_speculator_worker.py | 87 +++++++++++ vllm/spec_decode/spec_decode_worker.py | 42 ++++- vllm/spec_decode/top1_proposer.py | 4 + vllm/spec_decode/util.py | 8 - vllm/transformers_utils/config.py | 18 ++- vllm/transformers_utils/configs/__init__.py | 2 + .../configs/mlp_speculator.py | 50 ++++++ vllm/worker/model_runner.py | 18 ++- vllm/worker/worker.py | 9 ++ 18 files changed, 523 insertions(+), 40 deletions(-) create mode 100644 examples/offline_inference_mlpspeculator.py create mode 100644 vllm/model_executor/models/mlp_speculator.py create mode 100644 vllm/spec_decode/mlp_speculator_worker.py create mode 100644 vllm/transformers_utils/configs/mlp_speculator.py diff --git a/examples/offline_inference_mlpspeculator.py b/examples/offline_inference_mlpspeculator.py new file mode 100644 index 000000000000..5448ec1f6208 --- /dev/null +++ b/examples/offline_inference_mlpspeculator.py @@ -0,0 +1,59 @@ +import gc +import time +from typing import List + +from vllm import LLM, SamplingParams + + +def time_generation(llm: LLM, prompts: List[str], + sampling_params: SamplingParams): + # Generate texts from the prompts. The output is a list of RequestOutput + # objects that contain the prompt, generated text, and other information. + # Warmup first + llm.generate(prompts, sampling_params) + llm.generate(prompts, sampling_params) + start = time.time() + outputs = llm.generate(prompts, sampling_params) + end = time.time() + print((end - start) / sum([len(o.outputs[0].token_ids) for o in outputs])) + # Print the outputs. + for output in outputs: + generated_text = output.outputs[0].text + print(f"text: {generated_text!r}") + + +if __name__ == "__main__": + + template = ( + "Below is an instruction that describes a task. Write a response " + "that appropriately completes the request.\n\n### Instruction:\n{}" + "\n\n### Response:\n") + + # Sample prompts. + prompts = [ + "Write about the president of the United States.", + ] + prompts = [template.format(prompt) for prompt in prompts] + # Create a sampling params object. + sampling_params = SamplingParams(temperature=0.0, max_tokens=200) + + # Create an LLM without spec decoding + llm = LLM(model="meta-llama/Llama-2-13b-chat-hf") + + print("Without speculation") + time_generation(llm, prompts, sampling_params) + + del llm + gc.collect() + + # Create an LLM with spec decoding + llm = LLM( + model="meta-llama/Llama-2-13b-chat-hf", + speculative_model="ibm-fms/llama-13b-accelerator", + # These are currently required for MLPSpeculator decoding + use_v2_block_manager=True, + enforce_eager=True, + ) + + print("With speculation") + time_generation(llm, prompts, sampling_params) diff --git a/tests/spec_decode/test_spec_decode_worker.py b/tests/spec_decode/test_spec_decode_worker.py index afaeffc9681c..a20c793c9bfd 100644 --- a/tests/spec_decode/test_spec_decode_worker.py +++ b/tests/spec_decode/test_spec_decode_worker.py @@ -456,7 +456,9 @@ def test_k_equals_zero(k: int, batch_size: int): rejection_sampler.token_id_dtype = torch.int64 metrics_collector = MagicMock(spec=AsyncMetricsCollector) - target_worker.execute_model.return_value = [MagicMock(spec=SamplerOutput)] + sampler_output = MagicMock(spec=SamplerOutput) + sampler_output.hidden_states = None + target_worker.execute_model.return_value = [sampler_output] draft_worker.device = 'cuda' target_worker.device = 'cuda' @@ -497,7 +499,9 @@ def test_empty_input_batch(k: int, batch_size: int): rejection_sampler.token_id_dtype = torch.int64 metrics_collector = MagicMock(spec=AsyncMetricsCollector) - target_worker.execute_model.return_value = [MagicMock(spec=SamplerOutput)] + sampler_output = MagicMock(spec=SamplerOutput) + sampler_output.hidden_states = None + target_worker.execute_model.return_value = [sampler_output] draft_worker.device = 'cuda' target_worker.device = 'cuda' diff --git a/tests/spec_decode/test_utils.py b/tests/spec_decode/test_utils.py index 6b6f35a1a1d0..bccbf9a6aaae 100644 --- a/tests/spec_decode/test_utils.py +++ b/tests/spec_decode/test_utils.py @@ -2,8 +2,8 @@ import pytest -from vllm.sequence import SequenceGroupMetadata -from vllm.spec_decode.util import get_all_seq_ids, split_batch_by_proposal_len +from vllm.sequence import SequenceGroupMetadata, get_all_seq_ids +from vllm.spec_decode.util import split_batch_by_proposal_len def test_get_all_seq_ids(): diff --git a/vllm/config.py b/vllm/config.py index 5de00d7d38d4..8d004902fe4f 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -230,7 +230,8 @@ def verify_with_parallel_config( self, parallel_config: "ParallelConfig", ) -> None: - total_num_attention_heads = self.hf_text_config.num_attention_heads + total_num_attention_heads = getattr(self.hf_text_config, + "num_attention_heads", 0) tensor_parallel_size = parallel_config.tensor_parallel_size if total_num_attention_heads % tensor_parallel_size != 0: raise ValueError( @@ -238,7 +239,8 @@ def verify_with_parallel_config( " must be divisible by tensor parallel size " f"({tensor_parallel_size}).") - total_num_hidden_layers = self.hf_text_config.num_hidden_layers + total_num_hidden_layers = getattr(self.hf_text_config, + "num_hidden_layers", 0) pipeline_parallel_size = parallel_config.pipeline_parallel_size if total_num_hidden_layers % pipeline_parallel_size != 0: raise ValueError( @@ -341,8 +343,8 @@ def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int: def get_num_attention_heads(self, parallel_config: "ParallelConfig") -> int: - return self.hf_text_config.num_attention_heads // \ - parallel_config.tensor_parallel_size + num_heads = getattr(self.hf_text_config, "num_attention_heads", 0) + return num_heads // parallel_config.tensor_parallel_size def get_num_layers(self, parallel_config: "ParallelConfig") -> int: total_num_hidden_layers = self.hf_text_config.num_hidden_layers @@ -818,7 +820,8 @@ def maybe_create_spec_config( speculative_model (Optional[str]): The name of the speculative model, if provided. num_speculative_tokens (Optional[int]): The number of speculative - tokens, if provided. + tokens, if provided. Will default to the number in the draft + model config if present, otherwise is required. speculative_max_model_len (Optional[int]): The maximum model len of the speculative model. Used when testing the ability to skip speculation for some sequences. @@ -841,24 +844,18 @@ def maybe_create_spec_config( the necessary conditions are met, else None. """ - if speculative_model is None and num_speculative_tokens is None: + if speculative_model is None: + if num_speculative_tokens is not None: + raise ValueError("num_speculative_tokens was provided without " + "speculative_model.") return None - if speculative_model is not None and num_speculative_tokens is None: - raise ValueError( - "Expected both speculative_model and " - "num_speculative_tokens to be provided, but found " - f"{speculative_model=} and {num_speculative_tokens=}.") - if (speculative_disable_by_batch_size is not None and speculative_disable_by_batch_size < 2): raise ValueError("Expect the batch size threshold of disabling " "speculative decoding is > 1, but got " f"{speculative_disable_by_batch_size=}") - assert (speculative_model is not None - and num_speculative_tokens is not None) - if enable_chunked_prefill: raise ValueError( "Speculative decoding and chunked prefill are " @@ -912,6 +909,27 @@ def maybe_create_spec_config( max_logprobs=target_model_config.max_logprobs, ) + if (draft_model_config.hf_config.model_type == "mlp_speculator" + and target_parallel_config.world_size != 1): + # MLPSpeculator TP support will be added very soon + raise ValueError( + "Speculative decoding with mlp_speculator models does not " + "yet support distributed inferencing (TP > 1).") + + n_predict = getattr(draft_model_config.hf_config, "n_predict", + None) + if n_predict is not None: + if num_speculative_tokens is None: + # Default to max value defined in draft model config. + num_speculative_tokens = n_predict + elif num_speculative_tokens > n_predict: + # Verify provided value doesn't exceed the maximum + # supported by the draft model. + raise ValueError( + "Expected both speculative_model and " + "num_speculative_tokens to be provided, but found " + f"{speculative_model=} and {num_speculative_tokens=}.") + draft_model_config.max_model_len = ( SpeculativeConfig._maybe_override_draft_max_model_len( speculative_max_model_len, @@ -923,6 +941,12 @@ def maybe_create_spec_config( SpeculativeConfig.create_draft_parallel_config( target_parallel_config)) + if num_speculative_tokens is None: + raise ValueError( + "num_speculative_tokens must be provided with " + "speculative_model unless the draft model config contains an " + "n_predict parameter.") + return SpeculativeConfig( draft_model_config, draft_parallel_config, diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index f9ec7209689e..5afb2e1d44d3 100755 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -60,6 +60,7 @@ "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"), "XverseForCausalLM": ("xverse", "XverseForCausalLM"), "Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"), + "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"), } _EMBEDDING_MODELS = { diff --git a/vllm/model_executor/models/mlp_speculator.py b/vllm/model_executor/models/mlp_speculator.py new file mode 100644 index 000000000000..b18269777cd0 --- /dev/null +++ b/vllm/model_executor/models/mlp_speculator.py @@ -0,0 +1,143 @@ +import math +from typing import Iterable, List, Tuple + +import torch +import torch.nn as nn + +from vllm.model_executor import SamplingMetadata +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.sequence import SamplerOutput + + +class MLPSpeculatorLayerNorm(nn.Module): + """ + A L2 normalization implementation + ... + Args + ---- + normalized_shape : int + Dimensionality of input data (size of final tensor axis) + eps : float + Safety term to prevent division by zero. Make sure the chosen value + fits in the range of your encoding scheme + (i.e. fp16 requires eps >= 6e-8). + """ + + def __init__( + self, + normalized_shape, + eps=1e-06, + ): + super(MLPSpeculatorLayerNorm, self).__init__() + self.weight = nn.Parameter(torch.empty(normalized_shape)) + self.bias = nn.Parameter(torch.empty(normalized_shape)) + self.eps = eps + + def forward(self, x): + xf = x + xf = xf * torch.rsqrt(xf.pow(2).mean(-1, keepdim=True) + self.eps) + x = xf.type_as(x) + x = self.weight * x + x = x + self.bias + return x + + +class MLPSpeculator(nn.Module): + + def __init__(self, config, **kwargs) -> None: + super().__init__() + self.n_predict = config.n_predict + self.vocab_size = config.vocab_size + self.emb_dim = config.emb_dim + self.inner_dim = config.inner_dim if config.inner_dim != 0 \ + else config.emb_dim + + self.max_speculative_tokens = getattr(config, "max_speculative_tokens", + self.n_predict) + + self.emb = nn.ModuleList([ + VocabParallelEmbedding(config.vocab_size, + self.inner_dim, + org_num_embeddings=config.vocab_size) + for _ in range(self.max_speculative_tokens) + ]) + + self.proj = nn.ModuleList([ + nn.Linear((self.emb_dim if i == 0 else self.inner_dim), + self.inner_dim, + bias=False) for i in range(self.max_speculative_tokens) + ]) + + self.head = nn.ModuleList([ + nn.Linear(self.inner_dim, self.vocab_size, bias=False) + for _ in range(self.max_speculative_tokens) + ]) + self.ln = nn.ModuleList([ + MLPSpeculatorLayerNorm(self.inner_dim) + for _ in range(self.max_speculative_tokens) + ]) + + self.state_weight = 0.5**(0.5 / config.n_predict) + self.emb_weight = math.sqrt( + (1 - self.state_weight**2) * (self.inner_dim / 2)) + self.activation = nn.GELU() + self.config = config + self.logits_processor = LogitsProcessor(config.vocab_size, + config.vocab_size, 1.0) + self.sampler = Sampler() + + def generate_proposals( + self, + input_ids: torch.Tensor, + previous_hidden_states: torch.Tensor, + num_predict_tokens: int, + sampling_metadata: SamplingMetadata, + ) -> List[SamplerOutput]: + if num_predict_tokens > self.max_speculative_tokens: + raise ValueError(f"Max speculative tokens for model is " + f"{self.max_speculative_tokens}, but " + f"{num_predict_tokens} were requested") + + # b x 1 x d + previous_hidden_states = previous_hidden_states.unsqueeze(1) + + # b x 1 + last_tokens = input_ids.unsqueeze(1) + + next_tokens = [] + + for head_index in range(num_predict_tokens): + + # Project and predict + z = self.emb[head_index](last_tokens) # b k d + states = self.proj[head_index](previous_hidden_states) + + # Weighted add of state_weight*state and emb_weight*z + # Let subsequent LN take care of denominator + # state_weight is close to 1, so shouldn't be any precision issues + states.add_(z, alpha=self.emb_weight / self.state_weight) + + states = self.activation(self.ln[head_index](states)) # b k d + # TODO: not yet supporting top_k_tokens_per_head + previous_hidden_states = states + + logits = self.logits_processor(self.head[head_index].weight, + states, sampling_metadata) + + output = self.sampler(logits.flatten(0, 1), sampling_metadata) + last_tokens = output.sampled_token_ids + next_tokens.append(output) + + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + param = params_dict[name.replace("speculator.", "")] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/sequence.py b/vllm/sequence.py index 38d3349f2ab4..287e1b9df616 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -794,6 +794,9 @@ class SamplerOutput: # Spec decode metrics populated by workers. spec_decode_worker_metrics: Optional["SpecDecodeWorkerMetrics"] = None + # Optional last hidden states from the model. + hidden_states: Optional[torch.Tensor] = None + def __getitem__(self, idx: int): return self.outputs[idx] @@ -842,6 +845,46 @@ def __eq__(self, other: object): self.__class__) and self.outputs == other.outputs +def get_all_seq_ids( + seq_group_metadata_list: List[SequenceGroupMetadata]) -> List[int]: + """Given a list of SequenceGroupMetadata, create a list of all + sequence ids. + """ + return [seq_id for sg in seq_group_metadata_list for seq_id in sg.seq_data] + + +class HiddenStates: + """Hidden states corresponding to in-progress sequences. + Used in speculative decoding to pass hidden states from + the target model to the proposer model in the subsequent step. + + seq_ids are the sequence ids of each entry of the batch + dimension of the hidden_states tensor""" + + def __init__(self, seq_group_metadata_list: List[SequenceGroupMetadata], + hidden_states: torch.Tensor): + assert len(seq_group_metadata_list) == len(hidden_states) + self.seq_ids: List[int] = get_all_seq_ids(seq_group_metadata_list) + self.hidden_states: torch.Tensor = hidden_states + + def update(self, seq_group_metadata_list: List[SequenceGroupMetadata], + hidden_states: torch.Tensor) -> None: + """Update hidden states from target model invocation.""" + assert len(seq_group_metadata_list) == len(hidden_states) + self.seq_ids.extend(get_all_seq_ids(seq_group_metadata_list)) + self.hidden_states = torch.cat([self.hidden_states, hidden_states]) + + def prune(self, + seq_group_metadata_list: List[SequenceGroupMetadata]) -> None: + """Prune to provided list of sequence ids.""" + seq_ids = get_all_seq_ids(seq_group_metadata_list) + if seq_ids != self.seq_ids: + # Batch contents changed - prune removed sequences. + index = [self.seq_ids.index(seq_id) for seq_id in seq_ids] + self.hidden_states = self.hidden_states[index] + self.seq_ids = seq_ids + + @dataclass class ExecuteModelRequest: """The model execution request.""" @@ -857,6 +900,8 @@ class ExecuteModelRequest: num_lookahead_slots: int = 0 # The number of requests in the running queue. running_queue_size: int = 0 + # Optional hidden states from prior step. + previous_hidden_states: Optional[HiddenStates] = None def clone( self, seq_group_metadata_list: List[SequenceGroupMetadata] @@ -869,4 +914,5 @@ def clone( blocks_to_copy=self.blocks_to_copy.copy(), num_lookahead_slots=self.num_lookahead_slots, running_queue_size=self.running_queue_size, + previous_hidden_states=self.previous_hidden_states, ) diff --git a/vllm/spec_decode/batch_expansion.py b/vllm/spec_decode/batch_expansion.py index 1bde042086f0..40516556344e 100644 --- a/vllm/spec_decode/batch_expansion.py +++ b/vllm/spec_decode/batch_expansion.py @@ -4,11 +4,10 @@ import torch from vllm.sequence import (ExecuteModelRequest, SamplerOutput, SequenceData, - SequenceGroupMetadata) + SequenceGroupMetadata, get_all_seq_ids) from vllm.spec_decode.interfaces import (SpeculativeProposals, SpeculativeScorer, SpeculativeScores) -from vllm.spec_decode.util import (get_all_seq_ids, nvtx_range, - sampler_output_to_torch, +from vllm.spec_decode.util import (nvtx_range, sampler_output_to_torch, split_batch_by_proposal_len) from vllm.worker.worker_base import WorkerBase @@ -98,6 +97,7 @@ def score_proposals( probs=all_probs, token_ids=all_tokens, logprobs=spec_logprobs, + hidden_states=target_sampler_output.hidden_states, ) def _expand_batch( diff --git a/vllm/spec_decode/interfaces.py b/vllm/spec_decode/interfaces.py index 72d7818eb117..d236fc0f2cb6 100644 --- a/vllm/spec_decode/interfaces.py +++ b/vllm/spec_decode/interfaces.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod from dataclasses import dataclass +from typing import Optional import torch @@ -46,6 +47,9 @@ class SpeculativeScores: # tokens and also non-speculative normal decoding. token_ids: torch.Tensor + # Optional last hidden states from the scoring model. + hidden_states: Optional[torch.Tensor] = None + def __repr__(self): return (f"SpeculativeScores(" f"probs={self.probs.shape}, " diff --git a/vllm/spec_decode/mlp_speculator_worker.py b/vllm/spec_decode/mlp_speculator_worker.py new file mode 100644 index 000000000000..0926e13bedab --- /dev/null +++ b/vllm/spec_decode/mlp_speculator_worker.py @@ -0,0 +1,87 @@ +from typing import List, Optional, Tuple + +import torch + +from vllm.model_executor import SamplingMetadata +from vllm.sequence import (ExecuteModelRequest, SamplerOutput, + SequenceGroupMetadata) +from vllm.spec_decode.multi_step_worker import MultiStepWorker +from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase +from vllm.worker.model_runner import ModelInput + + +class MLPSpeculatorWorker(NonLLMProposerWorkerBase, MultiStepWorker): + """Worker for MLPSpeculator models. + + Not currently compatible with LoRA or chunked prefill. + """ + + @torch.inference_mode() + def sampler_output( + self, + execute_model_req: ExecuteModelRequest, + sample_len: int, + ) -> Tuple[List[SamplerOutput], bool]: + """Run the model forward pass to generate sample_len future tokens. + Returns the list of sampler output, one per layer, along with indicator + of whether torch tensor in sampler output need to be transposed in + latter sampler_output_to_torch logic. + + For mlp spec worker, this indicator shall be True. + """ + self._raise_if_unsupported(execute_model_req) + + seq_group_metadata_list = execute_model_req.seq_group_metadata_list + + (input_tokens, seq_lens, + query_lens) = self._prepare_input_tensors(seq_group_metadata_list) + + sampling_metadata = SamplingMetadata.prepare( + seq_group_metadata_list, seq_lens, query_lens, self.device, + self.model_runner.pin_memory) + + model_outputs = self.model_runner.model.generate_proposals( + input_ids=input_tokens, + previous_hidden_states=execute_model_req.previous_hidden_states. + hidden_states, + num_predict_tokens=sample_len, + sampling_metadata=sampling_metadata) + + assert len(model_outputs) == sample_len + + return model_outputs, True + + def _prepare_input_tensors( + self, + seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], + ) -> Tuple[torch.Tensor, List[int], List[int]]: + if not seq_group_metadata_list: + return ModelInput.empty(self.device) + + input_tokens: List[int] = [] + seq_lens: List[int] = [] + query_lens: List[int] = [] + + for seq_group_metadata in seq_group_metadata_list: + is_prompt = seq_group_metadata.is_prompt + + for seq_data in seq_group_metadata.seq_data.values(): + seq_data_len = seq_data.get_len() + if is_prompt: + context_len = seq_data.get_num_computed_tokens() + seq_len = min( + seq_data_len, + context_len + seq_group_metadata.token_chunk_size) + tokens = seq_data.get_token_ids()[context_len:seq_len] + seq_lens.append(seq_len) + input_tokens.extend(tokens) + query_lens.append(seq_len - context_len) + else: + seq_lens.append(seq_data_len) + input_tokens.append(seq_data.get_last_token_id()) + query_lens.append(1) + + input_tokens_tensor = torch.tensor(input_tokens, + dtype=torch.long, + device=self.device) + return input_tokens_tensor, seq_lens, query_lens diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 03fad5663037..58d3461a2518 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -8,16 +8,18 @@ from vllm.logger import init_logger from vllm.model_executor.layers.rejection_sampler import RejectionSampler from vllm.sequence import (CompletionSequenceGroupOutput, ExecuteModelRequest, - SamplerOutput, SequenceGroupMetadata) + HiddenStates, SamplerOutput, SequenceGroupMetadata, + get_all_seq_ids) from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer from vllm.spec_decode.interfaces import (SpeculativeProposals, SpeculativeScorer, SpeculativeScores) from vllm.spec_decode.metrics import AsyncMetricsCollector +from vllm.spec_decode.mlp_speculator_worker import MLPSpeculatorWorker from vllm.spec_decode.multi_step_worker import MultiStepWorker from vllm.spec_decode.ngram_worker import NGramWorker from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase from vllm.spec_decode.util import (create_sequence_group_output, - get_all_num_logprobs, get_all_seq_ids, + get_all_num_logprobs, get_sampled_token_logprobs, nvtx_range, split_batch_by_proposal_len) from vllm.worker.worker import Worker @@ -104,6 +106,10 @@ def create_worker( proposer_worker = NGramWorker(**draft_worker_kwargs) proposer_worker.set_ngram_window_size(ngram_prompt_lookup_min, ngram_prompt_lookup_max) + elif draft_worker_kwargs[ + "model_config"].hf_config.model_type == "mlp_speculator": + proposer_worker = MLPSpeculatorWorker(**draft_worker_kwargs) + disable_bonus_tokens = False else: proposer_worker = MultiStepWorker(**draft_worker_kwargs) @@ -155,6 +161,10 @@ def __init__( # Lazy initiazliation. self.scorer: SpeculativeScorer + # Hidden states from target model to pass to proposer + # in the subsequent step. + self.previous_hidden_states: Optional[HiddenStates] = None + def init_device(self) -> None: """Initialize both scorer and proposer models. """ @@ -337,6 +347,16 @@ def _run_no_spec(self, execute_model_req: ExecuteModelRequest, assert len(sampler_output) == 1 sampler_output = sampler_output[0] + # Store hidden states from target model execution. + hidden_states = sampler_output.hidden_states + if hidden_states is not None: + if self.previous_hidden_states is None: + self.previous_hidden_states = HiddenStates( + execute_model_req.seq_group_metadata_list, hidden_states) + else: + self.previous_hidden_states.update( + execute_model_req.seq_group_metadata_list, hidden_states) + # Clear device tensors from sampler output. This reduces communication # overhead when the engine runs in a different process than the workers. sampler_output.probs = None @@ -383,6 +403,10 @@ def _run_speculative_decoding_step( """ assert num_lookahead_slots == execute_model_req.num_lookahead_slots + # Pass last hidden states from target model to proposer + execute_model_req.previous_hidden_states = self.previous_hidden_states + self.previous_hidden_states = None + # Generate proposals using draft worker. proposals = self.proposer_worker.get_spec_proposals(execute_model_req) @@ -466,6 +490,20 @@ def _verify_tokens( # metadata. accepted_token_ids[original_indices] = accepted_token_ids.clone() + hidden_states = proposal_scores.hidden_states + if hidden_states is not None: + # Contract hidden states based on accepted tokens + hs_size = hidden_states.shape[1] + hidden_states = hidden_states.reshape(-1, max_proposal_len + 1, + hs_size) + accepted_index = accepted_token_ids + 1 # Convert -1 to 0 + accepted_index = accepted_index.count_nonzero(dim=1).add_(-1) + index = accepted_index[:, None, None].expand(-1, 1, hs_size) + hidden_states = hidden_states.gather(1, index).squeeze(1) # b x d + # Store hidden states from target model for subsequent decode step + self.previous_hidden_states = HiddenStates(seq_group_metadata_list, + hidden_states) + return accepted_token_ids, logprobs def _create_output_sampler_list( diff --git a/vllm/spec_decode/top1_proposer.py b/vllm/spec_decode/top1_proposer.py index 278db94bfc0d..d3e280e6843b 100644 --- a/vllm/spec_decode/top1_proposer.py +++ b/vllm/spec_decode/top1_proposer.py @@ -65,9 +65,13 @@ def get_spec_proposals( # token_ids is like [batch] format in proposal_len size list, # while if it is false, the format would be [proposal_len] # in batch size list + hidden_states = execute_model_req.previous_hidden_states + if hidden_states is not None: + hidden_states.prune(nonzero_proposal_len_seqs) nonzero_execute_model_req = ExecuteModelRequest( seq_group_metadata_list=nonzero_proposal_len_seqs, num_lookahead_slots=proposal_len, + previous_hidden_states=hidden_states, ) maybe_sampler_output, transposed = self._worker.sampler_output( execute_model_req=nonzero_execute_model_req, diff --git a/vllm/spec_decode/util.py b/vllm/spec_decode/util.py index 9bbe3f8d1611..80710419e602 100644 --- a/vllm/spec_decode/util.py +++ b/vllm/spec_decode/util.py @@ -10,14 +10,6 @@ SeqId = int -def get_all_seq_ids( - seq_group_metadata_list: List[SequenceGroupMetadata]) -> List[SeqId]: - """Given a list of SequenceGroupMetadata, create a list of all - sequence ids. - """ - return [seq_id for sg in seq_group_metadata_list for seq_id in sg.seq_data] - - def get_all_num_logprobs( seq_group_metadata_list: List[SequenceGroupMetadata]) -> List[int]: """Given a list of SequenceGroupMetadata, create a list of all num_logprobs. diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index ada84018212a..60fc756a12e3 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -1,3 +1,4 @@ +import contextlib from typing import Dict, Optional, Type from transformers import PretrainedConfig @@ -5,7 +6,13 @@ from vllm.envs import VLLM_USE_MODELSCOPE from vllm.logger import init_logger from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig, - JAISConfig, MPTConfig, RWConfig) + JAISConfig, MLPSpeculatorConfig, + MPTConfig, RWConfig) + +if VLLM_USE_MODELSCOPE: + from modelscope import AutoConfig +else: + from transformers import AutoConfig logger = init_logger(__name__) @@ -16,8 +23,13 @@ "RefinedWeb": RWConfig, # For tiiuae/falcon-40b(-instruct) "RefinedWebModel": RWConfig, # For tiiuae/falcon-7b(-instruct) "jais": JAISConfig, + "mlp_speculator": MLPSpeculatorConfig, } +for name, cls in _CONFIG_REGISTRY.items(): + with contextlib.suppress(ValueError): + AutoConfig.register(name, cls) + def get_config(model: str, trust_remote_code: bool, @@ -26,10 +38,6 @@ def get_config(model: str, rope_scaling: Optional[dict] = None, rope_theta: Optional[float] = None) -> PretrainedConfig: try: - if VLLM_USE_MODELSCOPE: - from modelscope import AutoConfig - else: - from transformers import AutoConfig config = AutoConfig.from_pretrained( model, trust_remote_code=trust_remote_code, diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index 0e486928824c..d8170858c2a9 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -5,6 +5,7 @@ # `FalconConfig` class from the official HuggingFace transformers library. from vllm.transformers_utils.configs.falcon import RWConfig from vllm.transformers_utils.configs.jais import JAISConfig +from vllm.transformers_utils.configs.mlp_speculator import MLPSpeculatorConfig from vllm.transformers_utils.configs.mpt import MPTConfig __all__ = [ @@ -13,4 +14,5 @@ "MPTConfig", "RWConfig", "JAISConfig", + "MLPSpeculatorConfig", ] diff --git a/vllm/transformers_utils/configs/mlp_speculator.py b/vllm/transformers_utils/configs/mlp_speculator.py new file mode 100644 index 000000000000..dd1d92b861b8 --- /dev/null +++ b/vllm/transformers_utils/configs/mlp_speculator.py @@ -0,0 +1,50 @@ +from typing import List, Optional + +from transformers import PretrainedConfig + + +class MLPSpeculatorConfig(PretrainedConfig): + model_type = "mlp_speculator" + + attribute_map = { + "hidden_size": "emb_dim", + } + + def __init__(self, + vocab_size: int = 32000, + emb_dim: int = 4096, + inner_dim: int = 0, + n_predict: int = 3, + top_k_tokens_per_head: Optional[List[int]] = None, + n_candidates: int = 5, + **kwargs): + """ + Initialize an MLPSpeculatorConfig + + Args: + vocab_size: int + the model vocab size + emb_dim: int + the model embedding dimension + inner_dim: int + the inner dimension of the model. If 0, will be the emb_dim. + n_predict: int + the number of lookaheads for the speculator + top_k_tokens_per_head: List[int] + Number of tokens to consider from each head when forming the + candidate tree. + For each candidate branch in the tree, head n produces topk[n] + additional sub-branches. + n_candidates: int + number of child candidates to create per sequence + """ + if top_k_tokens_per_head is None: + top_k_tokens_per_head = [5, 4, 3] + assert len(top_k_tokens_per_head) == n_predict + self.vocab_size = vocab_size + self.emb_dim = emb_dim + self.inner_dim = inner_dim + self.n_predict = n_predict + self.top_k_tokens_per_head = top_k_tokens_per_head + self.n_candidates = n_candidates + super().__init__(**kwargs) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index d0baa4337f84..e24835a1ea7f 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -86,6 +86,7 @@ def __init__( kv_cache_dtype: Optional[str] = "auto", is_driver_worker: bool = False, vision_language_config: Optional[VisionLanguageConfig] = None, + return_hidden_states: bool = False, ): self.model_config = model_config self.parallel_config = parallel_config @@ -96,6 +97,7 @@ def __init__( self.load_config = load_config self.is_driver_worker = is_driver_worker self.vision_language_config = vision_language_config + self.return_hidden_states = return_hidden_states self.device = self.device_config.device self.pin_memory = is_pin_memory_available() @@ -116,15 +118,17 @@ def __init__( self.graph_block_tables = np.zeros( (max(_BATCH_SIZES_TO_CAPTURE), self.get_max_block_per_batch()), dtype=np.int32) + num_attn_heads = self.model_config.get_num_attention_heads( + self.parallel_config) self.attn_backend = get_attn_backend( - self.model_config.get_num_attention_heads(self.parallel_config), + num_attn_heads, self.model_config.get_head_size(), self.model_config.get_num_kv_heads(self.parallel_config), self.model_config.get_sliding_window(), self.model_config.dtype, self.kv_cache_dtype, self.block_size, - ) + ) if num_attn_heads else None # Create processor for multi-modal data if self.vision_language_config is not None: @@ -762,11 +766,19 @@ def execute_model( return None # Sample the next token. - output = self.model.sample( + output: SamplerOutput = self.model.sample( logits=logits, sampling_metadata=sampling_metadata, ) + if self.return_hidden_states: + # we only need to pass hidden states of most recent token + assert seq_group_metadata_list is not None + if seq_group_metadata_list[0].is_prompt: + hidden_states = hidden_states.index_select( + 0, sampling_metadata.selected_token_indices) + output.hidden_states = hidden_states + return output @torch.inference_mode() diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index f9b8a065a8b2..e334ffbb755b 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -70,6 +70,14 @@ def __init__( assert not self.lora_config, ( "To be tested: vision language model with LoRA settings.") + # Return hidden states from target model if the draft model is an + # mlp_speculator + speculative_args = {} if speculative_config is None \ + or (speculative_config.draft_model_config.model == + model_config.model) \ + or (speculative_config.draft_model_config.hf_config.model_type != + "mlp_speculator") else {"return_hidden_states": True} + ModelRunnerClass = (EmbeddingModelRunner if self.model_config.embedding_mode else ModelRunner) self.model_runner = ModelRunnerClass( @@ -83,6 +91,7 @@ def __init__( kv_cache_dtype=self.cache_config.cache_dtype, is_driver_worker=is_driver_worker, vision_language_config=vision_language_config, + **speculative_args, ) # Uninitialized cache engine. Will be initialized by # initialize_cache.