Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] [Spec decode]: Combine chunked prefill with speculative decoding #9291

Open
wants to merge 36 commits into
base: main
Choose a base branch
from

Conversation

NickLucche
Copy link
Contributor

@NickLucche NickLucche commented Oct 11, 2024

Hey, this PR implements #5016.

The main idea is to make use of the current Speculative Decoder workflow and integrate it with mixed prefill-decode batches.
In particular, we can run the batched prefills and decodes together through the scorer (with the usual prefill|decode layout supported by backend), while the proposer can sync its KV cache on prefills only.

image

Current attention kernel implementation still doesn't make full use of the prefill|decode, but once the MQA integration is finalized we can get an easy speedup by running the batch in a single forward.

Current implementation on main already is (to some extent) prefill aware, so I was able to re-use a good chunk of the logic and the changes aren't (purposely) drastic.
On the other hand, one could prioritize optimizations more and I am open to any suggestion on how to best implement the approach, even at the cost
of re-writing more parts and making the PR more invasive (ie breaking some of the interfaces to avoid duplication).

TODO:

  • benchmark on A/H100
  • expand test coverage with prefill chunking enabled
  • test with new mqa_scorer, current implementation was rebased from v0.6.2
  • fix speculative methods requiring return_hidden_states EDIT: on second thought, I believe this would be better addressed in a separate PR
  • disable_logprobs_during_spec_decoding compatibility

Update:

We add support for chunk prefill and spec decoding with the workflow depicted above, unless the proposer requires final hidden state from the target model (MLPSpeculator/Medusa): this will require supporting chunked hidden states too as input x is now split into blocks x1|x2..|xn, so this definitely needs its own PR if we want to include it.

mqa_scorer is set to supersede BatchExpansion* thanks to the great work by @LiuXiaoxuanPKU, so we add support to that scorer directly in this PR!
Incidentally, this means enabling backend with flash_attn_varlen_func to take in any "mixed prefill-decode batch" in a single kernel call (so no more decoupled prefix-decode calls), which should also boost performance in "vanilla" chunked prefill scheduling policy (no spec).

Many thanks to @sroy745 for benchmarking the BatchExpansionTop1Scorer approach here (MQA to follow)!

Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

@NickLucche NickLucche marked this pull request as draft October 11, 2024 17:04
Copy link
Contributor

@sroy745 sroy745 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the pr. Left some comments. PTAL

# TODO skip this if chunking is not enabled
if len(non_spec_indices):
all_hidden_states = proposal_scores.hidden_states
# TODO fix `return_hidden_states`
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you clarify more on this TODO about return_hidden_states?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here you have a hidden state entry even for non-terminal chunks, while the LogitsProcessor only selects and returns the indices that needs sampling; hence we need to use the indices prior to filtering based on do_sample to get the right hidden states

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw I am planning to have that case covered too

(seq_id, seq_data) for sg in \
execute_model_req.seq_group_metadata_list \
for seq_id, seq_data in sg.seq_data.items()
)
if sg.do_sample # ignore empty token sequences
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this going to change the order of entries in seq_data_entries and seq_output_prompt_logprobs ? In the loop in L542 and L543 can we use the same value of output_index to access the seq_data_entries and seq_output_prompt_logprobs?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

relative order won't change, input is guaranteed to be prefill|decodes, so you have something like

seq1: chunk to sample | chunk no sample | chunk to sample | decode |... | decode
filtered to
seq2; chunk to sample | chunk to sample | decode |... | decode
so seq2 is a subset of seq1.

We used to iterate on filtered sequences seq2 (we had no chunks), now we iterate on seq1 to account for empty outputs and keep the old index as output_index, (only increment on seq1 elements) so the order is maintained

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please resolve the above comment if its not applicable?

vllm/spec_decode/batch_expansion.py Outdated Show resolved Hide resolved
vllm/spec_decode/batch_expansion.py Outdated Show resolved Hide resolved
vllm/config.py Show resolved Hide resolved
vllm/spec_decode/spec_decode_worker.py Show resolved Hide resolved
vllm/worker/model_runner.py Outdated Show resolved Hide resolved
@arashsadrieh
Copy link

arashsadrieh commented Oct 15, 2024

@NickLucche Thanks for the great work and understand that is WIP, just small note while you are working on this piece

We tried this PR with tensor parallelism and we found that it throughs the following exception when we activate tensor parallelism:

python -m vllm.entrypoints.openai.api_server --host 0.0.0.0 --port 8083 --model /8b/  --speculative_model /1b/  --served-model-name SpeculativeLLM --tensor-parallel-size 4  --max-model-len 34336  --max-num-seqs 128  --enable-prefix-caching  --disable-log-requests --use-v2-block-manager --seed 42 --num_speculative_tokens 5  --spec-decoding-acceptance-method typical_acceptance_sampler  --enable_chunked_prefill

Here is the exception:

Exception in worker VllmWorkerProcess while processing method start_worker_execution_loop: 'num_seq_groups', Traceback (most recent call last):
   File "/home/ec2-user/tengfei_workspace/vllm/vllm/executor/multiproc_worker_utils.py", line 224, in _run_worker_process
     output = executor(*args, **kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^^
   File "/opt/conda/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
     return func(*args, **kwargs)
            ^^^^^^^^^^^^^^^^^^^^^
   File "/home/ec2-user/tengfei_workspace/vllm/vllm/spec_decode/spec_decode_worker.py", line 459, in start_worker_execution_loop
     while self._run_non_driver_rank():
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
   File "/home/ec2-user/tengfei_workspace/vllm/vllm/spec_decode/spec_decode_worker.py", line 649, in _run_non_driver_rank
     self.proposer_worker.execute_model()
   File "/home/ec2-user/tengfei_workspace/vllm/vllm/worker/worker_base.py", line 308, in execute_model
     inputs = self.prepare_input(execute_model_req)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
   File "/home/ec2-user/tengfei_workspace/vllm/vllm/worker/worker_base.py", line 298, in prepare_input
     return self._get_worker_input_from_broadcast()
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
   File "/home/ec2-user/tengfei_workspace/vllm/vllm/worker/worker_base.py", line 240, in _get_worker_input_from_broadcast
     worker_input = WorkerInput.from_broadcasted_tensor_dict(broadcast_data)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
   File "/home/ec2-user/tengfei_workspace/vllm/vllm/worker/worker_base.py", line 151, in from_broadcasted_tensor_dict
     num_seq_groups=tensor_dict.pop("num_seq_groups"),
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 KeyError: 'num_seq_groups'

The following command works normally

python -m vllm.entrypoints.openai.api_server --host 0.0.0.0 --port 8083 --model /home/ec2-user/tengfei_workspace/output/8b-aio-20240923-3/merged/ --speculative_model /home/ec2-user/tengfei_workspace/output/1b-aio-20240923-3/merged/ --served-model-name SpeculativeLLM --tensor-parallel-size 1 --max-model-len 34336 --max-num-seqs 128 --enable-prefix-caching --disable-log-requests --use-v2-block-manager --seed 42 --num_speculative_tokens 5  --spec-decoding-acceptance-method typical_acceptance_sampler --enable_chunked_prefill --tensor-parallel-size 1

Thanks again and appreciate your work/ VLLM community

@NickLucche
Copy link
Contributor Author

NickLucche commented Oct 15, 2024

Thanks for testing that, will look right into it!
Might actually be related to prefix_caching, which I haven't taken into account yet (I know there's been some recent work on that too).

@NickLucche
Copy link
Contributor Author

Update on mqa_scorer integration: enable_chunked_prefill with changes in this PR appears to work fine with the flash_attn kernel prior to the optimized one introduced here #9298 (so flash_attn_with_kvcache instead of flash_attn_varlen_func). I will sync with @LiuXiaoxuanPKU on this.

@NickLucche NickLucche marked this pull request as ready for review October 17, 2024 15:45
vllm/config.py Outdated Show resolved Hide resolved
if (decode_meta and prefill_meta
and (pq := prefill_meta.query_start_loc)
and (dq := decode_meta.query_start_loc)):
combined_loc = torch.cat([pq, dq[1:]], axis=0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why 1 here?
Also curious is attention_meta.query_start_loc == combined_loc?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah you're right these methods are useless, will remove them, thanks!

Copy link
Contributor

@sroy745 sroy745 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the pr. Left a few comments. PTAL.

vllm/attention/backends/flash_attn.py Outdated Show resolved Hide resolved
vllm/config.py Outdated
if disable_logprobs is not None and enable_chunked_prefill:
raise ValueError("Chunked prefill and"
"`disable-logprobs-during-spec-decoding` are "
"not yet compatible.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit - not yet compatible -> not compatible. Same comment for L1285

vllm/config.py Outdated
"Speculative decoding and chunked prefill are "
f"currently mutually exclusive ({enable_chunked_prefill=}).")

if disable_logprobs is not None and enable_chunked_prefill:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you need to check if disable_logprobs evaluates to true or false? My understanding is that if we just specify --disable-logprobs-during-spec-decoding this variable will be set to True. In that case do we want to check the value of disable_logprobs in addition to it not being None?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are these not compatible?

prompt_logprobs = [
create_logprobs_output(
token_id=p_token_id,
output_index = 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aren't these changes needed only when disable_logprobs is True and enable_chunked_prefill is true? Currently are we allowing both to be set to true. If not are these changes needed?

(seq_id, seq_data) for sg in \
execute_model_req.seq_group_metadata_list \
for seq_id, seq_data in sg.seq_data.items()
)
if sg.do_sample # ignore empty token sequences
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please resolve the above comment if its not applicable?

create_logprobs_output(
token_id=p_token_id,
output_index = 0
# Make sure the even prefill chunks are still aligned with their own
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit - consider rewording to Make sure the even prefill chunks -> Make sure the non-terminal prefill chunks are still aligned with ...

token_id=p_token_id,
output_index = 0
# Make sure the even prefill chunks are still aligned with their own
# empty output. One single samplerout to avoid
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please elaborate on the " One single samplerout to avoid" comment?

tests/spec_decode/test_spec_decode_worker.py Outdated Show resolved Hide resolved
@@ -21,6 +21,11 @@ def score_proposals(
all_proposal_lengths = proposals.proposal_lens.tolist()
for i, seq_group_metadata in enumerate(
execute_model_req.seq_group_metadata_list):
if all_proposal_lengths[i] == 0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please consider adding a test similar to this https://sourcegraph.com/github.com/vllm-project/vllm/-/blob/tests/spec_decode/test_scorer.py?L49 with the request containing both prefill and decodes.

tests/spec_decode/test_spec_decode_worker.py Show resolved Hide resolved
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants