Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Asynchronous tokenization #2879

Merged
merged 18 commits into from
Mar 15, 2024
Merged

Conversation

Yard1
Copy link
Collaborator

@Yard1 Yard1 commented Feb 14, 2024

Currently, vLLM tokenizes incoming requests synchronously inside the Engine. This has a detrimental effect for serving with AsyncLLMEngine as tokenization of long prompts will block the event loop, causing both token generation and request handling to slow down, especially in high QPS scenarios.

This PR introduces an optional Ray-based TokenizerGroupPool that will maintain a pool of RayActors that will do the tokenization. Since now tokenization is ran in separate processes, the event loop will not be blocked (as Ray futures can simply be awaited). This removes the bottleneck described above.

Note that detokenization is not changed, as overheads from serialization/deserialization would be too great there. In case of tokenization, they are negligible.

@njhill
Copy link
Member

njhill commented Feb 14, 2024

@Yard1 did you consider using a ThreadPoolExecutor of some size in conjunction with loop.run_in_executor()? I expect this would give similar benefits without the IPC and serializtion/deserialization overhead. The tokenization itself should not hold the GIL since it's in rust for most tokenizers. I expect the modifications needed would also be much simpler.

@Yard1
Copy link
Collaborator Author

Yard1 commented Feb 19, 2024

@njhill I agree that a thread based solution should work in principle for the most popular models - it would be good to confirm that, though. Would you be interested in trying it out using the API here?

@njhill
Copy link
Member

njhill commented Feb 19, 2024

@Yard1 sure!

@njhill
Copy link
Member

njhill commented Mar 5, 2024

@Yard1 sorry for the delay with this, I've now opened #3206.

@zhuohan123 zhuohan123 self-assigned this Mar 6, 2024
@cadedaniel cadedaniel self-assigned this Mar 13, 2024
vllm/engine/llm_engine.py Outdated Show resolved Hide resolved
vllm/transformers_utils/tokenizer.py Outdated Show resolved Hide resolved
vllm/transformers_utils/tokenizer.py Outdated Show resolved Hide resolved
vllm/transformers_utils/tokenizer.py Outdated Show resolved Hide resolved
vllm/transformers_utils/tokenizer.py Outdated Show resolved Hide resolved
vllm/transformers_utils/tokenizer.py Show resolved Hide resolved
vllm/engine/llm_engine.py Outdated Show resolved Hide resolved
vllm/engine/arg_utils.py Outdated Show resolved Hide resolved
vllm/engine/arg_utils.py Outdated Show resolved Hide resolved
tests/async_engine/test_api_server.py Outdated Show resolved Hide resolved
vllm/engine/llm_engine.py Outdated Show resolved Hide resolved
@Yard1 Yard1 requested review from cadedaniel and njhill March 13, 2024 23:16
@Yard1
Copy link
Collaborator Author

Yard1 commented Mar 13, 2024

@njhill @cadedaniel updated, ptal

Copy link
Member

@njhill njhill left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @Yard1. I can rebase #3206 when this is ready, would be good to get your thoughts on that too!

Copy link
Collaborator

@cadedaniel cadedaniel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, one design comment and some nits

vllm/config.py Outdated Show resolved Hide resolved
vllm/engine/arg_utils.py Outdated Show resolved Hide resolved
vllm/transformers_utils/tokenizer.py Outdated Show resolved Hide resolved
@Yard1 Yard1 requested review from cadedaniel and njhill March 14, 2024 17:32
Copy link
Collaborator

@cadedaniel cadedaniel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small nits, looks great!

@simon-mo simon-mo self-assigned this Mar 15, 2024
Copy link
Collaborator

@simon-mo simon-mo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

some nits around code style

Comment on lines +58 to +59
if not lora_request or not self.enable_lora:
return self.tokenizer
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

readability wise, it would be helpful to move these up in corresponding encode function

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Those are not private methods, I think it makes sense to keep this logic here as it's relevant.

Comment on lines +72 to +73
if not lora_request or not self.enable_lora:
return self.tokenizer
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

Comment on lines +36 to +48
@abstractmethod
def get_lora_tokenizer(
self,
lora_request: Optional[LoRARequest]) -> "PreTrainedTokenizer":
"""Get a tokenizer for a LoRA request."""
pass

@abstractmethod
async def get_lora_tokenizer_async(
self,
lora_request: Optional[LoRARequest]) -> "PreTrainedTokenizer":
"""Get a tokenizer for a LoRA request."""
pass
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these method called externally at all? if not I would not put them in base class

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

they are called in LLMEngine

@simon-mo simon-mo enabled auto-merge (squash) March 15, 2024 21:54
@simon-mo
Copy link
Collaborator

if you can address the code style that would be great. automerge is enabled, once test passes (should be if you merge main), it will be merged.

Copy link
Member

@njhill njhill left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@simon-mo simon-mo merged commit fb96c1e into vllm-project:main Mar 15, 2024
23 of 24 checks passed
@Yard1 Yard1 deleted the async_tokenization branch March 15, 2024 23:41
@flexwang
Copy link

Very cool! Any benchmark for the improvement?

njhill added a commit to njhill/vllm that referenced this pull request Mar 19, 2024
vllm-project#2879 added support for using ray to offload tokenization from the asyncio event loop.

This PR extends that to support using a thread pool instead of ray, and makes that the default, with the default pool size determined based on the number of available CPU cores and the tensor parallel size.

The main thing to note is that separate tokenizer instances are used per thread. This is because officially the HF tokenizers are not thread-safe. In practice I think they are unless you're making use of padding/truncation, which we aren't currently but may want to soon (see for example vllm-project#3144).

Also includes some type hint additions to related parts of the code.

This replaces the original PR vllm-project#3206 from before vllm-project#2879 was reworked and merged.
Temirulan pushed a commit to Temirulan/vllm-whisper that referenced this pull request Sep 6, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants