Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Consolidate prompt arguments to LLM engines #4328

Merged
merged 57 commits into from
May 28, 2024

Conversation

DarkLight1337
Copy link
Member

@DarkLight1337 DarkLight1337 commented Apr 24, 2024

Currently, LLM.generate (and similar methods in LLMEngine and AsyncLLMEngine) accept prompt, prompt_token_ids and multi_modal_data separately. This PR consolidates them into PromptInputs so that only a single argument has to be passed in, using type annotations to ensure a consistent format. This reduces the chance for the user to accidentally pass in different lengths of prompt, prompt_token_ids, and multi_modal_data (related checks have been removed to avoid redundant code). On the other hand, sampling_params remains untouched because it is common to only pass a single instance even for multiple prompts.

This would also make it easier to define the interface for processing the inputs using HuggingFace processor, as mentioned in #4194.

API changes

The APIs of LLM.generate and LLM.encode have been updated, where the parameters prompt, prompt_token_ids and multi_modal_data are replaced with inputs. Currently, we still maintain the old API but it may be deprecated in a future release. Users may update their code as follows:

Single prompt:

# No change required since the parameter is not referred by name
llm.generate("Hello, my name is")

- llm.generate(prompt="Hello, my name is")
+ llm.generate("Hello, my name is")

- llm.generate(prompt_token_ids=[1, 2, 3])
+ llm.generate({"prompt_token_ids": [1, 2, 3]})

# image is a tensor in NCHW format where N=1
- llm.generate("Hello, my name is", multi_modal_data=MultiModalData(type=..., data=image))
+ llm.generate({"prompt": "Hello, my name is", "multi_modal_data": MultiModalData(type=..., data=image)})

Multiple prompts:

# No change required since the parameter is not referred by name
llm.generate(["Hello, my name is", "The future of AI is"])

- llm.generate(prompt=["Hello, my name is", "The future of AI is"])
+ llm.generate(["Hello, my name is", "The future of AI is"])

- llm.generate(prompt_token_ids=[[1, 2, 3], [4, 5, 6]])
+ llm.generate([{"prompt_token_ids": [1, 2, 3]}, {"prompt_token_ids": [4, 5, 6]}])

# images is a tensor in NCHW format where N=len(prompts)
- prompts = ["Hello, my name is", "The future of AI is"]
- llm.generate(prompts, multi_modal_data=MultiModalData(type=..., data=images))
+ llm.generate([
+    {"prompt": prompt, "multi_modal_data": MultiModalData(type=..., data=images[i:i+1])}
+    for i, prompt in enumerate(prompts)
+ ])

Based on the examples in the documentation, most users should already prefer the first way of calling LLM.generate; those users need not make any changes.

Other changes

By setting gpu_memory_utilization to a smaller value, we can now run multiple LLM instances and OpenAI servers at the same time. To better manage GPU utilization between entrypoints tests, I have grouped the tests using pytest markers so we no longer have to refer to the file name directly. This makes it easier to edit the commands used to test multiple files in the same pytest session.

@DarkLight1337 DarkLight1337 changed the title [Core] Combine prompt inputs to LLM engines [Core] Combine prompt arguments to LLM engines Apr 24, 2024
@DarkLight1337 DarkLight1337 changed the title [Core] Combine prompt arguments to LLM engines [Core] Consolidate prompt arguments to LLM engines Apr 25, 2024
@DarkLight1337
Copy link
Member Author

@ywang96 Any thoughts about this?

@ywang96
Copy link
Member

ywang96 commented Apr 29, 2024

Hey @DarkLight1337! Sorry I've been a bit busy lately, but I will surely take a look in the upcoming week! Apologies for the delay!

Copy link
Member

@ywang96 ywang96 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you very much for the contribution @DarkLight1337 and sorry for the delayed review. I left a few comments and questions. Please take a look!

vllm/entrypoints/llm.py Outdated Show resolved Hide resolved
vllm/engine/llm_engine.py Outdated Show resolved Hide resolved
vllm/engine/llm_engine.py Outdated Show resolved Hide resolved
vllm/engine/llm_engine.py Outdated Show resolved Hide resolved
vllm/engine/llm_engine.py Outdated Show resolved Hide resolved
# Create the sequence group.
seq_group = SequenceGroup(request_id, [seq], sampling_params,
arrival_time, lora_request, multi_modal_data)
processed_inputs = self.encode_request(request_id=request_id,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right before when we call encode_request is where I think we could combine prompt, prompt_token_ids and multi_modal_data to PromptInputs.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should do this directly on the API class (LLM) so that we do not have to involve the internals when finally deprecating the old API.

vllm/inputs.py Show resolved Hide resolved
vllm/engine/llm_engine.py Show resolved Hide resolved
vllm/engine/llm_engine.py Outdated Show resolved Hide resolved
- To facilitate equality tests, `CompletionOutput` is now a dataclass
@DarkLight1337 DarkLight1337 force-pushed the llm-inputs branch 6 times, most recently from 0c57f28 to 37baf38 Compare May 3, 2024 16:59
@DarkLight1337
Copy link
Member Author

I think we should keep prompts/prompt_tokens_ids input for backward compatibility?

These arguments are still being maintained. I just decorated the relevant methods so we can deprecate them simply by turning on the corresponding flag.

@DarkLight1337
Copy link
Member Author

Pretty sure the CI will pass now.

Copy link
Member

@njhill njhill left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @DarkLight1337, I added some comments inline

if isinstance(inputs, str):
inputs = {"prompt": inputs}

if "prompt_token_ids" not in inputs:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: better to do

prompt_token_ids = inputs.get("prompt_token_ids")
if prompt_token_ids is None:
   # ...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doing so actually fails the type checker (I'm using Pyright on VSCode), so I'm reluctant to change it to this form.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does it complain about? Couldn't additional hints be added for that e.g.

prompt_token_ids: List[int] = inputs.get("prompt_token_ids")

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does it complain about? Couldn't additional hints be added for that e.g.

prompt_token_ids: List[int] = inputs.get("prompt_token_ids")

if isinstance(inputs, str):
inputs = {"prompt": inputs}

if "prompt_token_ids" not in inputs:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggest

prompt_token_ids = inputs.get("prompt_token_ids")
if prompt_tokens_ids is not None:
    # ...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As above

arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> AsyncStream:
if self.log_requests:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess not related to this PR but I feel like the body of this if should go into a separate function

Copy link
Member Author

@DarkLight1337 DarkLight1337 May 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This bothers me as well since it's otherwise redundant to pass both prompt and prompt_token_ids at the same time. Perhaps it would be better to move this to OpenAI server in another PR?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure

vllm/engine/llm_engine.py Outdated Show resolved Hide resolved
eos_token_id = self.tokenizer.get_lora_tokenizer(
lora_request).eos_token_id
else:
logger.warning("Use None for EOS token id because tokenizer is "
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure that we want to log a warning on every request. This will be a lot of log output for the no-tokenizer cases.

Copy link
Member Author

@DarkLight1337 DarkLight1337 May 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually this came from #3748, this PR only moved the code from add_request into _add_processed_request.

@ywang96 Since you reviewed that PR, can you explain the reason behind this/is it still necessary?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm - this is probably an oversight from me. I thought this warning was added at engine initialization time, not inference time. IMO we should move it to the initialization time for sure.

Copy link
Member Author

@DarkLight1337 DarkLight1337 May 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It may not be a good idea to warn during initialization time if the user indeed intends the tokenizer to not be used.

I have updated the code so that the warning still occurs on request but is only emitted at most once during the lifetime of the engine.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It may not be a good idea to warn during initialization time if the user indeed intends the tokenizer to not be used.

I'm not sure that I understand this, I think a single warning is ok in this case, it's only a warning. And a single warning will still be logged with this latest change right? Just at the time of the first request rather than llm engine construction

Copy link
Member Author

@DarkLight1337 DarkLight1337 May 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It may not be a good idea to warn during initialization time if the user indeed intends the tokenizer to not be used.

I'm not sure that I understand this, I think a single warning is ok in this case, it's only a warning. And a single warning will still be logged with this latest change right? Just at the time of the first request rather than llm engine construction

If the user didn't intend to use the tokenizer, then the warning might cause some confusion during startup. True that it's not a big deal either way though.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I personally vote for moving this to engine initialization time since that will make our codebase cleaner. (no need for another attribute of engine just for warning count)

Copy link
Member Author

@DarkLight1337 DarkLight1337 May 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Upon inspection of the docs, it seems that avoiding repeated errors is actually unnecessary:

Warning messages are normally written to sys.stderr, but their disposition can be changed flexibly, from ignoring all warnings to turning them into exceptions. The disposition of warnings can vary based on the warning category, the text of the warning message, and the source location where it is issued. Repetitions of a particular warning for the same source location are typically suppressed.

Therefore, we can keep the warning in the code for handling the request.

vllm/entrypoints/llm.py Outdated Show resolved Hide resolved
vllm/entrypoints/openai/serving_chat.py Show resolved Hide resolved
vllm/engine/llm_engine.py Outdated Show resolved Hide resolved
vllm/engine/llm_engine.py Outdated Show resolved Hide resolved
@DarkLight1337
Copy link
Member Author

@njhill I have responded to your comments.

@ywang96
Copy link
Member

ywang96 commented May 25, 2024

@DarkLight1337 I just went though this PR again and made a change to move offline API reference to under developer doc. #4710 was a great addition, but I think we should have links in examples to developer doc instead of putting API reference there directly.

DarkLight1337 added a commit to DarkLight1337/vllm-rocm that referenced this pull request May 26, 2024
Copy link
Member

@ywang96 ywang96 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm going to give this PR a greenlight, and thank you @DarkLight1337 for working on this PR and addressing the comments from reviewers.

@njhill @rkooo567 If you have any other concern feel free to leave a comment, otherwise I'd appreciate an approval from either of you as well. Thanks!

(P.S. @DarkLight1337 in the future, it'll be nice to break a big PR like this even further down into small PRs, that way it's easier for us to review and get things merged in a progressive way.)

Copy link
Member

@njhill njhill left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks again @DarkLight1337 @ywang96!

@ywang96 ywang96 merged commit 5ae5ed1 into vllm-project:main May 28, 2024
63 checks passed
@DarkLight1337 DarkLight1337 deleted the llm-inputs branch May 29, 2024 01:18
blinkbear pushed a commit to blinkbear/vllm that referenced this pull request May 29, 2024
dtrifiro pushed a commit to opendatahub-io/vllm that referenced this pull request May 31, 2024
blinkbear pushed a commit to blinkbear/vllm that referenced this pull request Jun 6, 2024
robertgshaw2-neuralmagic pushed a commit to neuralmagic/nm-vllm that referenced this pull request Jun 8, 2024
joerunde pushed a commit to joerunde/vllm that referenced this pull request Jun 17, 2024
robertgshaw2-neuralmagic pushed a commit to neuralmagic/nm-vllm that referenced this pull request Jul 14, 2024
Temirulan pushed a commit to Temirulan/vllm-whisper that referenced this pull request Sep 6, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants