Skip to content
This repository was archived by the owner on Oct 11, 2024. It is now read-only.

Upstream sync 2024 05 19 #249

Merged
merged 131 commits into from
Jun 3, 2024
Merged
Changes from 1 commit
Commits
Show all changes
131 commits
Select commit Hold shift + click to select a range
1337ced
Disable cuda version check in vllm-openai image (#4530)
zhaoyang-star May 5, 2024
7d1afa9
[Bugfix] Fix `asyncio.Task` not being subscriptable (#4623)
DarkLight1337 May 6, 2024
76d1c0a
[CI] use ccache actions properly in release workflow (#4629)
simon-mo May 6, 2024
8c3136e
[CI] Add retry for agent lost (#4633)
cadedaniel May 6, 2024
5749888
Update lm-format-enforcer to 0.10.1 (#4631)
noamgat May 6, 2024
c6f73a2
[Kernel] Make static FP8 scaling more robust (#4570)
pcmoritz May 7, 2024
a542de1
[Core][Optimization] change python dict to pytorch tensor (#4607)
youkaichao May 7, 2024
a3ff2ae
[Build/CI] Fixing 'docker run' to re-enable AMD CI tests. (#4642)
Alexei-V-Ivanov-AMD May 7, 2024
e4ab5c6
[Bugfix] Fixed error in slice_lora_b for MergedQKVParallelLinearWithL…
FurtherAI May 7, 2024
fd69572
[Core][Optimization] change copy-on-write from dict[int, list] to lis…
youkaichao May 7, 2024
8673ad0
[Bug fix][Core] fixup ngram not setup correctly (#4551)
leiwen83 May 7, 2024
3fc0fa0
[Core][Distributed] support cpu&device in broadcast tensor dict (#4660)
youkaichao May 8, 2024
43bc7e9
[Core] Optimize sampler get_logprobs (#4594)
rkooo567 May 8, 2024
01ad752
[Kernel] Make static FP8 scaling more robust (#4570)
rkooo567 May 8, 2024
f64e4e4
[Bugfix][Kernel] allow non-power-of-2 for prefix prefill with alibi …
DefTruth May 8, 2024
e06c2d6
[Misc] Add `get_name` method to attention backends (#4685)
WoosukKwon May 8, 2024
01d4ceb
[Core] Faster startup for LoRA enabled models (#4634)
Yard1 May 8, 2024
8afd8f7
[Core][Optimization] change python dict to pytorch tensor for blocks …
youkaichao May 8, 2024
1fe8d9c
[CI/Test] fix swap test for multi gpu (#4689)
youkaichao May 8, 2024
b5967c4
[Misc] Use vllm-flash-attn instead of flash-attn (#4686)
WoosukKwon May 8, 2024
4a85263
[Dynamic Spec Decoding] Auto-disable by the running queue size (#4592)
comaniac May 8, 2024
edd9e90
[Speculative decoding] [Bugfix] Fix overallocation in ngram + spec lo…
cadedaniel May 8, 2024
32314e5
[Bugfix] Fine-tune gptq_marlin configs to be more similar to marlin (…
alexm-redhat May 9, 2024
b0d3937
[Frontend] add tok/s speed metric to llm class when using tqdm (#4400)
MahmoudAshraf97 May 9, 2024
294e480
[Frontend] Move async logic outside of constructor (#4674)
DarkLight1337 May 9, 2024
04a0387
[Misc] Remove unnecessary ModelRunner imports (#4703)
WoosukKwon May 9, 2024
fff9c2c
[Misc] Set block size at initialization & Fix test_model_runner (#4705)
WoosukKwon May 9, 2024
396a546
[ROCm] Add support for Punica kernels on AMD GPUs (#3140)
kliuae May 9, 2024
0c85c21
[Bugfix] Fix CLI arguments in OpenAI server docs (#4709)
DarkLight1337 May 9, 2024
631605d
[Bugfix] Update grafana.json (#4711)
robertgshaw2-redhat May 9, 2024
d824ab8
[Bugfix] Add logs for all model dtype casting (#4717)
mgoin May 9, 2024
9b500f3
[Model] Snowflake arctic model implementation (#4652)
sfc-gh-hazhang May 9, 2024
56c100c
[Kernel] [FP8] Improve FP8 linear layer performance (#4691)
pcmoritz May 9, 2024
0b429b8
[Kernel] Refactor FP8 kv-cache with NVIDIA float8_e4m3 support (#4535)
comaniac May 10, 2024
ca3311a
[Core][Distributed] refactor pynccl (#4591)
youkaichao May 10, 2024
4ea25ee
[Misc] Keep only one implementation of the create_dummy_prompt functi…
AllenDou May 10, 2024
cd151e1
chunked-prefill-doc-syntax (#4603)
simon-mo May 10, 2024
4b7644f
[Core]fix type annotation for `swap_blocks` (#4726)
jikunshang May 10, 2024
9aec672
[Misc] Apply a couple g++ cleanups (#4719)
stevegrubb May 10, 2024
65159a8
[Core] Fix circular reference which leaked llm instance in local dev …
rkooo567 May 10, 2024
2fc4bb4
[Bugfix] Fix CLI arguments in OpenAI server docs (#4729)
AllenDou May 10, 2024
f739bdb
[Speculative decoding] CUDA graph support (#4295)
heeju-kim2 May 10, 2024
20b780a
[CI] Nits for bad initialization of SeqGroup in testing (#4748)
robertgshaw2-redhat May 10, 2024
8a9d255
[Core][Test] fix function name typo in custom allreduce (#4750)
youkaichao May 10, 2024
9132d19
[Model][Misc] Add e5-mistral-7b-instruct and Embedding API (#3734)
CatherineSue May 11, 2024
18355a9
[Model] Add support for IBM Granite Code models (#4636)
yikangshen May 12, 2024
64367a0
[CI/Build] Tweak Marlin Nondeterminism Issues (#4713)
robertgshaw2-redhat May 13, 2024
fa95832
[CORE] Improvement in ranks code (#4718)
SwapnilDreams100 May 13, 2024
b5c4711
[Core][Distributed] refactor custom allreduce to support multiple tp …
youkaichao May 13, 2024
a92b874
[CI/Build] Move `test_utils.py` to `tests/utils.py` (#4425)
DarkLight1337 May 13, 2024
270c0c2
[Scheduler] Warning upon preemption and Swapping (#4647)
rkooo567 May 13, 2024
c944527
[Misc] Enhance attention selector (#4751)
WoosukKwon May 13, 2024
7dd2e73
[Frontend] [Core] perf: Automatically detect vLLM-tensorized model, u…
sangstar May 13, 2024
61e2bde
[Speculative decoding] Improve n-gram efficiency (#4724)
comaniac May 13, 2024
00d6bd6
[Kernel] Use flash-attn for decoding (#3648)
skrider May 13, 2024
81c2c05
[Bugfix] Fix dynamic FP8 quantization for Mixtral (#4793)
pcmoritz May 13, 2024
1d56497
[Doc] Shorten README by removing supported model list (#4796)
zhuohan123 May 13, 2024
2895ae9
[Doc] Add API reference for offline inference (#4710)
DarkLight1337 May 14, 2024
a1f43a0
[Doc] Add meetups to the doc (#4798)
zhuohan123 May 14, 2024
feed62d
[Core][Hash][Automatic Prefix caching] Accelerating the hashing funct…
KuntaiDu May 14, 2024
31c1cd3
[Bugfix][Doc] Fix CI failure in docs (#4804)
DarkLight1337 May 14, 2024
6838a99
[Core] Add MultiprocessingGPUExecutor (#4539)
njhill May 14, 2024
f246252
Add 4th meetup announcement to readme (#4817)
simon-mo May 14, 2024
bd73ad3
Revert "[Kernel] Use flash-attn for decoding (#3648)" (#4820)
rkooo567 May 15, 2024
30e935f
[Core][2/N] Model runner refactoring part 2. Combine prepare prefill …
rkooo567 May 15, 2024
e40b747
[CI/Build] Further decouple HuggingFace implementation from ours duri…
DarkLight1337 May 15, 2024
71c459f
[Bugfix] Properly set distributed_executor_backend in ParallelConfig …
zifeitong May 15, 2024
e6bc337
[Doc] Highlight the fourth meetup in the README (#4842)
zhuohan123 May 15, 2024
1b50825
[Frontend] Re-enable custom roles in Chat Completions API (#4758)
DarkLight1337 May 15, 2024
28f56b3
[Frontend] Support OpenAI batch file format (#4794)
wuisawesome May 15, 2024
e88dd2b
[Core] Implement sharded state loader (#4690)
aurickq May 16, 2024
3360031
[Speculative decoding][Re-take] Enable TP>1 speculative decoding (#4840)
comaniac May 16, 2024
0240ac9
Add marlin unit tests and marlin benchmark script (#4815)
alexm-redhat May 16, 2024
230af21
[Kernel] add bfloat16 support for gptq marlin kernel (#4788)
jinzhen-lin May 16, 2024
3426d29
[docs] Fix typo in examples filename openi -> openai (#4864)
wuisawesome May 16, 2024
de61ba7
[Frontend] Separate OpenAI Batch Runner usage from API Server (#4851)
wuisawesome May 16, 2024
28f605c
[Bugfix] Bypass authorization API token for preflight requests (#4862)
dulacp May 16, 2024
cf4926d
Add GPTQ Marlin 2:4 sparse structured support (#4790)
alexm-redhat May 16, 2024
40ce57a
Add JSON output support for benchmark_latency and benchmark_throughpu…
simon-mo May 16, 2024
273b3fe
[ROCm][AMD][Bugfix] adding a missing triton autotune config (#4845)
hongxiayang May 16, 2024
1589d50
[Core][Distributed] remove graph mode function (#4818)
youkaichao May 16, 2024
3ced8d0
[Misc] remove old comments (#4866)
youkaichao May 16, 2024
1a745a3
[Kernel] Add punica dimension for Qwen1.5-32B LoRA (#4850)
Silencioo May 16, 2024
7f372fb
[Kernel] Add w8a8 CUTLASS kernels (#4749)
tlrmchlsmth May 16, 2024
69ac7b4
[Bugfix] Fix FP8 KV cache support (#4869)
WoosukKwon May 16, 2024
e4b31f6
Support to serve vLLM on Kubernetes with LWS (#4829)
kerthcet May 16, 2024
3bf9ee0
[Frontend] OpenAI API server: Do not add bos token by default when en…
bofenghuang May 17, 2024
f2b3686
[Build/CI] Extending the set of AMD tests with Regression, Basic Corr…
Alexei-V-Ivanov-AMD May 17, 2024
3b9b8e5
[Bugfix] fix rope error when load models with different dtypes (#4835)
jinzhen-lin May 17, 2024
96e8baa
Sync huggingface modifications of qwen Moe model (#4774)
eigen2017 May 17, 2024
7af0041
[Doc] Update Ray Data distributed offline inference example (#4871)
Yard1 May 17, 2024
b1a73b5
[Bugfix] Relax tiktoken to >= 0.6.0 (#4890)
mgoin May 17, 2024
3bbe65e
[ROCm][Hardware][AMD] Adding Navi21 to fallback to naive attention if…
alexeykondrat May 18, 2024
670a8b8
[Lora] Support long context lora (#4787)
rkooo567 May 18, 2024
c79bcb7
[Bugfix][Model] Add base class for vision-language models (#4809)
DarkLight1337 May 19, 2024
7b70de3
./format
May 19, 2024
1689026
added skips to lora long context
May 19, 2024
1e984b1
format
May 19, 2024
9aef71f
added missed files
May 19, 2024
774ba57
updates check_logprobs_close.py
May 19, 2024
ab7274f
fixed tensorizer
May 19, 2024
2c8f45a
skip mosaic in strict correctness test
May 19, 2024
85ec849
format
May 19, 2024
296861b
Merge branch 'main' into upstream-sync-2024-05-19
robertgshaw2-redhat May 22, 2024
688ef6f
skipped sharded state loader
May 22, 2024
4c437ba
Merge branch 'main' into upstream-sync-2024-05-19
robertgshaw2-redhat May 27, 2024
2059e61
skip shared state loader
May 27, 2024
9642aef
updated build test to use 4 nvcc threads by default. We previously, w…
May 28, 2024
2dad479
tweaked to fix benchmark
May 28, 2024
3bdfeb4
updated workflow to run longer
May 28, 2024
3800a1c
Merge branch 'main' into upstream-sync-2024-05-19
robertgshaw2-redhat May 28, 2024
f1199dc
updated skip lists to skip sharded state loader
May 29, 2024
ee7e65a
verified that test multiproc workers is passing locally
May 29, 2024
b73a142
fixed the sampling params issue
May 29, 2024
8225ddd
fixed other sampling_params issue
May 29, 2024
c386e32
Merge branch 'main' into upstream-sync-2024-05-19
May 29, 2024
098e08a
format
May 29, 2024
7d32b8a
confirmed basic correctness test working
May 30, 2024
748d0e1
updated score for marlin 2:4
May 30, 2024
cd648c6
Merge branch 'main' into upstream-sync-2024-05-19
May 30, 2024
9785c41
Disable flaky marlin model
dbarbuzzi May 30, 2024
3507552
Increase benchmark server timeout to 15 minutes
dbarbuzzi May 30, 2024
1d6af5a
Merge branch 'main' into upstream-sync-2024-05-19
robertgshaw2-redhat May 30, 2024
96fbf17
Merge branch 'main' into upstream-sync-2024-05-19
robertgshaw2-redhat May 30, 2024
db69b5c
reduce number of prompts and models in basic server correctness
Jun 1, 2024
0654a43
Merge branch 'nm-vllm-main' into upstream-sync-2024-05-19
Jun 1, 2024
3ba575c
fixed workflows
Jun 1, 2024
43c0adc
removed basic server correctness from release
Jun 2, 2024
50ac573
Update test_compressed.py
robertgshaw2-redhat Jun 2, 2024
1802833
Update test_compressed.py (#277)
robertgshaw2-redhat Jun 2, 2024
2c52fee
nit in setup.py
Jun 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
[Dynamic Spec Decoding] Auto-disable by the running queue size (vllm-…
…project#4592)

Co-authored-by: Cade Daniel <edacih@gmail.com>
2 people authored and Robert Shaw committed May 19, 2024
commit 4a85263089d394aba17c6e5d4c2735f0c7f05c30
13 changes: 9 additions & 4 deletions tests/samplers/test_rejection_sampler.py
Original file line number Diff line number Diff line change
@@ -42,9 +42,11 @@ def mock_causal_accepted_tensor(k: int, last_accepted_indices: torch.Tensor,
@pytest.mark.parametrize(
"which_tokens_accepted",
["all_tokens_accepted", "no_tokens_accepted", "some_tokens_accepted"])
@pytest.mark.parametrize("disable_bonus_tokens", [True, False])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_correct_output_format(which_tokens_accepted: str, seed: int,
def test_correct_output_format(which_tokens_accepted: str,
disable_bonus_tokens: bool, seed: int,
device: str):
"""Verify the output has correct format given predetermined accepted matrix.
"""
@@ -86,7 +88,8 @@ def test_correct_output_format(which_tokens_accepted: str, seed: int,
size=(batch_size, 1),
dtype=torch.int64)

rejection_sampler = RejectionSampler()
rejection_sampler = RejectionSampler(
disable_bonus_tokens=disable_bonus_tokens)
device_rank = int(device[-1])
rejection_sampler.init_gpu_tensors(rank=device_rank)
output_token_ids = rejection_sampler._create_output( # pylint: disable=protected-access
@@ -96,9 +99,11 @@ def test_correct_output_format(which_tokens_accepted: str, seed: int,
bonus_token_ids,
)

# Bonus tokens are currently disabled. Verify they're set to -1.
expected_bonus_token_ids = bonus_token_ids.clone()
# If bonus tokens disabled. Verify they are set to -1.
# See https://github.com/vllm-project/vllm/issues/4212
expected_bonus_token_ids = bonus_token_ids.clone() * 0 - 1
if disable_bonus_tokens:
expected_bonus_token_ids = expected_bonus_token_ids * 0 - 1

if which_tokens_accepted == "all_tokens_accepted":
# Expect all tokens to be equal to draft tokens.
34 changes: 34 additions & 0 deletions tests/spec_decode/e2e/test_multistep_correctness.py
Original file line number Diff line number Diff line change
@@ -536,6 +536,40 @@ def test_skip_speculation(baseline_llm_generator, test_llm_generator,
force_output_len=True)


@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model": "JackFram/llama-160m",

# Skip cuda graph recording for fast test.
"enforce_eager": True,

# Required for spec decode.
"use_v2_block_manager": True
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"speculative_disable_by_batch_size": 2,
},
])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("output_len", [10])
@pytest.mark.parametrize("seed", [1])
def test_disable_speculation(baseline_llm_generator, test_llm_generator,
batch_size: int, output_len: int):
"""Verify greedy equality when all sequences disable speculation.
"""
run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True)


@pytest.mark.parametrize(
"common_llm_kwargs",
[{
2 changes: 1 addition & 1 deletion tests/spec_decode/e2e/test_ngram_correctness.py
Original file line number Diff line number Diff line change
@@ -57,7 +57,7 @@
@pytest.mark.parametrize("output_len", [
256,
])
@pytest.mark.parametrize("batch_size", [1, 64])
@pytest.mark.parametrize("batch_size", [1, 32])
@pytest.mark.parametrize("seed", [1])
def test_ngram_e2e_greedy_correctness(baseline_llm_generator,
test_llm_generator, batch_size: int,
77 changes: 77 additions & 0 deletions tests/spec_decode/test_dynamic_spec_decode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
from unittest.mock import MagicMock

import pytest
import torch

from vllm.model_executor.layers.rejection_sampler import RejectionSampler
from vllm.sequence import ExecuteModelRequest
from vllm.spec_decode.metrics import AsyncMetricsCollector
from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker
from vllm.spec_decode.top1_proposer import Top1Proposer

from .utils import create_batch, mock_worker


@pytest.mark.parametrize('queue_size', [2, 4])
@pytest.mark.parametrize('batch_size', [1, 2, 3, 6])
@pytest.mark.parametrize('k', [1, 2, 5, 7, 10])
@torch.inference_mode()
def test_disable_spec_tokens(queue_size: int, batch_size: int, k: int):
"""Verify that speculative tokens are disabled when the batch size
exceeds the threshold.
"""
disable_by_batch_size = 3

draft_worker = mock_worker(cls=MultiStepWorker)
target_worker = mock_worker()
rejection_sampler = MagicMock(spec=RejectionSampler)
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
worker = SpecDecodeWorker(proposer_worker=draft_worker,
scorer_worker=target_worker,
rejection_sampler=rejection_sampler,
metrics_collector=metrics_collector,
disable_by_batch_size=disable_by_batch_size)

exception_secret = 'artificial stop'
draft_worker.get_spec_proposals.side_effect = ValueError(exception_secret)

seq_group_metadata_list, _, _ = create_batch(batch_size, k)
execute_model_req = ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k,
running_queue_size=queue_size)

with pytest.raises(ValueError, match=exception_secret):
worker.execute_model(execute_model_req=execute_model_req)

# When the batch size is larger than the threshold,
# we expect no speculative tokens (0).
expected_num_spec_tokens = None if queue_size < disable_by_batch_size else 0
assert seq_group_metadata_list[
0].num_speculative_tokens == expected_num_spec_tokens

draft_worker.sampler_output.side_effect = ValueError(exception_secret)

proposer = Top1Proposer(
worker=draft_worker,
device='cpu', # not used
vocab_size=100, # not used
# Must be long enough to avoid being skipped due to length.
max_proposal_len=1024,
)

if queue_size < disable_by_batch_size:
# Should raise exception when executing the mocked draft model.
with pytest.raises(ValueError, match=exception_secret):
proposer.get_proposals(execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k), )
else:
# Should not execute the draft model because spec decode is disabled
# for all requests. Accordingly, the proposal length should be 0.
proposals = proposer.get_proposals(
execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k), )
assert proposals.proposal_lens.tolist() == [0] * batch_size
29 changes: 24 additions & 5 deletions vllm/config.py
Original file line number Diff line number Diff line change
@@ -736,6 +736,7 @@ def maybe_create_spec_config(
speculative_max_model_len: Optional[int],
enable_chunked_prefill: bool,
use_v2_block_manager: bool,
speculative_disable_by_batch_size: Optional[int],
ngram_prompt_lookup_max: Optional[int],
ngram_prompt_lookup_min: Optional[int],
) -> Optional["SpeculativeConfig"]:
@@ -764,6 +765,9 @@ def maybe_create_spec_config(
use_v2_block_manager (bool): Whether vLLM is configured to use the
v2 block manager or not. Used for raising an error since the v2
block manager is required with spec decode.
speculative_disable_by_batch_size (Optional[int]): Disable
speculative decoding for new incoming requests when the number
of enqueue requests is larger than this value, if provided.
ngram_prompt_lookup_max (Optional[int]): Max size of ngram token
window, if provided.
ngram_prompt_lookup_min (Optional[int]): Min size of ngram token
@@ -774,7 +778,7 @@ def maybe_create_spec_config(
the necessary conditions are met, else None.
"""

if (speculative_model is None and num_speculative_tokens is None):
if speculative_model is None and num_speculative_tokens is None:
return None

if speculative_model is not None and num_speculative_tokens is None:
@@ -783,6 +787,12 @@ def maybe_create_spec_config(
"num_speculative_tokens to be provided, but found "
f"{speculative_model=} and {num_speculative_tokens=}.")

if (speculative_disable_by_batch_size is not None
and speculative_disable_by_batch_size < 2):
raise ValueError("Expect the batch size threshold of disabling "
"speculative decoding is > 1, but got "
f"{speculative_disable_by_batch_size=}")

assert (speculative_model is not None
and num_speculative_tokens is not None)

@@ -851,6 +861,7 @@ def maybe_create_spec_config(
draft_model_config,
draft_parallel_config,
num_speculative_tokens,
speculative_disable_by_batch_size,
ngram_prompt_lookup_max,
ngram_prompt_lookup_min,
)
@@ -920,8 +931,9 @@ def __init__(
draft_model_config: ModelConfig,
draft_parallel_config: ParallelConfig,
num_speculative_tokens: int,
ngram_prompt_lookup_max: int,
ngram_prompt_lookup_min: int,
speculative_disable_by_batch_size: Optional[int],
ngram_prompt_lookup_max: Optional[int],
ngram_prompt_lookup_min: Optional[int],
):
"""Create a SpeculativeConfig object.

@@ -930,12 +942,19 @@ def __init__(
draft_parallel_config: ParallelConfig for the draft model.
num_speculative_tokens: The number of tokens to sample from the
draft model before scoring with the target model.
speculative_disable_by_batch_size: Disable speculative
decoding for new incoming requests when the number of
enqueue requests is larger than this value.
ngram_prompt_lookup_max: Max size of ngram token window.
ngram_prompt_lookup_min: Min size of ngram token window.
"""
self.draft_model_config = draft_model_config
self.draft_parallel_config = draft_parallel_config
self.num_speculative_tokens = num_speculative_tokens
self.ngram_prompt_lookup_max = ngram_prompt_lookup_max
self.ngram_prompt_lookup_min = ngram_prompt_lookup_min
self.speculative_disable_by_batch_size = \
speculative_disable_by_batch_size
self.ngram_prompt_lookup_max = ngram_prompt_lookup_max or 0
self.ngram_prompt_lookup_min = ngram_prompt_lookup_min or 0

self._verify_args()

10 changes: 10 additions & 0 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
@@ -87,6 +87,7 @@ class EngineArgs:
speculative_model: Optional[str] = None
num_speculative_tokens: Optional[int] = None
speculative_max_model_len: Optional[int] = None
speculative_disable_by_batch_size: Optional[int] = None
ngram_prompt_lookup_max: Optional[int] = None
ngram_prompt_lookup_min: Optional[int] = None

@@ -482,6 +483,13 @@ def add_cli_args(
'draft model. Sequences over this length will skip '
'speculation.')

parser.add_argument(
'--speculative-disable-by-batch-size',
type=int,
default=EngineArgs.speculative_disable_by_batch_size,
help='Disable speculative decoding for new incoming requests '
'if the number of enqueue requests is larger than this value.')

parser.add_argument(
'--ngram-prompt-lookup-max',
type=int,
@@ -575,6 +583,8 @@ def create_engine_config(self, ) -> EngineConfig:
target_dtype=self.dtype,
speculative_model=self.speculative_model,
num_speculative_tokens=self.num_speculative_tokens,
speculative_disable_by_batch_size=self.
speculative_disable_by_batch_size,
speculative_max_model_len=self.speculative_max_model_len,
enable_chunked_prefill=self.enable_chunked_prefill,
use_v2_block_manager=self.use_v2_block_manager,
2 changes: 2 additions & 0 deletions vllm/executor/gpu_executor.py
Original file line number Diff line number Diff line change
@@ -93,6 +93,8 @@ def _init_spec_worker(self):
spec_decode_worker = SpecDecodeWorker.create_worker(
scorer_worker=target_worker,
draft_worker_kwargs=draft_worker_kwargs,
disable_by_batch_size=self.speculative_config.
speculative_disable_by_batch_size,
)

assert self.parallel_config.world_size == 1, (
11 changes: 9 additions & 2 deletions vllm/model_executor/layers/rejection_sampler.py
Original file line number Diff line number Diff line change
@@ -12,15 +12,21 @@ class RejectionSampler(nn.Module):
https://arxiv.org/pdf/2302.01318.pdf.
"""

def __init__(self, strict_mode: bool = False):
def __init__(self,
disable_bonus_tokens: bool = True,
strict_mode: bool = False):
"""Create a rejection sampler.

Args:
disable_bonus_tokens: Whether or not to disable the bonus token.
Require when bonus tokens will cause corrupt KV cache for
proposal methods that require KV cache.
strict_mode: Whether or not to perform shape/device/dtype checks
during sampling. This catches correctness issues but adds
nontrivial latency.
"""
super().__init__()
self._disable_bonus_tokens = disable_bonus_tokens
self._strict_mode = strict_mode

# NOTE: A "bonus token" is accepted iff all proposal tokens are
@@ -312,7 +318,8 @@ def _create_output(
# proposal methods that require KV cache. We can fix it by "prefilling"
# the bonus token in the proposer. The following issue tracks the fix.
# https://github.com/vllm-project/vllm/issues/4212
output_with_bonus_tokens[:, -1] = -1
if self._disable_bonus_tokens:
output_with_bonus_tokens[:, -1] = -1

# Fill the recovered token ids.
output.mul_(~after_false_mask).add_(
6 changes: 6 additions & 0 deletions vllm/sequence.py
Original file line number Diff line number Diff line change
@@ -612,6 +612,12 @@ def __init__(
self._token_chunk_size = token_chunk_size
self.do_sample = do_sample

# The number of speculative tokens adopted in this request.
# None means specuative decoding is not used.
# Zero means speculative decoding is disabled for some reasons.
# TODO: We should maintain this states out of the sequence group.
self.num_speculative_tokens = None

if self._token_chunk_size is None:
if is_prompt:
self._token_chunk_size = list(seq_data.values())[0].get_len()
Loading