Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core] move parallel sampling out from vllm core #9302

Merged
merged 17 commits into from
Oct 22, 2024

Conversation

youkaichao
Copy link
Member

try to hide seq group from the core, by handling parallel sampling in llm engine.

Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

@youkaichao
Copy link
Member Author

youkaichao commented Oct 12, 2024

the caveat is that we cannot support streaming when n > 1 . I think people don't use streaming when n > 1, and it is not clearly defined.

say we have n = 5, and the first stream gives 5 tokens, and then sequence 2 finish, do we send 5 outputs with the 2nd as empty? or send 4 outputs and let users mantain the status?

the openai api behavior is:

every sequence in parallel sampling will be assigned a unique index, and then the stream is flattened, one token at a time. it does not need to have n tokens at a time.

this is the test script:

from openai import OpenAI
api_key = ''
client = OpenAI(
    api_key=api_key,
)

stream = client.chat.completions.create(
    model="gpt-4o-mini",
    messages=[{"role": "user", "content": "Repeat after me: apple."}],
    stream=True,
    max_tokens=5,
    n=1,
)
for chunk in stream:
    print(chunk)

and the output:

ChatCompletionChunk(id='chatcmpl-AJDmVXderuQ5vwRKIqOCT2vJsEydq', choices=[Choice(delta=ChoiceDelta(content='', function_call=None, role='assistant', tool_calls=None, refusal=None), finish_reason=None, index=0, logprobs=None)], created=1729144571, model='gpt-4o-mini-2024-07-18', object='chat.completion.chunk', system_fingerprint='fp_e2bde53e6e')
ChatCompletionChunk(id='chatcmpl-AJDmVXderuQ5vwRKIqOCT2vJsEydq', choices=[Choice(delta=ChoiceDelta(content='Apple', function_call=None, role=None, tool_calls=None), finish_reason=None, index=0, logprobs=None)], created=1729144571, model='gpt-4o-mini-2024-07-18', object='chat.completion.chunk', system_fingerprint='fp_e2bde53e6e')
ChatCompletionChunk(id='chatcmpl-AJDmVXderuQ5vwRKIqOCT2vJsEydq', choices=[Choice(delta=ChoiceDelta(content='.', function_call=None, role=None, tool_calls=None), finish_reason=None, index=0, logprobs=None)], created=1729144571, model='gpt-4o-mini-2024-07-18', object='chat.completion.chunk', system_fingerprint='fp_e2bde53e6e')
ChatCompletionChunk(id='chatcmpl-AJDmVXderuQ5vwRKIqOCT2vJsEydq', choices=[Choice(delta=ChoiceDelta(content=None, function_call=None, role=None, tool_calls=None), finish_reason='stop', index=0, logprobs=None)], created=1729144571, model='gpt-4o-mini-2024-07-18', object='chat.completion.chunk', system_fingerprint='fp_e2bde53e6e')

when I use n=2:

ChatCompletionChunk(id='chatcmpl-AJDo9qfpbDuIFeHHjJtkkMxvVhs2P', choices=[Choice(delta=ChoiceDelta(content='', function_call=None, role='assistant', tool_calls=None, refusal=None), finish_reason=None, index=0, logprobs=None)], created=1729144673, model='gpt-4o-mini-2024-07-18', object='chat.completion.chunk', system_fingerprint='fp_e2bde53e6e')
ChatCompletionChunk(id='chatcmpl-AJDo9qfpbDuIFeHHjJtkkMxvVhs2P', choices=[Choice(delta=ChoiceDelta(content='Apple', function_call=None, role=None, tool_calls=None), finish_reason=None, index=0, logprobs=None)], created=1729144673, model='gpt-4o-mini-2024-07-18', object='chat.completion.chunk', system_fingerprint='fp_e2bde53e6e')
ChatCompletionChunk(id='chatcmpl-AJDo9qfpbDuIFeHHjJtkkMxvVhs2P', choices=[Choice(delta=ChoiceDelta(content='', function_call=None, role='assistant', tool_calls=None, refusal=None), finish_reason=None, index=1, logprobs=None)], created=1729144673, model='gpt-4o-mini-2024-07-18', object='chat.completion.chunk', system_fingerprint='fp_e2bde53e6e')
ChatCompletionChunk(id='chatcmpl-AJDo9qfpbDuIFeHHjJtkkMxvVhs2P', choices=[Choice(delta=ChoiceDelta(content='Apple', function_call=None, role=None, tool_calls=None), finish_reason=None, index=1, logprobs=None)], created=1729144673, model='gpt-4o-mini-2024-07-18', object='chat.completion.chunk', system_fingerprint='fp_e2bde53e6e')
ChatCompletionChunk(id='chatcmpl-AJDo9qfpbDuIFeHHjJtkkMxvVhs2P', choices=[Choice(delta=ChoiceDelta(content='.', function_call=None, role=None, tool_calls=None), finish_reason=None, index=0, logprobs=None)], created=1729144673, model='gpt-4o-mini-2024-07-18', object='chat.completion.chunk', system_fingerprint='fp_e2bde53e6e')
ChatCompletionChunk(id='chatcmpl-AJDo9qfpbDuIFeHHjJtkkMxvVhs2P', choices=[Choice(delta=ChoiceDelta(content='.', function_call=None, role=None, tool_calls=None), finish_reason=None, index=1, logprobs=None)], created=1729144673, model='gpt-4o-mini-2024-07-18', object='chat.completion.chunk', system_fingerprint='fp_e2bde53e6e')
ChatCompletionChunk(id='chatcmpl-AJDo9qfpbDuIFeHHjJtkkMxvVhs2P', choices=[Choice(delta=ChoiceDelta(content=None, function_call=None, role=None, tool_calls=None), finish_reason='stop', index=0, logprobs=None)], created=1729144673, model='gpt-4o-mini-2024-07-18', object='chat.completion.chunk', system_fingerprint='fp_e2bde53e6e')
ChatCompletionChunk(id='chatcmpl-AJDo9qfpbDuIFeHHjJtkkMxvVhs2P', choices=[Choice(delta=ChoiceDelta(content=None, function_call=None, role=None, tool_calls=None), finish_reason='stop', index=1, logprobs=None)], created=1729144673, model='gpt-4o-mini-2024-07-18', object='chat.completion.chunk', system_fingerprint='fp_e2bde53e6e')

I get two tokens from sequence 0 at first, and then two tokens from sequence 1.

@youkaichao youkaichao marked this pull request as draft October 12, 2024 03:00
@youkaichao youkaichao marked this pull request as ready for review October 12, 2024 05:35
@youkaichao
Copy link
Member Author

youkaichao commented Oct 12, 2024

@robertgshaw2-neuralmagic can you help take a look? I met a strange error:

pytest -v -s tests/entrypoints/openai/test_completion.py::test_guided_json_completion[-outlines]

will fail in this implementation. lm-format-enforcer works well, and --disable-frontend-multiprocessing also works. only the combination of the mqllmengine and outlines does not work.

it is surprising that ci actually passes ... it errors in my local dev machine.

@youkaichao youkaichao added the ready ONLY add when PR is ready to merge/full CI is needed label Oct 12, 2024
@youkaichao youkaichao changed the title [draft] try to remove seq group inside core [core] try to remove seq group from core Oct 12, 2024
vllm/engine/llm_engine.py Outdated Show resolved Hide resolved
vllm/outputs.py Outdated Show resolved Hide resolved
@afeldman-nm
Copy link
Contributor

afeldman-nm commented Oct 17, 2024

@youkaichao how would this PR impact best_of > 1 requests? Is best_of functionality still within the engine, or is it moved outside the engine as has been done for beam search? @robertgshaw2-neuralmagic @njhill

@youkaichao
Copy link
Member Author

@afeldman-nm best_of > 1 is already converted to parallel sampling in #9261

@youkaichao youkaichao changed the title [core] try to remove seq group from core [core] move parallel sampling out from vllm core Oct 21, 2024
@youkaichao youkaichao enabled auto-merge (squash) October 21, 2024 23:40
@youkaichao youkaichao merged commit 76a5e13 into vllm-project:main Oct 22, 2024
60 checks passed
@youkaichao youkaichao deleted the rm_n branch October 22, 2024 00:34
charlifu pushed a commit to charlifu/vllm that referenced this pull request Oct 23, 2024
vrdn-23 pushed a commit to vrdn-23/vllm that referenced this pull request Oct 23, 2024
Alvant pushed a commit to compressa-ai/vllm that referenced this pull request Oct 26, 2024
@rand-fly
Copy link

Could you explain the benefit of doing so? It seems that with this change, the scheduler can no longer make decisions based on the number of sequences within a SequenceGroup.

MErkinSag pushed a commit to MErkinSag/vllm that referenced this pull request Oct 26, 2024
Signed-off-by: Erkin Sagiroglu <erkin@infra-aipipeline-1-at1-prox-prod-a.ipa.corp.telnyx.com>
@youkaichao
Copy link
Member Author

Could you explain the benefit of doing so? It seems that with this change, the scheduler can no longer make decisions based on the number of sequences within a SequenceGroup.

yes, the scheduler will only process single sequence in the future, to make the core code simple.

garg-amit pushed a commit to garg-amit/vllm that referenced this pull request Oct 28, 2024
Signed-off-by: Amit Garg <mitgarg17495@gmail.com>
@rand-fly
Copy link

This modification makes the "fork" mechanism of vLLM completely unused. Previously, for a request with n > 1, its prompt was prefilled only once, and then the sequence was "forked" into n sequences to avoid redundant computation. After this modification, a request with n > 1 has to prefill its prompt n times.
A small experiment code can be used to verify this. Notice how it has become much slower after this squashed commit.

from vllm import LLM, SamplingParams
import time

# Sample prompts.
prompts = [
    "Once upon a time, there was a king.",
]
# Create a sampling params object.
sampling_params = SamplingParams(seed=42, temperature=0.1, max_tokens=1, n=100)

# Create an LLM.
llm = LLM(model="Qwen/Qwen2.5-1.5B-Instruct")

# warm up
outputs = llm.generate(prompts, sampling_params) 

begin_time = time.time()
outputs = llm.generate(prompts, sampling_params)
end_time = time.time()

print(f"{end_time - begin_time}s")

# Print the outputs.
for output in outputs:
    prompt = output.prompt
    generated_text = output.outputs[0].text
    print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

@youkaichao
Copy link
Member Author

This modification makes the "fork" mechanism of vLLM completely unused. Previously, for a request with n > 1, its prompt was prefilled only once, and then the sequence was "forked" into n sequences to avoid redundant computation. After this modification, a request with n > 1 has to prefill its prompt n times.

Yes, this is intended. Please use prefix caching to speed up and share the prefill. All the sharing will not be hardcoded in the scheduler, and will only happen through prefix caching.

I'm not sure if prefix caching currently supports sharing in the same batch. If you want optimal performance, I would suggest running a n=1 request first, and then run another n=n-1 request.

FerdinandZhong pushed a commit to FerdinandZhong/vllm that referenced this pull request Oct 29, 2024
Signed-off-by: qishuai <ferdinandzhong@gmail.com>
@rand-fly
Copy link

This modification makes the "fork" mechanism of vLLM completely unused. Previously, for a request with n > 1, its prompt was prefilled only once, and then the sequence was "forked" into n sequences to avoid redundant computation. After this modification, a request with n > 1 has to prefill its prompt n times.

Yes, this is intended. Please use prefix caching to speed up and share the prefill. All the sharing will not be hardcoded in the scheduler, and will only happen through prefix caching.

I'm not sure if prefix caching currently supports sharing in the same batch. If you want optimal performance, I would suggest running a n=1 request first, and then run another n=n-1 request.

Thank you for clarifying. Prefix caching does support sharing in the same batch, though the performance gain is not as much as using the "fork" mechanism.

sumitd2 pushed a commit to sumitd2/vllm that referenced this pull request Nov 14, 2024
Signed-off-by: Sumit Dubey <sumit.dubey2@ibm.com>
KuntaiDu pushed a commit to KuntaiDu/vllm that referenced this pull request Nov 20, 2024
mfournioux pushed a commit to mfournioux/vllm that referenced this pull request Nov 20, 2024
Signed-off-by: Maxime Fournioux <55544262+mfournioux@users.noreply.github.com>
tlrmchlsmth pushed a commit to neuralmagic/vllm that referenced this pull request Nov 23, 2024
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants