Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Spec Decode] (1/2) Remove batch expansion #8839

Merged
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
c4c5dab
draft
LiuXiaoxuanPKU Sep 25, 2024
cb08091
w/o cuda graph support
LiuXiaoxuanPKU Sep 25, 2024
8c10b11
args and tests
LiuXiaoxuanPKU Sep 25, 2024
44930fb
disable mqa for ngram and format
LiuXiaoxuanPKU Sep 26, 2024
e64c61b
clean up and tests
LiuXiaoxuanPKU Sep 26, 2024
541b767
revert example
LiuXiaoxuanPKU Sep 26, 2024
b6c1de3
minor
LiuXiaoxuanPKU Sep 26, 2024
5824b78
minor
LiuXiaoxuanPKU Sep 26, 2024
07aebc0
fix tests -- chunked prefill and hiddens states in spec dec
LiuXiaoxuanPKU Sep 26, 2024
d6cb1cc
fix
LiuXiaoxuanPKU Sep 26, 2024
bcc1fe9
minor
LiuXiaoxuanPKU Sep 26, 2024
b036d06
fix
LiuXiaoxuanPKU Sep 27, 2024
b93694d
Merge branch 'main' into remove_batch_expansion
LiuXiaoxuanPKU Sep 27, 2024
35750a6
fix sampler for beam search
LiuXiaoxuanPKU Sep 28, 2024
741068a
revert num compute tokens
LiuXiaoxuanPKU Sep 28, 2024
71be340
disbale mqa scorer when draft model and target model have different m…
LiuXiaoxuanPKU Sep 28, 2024
cff6b0f
diable mqa for cuda graph
LiuXiaoxuanPKU Sep 28, 2024
f4fb00b
fix partial comments
LiuXiaoxuanPKU Sep 28, 2024
b3e8691
fix comments
LiuXiaoxuanPKU Sep 28, 2024
238e5a0
fix sampler and spec dec tests
LiuXiaoxuanPKU Sep 28, 2024
5063c95
remove backend
LiuXiaoxuanPKU Sep 29, 2024
70662b0
more test fix
LiuXiaoxuanPKU Sep 30, 2024
878d2da
Merge branch 'main' into remove_batch_expansion
LiuXiaoxuanPKU Oct 1, 2024
0e32744
fix num_compute_token
LiuXiaoxuanPKU Oct 1, 2024
7ee2998
clean up
LiuXiaoxuanPKU Oct 1, 2024
d39c8a9
more fix for num_compute_token
LiuXiaoxuanPKU Oct 1, 2024
79ac29c
change log condition
LiuXiaoxuanPKU Oct 1, 2024
6f3388b
add comments
LiuXiaoxuanPKU Oct 1, 2024
1425332
query len for multi-step, specify ci backend
LiuXiaoxuanPKU Oct 1, 2024
e5702a9
fix ci
LiuXiaoxuanPKU Oct 1, 2024
8e27664
fix
LiuXiaoxuanPKU Oct 1, 2024
3f3c222
format
LiuXiaoxuanPKU Oct 1, 2024
2707422
context_len for multi-step and encoder decoder, fix decode_len
LiuXiaoxuanPKU Oct 1, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/samplers/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,7 @@ def run_test_case(*, expected_penalization: List[bool],
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list,
seq_lens=seq_lens if seq_lens else None,
query_lens=seq_lens if seq_lens else None,
query_lens=seq_lens if seq_lens else [1] * batch_size,
device=device,
pin_memory=is_pin_memory_available())
# the logits tensor is modified in-place by the sampler
Expand Down
44 changes: 44 additions & 0 deletions tests/spec_decode/e2e/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,47 @@ def test_speculative_model_quantization_config(vllm_runner, common_llm_kwargs,
max_output_len=32,
seed=seed,
temperature=0.0)


@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model_name": MAIN_MODEL,

# Skip cuda graph recording for fast test.
"enforce_eager": True,

# Required for spec decode.
"use_v2_block_manager": True,
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 3,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs",
[{
"speculative_disable_mqa_scorer": True,
}])
@pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
def test_mqa_scorer(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs, batch_size: int,
output_len: int, seed: int):
"""Verify that ngram speculative decoding generates the same output
with batch expansion scorer and mqa scorer.
"""
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0)
49 changes: 49 additions & 0 deletions tests/spec_decode/e2e/test_medusa_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,55 @@ def test_medusa_disable_queue(vllm_runner, common_llm_kwargs,
temperature=0.0)


@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,

# Required for spec decode.
"use_v2_block_manager": True,

# Precision
"dtype": PRECISION,

# Main model
"model_name": MAIN_MODEL,
"speculative_model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
"speculative_disable_by_batch_size": 4
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs",
[{
"speculative_disable_mqa_scorer": True,
}])
@pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
def test_mqa_scorer(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs, batch_size: int,
output_len: int, seed: int):
"""Verify that speculative decoding generates the same output
with batch expansion scorer and mqa scorer.
"""
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0)


if __name__ == "__main__":
import pytest
pytest.main([__file__])
43 changes: 43 additions & 0 deletions tests/spec_decode/e2e/test_mlp_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,3 +460,46 @@ def test_mlp_disable_queue(vllm_runner, common_llm_kwargs,
max_output_len=output_len,
seed=seed,
temperature=0.0)


@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model_name": MAIN_MODEL,

# Skip cuda graph recording for fast test.
"enforce_eager": True,

# Required for spec decode.
"use_v2_block_manager": True,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to specify num_speculative_tokens for here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to specify num_speculative_tokens for here?

The test verifies the correctness of MLPSpeculator, which uses additional heads to make proposals. We don't need to specify num_speculative_tokens here because it will read the number of tokens from the model config here.

"speculative_model": SPEC_MODEL,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs",
[{
"speculative_disable_mqa_scorer": True,
}])
@pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
def test_mqa_scorer(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs, batch_size: int,
output_len: int, seed: int):
"""Verify that speculative decoding generates the same output
with batch expansion scorer and mqa scorer.
"""
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0)
46 changes: 46 additions & 0 deletions tests/spec_decode/e2e/test_ngram_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,3 +292,49 @@ def test_ngram_disable_queue(vllm_runner, common_llm_kwargs,
max_output_len=output_len,
seed=seed,
temperature=0.0)


@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model_name": "JackFram/llama-68m",

# Skip cuda graph recording for fast test.
"enforce_eager": True,

# Required for spec decode.
"use_v2_block_manager": True,
"speculative_model": "[ngram]",
"num_speculative_tokens": 5,
"ngram_prompt_lookup_max": 3,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs",
[{
"speculative_disable_mqa_scorer": True,
}])
@pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
def test_ngram_scorer(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs, baseline_llm_kwargs,
test_llm_kwargs, batch_size: int, output_len: int,
seed: int):
"""Verify that ngram speculative decoding generates the same output
with batch expansion scorer and mqa scorer.
"""
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0)
1 change: 0 additions & 1 deletion tests/spec_decode/test_multi_step_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,6 @@ def test_same_output_for_multi_step():
block_size,
num_gpu_blocks,
seed,
model_runner_cls=TP1DraftModelRunner,
)

worker = create_worker(
Expand Down
65 changes: 65 additions & 0 deletions tests/spec_decode/test_scorer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import pytest
import torch

from vllm.sequence import ExecuteModelRequest
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
from vllm.spec_decode.interfaces import SpeculativeProposals, SpeculativeScores
from vllm.spec_decode.mqa_scorer import MQAScorer
from vllm.worker.worker import Worker

from .utils import create_batch, create_worker


def create_proposal(batch_size: int, propose_len: int, vocab_size: int,
device: str) -> SpeculativeProposals:
proposal_probs = torch.rand((batch_size, propose_len, vocab_size),
device=device)
proposal_token_ids = torch.argmax(proposal_probs, dim=-1)
proposal_lens = torch.tensor([propose_len] * batch_size, device=device)
return SpeculativeProposals(proposal_token_ids, proposal_probs,
proposal_lens)


def assert_score_equal(score1: SpeculativeScores,
score2: SpeculativeScores) -> None:
assert torch.allclose(score1.probs, score2.probs)
assert torch.allclose(score1.logprobs, score2.logprobs)
assert torch.equal(score1.token_ids, score2.token_ids)


@pytest.mark.parametrize('model_name', ['facebook/opt-125m'])
@pytest.mark.parametrize('batch_size', [1, 2, 4, 8, 16])
@pytest.mark.parametrize('propose_len', [1, 3, 5])
@pytest.mark.parametrize('device', ['cuda'])
def test_scoroer(model_name: str, batch_size: int, propose_len: int,
device: str) -> None:
"""
Compare the batch expansion scorer and mqa scorer return the same score
"""
seed = 0
block_size = 32
num_gpu_blocks = 2048 // block_size
scorer_worker = create_worker(Worker, model_name, block_size,
num_gpu_blocks, seed)
scorer_worker.model_runner.model.sampler.include_gpu_probs_tensor = True
scorer_worker.model_runner.model.sampler.\
should_modify_greedy_probs_inplace = True

vocab_size = scorer_worker.vocab_size
proposals = create_proposal(batch_size, propose_len, vocab_size, device)
seq_group_metadatalist, _, _ = create_batch(batch_size,
propose_len,
block_size=block_size,
num_gpu_blocks=num_gpu_blocks)
requests = ExecuteModelRequest(seq_group_metadatalist,
num_lookahead_slots=propose_len)

batch_expansion_scorer = BatchExpansionTop1Scorer(scorer_worker, device,
vocab_size)
batch_expansion_score = batch_expansion_scorer.score_proposals(
requests, proposals)

mqa_scorer = MQAScorer(scorer_worker, device, vocab_size)
mqa_score = mqa_scorer.score_proposals(requests, proposals)

assert_score_equal(batch_expansion_score, mqa_score)
9 changes: 5 additions & 4 deletions tests/spec_decode/test_spec_decode_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,10 @@ def test_correctly_calls_draft_model(k: int, batch_size: int,
@pytest.mark.parametrize("acceptance_sampler_method",
["rejection_sampler", "typical_acceptance_sampler"])
@torch.inference_mode()
def test_correctly_calls_target_model(k: int, batch_size: int,
acceptance_sampler_method: str):
def test_batch_expansion_correctly_calls_target_model(
k: int, batch_size: int, acceptance_sampler_method: str):
"""Verify SpecDecodeWorker calls the target model with correct
inputs. Everything else is mocked out.
inputs with batch expansion. Everything else is mocked out.
"""
draft_worker = mock_worker(cls=MultiStepWorker, use_spec=False)
target_worker = mock_worker(use_spec=False)
Expand All @@ -82,7 +82,8 @@ def test_correctly_calls_target_model(k: int, batch_size: int,
target_worker,
mock_spec_decode_sampler(acceptance_sampler_method),
disable_logprobs=False,
metrics_collector=metrics_collector)
metrics_collector=metrics_collector,
disable_mqa_scorer=True)
worker.init_device()

vocab_size = 32_000
Expand Down
29 changes: 16 additions & 13 deletions tests/spec_decode/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,19 +131,22 @@ def create_seq_group_metadata_from_prompts(
for i, final_len in enumerate(final_prompt_lens)
}

return [
SequenceGroupMetadata(
request_id=str(i),
is_prompt=len(cont_token_ids) == 0,
seq_data={
i: SequenceData.from_seqs(prompt_token_ids[:],
cont_token_ids[:]),
},
sampling_params=SamplingParams(temperature=0.0, ),
block_tables={i: block_allocations[i][:]},
) for i, (prompt_token_ids,
cont_token_ids) in enumerate(zip(prompts, continuations))
]
seq_grou_metadata_list = []
for i, (prompt_token_ids,
cont_token_ids) in enumerate(zip(prompts, continuations)):
data = SequenceData.from_seqs(prompt_token_ids, cont_token_ids)
data.update_num_computed_tokens(
len(prompt_token_ids) + len(cont_token_ids) - 1)
seq_data = {i: data}
seq_grou_metadata_list.append(
SequenceGroupMetadata(
request_id=str(i),
is_prompt=len(cont_token_ids) == 0,
seq_data=seq_data,
sampling_params=SamplingParams(temperature=0.0),
block_tables={i: block_allocations[i][:]},
))
return seq_grou_metadata_list


def assert_logprobs_dict_allclose(
Expand Down
6 changes: 6 additions & 0 deletions vllm/attention/backends/blocksparse_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,12 @@ class BlocksparseFlashAttentionMetadata(AttentionMetadata):
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
use_cuda_graph: bool

# Number of query tokens for each request in the batch.
# Currently, we require that all requests have the same number of query
# tokens during the decoding phase. When speculavie decoding is enabled,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: speculavie -> speculative

# decode_query_len might be greater than 1. In all other cases, it is 1.
decode_query_len: Optional[int] = None

_cached_prefill_metadata: Optional[
"BlocksparseFlashAttentionMetadata"] = None
_cached_decode_metadata: Optional[
Expand Down
Loading
Loading