Skip to content

Commit

Permalink
[Speculative Decoding] MLPSpeculator Tensor Parallel support (1/2) (#…
Browse files Browse the repository at this point in the history
…6050)

Co-authored-by: Sirej Dua <sirej.dua@databricks.com>
Co-authored-by: Sirej Dua <Sirej Dua>
  • Loading branch information
sirejdua and sirejdua-db authored Jul 2, 2024
1 parent 31354e5 commit 15aba08
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 25 deletions.
36 changes: 24 additions & 12 deletions tests/spec_decode/e2e/test_integration_dist_tp2.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,6 @@ def test_target_model_tp_gt_1(baseline_llm_generator, test_llm_generator,
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Use a small model for a fast test.
# Note this is repeated in the test body; to initialize a tokenizer.
"model": "JackFram/llama-68m",
# Skip cuda graph recording for fast test.
"enforce_eager": True,
Expand All @@ -88,15 +84,31 @@ def test_target_model_tp_gt_1(baseline_llm_generator, test_llm_generator,
# second run of the test to fail with internal NCCL error.
"use_async": True,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"speculative_draft_tensor_parallel_size": 1,
},
])
@pytest.mark.parametrize(
"per_test_common_llm_kwargs, test_llm_kwargs",
[
(
{
# Use a small model for a fast test.
# Note this is repeated in the test body; to initialize a
# tokenizer.
"model": "JackFram/llama-68m",
},
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"speculative_draft_tensor_parallel_size": 1,
}),
({
"model": "ibm-granite/granite-3b-code-instruct",
}, {
"speculative_model":
"ibm-granite/granite-3b-code-instruct-accelerator",
"num_speculative_tokens": 5,
"speculative_draft_tensor_parallel_size": 1,
})
])
@pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize("seed", [1])
def test_draft_model_tp_lt_target_model_tp2(test_llm_generator,
Expand Down
6 changes: 0 additions & 6 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -957,12 +957,6 @@ def maybe_create_spec_config(
)

draft_hf_config = draft_model_config.hf_config
if (draft_hf_config.model_type == "mlp_speculator"
and target_parallel_config.world_size != 1):
# MLPSpeculator TP support will be added very soon
raise ValueError(
"Speculative decoding with mlp_speculator models does not "
"yet support distributed inferencing (TP > 1).")

if (num_speculative_tokens is not None
and hasattr(draft_hf_config, "num_lookahead_tokens")):
Expand Down
18 changes: 11 additions & 7 deletions vllm/spec_decode/spec_decode_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,24 +113,28 @@ def create_worker(
draft_worker_kwargs.pop("ngram_prompt_lookup_min"))

disable_bonus_tokens = True

if ngram_prompt_lookup_max > 0:
disable_bonus_tokens = False
proposer_worker = NGramWorker(**draft_worker_kwargs)
proposer_worker.set_ngram_window_size(ngram_prompt_lookup_min,
ngram_prompt_lookup_max)
elif draft_worker_kwargs[
"model_config"].hf_config.model_type == "mlp_speculator":
proposer_worker = MLPSpeculatorWorker(**draft_worker_kwargs)
disable_bonus_tokens = False
else:
draft_parallel_config: ParallelConfig = draft_worker_kwargs[
'parallel_config']
draft_tp = draft_parallel_config.tensor_parallel_size
target_tp = scorer_worker.parallel_config.tensor_parallel_size

if draft_tp == 1:
draft_worker_kwargs["model_runner_cls"] = TP1DraftModelRunner
proposer_worker = MultiStepWorker(**draft_worker_kwargs)
if draft_worker_kwargs[
"model_config"].hf_config.model_type == "mlp_speculator":
disable_bonus_tokens = False
proposer_worker = MLPSpeculatorWorker(**draft_worker_kwargs)
else:
if draft_tp == 1:
draft_worker_kwargs[
"model_runner_cls"] = TP1DraftModelRunner
proposer_worker = MultiStepWorker(**draft_worker_kwargs)

proposer_worker = SmallerTpProposerWorker.maybe_wrap_worker(
proposer_worker, draft_tp, target_tp)

Expand Down

0 comments on commit 15aba08

Please sign in to comment.