Skip to content

Commit

Permalink
[Model] MLPSpeculator speculative decoding support (vllm-project#4947)
Browse files Browse the repository at this point in the history
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>

Co-authored-by: Thomas Parnell <tpa@zurich.ibm.com>
Co-authored-by: Nick Hill <nickhill@us.ibm.com>
Co-authored-by: Davis Wertheimer <Davis.Wertheimer@ibm.com>
  • Loading branch information
4 people authored Jun 21, 2024
1 parent a61199a commit 4e1b011
Show file tree
Hide file tree
Showing 18 changed files with 523 additions and 40 deletions.
59 changes: 59 additions & 0 deletions examples/offline_inference_mlpspeculator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import gc
import time
from typing import List

from vllm import LLM, SamplingParams


def time_generation(llm: LLM, prompts: List[str],
sampling_params: SamplingParams):
# Generate texts from the prompts. The output is a list of RequestOutput
# objects that contain the prompt, generated text, and other information.
# Warmup first
llm.generate(prompts, sampling_params)
llm.generate(prompts, sampling_params)
start = time.time()
outputs = llm.generate(prompts, sampling_params)
end = time.time()
print((end - start) / sum([len(o.outputs[0].token_ids) for o in outputs]))
# Print the outputs.
for output in outputs:
generated_text = output.outputs[0].text
print(f"text: {generated_text!r}")


if __name__ == "__main__":

template = (
"Below is an instruction that describes a task. Write a response "
"that appropriately completes the request.\n\n### Instruction:\n{}"
"\n\n### Response:\n")

# Sample prompts.
prompts = [
"Write about the president of the United States.",
]
prompts = [template.format(prompt) for prompt in prompts]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.0, max_tokens=200)

# Create an LLM without spec decoding
llm = LLM(model="meta-llama/Llama-2-13b-chat-hf")

print("Without speculation")
time_generation(llm, prompts, sampling_params)

del llm
gc.collect()

# Create an LLM with spec decoding
llm = LLM(
model="meta-llama/Llama-2-13b-chat-hf",
speculative_model="ibm-fms/llama-13b-accelerator",
# These are currently required for MLPSpeculator decoding
use_v2_block_manager=True,
enforce_eager=True,
)

print("With speculation")
time_generation(llm, prompts, sampling_params)
8 changes: 6 additions & 2 deletions tests/spec_decode/test_spec_decode_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,7 +456,9 @@ def test_k_equals_zero(k: int, batch_size: int):
rejection_sampler.token_id_dtype = torch.int64
metrics_collector = MagicMock(spec=AsyncMetricsCollector)

target_worker.execute_model.return_value = [MagicMock(spec=SamplerOutput)]
sampler_output = MagicMock(spec=SamplerOutput)
sampler_output.hidden_states = None
target_worker.execute_model.return_value = [sampler_output]

draft_worker.device = 'cuda'
target_worker.device = 'cuda'
Expand Down Expand Up @@ -497,7 +499,9 @@ def test_empty_input_batch(k: int, batch_size: int):
rejection_sampler.token_id_dtype = torch.int64
metrics_collector = MagicMock(spec=AsyncMetricsCollector)

target_worker.execute_model.return_value = [MagicMock(spec=SamplerOutput)]
sampler_output = MagicMock(spec=SamplerOutput)
sampler_output.hidden_states = None
target_worker.execute_model.return_value = [sampler_output]

draft_worker.device = 'cuda'
target_worker.device = 'cuda'
Expand Down
4 changes: 2 additions & 2 deletions tests/spec_decode/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

import pytest

from vllm.sequence import SequenceGroupMetadata
from vllm.spec_decode.util import get_all_seq_ids, split_batch_by_proposal_len
from vllm.sequence import SequenceGroupMetadata, get_all_seq_ids
from vllm.spec_decode.util import split_batch_by_proposal_len


def test_get_all_seq_ids():
Expand Down
54 changes: 39 additions & 15 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,15 +230,17 @@ def verify_with_parallel_config(
self,
parallel_config: "ParallelConfig",
) -> None:
total_num_attention_heads = self.hf_text_config.num_attention_heads
total_num_attention_heads = getattr(self.hf_text_config,
"num_attention_heads", 0)
tensor_parallel_size = parallel_config.tensor_parallel_size
if total_num_attention_heads % tensor_parallel_size != 0:
raise ValueError(
f"Total number of attention heads ({total_num_attention_heads})"
" must be divisible by tensor parallel size "
f"({tensor_parallel_size}).")

total_num_hidden_layers = self.hf_text_config.num_hidden_layers
total_num_hidden_layers = getattr(self.hf_text_config,
"num_hidden_layers", 0)
pipeline_parallel_size = parallel_config.pipeline_parallel_size
if total_num_hidden_layers % pipeline_parallel_size != 0:
raise ValueError(
Expand Down Expand Up @@ -341,8 +343,8 @@ def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int:

def get_num_attention_heads(self,
parallel_config: "ParallelConfig") -> int:
return self.hf_text_config.num_attention_heads // \
parallel_config.tensor_parallel_size
num_heads = getattr(self.hf_text_config, "num_attention_heads", 0)
return num_heads // parallel_config.tensor_parallel_size

def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
total_num_hidden_layers = self.hf_text_config.num_hidden_layers
Expand Down Expand Up @@ -818,7 +820,8 @@ def maybe_create_spec_config(
speculative_model (Optional[str]): The name of the speculative
model, if provided.
num_speculative_tokens (Optional[int]): The number of speculative
tokens, if provided.
tokens, if provided. Will default to the number in the draft
model config if present, otherwise is required.
speculative_max_model_len (Optional[int]): The maximum model len of
the speculative model. Used when testing the ability to skip
speculation for some sequences.
Expand All @@ -841,24 +844,18 @@ def maybe_create_spec_config(
the necessary conditions are met, else None.
"""

if speculative_model is None and num_speculative_tokens is None:
if speculative_model is None:
if num_speculative_tokens is not None:
raise ValueError("num_speculative_tokens was provided without "
"speculative_model.")
return None

if speculative_model is not None and num_speculative_tokens is None:
raise ValueError(
"Expected both speculative_model and "
"num_speculative_tokens to be provided, but found "
f"{speculative_model=} and {num_speculative_tokens=}.")

if (speculative_disable_by_batch_size is not None
and speculative_disable_by_batch_size < 2):
raise ValueError("Expect the batch size threshold of disabling "
"speculative decoding is > 1, but got "
f"{speculative_disable_by_batch_size=}")

assert (speculative_model is not None
and num_speculative_tokens is not None)

if enable_chunked_prefill:
raise ValueError(
"Speculative decoding and chunked prefill are "
Expand Down Expand Up @@ -912,6 +909,27 @@ def maybe_create_spec_config(
max_logprobs=target_model_config.max_logprobs,
)

if (draft_model_config.hf_config.model_type == "mlp_speculator"
and target_parallel_config.world_size != 1):
# MLPSpeculator TP support will be added very soon
raise ValueError(
"Speculative decoding with mlp_speculator models does not "
"yet support distributed inferencing (TP > 1).")

n_predict = getattr(draft_model_config.hf_config, "n_predict",
None)
if n_predict is not None:
if num_speculative_tokens is None:
# Default to max value defined in draft model config.
num_speculative_tokens = n_predict
elif num_speculative_tokens > n_predict:
# Verify provided value doesn't exceed the maximum
# supported by the draft model.
raise ValueError(
"Expected both speculative_model and "
"num_speculative_tokens to be provided, but found "
f"{speculative_model=} and {num_speculative_tokens=}.")

draft_model_config.max_model_len = (
SpeculativeConfig._maybe_override_draft_max_model_len(
speculative_max_model_len,
Expand All @@ -923,6 +941,12 @@ def maybe_create_spec_config(
SpeculativeConfig.create_draft_parallel_config(
target_parallel_config))

if num_speculative_tokens is None:
raise ValueError(
"num_speculative_tokens must be provided with "
"speculative_model unless the draft model config contains an "
"n_predict parameter.")

return SpeculativeConfig(
draft_model_config,
draft_parallel_config,
Expand Down
1 change: 1 addition & 0 deletions vllm/model_executor/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
"ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
"XverseForCausalLM": ("xverse", "XverseForCausalLM"),
"Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"),
"MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
}

_EMBEDDING_MODELS = {
Expand Down
143 changes: 143 additions & 0 deletions vllm/model_executor/models/mlp_speculator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
import math
from typing import Iterable, List, Tuple

import torch
import torch.nn as nn

from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.sequence import SamplerOutput


class MLPSpeculatorLayerNorm(nn.Module):
"""
A L2 normalization implementation
...
Args
----
normalized_shape : int
Dimensionality of input data (size of final tensor axis)
eps : float
Safety term to prevent division by zero. Make sure the chosen value
fits in the range of your encoding scheme
(i.e. fp16 requires eps >= 6e-8).
"""

def __init__(
self,
normalized_shape,
eps=1e-06,
):
super(MLPSpeculatorLayerNorm, self).__init__()
self.weight = nn.Parameter(torch.empty(normalized_shape))
self.bias = nn.Parameter(torch.empty(normalized_shape))
self.eps = eps

def forward(self, x):
xf = x
xf = xf * torch.rsqrt(xf.pow(2).mean(-1, keepdim=True) + self.eps)
x = xf.type_as(x)
x = self.weight * x
x = x + self.bias
return x


class MLPSpeculator(nn.Module):

def __init__(self, config, **kwargs) -> None:
super().__init__()
self.n_predict = config.n_predict
self.vocab_size = config.vocab_size
self.emb_dim = config.emb_dim
self.inner_dim = config.inner_dim if config.inner_dim != 0 \
else config.emb_dim

self.max_speculative_tokens = getattr(config, "max_speculative_tokens",
self.n_predict)

self.emb = nn.ModuleList([
VocabParallelEmbedding(config.vocab_size,
self.inner_dim,
org_num_embeddings=config.vocab_size)
for _ in range(self.max_speculative_tokens)
])

self.proj = nn.ModuleList([
nn.Linear((self.emb_dim if i == 0 else self.inner_dim),
self.inner_dim,
bias=False) for i in range(self.max_speculative_tokens)
])

self.head = nn.ModuleList([
nn.Linear(self.inner_dim, self.vocab_size, bias=False)
for _ in range(self.max_speculative_tokens)
])
self.ln = nn.ModuleList([
MLPSpeculatorLayerNorm(self.inner_dim)
for _ in range(self.max_speculative_tokens)
])

self.state_weight = 0.5**(0.5 / config.n_predict)
self.emb_weight = math.sqrt(
(1 - self.state_weight**2) * (self.inner_dim / 2))
self.activation = nn.GELU()
self.config = config
self.logits_processor = LogitsProcessor(config.vocab_size,
config.vocab_size, 1.0)
self.sampler = Sampler()

def generate_proposals(
self,
input_ids: torch.Tensor,
previous_hidden_states: torch.Tensor,
num_predict_tokens: int,
sampling_metadata: SamplingMetadata,
) -> List[SamplerOutput]:
if num_predict_tokens > self.max_speculative_tokens:
raise ValueError(f"Max speculative tokens for model is "
f"{self.max_speculative_tokens}, but "
f"{num_predict_tokens} were requested")

# b x 1 x d
previous_hidden_states = previous_hidden_states.unsqueeze(1)

# b x 1
last_tokens = input_ids.unsqueeze(1)

next_tokens = []

for head_index in range(num_predict_tokens):

# Project and predict
z = self.emb[head_index](last_tokens) # b k d
states = self.proj[head_index](previous_hidden_states)

# Weighted add of state_weight*state and emb_weight*z
# Let subsequent LN take care of denominator
# state_weight is close to 1, so shouldn't be any precision issues
states.add_(z, alpha=self.emb_weight / self.state_weight)

states = self.activation(self.ln[head_index](states)) # b k d
# TODO: not yet supporting top_k_tokens_per_head
previous_hidden_states = states

logits = self.logits_processor(self.head[head_index].weight,
states, sampling_metadata)

output = self.sampler(logits.flatten(0, 1), sampling_metadata)
last_tokens = output.sampled_token_ids
next_tokens.append(output)

return next_tokens

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
param = params_dict[name.replace("speculator.", "")]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
Loading

0 comments on commit 4e1b011

Please sign in to comment.