Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model][Speculative Decoding] Add EAGLE-style MTP module reference code for DeepSeek-R1 #12915

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

benchislett
Copy link

@benchislett benchislett commented Feb 7, 2025

This is @CentML's implementation of DeepSeek MTP modules that enable speculative decoding for DeepSeek-R1.

The changes in this branch enable --speculative-model DeepSeekV3MTP to register itself as an alternate implementation of DeepSeek-R1, that loads only the MTP weights and have modified model code to invoke only the MTP layer. This model code resides in deepseek_mtp.py.

An additional change moves the RMSNorm application of the final hidden states into the compute_logits function, such that the previous_hidden_states which are taken from the output of model(...) are un-normalized. Our experimental results show a small increase in predictive accuracy by making this change. This makes sense intuitively because the hidden states are treated as they would be for an additional layer of the transformer model and the output norm is treated separately as a part of the output head.

This code also enables the EAGLE code path for reusing previous_hidden_states across TP workers in the base model runner code, enabling (single-step) multi-GPU draft model execution.

While it is possible that the hidden states can be reused from each worker and do not need to be broadcasted to the TP workers between iterations, we choose to follow the EAGLE syntax and respect the abstraction boundary of draft worker / target worker for speculative decoding.

Notably, this implementation does not mask the zero-th position of the input embeddings into the MTP module, though the existing EAGLE models do. This is because of an unknown issue with CUDA graphs and MLA attention that causes accuracy issues when this is performed. Our testing shows much improved acceptance rates by omitting this problematic masking.

Note that this is not the first vLLM PR to implement MTP module support for DeepSeekV3 models, and serves primarily as a reference implementation for validation purposes. See vllm#12755

Copy link

github-actions bot commented Feb 7, 2025

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@benchislett benchislett changed the title Add EAGLE-style MTP module reference code for DeepSeek-R1 and example usage script [Model][Speculative Decoding] Add EAGLE-style MTP module reference code for DeepSeek-R1 Feb 7, 2025
@Neo9061
Copy link

Neo9061 commented Feb 8, 2025

Hi @benchislett, does that mean your implementation will be like EAGLE head and predict multiple tokens with k > 1 by reusing MTP 1? Trying to understand your statement This code also enables the EAGLE code path for reusing previous_hidden_states across TP workers in the base model runner code, enabling (single-step) multi-GPU draft model execution.

@benchislett
Copy link
Author

Hi @Neo9061, that is correct. In the existing EAGLE implementation (limited to single-GPU TP=1), the hidden states from the output of the draft model are reused as inputs for multi-token drafting. I extended this functionality to the model_runner path (supports multi-gpu TP=N) to unlock multi-token prediction from a single MTP module. This is not entirely future-proof, as the future release of additional MTP module weights could not trivially integrate with this strategy, but as of right now there is only k=1 module available and the current style of drafting multiple tokens allows for much more effective speculative decoding.

@LiuXiaoxuanPKU
Copy link
Collaborator

LiuXiaoxuanPKU commented Feb 10, 2025

Hi @benchislett, have you tested this PR, how does the speed look like?

@benchislett
Copy link
Author

Hi @LiuXiaoxuanPKU , the performance for this implementation in practice is quite good. Approximately a 2x speedup for single-request inference of DeepSeek-R1 on 8xH200, and a significant improvement across nearly all batch sizes.

The acceptance rate for k=2 using this implementation is about 73%, with <2ms for drafting each token and ~30ms for scoring (single-request). This gives a theoretical speedup improvement of 1.997x, which we do see in practice.

@DragonFive
Copy link

@benchislett I want to know whether it is possible to run the R1 model with pp=2 and tp=8, while running this draft model with tp=8. Because I have no 8xH200 node, so I use 2 nodes with 8*H800.
I got an error says "NotImplementedError: Pipeline parallelism is not supported for this model. Supported models implement the SupportsPP interface".

@benchislett
Copy link
Author

@benchislett I want to know whether it is possible to run the R1 model with pp=2 and tp=8, while running this draft model with tp=8. Because I have no 8xH200 node, so I use 2 nodes with 8*H800.
I got an error says "NotImplementedError: Pipeline parallelism is not supported for this model. Supported models implement the SupportsPP interface".

It is my understanding that PP and speculative decoding are incompatible. So for this pr I didn't worry about removing the supprtspp flag. If this has changed, I can make sure PP is functional on this branch.

Copy link

mergify bot commented Feb 12, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @benchislett.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Feb 12, 2025
# (~80% for token 1, ~60% for token 2 due to accuracy decay)
python3 \
-m vllm.entrypoints.openai.api_server \
--disable-log-requests \
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wonder which dataset is used in this testing?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sample requests from ShareGPT are used.

@Pokemons386
Copy link

Pokemons386 commented Feb 18, 2025

and a significant improvement across nearly all batch sizes

What is the specific range of all batch sizes? Is it speed down rapidly with batchsize grows?

@Neo9061
Copy link

Neo9061 commented Feb 18, 2025

Hi @benchislett I am using your code but hitting an error loading the MTP head. Error is shown as below. Do you have any insights where the problem might be? I am using two nodes 8 H100 with Ray cluster as backend. The error seems to be relating to the implementation.

 404 ^[[36m(RayWorkerWrapper pid=22477, ip=)^[[0m ERROR 02-18 15:16:57 worker_base.py:575] Error executing method 'init_device'. This might cause deadlock in distributed execution.
 405 ^[[36m(RayWorkerWrapper pid=22477, ip=)^[[0m ERROR 02-18 15:16:57 worker_base.py:575] Traceback (most recent call last):
 406 ^[[36m(RayWorkerWrapper pid=22477, ip=)^[[0m ERROR 02-18 15:16:57 worker_base.py:575]   File "/usr/local/lib/python3.12/dist-packages/vllm/worker/worker_base.py", line 567, in      execute_method
 407 ^[[36m(RayWorkerWrapper pid=22477, ip=)^[[0m ERROR 02-18 15:16:57 worker_base.py:575]     return run_method(target, method, args, kwargs)
 408 ^[[36m(RayWorkerWrapper pid=22477, ip=)^[[0m ERROR 02-18 15:16:57 worker_base.py:575]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 409 ^[[36m(RayWorkerWrapper pid=22477, ip=)^[[0m ERROR 02-18 15:16:57 worker_base.py:575]   File "/usr/local/lib/python3.12/dist-packages/vllm/utils.py", line 2208, in run_method
 410 ^[[36m(RayWorkerWrapper pid=22477, ip=)^[[0m ERROR 02-18 15:16:57 worker_base.py:575]     return func(*args, **kwargs)
 411 ^[[36m(RayWorkerWrapper pid=22477, ip=)^[[0m ERROR 02-18 15:16:57 worker_base.py:575]            ^^^^^^^^^^^^^^^^^^^^^
 412 ^[[36m(RayWorkerWrapper pid=22477, ip=)^[[0m ERROR 02-18 15:16:57 worker_base.py:575]   File "/usr/local/lib/python3.12/dist-packages/vllm/spec_decode/spec_decode_worker.py",      line 326, in init_device
 413 ^[[36m(RayWorkerWrapper pid=22477, ip=)^[[0m ERROR 02-18 15:16:57 worker_base.py:575]     self.proposer_worker.load_model()
 414 ^[[36m(RayWorkerWrapper pid=22477, ip=)^[[0m ERROR 02-18 15:16:57 worker_base.py:575]   File "/usr/local/lib/python3.12/dist-packages/vllm/worker/worker.py", line 182, in load     _model
 415 ^[[36m(RayWorkerWrapper pid=22477, ip=)^[[0m ERROR 02-18 15:16:57 worker_base.py:575]     self.model_runner.load_model()
 416 ^[[36m(RayWorkerWrapper pid=22477, ip=)^[[0m ERROR 02-18 15:16:57 worker_base.py:575]   File "/usr/local/lib/python3.12/dist-packages/vllm/worker/model_runner.py", line 1113,      in load_model
 417 ^[[36m(RayWorkerWrapper pid=22477, ip=)^[[0m ERROR 02-18 15:16:57 worker_base.py:575]     self.model = get_model(vllm_config=self.vllm_config)
 418 ^[[36m(RayWorkerWrapper pid=22477, ip=)^[[0m ERROR 02-18 15:16:57 worker_base.py:575]                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 419 ^[[36m(RayWorkerWrapper pid=22477, ip=)^[[0m ERROR 02-18 15:16:57 worker_base.py:575]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/model_loader/__init__     .py", line 12, in get_model
 420 ^[[36m(RayWorkerWrapper pid=22477, ip=)^[[0m ERROR 02-18 15:16:57 worker_base.py:575]     return loader.load_model(vllm_config=vllm_config)
 421 ^[[36m(RayWorkerWrapper pid=22477, ip=)^[[0m ERROR 02-18 15:16:57 worker_base.py:575]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 422 ^[[36m(RayWorkerWrapper pid=22477, ip=)^[[0m ERROR 02-18 15:16:57 worker_base.py:575]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/model_loader/loader.p     y", line 382, in load_model
 423 ^[[36m(RayWorkerWrapper pid=22477, ip=)^[[0m ERROR 02-18 15:16:57 worker_base.py:575]     loaded_weights = model.load_weights(
 424 ^[[36m(RayWorkerWrapper pid=22477, ip=)^[[0m ERROR 02-18 15:16:57 worker_base.py:575]                      ^^^^^^^^^^^^^^^^^^^
 425 ^[[36m(RayWorkerWrapper pid=22477, ip=)^[[0m ERROR 02-18 15:16:57 worker_base.py:575]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/deepseek_mtp.p     y", line 182, in load_weights
 426 ^[[36m(RayWorkerWrapper pid=22477, ip=)^[[0m ERROR 02-18 15:16:57 worker_base.py:575]     param = params_dict[name]
 427 ^[[36m(RayWorkerWrapper pid=22477, ip=)^[[0m ERROR 02-18 15:16:57 worker_base.py:575]             ~~~~~~~~~~~^^^^^^
 428 ^[[36m(RayWorkerWrapper pid=22477, ip=)^[[0m ERROR 02-18 15:16:57 worker_base.py:575] KeyError: 'transformer.mlp.gate.e_score_correction_bias'
 429 ^[[36m(RayWorkerWrapper pid=115975)^[[0m MTP module init is completed!!!!!!^[[32m [repeated 6x across cluster]^[[0m
 430 ^[[36m(RayWorkerWrapper pid=115975)^[[0m INFO 02-18 15:13:54 model_runner.py:1116] Loading model weights took 42.5946 GB^[[32m [repeated 6x across cluster]^[[0m
 431 ^[[36m(RayWorkerWrapper pid=115975)^[[0m INFO 02-18 15:13:54 model_runner.py:1111] Starting to load model DeepSeekV3MTP...^[[32m [repeated 6x across cluster]^[[0m
 432 ^[[36m(RayWorkerWrapper pid=115975)^[[0m INFO 02-18 15:13:54 weight_utils.py:251] Using model weights format ['*.safetensors']^[[32m [repeated 7x across cluster]^[[0m
 433 ERROR 02-18 15:17:07 worker_base.py:574] Error executing method 'init_device'. This might cause deadlock in distributed execution.^M

This is my invoking command.

  1 python3 -m vllm.entrypoints.openai.api_server \
  2     --host 0.0.0.0 \
  3     --port 8000 \
  4     --model /root/models/DeepSeekV3/DeepSeek-R1 \
  5     --tensor-parallel-size 16 \
  6     --seed 42 \
  7     --swap-space 0 \
  8     --block-size 32 \
  9     --speculative-model DeepSeekV3MTP  \
 10     --trust-remote-code \
 11     --num-speculative-tokens 5 \
 12     --gpu-memory-utilization 0.8 2>&1 | tee log_ONLINE.log

@benchislett
Copy link
Author

Hi @benchislett I am using your code but hitting an error loading the MTP head. Error is shown as below. Do you have any insights where the problem might be? I am using two nodes 8 H100 with Ray cluster as backend. The error seems to be relating to the implementation.

 404 ^[[36m(RayWorkerWrapper pid=22477, ip=)^[[0m ERROR 02-18 15:16:57 worker_base.py:575] Error executing method 'init_device'. This might cause deadlock in distributed execution.
 405 ^[[36m(RayWorkerWrapper pid=22477, ip=)^[[0m ERROR 02-18 15:16:57 worker_base.py:575] Traceback (most recent call last):
 406 ^[[36m(RayWorkerWrapper pid=22477, ip=)^[[0m ERROR 02-18 15:16:57 worker_base.py:575]   File "/usr/local/lib/python3.12/dist-packages/vllm/worker/worker_base.py", line 567, in      execute_method
 407 ^[[36m(RayWorkerWrapper pid=22477, ip=)^[[0m ERROR 02-18 15:16:57 worker_base.py:575]     return run_method(target, method, args, kwargs)
 408 ^[[36m(RayWorkerWrapper pid=22477, ip=)^[[0m ERROR 02-18 15:16:57 worker_base.py:575]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 409 ^[[36m(RayWorkerWrapper pid=22477, ip=)^[[0m ERROR 02-18 15:16:57 worker_base.py:575]   File "/usr/local/lib/python3.12/dist-packages/vllm/utils.py", line 2208, in run_method
 410 ^[[36m(RayWorkerWrapper pid=22477, ip=)^[[0m ERROR 02-18 15:16:57 worker_base.py:575]     return func(*args, **kwargs)
 411 ^[[36m(RayWorkerWrapper pid=22477, ip=)^[[0m ERROR 02-18 15:16:57 worker_base.py:575]            ^^^^^^^^^^^^^^^^^^^^^
 412 ^[[36m(RayWorkerWrapper pid=22477, ip=)^[[0m ERROR 02-18 15:16:57 worker_base.py:575]   File "/usr/local/lib/python3.12/dist-packages/vllm/spec_decode/spec_decode_worker.py",      line 326, in init_device
 413 ^[[36m(RayWorkerWrapper pid=22477, ip=)^[[0m ERROR 02-18 15:16:57 worker_base.py:575]     self.proposer_worker.load_model()
 414 ^[[36m(RayWorkerWrapper pid=22477, ip=)^[[0m ERROR 02-18 15:16:57 worker_base.py:575]   File "/usr/local/lib/python3.12/dist-packages/vllm/worker/worker.py", line 182, in load     _model
 415 ^[[36m(RayWorkerWrapper pid=22477, ip=)^[[0m ERROR 02-18 15:16:57 worker_base.py:575]     self.model_runner.load_model()
 416 ^[[36m(RayWorkerWrapper pid=22477, ip=)^[[0m ERROR 02-18 15:16:57 worker_base.py:575]   File "/usr/local/lib/python3.12/dist-packages/vllm/worker/model_runner.py", line 1113,      in load_model
 417 ^[[36m(RayWorkerWrapper pid=22477, ip=)^[[0m ERROR 02-18 15:16:57 worker_base.py:575]     self.model = get_model(vllm_config=self.vllm_config)
 418 ^[[36m(RayWorkerWrapper pid=22477, ip=)^[[0m ERROR 02-18 15:16:57 worker_base.py:575]                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 419 ^[[36m(RayWorkerWrapper pid=22477, ip=)^[[0m ERROR 02-18 15:16:57 worker_base.py:575]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/model_loader/__init__     .py", line 12, in get_model
 420 ^[[36m(RayWorkerWrapper pid=22477, ip=)^[[0m ERROR 02-18 15:16:57 worker_base.py:575]     return loader.load_model(vllm_config=vllm_config)
 421 ^[[36m(RayWorkerWrapper pid=22477, ip=)^[[0m ERROR 02-18 15:16:57 worker_base.py:575]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 422 ^[[36m(RayWorkerWrapper pid=22477, ip=)^[[0m ERROR 02-18 15:16:57 worker_base.py:575]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/model_loader/loader.p     y", line 382, in load_model
 423 ^[[36m(RayWorkerWrapper pid=22477, ip=)^[[0m ERROR 02-18 15:16:57 worker_base.py:575]     loaded_weights = model.load_weights(
 424 ^[[36m(RayWorkerWrapper pid=22477, ip=)^[[0m ERROR 02-18 15:16:57 worker_base.py:575]                      ^^^^^^^^^^^^^^^^^^^
 425 ^[[36m(RayWorkerWrapper pid=22477, ip=)^[[0m ERROR 02-18 15:16:57 worker_base.py:575]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/deepseek_mtp.p     y", line 182, in load_weights
 426 ^[[36m(RayWorkerWrapper pid=22477, ip=)^[[0m ERROR 02-18 15:16:57 worker_base.py:575]     param = params_dict[name]
 427 ^[[36m(RayWorkerWrapper pid=22477, ip=)^[[0m ERROR 02-18 15:16:57 worker_base.py:575]             ~~~~~~~~~~~^^^^^^
 428 ^[[36m(RayWorkerWrapper pid=22477, ip=)^[[0m ERROR 02-18 15:16:57 worker_base.py:575] KeyError: 'transformer.mlp.gate.e_score_correction_bias'
 429 ^[[36m(RayWorkerWrapper pid=115975)^[[0m MTP module init is completed!!!!!!^[[32m [repeated 6x across cluster]^[[0m
 430 ^[[36m(RayWorkerWrapper pid=115975)^[[0m INFO 02-18 15:13:54 model_runner.py:1116] Loading model weights took 42.5946 GB^[[32m [repeated 6x across cluster]^[[0m
 431 ^[[36m(RayWorkerWrapper pid=115975)^[[0m INFO 02-18 15:13:54 model_runner.py:1111] Starting to load model DeepSeekV3MTP...^[[32m [repeated 6x across cluster]^[[0m
 432 ^[[36m(RayWorkerWrapper pid=115975)^[[0m INFO 02-18 15:13:54 weight_utils.py:251] Using model weights format ['*.safetensors']^[[32m [repeated 7x across cluster]^[[0m
 433 ERROR 02-18 15:17:07 worker_base.py:574] Error executing method 'init_device'. This might cause deadlock in distributed execution.^M

This is my invoking command.

  1 python3 -m vllm.entrypoints.openai.api_server \
  2     --host 0.0.0.0 \
  3     --port 8000 \
  4     --model /root/models/DeepSeekV3/DeepSeek-R1 \
  5     --tensor-parallel-size 16 \
  6     --seed 42 \
  7     --swap-space 0 \
  8     --block-size 32 \
  9     --speculative-model DeepSeekV3MTP  \
 10     --trust-remote-code \
 11     --num-speculative-tokens 5 \
 12     --gpu-memory-utilization 0.8 2>&1 | tee log_ONLINE.log

Hi @Neo9061 , this PR serves as a reference implementation, so a number of things are hacked together. One such thing is that there are a number of manual checks for "model == 'deepseek-ai/DeepSeek-R1'". Could you set the model name to this, and use an environment variable like "HF_HUB_CACHE" to redirect the model loading to your desired path?

If that doesn't change things, perhaps inspect the model loading/weight files. Is this a standard download of the deepseek R1 weights?

@Neo9061
Copy link

Neo9061 commented Feb 19, 2025

Hi @Neo9061 , this PR serves as a reference implementation, so a number of things are hacked together. One such thing is that there are a number of manual checks for "model == 'deepseek-ai/DeepSeek-R1'". Could you set the model name to this, and use an environment variable like "HF_HUB_CACHE" to redirect the model loading to your desired path?

If that doesn't change things, perhaps inspect the model loading/weight files. Is this a standard download of the deepseek R1 weights?

Thanks @benchislett for sharing insights. I have done what you suggested - 1/ change model name to strict deepseek-ai/DeepSeek-R1 and change HUB CACHE env. 2/ I examined the model is exactly the same as HF DeepSeek R1. However, I still face the same error.

The key error term transformer.mlp.gate.e_score_correction_bias seems to relate with this line of your code and the corresponding layer does exist in file model.safetensors.index.json but is named as "model.layers.61.mlp.gate.e_score_correction_bias": "model-00160-of-000163.safetensors",

Update: the error happens at this line and expert_params_mapping is using function in here. I wonder if we need extend it to include gate.e_score_correction_bias?

And this line is never effective as there is no layer ended with .bias. I also wonder if this is intended?

@benchislett
Copy link
Author

@Neo9061

I am not sure why this is happening, as I am unable to reproduce this issue. Do you have the same issue with #12755 ? If not, you should be able to use this branch going forward. It was recently merged.

@Neo9061
Copy link

Neo9061 commented Feb 19, 2025

@Neo9061

I am not sure why this is happening, as I am unable to reproduce this issue. Do you have the same issue with #12755 ? If not, you should be able to use this branch going forward. It was recently merged.

Thanks @benchislett. I haven't tried that recent merged PR, as my top priority is to be able to use speculative tokens more than 1. I am also checking whether if it is feasible in #12755

@benchislett
Copy link
Author

@Neo9061 please prioritize testing with the existing merged PR. I will assist them with enabling k>1 similarly to this PR going forward.

If you find that there are no issues with the other branch, please let me know. Otherwise, it may just be an issue with the multi-node configuration

@Neo9061
Copy link

Neo9061 commented Feb 19, 2025

@Neo9061 please prioritize testing with the existing merged PR. I will assist them with enabling k>1 similarly to this PR going forward.

If you find that there are no issues with the other branch, please let me know. Otherwise, it may just be an issue with the multi-node configuration

Thanks @benchislett . Will prioritize testing it. I also leave a comment in their PR asking this question. Feel free to chime in if you have any thoughts - even a hacky solution is appreciated.

@benchislett
Copy link
Author

In essence, this PR is exemplative of one such hacky solution. For a simpler modification, you could try to force the spec_step_idx to be 0 always during inference (see here):

https://github.com/vllm-project/vllm/pull/12755/files#diff-5f6148d0c4c01c76d240579b59a85252cc8edeed6aff5fe59dd26a41e35b893fR131

However, you may encounter the fact that MLA attention is currently (to my knowledge) incompatible with multi-step scheduling, so running k>1 through the TP==1 code path in that PR will likely fail. You will probably also need to modify the code here to forward the hidden states:
https://github.com/vllm-project/vllm/pull/12755/files#diff-3766440dade6e366a77da065708995820749f18b3ecc33febdd7c00f5e928dd5R702

as I have done here: https://github.com/vllm-project/vllm/pull/12915/files#diff-4b4a724a124ddb123bcb688b15678ad862815bb769880f74affbefed82d4354bR99

@Neo9061
Copy link

Neo9061 commented Feb 20, 2025

In essence, this PR is exemplative of one such hacky solution. For a simpler modification, you could try to force the spec_step_idx to be 0 always during inference (see here):

https://github.com/vllm-project/vllm/pull/12755/files#diff-5f6148d0c4c01c76d240579b59a85252cc8edeed6aff5fe59dd26a41e35b893fR131

However, you may encounter the fact that MLA attention is currently (to my knowledge) incompatible with multi-step scheduling, so running k>1 through the TP==1 code path in that PR will likely fail. You will probably also need to modify the code here to forward the hidden states: https://github.com/vllm-project/vllm/pull/12755/files#diff-3766440dade6e366a77da065708995820749f18b3ecc33febdd7c00f5e928dd5R702

as I have done here: https://github.com/vllm-project/vllm/pull/12915/files#diff-4b4a724a124ddb123bcb688b15678ad862815bb769880f74affbefed82d4354bR99

Hey @benchislett I finally made your solution working by distributed inference with speculation length k > 1. Because I have to use docker for distributed ray cluster setup so the vllm installed there is not really the base your PR is comparing with :)

Thank you so much for the patient reply.

@benchislett
Copy link
Author

Hi @Neo9061, please see my latest PR. I hope this might unlock better performance for your use case:
#13626

Feedback is welcome.

@Neo9061
Copy link

Neo9061 commented Feb 20, 2025

Hi @Neo9061, please see my latest PR. I hope this might unlock better performance for your use case: #13626

Feedback is welcome.

Thanks @benchislett ! Will look into your new PR. Just wonder did you see improved performance compared to your current PR?

@benchislett
Copy link
Author

The performance is the same.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
Status: In progress
Development

Successfully merging this pull request may close these issues.

5 participants