Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Speculative Decoding 2/2 ] Integrate typical acceptance sampler into Spec Decode Worker #5348

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
5650b95
Merge pull request #1 from vllm-project/main
sroy745 May 29, 2024
8f36146
Merge branch 'vllm-project:main' into main
sroy745 Jun 3, 2024
9e75057
Merge branch 'vllm-project:main' into main
sroy745 Jun 3, 2024
bbf1484
Integrate Typical Acceptance Sampler into spec decode worker
sroy745 Jun 7, 2024
db2c679
Merge branch 'vllm-project:main' into main
sroy745 Jun 7, 2024
3495673
Fixing tests
sroy745 Jun 9, 2024
26c7c57
adding missing commit
sroy745 Jun 10, 2024
090f0bf
reverting changes to conftest
sroy745 Jun 10, 2024
733cc6e
reverting changes to conftest
sroy745 Jun 10, 2024
19ca0c9
Merge branch 'main' into spec_decode_integrate_accpetance_sampler
sroy745 Jun 10, 2024
acf8d2c
Dummy commit
sroy745 Jun 10, 2024
2d2b02b
Merge branch 'spec_decode_integrate_accpetance_sampler' of https://gi…
sroy745 Jun 10, 2024
2010b35
Revert unnecessary commits
sroy745 Jun 10, 2024
8d7512c
Merge branch 'vllm-project:main' into main
sroy745 Jun 10, 2024
7fa64b6
Merge remote-tracking branch 'origin/main' into spec_decode_integrate…
sroy745 Jun 10, 2024
dea6fbd
Pass only one sampler which can either be the RejectionSampler of the…
sroy745 Jun 10, 2024
c3383db
Fix test scripture
sroy745 Jun 10, 2024
b15abba
Fix tests
sroy745 Jun 11, 2024
6ca731c
Fix tests
sroy745 Jun 11, 2024
483c671
Pass only 1 verification_sampler which can either be rejectionSampler…
sroy745 Jun 11, 2024
2c6d06c
Update metrics.py to take the base sampler class
sroy745 Jun 11, 2024
027b485
Fix tests and comments
sroy745 Jun 11, 2024
ded92ac
Fix test fixture and default values of args
sroy745 Jun 11, 2024
738871e
Small misc fixes
sroy745 Jun 11, 2024
50e8771
Fix spec_decode/test_metrics.py
sroy745 Jun 11, 2024
101611e
Merge branch 'main' into spec_decode_integrate_accpetance_sampler
sroy745 Jun 11, 2024
5e6638b
Merge branch 'main' into spec_decode_integrate_accpetance_sampler
sroy745 Jun 25, 2024
cc760a0
Make rejection_sampler.py and typical_acceptance_sampler.py implement…
sroy745 Jun 25, 2024
360ce0b
Raise exception instead of returning None for invalid sampler name
sroy745 Jun 25, 2024
6572ba4
Adding log about type of sampler
sroy745 Jun 25, 2024
be85f07
Misc comment fixes
sroy745 Jun 26, 2024
6dc9efe
Misc fixes
sroy745 Jun 26, 2024
512fad9
Misc fixes
sroy745 Jun 26, 2024
b1d510c
Misc fixes
sroy745 Jun 26, 2024
f4b9e4d
Misc fixes
sroy745 Jun 26, 2024
0ea9408
Documentation
sroy745 Jun 26, 2024
5772d04
Fix comments
sroy745 Jun 26, 2024
b7254e7
Fix arg name
sroy745 Jun 26, 2024
ef93081
Fixing a test
sroy745 Jun 26, 2024
0165842
Fix comment
sroy745 Jun 26, 2024
510974b
Fix formatting
sroy745 Jun 26, 2024
396fa54
Fixing tests and lint failures
sroy745 Jun 26, 2024
f8cc895
Removing e2e test for TypicalAcceptanceSampler from test_ngram_correc…
sroy745 Jun 27, 2024
439117d
Fix a comment
sroy745 Jun 27, 2024
75f034f
Dummy commit
sroy745 Jun 27, 2024
a0f5ade
Merge pull request #2 from vllm-project/main
sroy745 Jun 27, 2024
3082255
Fix format error
sroy745 Jun 27, 2024
4e7f51a
Merge pull request #3 from vllm-project/main
sroy745 Jun 28, 2024
d26c624
Dummy fix
sroy745 Jun 29, 2024
98d5f92
Merge branch 'main' into spec_decode_integrate_accpetance_sampler
sroy745 Jun 29, 2024
f186844
Update test_multistep_correctness.py
sroy745 Jun 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 64 additions & 32 deletions tests/samplers/test_typical_acceptance_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,19 @@ def get_draft_token_ids(batch_size: int, k: int, vocab_size: int,
return draft_token_ids


def get_acceptance_sampler(
posterior_threshold: float = 0.03,
posterior_alpha: float = 0.9,
disable_bonus_tokens: bool = False,
strict_mode: bool = False,
) -> TypicalAcceptanceSampler:
"""
Initializes and returns a TypicalAcceptanceSampler.
"""
return TypicalAcceptanceSampler(posterior_threshold, posterior_alpha,
disable_bonus_tokens, strict_mode)


@pytest.mark.parametrize("k", list(range(1, 6)))
@pytest.mark.parametrize("vocab_size", [30_000, 50_000])
@pytest.mark.parametrize("batch_size", list(range(1, 32)))
Expand All @@ -64,7 +77,7 @@ def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int,
different combinations of k, vocab_size, batch_size and num devices.
"""
torch.set_default_device(device)
typical_acceptance_sampler = TypicalAcceptanceSampler()
typical_acceptance_sampler = get_acceptance_sampler()
typical_acceptance_sampler.init_gpu_tensors(rank=0)
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
bonus_token_ids = torch.randint(low=0,
Expand All @@ -76,7 +89,10 @@ def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int,
size=(batch_size, k),
dtype=torch.int64)
# Verify that sampling succeeds for all cases.
typical_acceptance_sampler(target_probs, bonus_token_ids, draft_token_ids)
typical_acceptance_sampler(target_probs,
bonus_token_ids,
draft_probs=None,
draft_token_ids=draft_token_ids)


@pytest.mark.parametrize("above_or_below_vocab_range", ["above", "below"])
Expand All @@ -94,7 +110,7 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
batch_size = 5
vocab_size = 30_000
torch.set_default_device(device)
typical_acceptance_sampler = TypicalAcceptanceSampler(strict_mode=True)
typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True)
typical_acceptance_sampler.init_gpu_tensors(rank=0)
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
bonus_token_ids = torch.randint(low=0,
Expand Down Expand Up @@ -125,8 +141,10 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
oob_token_ids[0][0] = rogue_token_id

with pytest.raises(AssertionError):
typical_acceptance_sampler(target_probs, bonus_token_ids,
draft_token_ids)
typical_acceptance_sampler(target_probs,
bonus_token_ids,
draft_probs=None,
draft_token_ids=draft_token_ids)


@pytest.mark.parametrize("seed", list(range(10)))
Expand All @@ -151,7 +169,7 @@ def test_uniform_target_distribution_accepts_all_tokens(
batch_size = 5
vocab_size = 30_000
torch.set_default_device(device)
typical_acceptance_sampler = TypicalAcceptanceSampler(
typical_acceptance_sampler = get_acceptance_sampler(
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
typical_acceptance_sampler.init_gpu_tensors(rank=0)
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
Expand All @@ -163,9 +181,11 @@ def test_uniform_target_distribution_accepts_all_tokens(
high=vocab_size,
size=(batch_size, 1),
dtype=torch.int64)
output_token_ids = typical_acceptance_sampler(target_probs,
bonus_token_ids,
draft_token_ids)
output_token_ids = typical_acceptance_sampler(
target_probs,
bonus_token_ids,
draft_probs=None,
draft_token_ids=draft_token_ids)
# We are using a uniform target probability distribution.
# For a uniform distribution the entropy is very high and it
# should lead to all draft tokens being accepted. Verify that.
Expand Down Expand Up @@ -203,7 +223,7 @@ def test_temperature_zero_target_distribution(seed: int,
vocab_size = 30_000
torch.set_default_device(device)

typical_acceptance_sampler = TypicalAcceptanceSampler(
typical_acceptance_sampler = get_acceptance_sampler(
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
typical_acceptance_sampler.init_gpu_tensors(rank=0)
# Simulate temperature 0 probability distribution for target probabilities
Expand All @@ -224,9 +244,11 @@ def test_temperature_zero_target_distribution(seed: int,
# 1.0 tokens in the target distribution we will reject all of them and
# fallback to the greedy sampling for selecting 1 token for each sequence.
# Verify the same.
output_token_ids = typical_acceptance_sampler(target_probs,
bonus_token_ids,
draft_token_ids)
output_token_ids = typical_acceptance_sampler(
target_probs,
bonus_token_ids,
draft_probs=None,
draft_token_ids=draft_token_ids)
assert output_token_ids.shape[0] == batch_size
assert output_token_ids.shape[1] == (k + 1)
assert torch.all(output_token_ids[:, -1] == -1)
Expand Down Expand Up @@ -261,7 +283,7 @@ def test_mixed_target_distribution(seed: int, disable_bonus_tokens: bool,
batch_size = 4
vocab_size = 30_000
torch.set_default_device(device)
typical_acceptance_sampler = TypicalAcceptanceSampler(
typical_acceptance_sampler = get_acceptance_sampler(
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
typical_acceptance_sampler.init_gpu_tensors(rank=0)
# For sequences 0 and 2 set the distribution to a temperature
Expand All @@ -277,9 +299,11 @@ def test_mixed_target_distribution(seed: int, disable_bonus_tokens: bool,
high=vocab_size,
size=(batch_size, 1),
dtype=torch.int64)
output_token_ids = typical_acceptance_sampler(target_probs,
bonus_token_ids,
draft_token_ids)
output_token_ids = typical_acceptance_sampler(
target_probs,
bonus_token_ids,
draft_probs=None,
draft_token_ids=draft_token_ids)
# verify the shape of output_token_ids
assert output_token_ids.shape[0] == batch_size
assert output_token_ids.shape[1] == (k + 1)
Expand Down Expand Up @@ -326,7 +350,7 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool,
batch_size = 1
vocab_size = 30_000
torch.set_default_device(device)
typical_acceptance_sampler = TypicalAcceptanceSampler(
typical_acceptance_sampler = get_acceptance_sampler(
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
typical_acceptance_sampler.init_gpu_tensors(rank=0)
# Create a temperature zero target probability distribution and ensure
Expand All @@ -339,9 +363,11 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool,
high=vocab_size,
size=(batch_size, 1),
dtype=torch.int64)
output_token_ids = typical_acceptance_sampler(target_probs,
bonus_token_ids,
draft_token_ids)
output_token_ids = typical_acceptance_sampler(
target_probs,
bonus_token_ids,
draft_probs=None,
draft_token_ids=draft_token_ids)
assert output_token_ids.shape[0] == batch_size
assert output_token_ids.shape[1] == (k + 1)
assert torch.all(output_token_ids[:, 0:-1] == draft_token_ids)
Expand All @@ -357,9 +383,11 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool,
batch_size, k, vocab_size, zero_temperature_token_ids)
draft_token_ids = torch.cat(
(draft_token_ids[:, :2], draft_token_ids_to_replace[:, -3:]), dim=1)
output_token_ids = typical_acceptance_sampler(target_probs,
bonus_token_ids,
draft_token_ids)
output_token_ids = typical_acceptance_sampler(
target_probs,
bonus_token_ids,
draft_probs=None,
draft_token_ids=draft_token_ids)
assert output_token_ids.shape[0] == batch_size
assert output_token_ids.shape[1] == (k + 1)
assert torch.all(output_token_ids[:, :2] == draft_token_ids[:, :2])
Expand All @@ -384,7 +412,7 @@ def test_accept_tokens_set_non_default_posteriors(seed: int,
batch_size = 1
vocab_size = 30_000
torch.set_default_device(device)
typical_acceptance_sampler = TypicalAcceptanceSampler(
typical_acceptance_sampler = get_acceptance_sampler(
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
typical_acceptance_sampler.init_gpu_tensors(rank=0)
# Simulate temperature 0 probability distribution for target
Expand All @@ -402,9 +430,11 @@ def test_accept_tokens_set_non_default_posteriors(seed: int,
high=vocab_size,
size=(batch_size, 1),
dtype=torch.int64)
output_token_ids = typical_acceptance_sampler(target_probs,
bonus_token_ids,
draft_token_ids)
output_token_ids = typical_acceptance_sampler(
target_probs,
bonus_token_ids,
draft_probs=None,
draft_token_ids=draft_token_ids)
assert output_token_ids.shape[0] == batch_size
assert output_token_ids.shape[1] == (k + 1)
assert torch.all(output_token_ids[:, 1:-1] == -1)
Expand All @@ -418,9 +448,11 @@ def test_accept_tokens_set_non_default_posteriors(seed: int,
posterior_threshold=0.0,
posterior_alpha=0.0)
typical_acceptance_sampler.init_gpu_tensors(rank=0)
output_token_ids = typical_acceptance_sampler(target_probs,
bonus_token_ids,
draft_token_ids)
output_token_ids = typical_acceptance_sampler(
target_probs,
bonus_token_ids,
draft_probs=None,
draft_token_ids=draft_token_ids)
assert output_token_ids.shape[0] == batch_size
assert output_token_ids.shape[1] == (k + 1)
assert torch.all(output_token_ids[:, 0:-1] == draft_token_ids)
Expand Down Expand Up @@ -451,7 +483,7 @@ def test_replacement_token_ids(seed: int, disable_bonus_tokens: bool,
batch_size = 5
vocab_size = 30_000
torch.set_default_device(device)
typical_acceptance_sampler = TypicalAcceptanceSampler(
typical_acceptance_sampler = get_acceptance_sampler(
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
typical_acceptance_sampler.init_gpu_tensors(rank=0)
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
Expand Down
54 changes: 53 additions & 1 deletion tests/spec_decode/e2e/test_multistep_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,15 @@
numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy
equality. This gives us good coverage of temp=0.

At temp=0, the TypicalAcceptanceSampler ensures that only the tokens with the
highest probability in the target distribution are accepted. Therefore, we can
expect greedy equality for the TypicalAcceptanceSampler at temp=0.

For temp>0, we rely on unit tests on the rejection sampler to verify that the
output distribution is the same with spec decode vs. no spec decode (this would
be prohibitively expensive to run with a real model).
be prohibitively expensive to run with a real model). Similarly, for the
TypicalAcceptance sampler also, we rely on unit tests to validate temp>0
test cases.

NOTE: Speculative decoding's distribution equality requires that the measured
distributions of the target model and proposal model be deterministic given the
Expand Down Expand Up @@ -611,3 +617,49 @@ def test_many_k(baseline_llm_generator, test_llm_generator, batch_size: int,
batch_size,
max_output_len=output_len,
force_output_len=True)


@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model": "JackFram/llama-160m",

# Skip cuda graph recording for fast test.
"enforce_eager": True,

# Required for spec decode.
"use_v2_block_manager": True
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize(
"test_llm_kwargs",
[
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": k,
"spec_decoding_acceptance_method": "typical_acceptance_sampler"
}
# Try a range of common k.
for k in [1, 2, 3]
])
@pytest.mark.parametrize("batch_size", [1, 32])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
def test_typical_acceptance_sampling(baseline_llm_generator,
test_llm_generator, batch_size: int,
output_len: int):
"""Verify that speculative decoding produces exact equality to without spec
decode with TypicalAcceptanceSampler as the draft token acceptance
sampling method.
"""
run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True)
12 changes: 7 additions & 5 deletions tests/spec_decode/test_dynamic_spec_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,33 +3,35 @@
import pytest
import torch

from vllm.model_executor.layers.rejection_sampler import RejectionSampler
from vllm.sequence import ExecuteModelRequest
from vllm.spec_decode.metrics import AsyncMetricsCollector
from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker
from vllm.spec_decode.top1_proposer import Top1Proposer

from .test_utils import mock_spec_decode_sampler
from .utils import create_batch, mock_worker


@pytest.mark.parametrize('queue_size', [4])
@pytest.mark.parametrize('batch_size', [1])
@pytest.mark.parametrize('k', [1])
@pytest.mark.parametrize("acceptance_sampler_method",
["rejection_sampler", "typical_acceptance_sampler"])
@torch.inference_mode()
def test_disable_spec_tokens(queue_size: int, batch_size: int, k: int):
def test_disable_spec_tokens(queue_size: int, batch_size: int, k: int,
acceptance_sampler_method: str):
"""Verify that speculative tokens are disabled when the batch size
exceeds the threshold.
"""
disable_by_batch_size = 3

draft_worker = mock_worker(cls=MultiStepWorker)
target_worker = mock_worker()
rejection_sampler = MagicMock(spec=RejectionSampler)
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
worker = SpecDecodeWorker(proposer_worker=draft_worker,
scorer_worker=target_worker,
rejection_sampler=rejection_sampler,
spec_decode_sampler=mock_spec_decode_sampler(
acceptance_sampler_method),
metrics_collector=metrics_collector,
disable_by_batch_size=disable_by_batch_size)

Expand Down
Loading
Loading