-
-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Frontend][Misc] Don't Repeat Yourself (DRY) Sampling #11368
base: main
Are you sure you want to change the base?
Conversation
👋 Hi! Thank you for contributing to the vLLM project. Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can do one of these:
🚀 |
/ready |
This pull request has merge conflicts that must be resolved before it can be |
""" | ||
Apply Don't Repeat Yourself (DRY) sampling to the logits. | ||
|
||
Reference: https://github.com/PygmalionAI/aphrodite-engine/blob/a3c03db7355b33c0dfd670b084e827d0aa7442d1/aphrodite/modeling/layers/sampler.py#L621-L702 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably makes more sense to link to the original PR (oobabooga/text-generation-webui#5677), as that contains a detailed description and discussion, whereas this is just code similar to the code below.
""" | ||
|
||
VOCAB_SIZE = logits.size(-1) | ||
MAX_NGRAM = 100 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
50 is more than enough. In practice, top logits tend to be on the order of 10^2, and 1.75^50 is already 10^12. This will double performance for pathological inputs.
# Find all instances of the last token- potential ngrams! | ||
endpoint_indexes = torch.nonzero(token_seq == last_token, | ||
as_tuple=True)[0].tolist() | ||
# NOTE(alpin): This seems like the slow part. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here's an idea: Instead of matching a single token above with nonzero
, try to match the last allowed_length
tokens. Shorter matches don't need to be penalized by definition, and this should drastically cut down on the number of matches that must be looped over. I'm not sure how to perform subsequence matching in PyTorch, but there should be a way to do this as a vectorized operation.
|
||
# Convert ngram lengths to penalty exponents | ||
penalty_mask = ngram_lens > 0 | ||
scales = bases[irow]**(ngram_lens[penalty_mask] - min_ngram) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What if ngram_lens[penalty_mask] - min_ngram < 0
? Perhaps I'm misunderstanding the implementation, but that case doesn't appear to be accounted for.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 on confirming this.
If I understand correctly, we want something like the following to filer out the cases when ngram_lens[penalty_mask] - min_ngram < 0
:
# Apply penalties
penalty_mask = ngram_lens > 0
exponent = ngram_lens[penalty_mask] - min_ngram
# Ensure that when exponent < 0, scales is set to 0
valid_exponent_mask = exponent >= 0
scales = torch.zeros_like(exponent, dtype=torch.float32)
scales[valid_exponent_mask] = bases[irow] ** exponent[valid_exponent_mask]
# Apply penalties
logits[irow][penalty_mask] -= multipliers[irow] * scales
This pull request has merge conflicts that must be resolved before it can be |
FIX #8581
This PR adds support for the DRY sampler, a modern repetition penalty to prevent repetitive outputs at a sequence-level, and performs much better than older penalty methods.
Showcase:
Prompt:
No DRY output:
dry_multiplier=1.0
Currently, using DRY can incur an overhead of up to ~25%. We can probably optimize this.