-
-
Notifications
You must be signed in to change notification settings - Fork 4.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Spec Decode] (1/2) Remove batch expansion #8839
Merged
LiuXiaoxuanPKU
merged 33 commits into
vllm-project:main
from
LiuXiaoxuanPKU:remove_batch_expansion
Oct 1, 2024
+531
−99
Merged
Changes from 25 commits
Commits
Show all changes
33 commits
Select commit
Hold shift + click to select a range
c4c5dab
draft
LiuXiaoxuanPKU cb08091
w/o cuda graph support
LiuXiaoxuanPKU 8c10b11
args and tests
LiuXiaoxuanPKU 44930fb
disable mqa for ngram and format
LiuXiaoxuanPKU e64c61b
clean up and tests
LiuXiaoxuanPKU 541b767
revert example
LiuXiaoxuanPKU b6c1de3
minor
LiuXiaoxuanPKU 5824b78
minor
LiuXiaoxuanPKU 07aebc0
fix tests -- chunked prefill and hiddens states in spec dec
LiuXiaoxuanPKU d6cb1cc
fix
LiuXiaoxuanPKU bcc1fe9
minor
LiuXiaoxuanPKU b036d06
fix
LiuXiaoxuanPKU b93694d
Merge branch 'main' into remove_batch_expansion
LiuXiaoxuanPKU 35750a6
fix sampler for beam search
LiuXiaoxuanPKU 741068a
revert num compute tokens
LiuXiaoxuanPKU 71be340
disbale mqa scorer when draft model and target model have different m…
LiuXiaoxuanPKU cff6b0f
diable mqa for cuda graph
LiuXiaoxuanPKU f4fb00b
fix partial comments
LiuXiaoxuanPKU b3e8691
fix comments
LiuXiaoxuanPKU 238e5a0
fix sampler and spec dec tests
LiuXiaoxuanPKU 5063c95
remove backend
LiuXiaoxuanPKU 70662b0
more test fix
LiuXiaoxuanPKU 878d2da
Merge branch 'main' into remove_batch_expansion
LiuXiaoxuanPKU 0e32744
fix num_compute_token
LiuXiaoxuanPKU 7ee2998
clean up
LiuXiaoxuanPKU d39c8a9
more fix for num_compute_token
LiuXiaoxuanPKU 79ac29c
change log condition
LiuXiaoxuanPKU 6f3388b
add comments
LiuXiaoxuanPKU 1425332
query len for multi-step, specify ci backend
LiuXiaoxuanPKU e5702a9
fix ci
LiuXiaoxuanPKU 8e27664
fix
LiuXiaoxuanPKU 3f3c222
format
LiuXiaoxuanPKU 2707422
context_len for multi-step and encoder decoder, fix decode_len
LiuXiaoxuanPKU File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
import pytest | ||
import torch | ||
|
||
from vllm.sequence import ExecuteModelRequest | ||
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer | ||
from vllm.spec_decode.interfaces import SpeculativeProposals, SpeculativeScores | ||
from vllm.spec_decode.mqa_scorer import MQAScorer | ||
from vllm.worker.worker import Worker | ||
|
||
from .utils import create_batch, create_worker | ||
|
||
|
||
def create_proposal(batch_size: int, propose_len: int, vocab_size: int, | ||
device: str) -> SpeculativeProposals: | ||
proposal_probs = torch.rand((batch_size, propose_len, vocab_size), | ||
device=device) | ||
proposal_token_ids = torch.argmax(proposal_probs, dim=-1) | ||
proposal_lens = torch.tensor([propose_len] * batch_size, device=device) | ||
return SpeculativeProposals(proposal_token_ids, proposal_probs, | ||
proposal_lens) | ||
|
||
|
||
def assert_score_equal(score1: SpeculativeScores, | ||
score2: SpeculativeScores) -> None: | ||
assert torch.allclose(score1.probs, score2.probs) | ||
assert torch.allclose(score1.logprobs, score2.logprobs) | ||
assert torch.equal(score1.token_ids, score2.token_ids) | ||
|
||
|
||
@pytest.mark.parametrize('model_name', ['facebook/opt-125m']) | ||
@pytest.mark.parametrize('batch_size', [1, 2, 4, 8, 16]) | ||
@pytest.mark.parametrize('propose_len', [1, 3, 5]) | ||
@pytest.mark.parametrize('device', ['cuda']) | ||
def test_scoroer(model_name: str, batch_size: int, propose_len: int, | ||
device: str) -> None: | ||
""" | ||
Compare the batch expansion scorer and mqa scorer return the same score | ||
""" | ||
seed = 0 | ||
block_size = 32 | ||
num_gpu_blocks = 2048 // block_size | ||
scorer_worker = create_worker(Worker, model_name, block_size, | ||
num_gpu_blocks, seed) | ||
scorer_worker.model_runner.model.sampler.include_gpu_probs_tensor = True | ||
scorer_worker.model_runner.model.sampler.\ | ||
should_modify_greedy_probs_inplace = True | ||
|
||
vocab_size = scorer_worker.vocab_size | ||
proposals = create_proposal(batch_size, propose_len, vocab_size, device) | ||
seq_group_metadatalist, _, _ = create_batch(batch_size, | ||
propose_len, | ||
block_size=block_size, | ||
num_gpu_blocks=num_gpu_blocks) | ||
requests = ExecuteModelRequest(seq_group_metadatalist, | ||
num_lookahead_slots=propose_len) | ||
|
||
batch_expansion_scorer = BatchExpansionTop1Scorer(scorer_worker, device, | ||
vocab_size) | ||
batch_expansion_score = batch_expansion_scorer.score_proposals( | ||
requests, proposals) | ||
|
||
mqa_scorer = MQAScorer(scorer_worker, device, vocab_size) | ||
mqa_score = mqa_scorer.score_proposals(requests, proposals) | ||
|
||
assert_score_equal(batch_expansion_score, mqa_score) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -186,6 +186,12 @@ class BlocksparseFlashAttentionMetadata(AttentionMetadata): | |
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. | ||
use_cuda_graph: bool | ||
|
||
# Number of query tokens for each request in the batch. | ||
# Currently, we require that all requests have the same number of query | ||
# tokens during the decoding phase. When speculavie decoding is enabled, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: |
||
# decode_query_len might be greater than 1. In all other cases, it is 1. | ||
decode_query_len: Optional[int] = None | ||
|
||
_cached_prefill_metadata: Optional[ | ||
"BlocksparseFlashAttentionMetadata"] = None | ||
_cached_decode_metadata: Optional[ | ||
|
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to specify num_speculative_tokens for here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The test verifies the correctness of MLPSpeculator, which uses additional heads to make proposals. We don't need to specify num_speculative_tokens here because it will read the number of tokens from the model config here.