Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[1/n] Triton sampling kernel #3186

Merged
merged 20 commits into from
Mar 20, 2024
Merged

[1/n] Triton sampling kernel #3186

merged 20 commits into from
Mar 20, 2024

Conversation

Yard1
Copy link
Collaborator

@Yard1 Yard1 commented Mar 4, 2024

This PR is the first one in a series of PRs.

This PR adds a custom triton sampling kernel, giving us the following benefits:

  • sampling from both greedy and random sequences in the same kernel
  • batched deterministic sampling with per-sequence seeds
  • potentially fusing other operations like logprob gather

Currently the codepath using the triton kernel is disabled due to the following issues:

  • Triton JIT has a large kernel launch overhead which is noticeable for small models. Potential solution would be to compile the kernels ahead of time - we have a pipeline for that internally
  • We need to call the kernel multiple times for models with very large vocabulary (eg. gemma). This should be possible to solve in the kernel itself.
  • The sampling code in general is unoptimized and adds overhead on top of the kernel. It is non-trivial to simplify it due to the beam search code. Next PR will try to separate out the beam search sampling code from the rest of the sampling.

tests/kernels/test_sampler.py Outdated Show resolved Hide resolved
@Yard1 Yard1 changed the title [WIP] Triton sampling kernel [1/n] Triton sampling kernel Mar 5, 2024
@Yard1 Yard1 marked this pull request as ready for review March 5, 2024 22:44
@Yard1 Yard1 requested a review from simon-mo March 5, 2024 22:56
Copy link
Member

@ywang96 ywang96 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left a few comments & questions and hope you don't mind them!

vllm/model_executor/layers/sampler.py Outdated Show resolved Hide resolved
vllm/model_executor/layers/sampler.py Outdated Show resolved Hide resolved
vllm/model_executor/layers/triton_kernel/rand.py Outdated Show resolved Hide resolved
vllm/model_executor/layers/triton_kernel/sample.py Outdated Show resolved Hide resolved
vllm/model_executor/sampling_metadata.py Outdated Show resolved Hide resolved
vllm/model_executor/sampling_metadata.py Outdated Show resolved Hide resolved
Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com>
Copy link
Member

@njhill njhill left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @Yard1!


_SAMPLING_EPS = 1e-5
_SEED_0_REPLACEMENT = 3403598558


class SamplingMetadata:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we skip any of the new operations in this class in the case that no seeds are in use? (which I expect would be very common).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question - I think there are three considerations:

  • Skipping seeds could bring a little better performance.
  • Skipping seeds introduces more special cases (undesirable).
  • Not skipping seeds allows for request-level reproducibility on the server side, which could be useful for debugging model behavior.

Aside from those, triton random operations require some sort of a seed, so generating one would be necessary regardless.

"""Get `seeds_to_generate` child seeds from `seed` and extra entropy."""
if not is_greedy:
if seed is None:
randint_fn = random.randint
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there's effectively no overhead of seeded vs non-seeded random sampling, a nice feature would be to treat random.randint here equivalent to a passed-in seed, and then always return this seed in the API response.

This allows users to use the returned seed to reproduce the same output, if it happened to be something they particularly liked for example (without them having to provide a seed explicitly up-front).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I agree! That's one of the advantages of always generating the seed. I think it would be good to include it in a followup (ideally once we are using just the kernel so the logic is consistent).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah it would be costly to do this in the non-kernel case.

@njhill
Copy link
Member

njhill commented Mar 19, 2024

  • sampling from both greedy and random sequences in the same kernel
  • batched deterministic sampling with per-sequence seeds
  • potentially fusing other operations like logprob gather

These are only benefits if they translate to non-negligible end-to-end performance improvements right? Curious what the speedup looks like as a proportion of total TPOT? I guess it depends on the mix of parameters and in particular if there are many seeded requests (presumably uncommon) and/or mix of greed, random, seeded random in the same batch (presumably more common).

I guess this question might be more important here given the nontrivial amount of new code introduced for this specific optimization.

  • The sampling code in general is unoptimized and adds overhead on top of the kernel. It is non-trivial to simplify it due to the beam search code. Next PR will try to separate out the beam search sampling code from the rest of the sampling.

Would these optimizations be applicable whether or not the dedicated kernel is used?

Copy link
Collaborator

@simon-mo simon-mo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

stamp. please address @njhill's comment before merge.

@Yard1
Copy link
Collaborator Author

Yard1 commented Mar 19, 2024

@njhill We are seeing ~10% reduction in sampler time in our fork, but that will require more work to achieve (two next PRs required are Triton AOT compilation for those kernels and refactor of the sampler code to avoid unnecessary operations). This PR only adds the kernel to streamline the review process. Furthermore, once we can fully move to the kernel, we'll be able to remove the existing torch-based sampling code (not including the logit processing code).

Would these optimizations be applicable whether or not the dedicated kernel is used?

I think they would make the sampler code easier to work with, though they would be tailored for the kernel. In general, the introduction of this kernel will allow us to push code complexity away from the sampler and into the kernel.

@njhill
Copy link
Member

njhill commented Mar 19, 2024

We are seeing ~10% reduction in sampler time in our fork,

@Yard1 do you have a rough sense of what percentage of TPOT sampler time accounts for? (I know as a proportion it would vary based on model size) .. e.g. if that is <10% then I guess this would translate to <1%?

@Yard1
Copy link
Collaborator Author

Yard1 commented Mar 19, 2024

@njhill You are correct it's not that noticeable in normal usage, but we are seeing large gains in draft model speculative decoding, where the draft model is CPU bound. It can reduce ITL by several ms in that case.

@Yard1 Yard1 merged commit 426ec4e into vllm-project:main Mar 20, 2024
30 checks passed
tjohnson31415 added a commit to tjohnson31415/vllm that referenced this pull request Mar 21, 2024
* upstream/main:
  [Misc] Bump up transformers to v4.39.0 & Remove StarCoder2Config (vllm-project#3551)
  [Misc][Log] Add log for tokenizer length not equal to vocabulary size (vllm-project#3500)
  [🚀 Ready to be merged] Added support for Jais models (vllm-project#3183)
  Fix 1D query issue from `_prune_hidden_states` (vllm-project#3539)
  [PREFIX CACHING FOLLOW UP] OrderedDict-based evictor (vllm-project#3431)
  [BugFix] Hot fix in setup.py for neuron build (vllm-project#3537)
  Migrate `logits` computation and gather to `model_runner` (vllm-project#3233)
  [1/n][Chunked Prefill] Refactor input query shapes (vllm-project#3236)
  [1/n] Triton sampling kernel (vllm-project#3186)
  [Bugfix] Fix ROCm support in CMakeLists.txt (vllm-project#3534)
Temirulan pushed a commit to Temirulan/vllm-whisper that referenced this pull request Sep 6, 2024
Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants