Skip to content

Commit b12518d

Browse files
JRosenkranztdoublepnjhilldaviswer
authored
[Model] MLPSpeculator speculative decoding support (#4947)
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com> Co-authored-by: Thomas Parnell <tpa@zurich.ibm.com> Co-authored-by: Nick Hill <nickhill@us.ibm.com> Co-authored-by: Davis Wertheimer <Davis.Wertheimer@ibm.com>
1 parent 6c5b7af commit b12518d

18 files changed

+523
-40
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import gc
2+
import time
3+
from typing import List
4+
5+
from vllm import LLM, SamplingParams
6+
7+
8+
def time_generation(llm: LLM, prompts: List[str],
9+
sampling_params: SamplingParams):
10+
# Generate texts from the prompts. The output is a list of RequestOutput
11+
# objects that contain the prompt, generated text, and other information.
12+
# Warmup first
13+
llm.generate(prompts, sampling_params)
14+
llm.generate(prompts, sampling_params)
15+
start = time.time()
16+
outputs = llm.generate(prompts, sampling_params)
17+
end = time.time()
18+
print((end - start) / sum([len(o.outputs[0].token_ids) for o in outputs]))
19+
# Print the outputs.
20+
for output in outputs:
21+
generated_text = output.outputs[0].text
22+
print(f"text: {generated_text!r}")
23+
24+
25+
if __name__ == "__main__":
26+
27+
template = (
28+
"Below is an instruction that describes a task. Write a response "
29+
"that appropriately completes the request.\n\n### Instruction:\n{}"
30+
"\n\n### Response:\n")
31+
32+
# Sample prompts.
33+
prompts = [
34+
"Write about the president of the United States.",
35+
]
36+
prompts = [template.format(prompt) for prompt in prompts]
37+
# Create a sampling params object.
38+
sampling_params = SamplingParams(temperature=0.0, max_tokens=200)
39+
40+
# Create an LLM without spec decoding
41+
llm = LLM(model="meta-llama/Llama-2-13b-chat-hf")
42+
43+
print("Without speculation")
44+
time_generation(llm, prompts, sampling_params)
45+
46+
del llm
47+
gc.collect()
48+
49+
# Create an LLM with spec decoding
50+
llm = LLM(
51+
model="meta-llama/Llama-2-13b-chat-hf",
52+
speculative_model="ibm-fms/llama-13b-accelerator",
53+
# These are currently required for MLPSpeculator decoding
54+
use_v2_block_manager=True,
55+
enforce_eager=True,
56+
)
57+
58+
print("With speculation")
59+
time_generation(llm, prompts, sampling_params)

tests/spec_decode/test_spec_decode_worker.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -456,7 +456,9 @@ def test_k_equals_zero(k: int, batch_size: int):
456456
rejection_sampler.token_id_dtype = torch.int64
457457
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
458458

459-
target_worker.execute_model.return_value = [MagicMock(spec=SamplerOutput)]
459+
sampler_output = MagicMock(spec=SamplerOutput)
460+
sampler_output.hidden_states = None
461+
target_worker.execute_model.return_value = [sampler_output]
460462

461463
draft_worker.device = 'cuda'
462464
target_worker.device = 'cuda'
@@ -497,7 +499,9 @@ def test_empty_input_batch(k: int, batch_size: int):
497499
rejection_sampler.token_id_dtype = torch.int64
498500
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
499501

500-
target_worker.execute_model.return_value = [MagicMock(spec=SamplerOutput)]
502+
sampler_output = MagicMock(spec=SamplerOutput)
503+
sampler_output.hidden_states = None
504+
target_worker.execute_model.return_value = [sampler_output]
501505

502506
draft_worker.device = 'cuda'
503507
target_worker.device = 'cuda'

tests/spec_decode/test_utils.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22

33
import pytest
44

5-
from vllm.sequence import SequenceGroupMetadata
6-
from vllm.spec_decode.util import get_all_seq_ids, split_batch_by_proposal_len
5+
from vllm.sequence import SequenceGroupMetadata, get_all_seq_ids
6+
from vllm.spec_decode.util import split_batch_by_proposal_len
77

88

99
def test_get_all_seq_ids():

vllm/config.py

+39-15
Original file line numberDiff line numberDiff line change
@@ -230,15 +230,17 @@ def verify_with_parallel_config(
230230
self,
231231
parallel_config: "ParallelConfig",
232232
) -> None:
233-
total_num_attention_heads = self.hf_text_config.num_attention_heads
233+
total_num_attention_heads = getattr(self.hf_text_config,
234+
"num_attention_heads", 0)
234235
tensor_parallel_size = parallel_config.tensor_parallel_size
235236
if total_num_attention_heads % tensor_parallel_size != 0:
236237
raise ValueError(
237238
f"Total number of attention heads ({total_num_attention_heads})"
238239
" must be divisible by tensor parallel size "
239240
f"({tensor_parallel_size}).")
240241

241-
total_num_hidden_layers = self.hf_text_config.num_hidden_layers
242+
total_num_hidden_layers = getattr(self.hf_text_config,
243+
"num_hidden_layers", 0)
242244
pipeline_parallel_size = parallel_config.pipeline_parallel_size
243245
if total_num_hidden_layers % pipeline_parallel_size != 0:
244246
raise ValueError(
@@ -341,8 +343,8 @@ def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int:
341343

342344
def get_num_attention_heads(self,
343345
parallel_config: "ParallelConfig") -> int:
344-
return self.hf_text_config.num_attention_heads // \
345-
parallel_config.tensor_parallel_size
346+
num_heads = getattr(self.hf_text_config, "num_attention_heads", 0)
347+
return num_heads // parallel_config.tensor_parallel_size
346348

347349
def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
348350
total_num_hidden_layers = self.hf_text_config.num_hidden_layers
@@ -818,7 +820,8 @@ def maybe_create_spec_config(
818820
speculative_model (Optional[str]): The name of the speculative
819821
model, if provided.
820822
num_speculative_tokens (Optional[int]): The number of speculative
821-
tokens, if provided.
823+
tokens, if provided. Will default to the number in the draft
824+
model config if present, otherwise is required.
822825
speculative_max_model_len (Optional[int]): The maximum model len of
823826
the speculative model. Used when testing the ability to skip
824827
speculation for some sequences.
@@ -841,24 +844,18 @@ def maybe_create_spec_config(
841844
the necessary conditions are met, else None.
842845
"""
843846

844-
if speculative_model is None and num_speculative_tokens is None:
847+
if speculative_model is None:
848+
if num_speculative_tokens is not None:
849+
raise ValueError("num_speculative_tokens was provided without "
850+
"speculative_model.")
845851
return None
846852

847-
if speculative_model is not None and num_speculative_tokens is None:
848-
raise ValueError(
849-
"Expected both speculative_model and "
850-
"num_speculative_tokens to be provided, but found "
851-
f"{speculative_model=} and {num_speculative_tokens=}.")
852-
853853
if (speculative_disable_by_batch_size is not None
854854
and speculative_disable_by_batch_size < 2):
855855
raise ValueError("Expect the batch size threshold of disabling "
856856
"speculative decoding is > 1, but got "
857857
f"{speculative_disable_by_batch_size=}")
858858

859-
assert (speculative_model is not None
860-
and num_speculative_tokens is not None)
861-
862859
if enable_chunked_prefill:
863860
raise ValueError(
864861
"Speculative decoding and chunked prefill are "
@@ -912,6 +909,27 @@ def maybe_create_spec_config(
912909
max_logprobs=target_model_config.max_logprobs,
913910
)
914911

912+
if (draft_model_config.hf_config.model_type == "mlp_speculator"
913+
and target_parallel_config.world_size != 1):
914+
# MLPSpeculator TP support will be added very soon
915+
raise ValueError(
916+
"Speculative decoding with mlp_speculator models does not "
917+
"yet support distributed inferencing (TP > 1).")
918+
919+
n_predict = getattr(draft_model_config.hf_config, "n_predict",
920+
None)
921+
if n_predict is not None:
922+
if num_speculative_tokens is None:
923+
# Default to max value defined in draft model config.
924+
num_speculative_tokens = n_predict
925+
elif num_speculative_tokens > n_predict:
926+
# Verify provided value doesn't exceed the maximum
927+
# supported by the draft model.
928+
raise ValueError(
929+
"Expected both speculative_model and "
930+
"num_speculative_tokens to be provided, but found "
931+
f"{speculative_model=} and {num_speculative_tokens=}.")
932+
915933
draft_model_config.max_model_len = (
916934
SpeculativeConfig._maybe_override_draft_max_model_len(
917935
speculative_max_model_len,
@@ -923,6 +941,12 @@ def maybe_create_spec_config(
923941
SpeculativeConfig.create_draft_parallel_config(
924942
target_parallel_config))
925943

944+
if num_speculative_tokens is None:
945+
raise ValueError(
946+
"num_speculative_tokens must be provided with "
947+
"speculative_model unless the draft model config contains an "
948+
"n_predict parameter.")
949+
926950
return SpeculativeConfig(
927951
draft_model_config,
928952
draft_parallel_config,

vllm/model_executor/models/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
"ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
6161
"XverseForCausalLM": ("xverse", "XverseForCausalLM"),
6262
"Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"),
63+
"MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
6364
}
6465

6566
_EMBEDDING_MODELS = {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
import math
2+
from typing import Iterable, List, Tuple
3+
4+
import torch
5+
import torch.nn as nn
6+
7+
from vllm.model_executor import SamplingMetadata
8+
from vllm.model_executor.layers.logits_processor import LogitsProcessor
9+
from vllm.model_executor.layers.sampler import Sampler
10+
from vllm.model_executor.layers.vocab_parallel_embedding import (
11+
VocabParallelEmbedding)
12+
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
13+
from vllm.sequence import SamplerOutput
14+
15+
16+
class MLPSpeculatorLayerNorm(nn.Module):
17+
"""
18+
A L2 normalization implementation
19+
...
20+
Args
21+
----
22+
normalized_shape : int
23+
Dimensionality of input data (size of final tensor axis)
24+
eps : float
25+
Safety term to prevent division by zero. Make sure the chosen value
26+
fits in the range of your encoding scheme
27+
(i.e. fp16 requires eps >= 6e-8).
28+
"""
29+
30+
def __init__(
31+
self,
32+
normalized_shape,
33+
eps=1e-06,
34+
):
35+
super(MLPSpeculatorLayerNorm, self).__init__()
36+
self.weight = nn.Parameter(torch.empty(normalized_shape))
37+
self.bias = nn.Parameter(torch.empty(normalized_shape))
38+
self.eps = eps
39+
40+
def forward(self, x):
41+
xf = x
42+
xf = xf * torch.rsqrt(xf.pow(2).mean(-1, keepdim=True) + self.eps)
43+
x = xf.type_as(x)
44+
x = self.weight * x
45+
x = x + self.bias
46+
return x
47+
48+
49+
class MLPSpeculator(nn.Module):
50+
51+
def __init__(self, config, **kwargs) -> None:
52+
super().__init__()
53+
self.n_predict = config.n_predict
54+
self.vocab_size = config.vocab_size
55+
self.emb_dim = config.emb_dim
56+
self.inner_dim = config.inner_dim if config.inner_dim != 0 \
57+
else config.emb_dim
58+
59+
self.max_speculative_tokens = getattr(config, "max_speculative_tokens",
60+
self.n_predict)
61+
62+
self.emb = nn.ModuleList([
63+
VocabParallelEmbedding(config.vocab_size,
64+
self.inner_dim,
65+
org_num_embeddings=config.vocab_size)
66+
for _ in range(self.max_speculative_tokens)
67+
])
68+
69+
self.proj = nn.ModuleList([
70+
nn.Linear((self.emb_dim if i == 0 else self.inner_dim),
71+
self.inner_dim,
72+
bias=False) for i in range(self.max_speculative_tokens)
73+
])
74+
75+
self.head = nn.ModuleList([
76+
nn.Linear(self.inner_dim, self.vocab_size, bias=False)
77+
for _ in range(self.max_speculative_tokens)
78+
])
79+
self.ln = nn.ModuleList([
80+
MLPSpeculatorLayerNorm(self.inner_dim)
81+
for _ in range(self.max_speculative_tokens)
82+
])
83+
84+
self.state_weight = 0.5**(0.5 / config.n_predict)
85+
self.emb_weight = math.sqrt(
86+
(1 - self.state_weight**2) * (self.inner_dim / 2))
87+
self.activation = nn.GELU()
88+
self.config = config
89+
self.logits_processor = LogitsProcessor(config.vocab_size,
90+
config.vocab_size, 1.0)
91+
self.sampler = Sampler()
92+
93+
def generate_proposals(
94+
self,
95+
input_ids: torch.Tensor,
96+
previous_hidden_states: torch.Tensor,
97+
num_predict_tokens: int,
98+
sampling_metadata: SamplingMetadata,
99+
) -> List[SamplerOutput]:
100+
if num_predict_tokens > self.max_speculative_tokens:
101+
raise ValueError(f"Max speculative tokens for model is "
102+
f"{self.max_speculative_tokens}, but "
103+
f"{num_predict_tokens} were requested")
104+
105+
# b x 1 x d
106+
previous_hidden_states = previous_hidden_states.unsqueeze(1)
107+
108+
# b x 1
109+
last_tokens = input_ids.unsqueeze(1)
110+
111+
next_tokens = []
112+
113+
for head_index in range(num_predict_tokens):
114+
115+
# Project and predict
116+
z = self.emb[head_index](last_tokens) # b k d
117+
states = self.proj[head_index](previous_hidden_states)
118+
119+
# Weighted add of state_weight*state and emb_weight*z
120+
# Let subsequent LN take care of denominator
121+
# state_weight is close to 1, so shouldn't be any precision issues
122+
states.add_(z, alpha=self.emb_weight / self.state_weight)
123+
124+
states = self.activation(self.ln[head_index](states)) # b k d
125+
# TODO: not yet supporting top_k_tokens_per_head
126+
previous_hidden_states = states
127+
128+
logits = self.logits_processor(self.head[head_index].weight,
129+
states, sampling_metadata)
130+
131+
output = self.sampler(logits.flatten(0, 1), sampling_metadata)
132+
last_tokens = output.sampled_token_ids
133+
next_tokens.append(output)
134+
135+
return next_tokens
136+
137+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
138+
params_dict = dict(self.named_parameters())
139+
for name, loaded_weight in weights:
140+
param = params_dict[name.replace("speculator.", "")]
141+
weight_loader = getattr(param, "weight_loader",
142+
default_weight_loader)
143+
weight_loader(param, loaded_weight)

0 commit comments

Comments
 (0)