Skip to content

Commit

Permalink
[Spec Decode] Disable Log Prob serialization to CPU for spec decoding…
Browse files Browse the repository at this point in the history
… for both draft and target models. (vllm-project#6485)
  • Loading branch information
sroy745 authored Jul 21, 2024
1 parent d7f4178 commit 14f91fe
Show file tree
Hide file tree
Showing 8 changed files with 332 additions and 63 deletions.
48 changes: 29 additions & 19 deletions tests/spec_decode/e2e/test_logprobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,12 @@
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [{
"speculative_model": "JackFram/llama-160m",
"num_speculative_tokens": 3,
}])
@pytest.mark.parametrize("test_llm_kwargs",
[{
"speculative_model": "JackFram/llama-160m",
"num_speculative_tokens": 3,
"disable_logprobs_during_spec_decoding": False,
}])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize(
"output_len",
Expand Down Expand Up @@ -59,10 +61,12 @@ def test_logprobs_equality(baseline_llm_generator, test_llm_generator,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [{
"speculative_model": "JackFram/llama-160m",
"num_speculative_tokens": 3,
}])
@pytest.mark.parametrize("test_llm_kwargs",
[{
"speculative_model": "JackFram/llama-160m",
"num_speculative_tokens": 3,
"disable_logprobs_during_spec_decoding": False,
}])
@pytest.mark.parametrize("batch_size", [1])
@pytest.mark.parametrize("num_logprobs", [6])
@pytest.mark.parametrize(
Expand Down Expand Up @@ -99,13 +103,16 @@ def test_diff_num_logprobs(baseline_llm_generator, test_llm_generator,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [{
"speculative_model": "JackFram/llama-160m",
"num_speculative_tokens": 3,
}, {
"speculative_model": "JackFram/llama-160m",
"num_speculative_tokens": 6,
}])
@pytest.mark.parametrize("test_llm_kwargs",
[{
"speculative_model": "JackFram/llama-160m",
"num_speculative_tokens": 3,
"disable_logprobs_during_spec_decoding": False,
}, {
"speculative_model": "JackFram/llama-160m",
"num_speculative_tokens": 6,
"disable_logprobs_during_spec_decoding": False,
}])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize(
"output_len",
Expand Down Expand Up @@ -143,6 +150,7 @@ def test_logprobs_different_k(baseline_llm_generator, test_llm_generator,
[{
"speculative_model": "JackFram/llama-160m",
"num_speculative_tokens": 3,
"disable_logprobs_during_spec_decoding": False,
# Artificially limit the draft model max model len; this forces vLLM
# to skip speculation once the sequences grow beyond 32-k tokens.
Expand Down Expand Up @@ -181,10 +189,12 @@ def test_logprobs_when_skip_speculation(baseline_llm_generator,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [{
"speculative_model": "JackFram/llama-160m",
"num_speculative_tokens": 3,
}])
@pytest.mark.parametrize("test_llm_kwargs",
[{
"speculative_model": "JackFram/llama-160m",
"num_speculative_tokens": 3,
"disable_logprobs_during_spec_decoding": False,
}])
@pytest.mark.parametrize("batch_size", [1])
@pytest.mark.parametrize(
"output_len",
Expand Down
1 change: 1 addition & 0 deletions tests/spec_decode/test_dynamic_spec_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def test_disable_spec_tokens(queue_size: int, batch_size: int, k: int,
scorer_worker=target_worker,
spec_decode_sampler=mock_spec_decode_sampler(
acceptance_sampler_method),
disable_logprobs=False,
metrics_collector=metrics_collector,
disable_by_batch_size=disable_by_batch_size)

Expand Down
20 changes: 13 additions & 7 deletions tests/spec_decode/test_spec_decode_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,7 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool,
worker = SpecDecodeWorker(draft_worker,
target_worker,
spec_decode_sampler,
disable_logprobs=False,
metrics_collector=metrics_collector)
worker.init_device()

Expand Down Expand Up @@ -479,7 +480,8 @@ def test_k_equals_zero(k: int, batch_size: int,

worker = SpecDecodeWorker(
draft_worker, target_worker,
mock_spec_decode_sampler(acceptance_sampler_method), metrics_collector)
mock_spec_decode_sampler(acceptance_sampler_method), False,
metrics_collector)

seq_group_metadata_list, _, _ = create_batch(batch_size,
k,
Expand All @@ -490,9 +492,10 @@ def test_k_equals_zero(k: int, batch_size: int,
out = worker.execute_model(execute_model_req=execute_model_req)

assert len(out) == 1, f"expected only one token output when {k=}"
assert out[0].probs is None, "expect gpu tensor references to be None"
assert out[0].sampled_token_probs is None, (
"expect gpu tensor references to be None")
assert out[
0].sampled_tokens is None, "expect gpu tensor references to be None"
0].sampled_token_ids is None, "expect gpu tensor references to be None"

draft_worker.execute_model.assert_called_once_with(execute_model_req)
target_worker.execute_model.assert_called_once_with(execute_model_req)
Expand Down Expand Up @@ -524,7 +527,8 @@ def test_empty_input_batch(k: int, batch_size: int,

worker = SpecDecodeWorker(
draft_worker, target_worker,
mock_spec_decode_sampler(acceptance_sampler_method), metrics_collector)
mock_spec_decode_sampler(acceptance_sampler_method), False,
metrics_collector)

seq_group_metadata_list, _, _ = create_batch(batch_size,
k,
Expand All @@ -535,9 +539,10 @@ def test_empty_input_batch(k: int, batch_size: int,
out = worker.execute_model(execute_model_req=execute_model_req)

assert len(out) == 1, f"expected only one token output when {k=}"
assert out[0].probs is None, "expect gpu tensor references to be None"
assert out[0].sampled_token_probs is None, (
"expect gpu tensor references to be None")
assert out[
0].sampled_tokens is None, "expect gpu tensor references to be None"
0].sampled_token_ids is None, "expect gpu tensor references to be None"

draft_worker.execute_model.assert_called_once_with(execute_model_req)
target_worker.execute_model.assert_called_once_with(execute_model_req)
Expand All @@ -556,7 +561,7 @@ def test_init_device(acceptance_sampler_method: str):
metrics_collector = MagicMock(spec=AsyncMetricsCollector)

worker = SpecDecodeWorker(draft_worker, target_worker, spec_decode_sampler,
metrics_collector)
False, metrics_collector)
worker.init_device()

draft_worker.init_device.assert_called_once()
Expand Down Expand Up @@ -707,6 +712,7 @@ def test_populate_seq_ids_with_bonus_tokens():
worker = SpecDecodeWorker(draft_worker,
target_worker,
mock_spec_decode_sampler("rejection_sampler"),
disable_logprobs=False,
metrics_collector=metrics_collector)
# Initialize _seq_with_bonus_token_in_last_step with a set of sequence IDs.
# This set includes all sequence IDs in the batch as well as an additional
Expand Down
17 changes: 17 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -894,6 +894,7 @@ def maybe_create_spec_config(
draft_token_acceptance_method: str,
typical_acceptance_sampler_posterior_threshold: Optional[float],
typical_acceptance_sampler_posterior_alpha: Optional[float],
disable_logprobs: Optional[bool],
) -> Optional["SpeculativeConfig"]:
"""Create a SpeculativeConfig if possible, else return None.
Expand Down Expand Up @@ -943,6 +944,11 @@ def maybe_create_spec_config(
typical_acceptance_sampler_posterior_alpha (Optional[float]):
A scaling factor for the entropy-based threshold in the
TypicalAcceptanceSampler.
disable_logprobs (Optional[bool]): If set to True, token log
probabilities are not returned during speculative decoding.
If set to False, token log probabilities are returned
according to the log probability settings in SamplingParams.
If not specified, it defaults to True.
Returns:
Optional["SpeculativeConfig"]: An instance of SpeculativeConfig if
Expand Down Expand Up @@ -1055,6 +1061,8 @@ def maybe_create_spec_config(
typical_acceptance_sampler_posterior_threshold = 0.09
if typical_acceptance_sampler_posterior_alpha is None:
typical_acceptance_sampler_posterior_alpha = 0.3
if disable_logprobs is None:
disable_logprobs = True

return SpeculativeConfig(
draft_model_config,
Expand All @@ -1068,6 +1076,7 @@ def maybe_create_spec_config(
typical_acceptance_sampler_posterior_threshold,
typical_acceptance_sampler_posterior_alpha=\
typical_acceptance_sampler_posterior_alpha,
disable_logprobs=disable_logprobs
)

@staticmethod
Expand Down Expand Up @@ -1152,6 +1161,7 @@ def __init__(
draft_token_acceptance_method: str,
typical_acceptance_sampler_posterior_threshold: float,
typical_acceptance_sampler_posterior_alpha: float,
disable_logprobs: bool,
):
"""Create a SpeculativeConfig object.
Expand All @@ -1178,6 +1188,12 @@ def __init__(
typical_acceptance_sampler_posterior_alpha (Optional[float]):
A scaling factor for the entropy-based threshold in the
TypicalAcceptanceSampler.
disable_logprobs: If set to True, token log probabilities will not
be returned even if requested by sampling parameters. This
reduces latency by skipping logprob calculation in proposal
sampling, target sampling, and after accepted tokens are
determined. If set to False, log probabilities will be
returned.
"""
self.draft_model_config = draft_model_config
self.draft_parallel_config = draft_parallel_config
Expand All @@ -1191,6 +1207,7 @@ def __init__(
typical_acceptance_sampler_posterior_threshold
self.typical_acceptance_sampler_posterior_alpha = \
typical_acceptance_sampler_posterior_alpha
self.disable_logprobs = disable_logprobs

self._verify_args()

Expand Down
14 changes: 14 additions & 0 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ class EngineArgs:
typical_acceptance_sampler_posterior_threshold: Optional[float] = None
typical_acceptance_sampler_posterior_alpha: Optional[float] = None
qlora_adapter_name_or_path: Optional[str] = None
disable_logprobs_during_spec_decoding: Optional[bool] = None

otlp_traces_endpoint: Optional[str] = None

Expand Down Expand Up @@ -592,6 +593,18 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
'to sqrt of --typical-acceptance-sampler-posterior-threshold '
'i.e. 0.3')

parser.add_argument(
'--disable-logprobs-during-spec-decoding',
type=bool,
default=EngineArgs.disable_logprobs_during_spec_decoding,
help='If set to True, token log probabilities are not returned '
'during speculative decoding. If set to False, log probabilities '
'are returned according to the settings in SamplingParams. If '
'not specified, it defaults to True. Disabling log probabilities '
'during speculative decoding reduces latency by skipping logprob '
'calculation in proposal sampling, target sampling, and after '
'accepted tokens are determined.')

parser.add_argument('--model-loader-extra-config',
type=nullable_str,
default=EngineArgs.model_loader_extra_config,
Expand Down Expand Up @@ -736,6 +749,7 @@ def create_engine_config(self, ) -> EngineConfig:
typical_acceptance_sampler_posterior_threshold,
typical_acceptance_sampler_posterior_alpha=self.
typical_acceptance_sampler_posterior_alpha,
disable_logprobs=self.disable_logprobs_during_spec_decoding,
)

scheduler_config = SchedulerConfig(
Expand Down
Loading

0 comments on commit 14f91fe

Please sign in to comment.