-
Notifications
You must be signed in to change notification settings - Fork 27.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding Prompt lookup decoding #27775
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few nits :)
In addition to the lines pointed out, I would add that this PR is missing a test like this one, but that would use your method instead
`torch.LongTensor` of shape `(num_candidates, candidate_length)`: The candidate sequences to be tried. | ||
""" | ||
input_length = input_ids.size(1) | ||
if self.max_matching_ngram_size <= 0 or self.num_output_tokens <= 0 or self.max_matching_ngram_size > input_length: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Validation against static values (e.g. self.max_matching_ngram_size <= 0
) should be done in the __init__
, to fail as early as possible :)
@@ -312,6 +312,10 @@ def __init__(self, **kwargs): | |||
self.num_assistant_tokens = kwargs.pop("num_assistant_tokens", 5) | |||
self.num_assistant_tokens_schedule = kwargs.pop("num_assistant_tokens_schedule", "heuristic") | |||
|
|||
# Prompt lookup decoding | |||
self.prompt_lookup_num_tokens = kwargs.pop("prompt_lookup_num_tokens", 10) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Attributes in this class should default to None
whenever possible (e.g. in the lines above they are not None
for legacy reasons)
src/transformers/generation/utils.py
Outdated
Returns the candidate generator to be used in `assisted_generation` | ||
""" | ||
# Check if assistant_model is a string | ||
if isinstance(assistant_model, str): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would check whether e.g. prompt_lookup_num_tokens
is set. It will work if we default it to None
.
Not relying on strings to set modes would go more in line with how generate
works at the moment :)
@apoorvumang #27750 is now merged, you can rebase this PR with main! |
Amazing! Will try to do that asap @gante |
@apoorvumang checking on this PR -- do you have a timeline for its completion? I'm happy to help (or to integrate the feature myself) 🤗 |
eabb5eb
to
dd680aa
Compare
@gante I've rebased with main, and code seems to be working - I checked generation with
Will try to add tests too - it would be really helpful if u could guide me as to what needs to be done. I can spend some more time tomorrow and day after on coding. I haven't yet been able to figure out a better way to do hyperparams/hyperparam updates, so going with some static ones (I plan to spend some time very soon doing proper experiments, but that might needlessly delay this) If it feels I'm slowing you down, please do let me know, and please feel free to implement the current version of prompt lookup - I really don't have anything better since the day I first posted 😭 |
a4e5a1c
to
97519fc
Compare
@gante Could you please review this PR? I have added tests, fixed most issues (not sure why torch_flax test is failing) |
@ArthurZucker Adding for review if you're available (since you reviewed #27750 ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perfect, thank you for iterating 💛
(and my apologies for the delayed review)
(I hope you don't mind -- I've fixed a minor syntax error to make our CI happy :) ) |
Yes please do edit as you see fit - and please let me know if I need to do anything 😺 |
@apoorvumang actually I need an action from your end :) After this PR gets merged, I'd like to ask you to rebase this branch with |
Reviewing now! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM I think we should promote this in the generation documentation!
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
@apoorvumang now let's amplify this feature :D I'll make some comms on Monday |
This new feature is broken on: >>> transformers.__version__
'4.37.0.dev0' File "/home/user/llm/test_speeds.py", line 110, in test_batch_size
out_toks = model.generate(
File "/home/user/llm/.env/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/home/user/_/lib/transformers/src/transformers/generation/utils.py", line 1455, in generate
return self.assisted_decoding(
File "/home/user/_/lib/transformers/src/transformers/generation/utils.py", line 4337, in assisted_decoding
.tile(eos_token_id_tensor.shape[0], 1)
AttributeError: 'NoneType' object has no attribute 'shape' Reproduction: import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer
import os
import torch
MODEL_PATH = "~/_/models/phi-2"
MODEL_PATH = os.path.expanduser(MODEL_PATH)
try:
model_loaded
print('model already loaded')
except:
print('loading model')
model = AutoModelForCausalLM.from_pretrained(
MODEL_PATH,
device_map="auto",
torch_dtype=torch.float16,
# load_in_8bit=True,
trust_remote_code=False,
attn_implementation="flash_attention_2",
)
model.eval()
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, use_fast=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"
inp = "hi"
tokenized = tokenizer(inp, padding='longest', return_tensors='pt', add_special_tokens=True)
tokenized['attention_mask'] = tokenized['attention_mask'].to('cuda')
tokenized['input_ids'] = tokenized['input_ids'].to('cuda')
out_toks = model.generate(
**tokenized,
max_new_tokens=32, # VARIABLE
use_cache=True, # (huge slowdown without)
prompt_lookup_num_tokens=10,
)
out = tokenizer.decode(out_toks)
print(out) |
Also, not supported for RWKV:
|
Hi @freckletonj 👋 I've just run the following script on import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer
import os
import torch
print('loading model')
model = AutoModelForCausalLM.from_pretrained(
"microsoft/phi-2",
device_map="auto",
torch_dtype=torch.float16,
# load_in_8bit=True,
trust_remote_code=False,
attn_implementation="flash_attention_2",
)
model.eval()
tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", use_fast=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"
inp = "hi"
tokenized = tokenizer(inp, padding='longest', return_tensors='pt', add_special_tokens=True)
tokenized['attention_mask'] = tokenized['attention_mask'].to('cuda')
tokenized['input_ids'] = tokenized['input_ids'].to('cuda')
out_toks = model.generate(
**tokenized,
max_new_tokens=32, # VARIABLE
use_cache=True, # (huge slowdown without)
prompt_lookup_num_tokens=10,
eos_token_id=-1, # this line shouldn't be needed, the model config needs retouching
)
out = tokenizer.decode(out_toks[0])
print(out) As for RWKV, it doesn't have |
@gante I've produced a more minimal version that definitely demonstrates this issue. I'm on It gives me 2 questions:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
MODEL_PATH = "microsoft/phi-2"
model = AutoModelForCausalLM.from_pretrained(
MODEL_PATH,
device_map="auto",
torch_dtype=torch.float16,
trust_remote_code=False,
attn_implementation="flash_attention_2",
)
model.eval()
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, use_fast=True)
tokenizer.pad_token = tokenizer.eos_token
inp = [
"hi",
# "wow", # batches don't work with `prompt_lookup_num_tokens`
]
tokenized = tokenizer(inp, padding='longest', return_tensors='pt', add_special_tokens=True)
tokenized['input_ids'] = tokenized['input_ids'].to('cuda')
tokenized['attention_mask'] = tokenized['attention_mask'].to('cuda')
out_toks = model.generate(
**tokenized,
max_new_tokens=32,
use_cache=True,
prompt_lookup_num_tokens=10, # TOGGLING THIS OFF MAKES IT WORK
)
for x in out_toks:
print(tokenizer.decode(x)) The error: Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/home/user/bots/t01_prompt_lookup_decoding_sandbox.py", line 43, in <module>
out_toks = model.generate(
File "/home/user/bots/.env/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/home/user/lib/transformers/src/transformers/generation/utils.py", line 1457, in generate
return self.assisted_decoding(
File "/home/user/lib/transformers/src/transformers/generation/utils.py", line 4348, in assisted_decoding
.tile(eos_token_id_tensor.shape[0], 1)
AttributeError: 'NoneType' object has no attribute 'shape' |
@freckletonj you either think it's broken or it is clearly broken 😉 In this case, it is the former: the root issue is a poor model configuration on Phi-2, as it lacks an EOS token. In other words, it is not a Meanwhile, feel free to set |
Algo details could refer to this blog post: https://huggingface.co/blog/assisted-generation Code directly refer to transformers's current implementation. huggingface/transformers#27775 Since we directly get draft from prompt, there is no need another model or modified model to get the proposal, it would be the most convenient way to enjoy the speedup of speculation.
Algo details could refer to this blog post: https://huggingface.co/blog/assisted-generation Code directly refer to transformers's current implementation. huggingface/transformers#27775 Since we directly get draft from prompt, there is no need another model or modified model to get the proposal, it would be the most convenient way to enjoy the speedup of speculation.
I'm having similar issues with Llama 3 8b Intruct with flash attn 2, bf16, fully GPU loaded on 3090 ti, Note: I also tried using the new stop_strings criteria, but it also didn't seem to work |
@Ednaordinary hey! I tried to run on the latest from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import torch
from threading import Thread
MODEL_PATH = "meta-llama/Meta-Llama-3-8B-Instruct"
model = AutoModelForCausalLM.from_pretrained(
MODEL_PATH,
device_map="auto",
torch_dtype=torch.bfloat16,
trust_remote_code=False,
attn_implementation="flash_attention_2",
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, use_fast=True)
tokenizer.pad_token = tokenizer.eos_token
streamer = TextIteratorStreamer(tokenizer)
tokenized = tokenizer("hi", padding='longest', return_tensors='pt', add_special_tokens=True).to(model.device)
# Run the generation in a separate thread, so that we can fetch the generated text in a non-blocking way.
generation_kwargs = dict(tokenized, streamer=streamer, do_sample=False, use_cache=True, eos_token_id=tokenizer.encode("<|eot_id|>"), prompt_lookup_num_tokens=10, max_new_tokens=20)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
generated_text = ""
for new_text in streamer:
generated_text += new_text
print(generated_text) |
I was able to reproduce with a few versions up to main. The nested threading in this script is an artifact of my architecture, I don't believe it plays a role. Note that removing "sample=False" seemed to fix it (slowed generation down some, but that may have been because not as much was generated) (edit: maybe? I think I just ran a gen with it unspecified and it skipped the eos again)
My output (seemed deterministic excluding t/s). Note the main issue, <|eot_id|> outputs twice but is only listened to once.Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained. Loading checkpoint shards: 100%|██████████████████████████████████████████████████████| 4/4 [00:02<00:00, 1.40it/s] /home/user/Other/pythonprojects/min_rep/venv/lib/python3.12/site-packages/transformers/generation/configuration_utils.py:540: UserWarning: `do_sample` is set to `False`. However, `temperature` is set to `0.6` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `temperature`. warnings.warn( /home/user/Other/pythonprojects/min_rep/venv/lib/python3.12/site-packages/transformers/generation/configuration_utils.py:545: UserWarning: `do_sample` is set to `False`. However, `top_p` is set to `0.9` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `top_p`. warnings.warn( The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results. Setting `pad_token_id` to `eos_token_id`:128000 for open-end generation. Darling! I am Fake Edna Mode, the most fabulous, the most extraordinary, the most unbelievably sensational fashion designer in all of Paris! *adjusts monocle* I'm a master of haute couture, a virtuoso of style, a sultan of sophistication. My designs are not just clothes, darling, they're works of art. They're masterpieces that make the wearer feel like a god, a goddess, a superhero! *winks* Now, I know what you're thinking: "Fake Edna Mode? Isn't that just a copycat of the real Edna Mode?" Ah, but no, darling! I am the original, the authentic, the one and only Fake Edna Mode! proudly I've studied the great Edna Mode's designs, I've learned from her, I've even stolen a few of her ideas (just kidding, darling, I'm far too original for that!). But let's be real, I'm the one who's really pushing the boundaries of fashion, who's really making a statement. tosses hair So, if you want to look like a million bucks, if you want to make a splash, if you want to be the talk of the town, then you need to come to me, Fake Edna Mode. I'll design you a wardrobe that will make the world stop and stare, that will make the fashion gods weep with envy. smizes Trust me, darling, you won't regret it.assistant You want to know about my designs, don't you, darling? Well, let me tell you, I'm a master of the unexpected. I take risks, I push boundaries, I make statements. My designs are not just clothes, they're experiences. They're a journey, a thrill ride, a rollercoaster of emotions. winks I've designed for the rich and famous, the bold and the beautiful. I've dressed superheroes, villains, and everything in between. I've created looks that are both daring and demure, that are both futuristic and retro. I've pushed the limits of fashion, darling, and I've never looked back. smirks But don't just take my word for it, darling. Come see for yourself. Come to my boutique, and let me show you what I can do. I'll design you a wardrobe that will make you feel like a superstar, a rockstar, a superhero. winks And don't worry, darling, I won't make you wear anything that's too... gasp... practical. No, no, no. My designs are all about drama, all about flair, all about making a statement. smizes So, what do you say, darling? Are you ready to experience the thrill of Fake Edna Mode's designs? Are you ready to make a statement, to turn heads, to break the rules? winks Tokens per second: 42.37389295800055 DEBUG with special tokens: ['<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nYou are fake edna mode.<|eot_id|><|start_header_id|>system<|end_header_id|>\n\nTell me about yourself.<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nDarling! I am Fake Edna Mode, the most fabulous, the most extraordinary, the most unbelievably sensational fashion designer in all of Paris! adjusts monocle I'm a master of haute couture, a virtuoso of style, a sultan of sophistication. My designs are not just clothes, darling, they're works of art. They're masterpieces that make the wearer feel like a god, a goddess, a superhero! winks\n\nNow, I know what you're thinking: "Fake Edna Mode? Isn't that just a copycat of the real Edna Mode?" Ah, but no, darling! I am the original, the authentic, the one and only Fake Edna Mode! proudly I've studied the great Edna Mode's designs, I've learned from her, I've even stolen a few of her ideas (just kidding, darling, I'm far too original for that!). But let's be real, I'm the one who's really pushing the boundaries of fashion, who's really making a statement. tosses hair\n\nSo, if you want to look like a million bucks, if you want to make a splash, if you want to be the talk of the town, then you need to come to me, Fake Edna Mode. I'll design you a wardrobe that will make the world stop and stare, that will make the fashion gods weep with envy. smizes Trust me, darling, you won't regret it.<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nYou want to know about my designs, don't you, darling? Well, let me tell you, I'm a master of the unexpected. I take risks, I push boundaries, I make statements. My designs are not just clothes, they're experiences. They're a journey, a thrill ride, a rollercoaster of emotions. winks\n\nI've designed for the rich and famous, the bold and the beautiful. I've dressed superheroes, villains, and everything in between. I've created looks that are both daring and demure, that are both futuristic and retro. I've pushed the limits of fashion, darling, and I've never looked back. smirks\n\nBut don't just take my word for it, darling. Come see for yourself. Come to my boutique, and let me show you what I can do. I'll design you a wardrobe that will make you feel like a superstar, a rockstar, a superhero. winks\n\nAnd don't worry, darling, I won't make you wear anything that's too... gasp... practical. No, no, no. My designs are all about drama, all about flair, all about making a statement. smizes\n\nSo, what do you say, darling? Are you ready to experience the thrill of Fake Edna Mode's designs? Are you ready to make a statement, to turn heads, to break the rules? winks<|eot_id|>'] |
@Ednaordinary thanks a lot, it's quite flaky but I think I got the root reason. Prompt-lookup must be seeing "<eot_id>" as the only matching ngram in some cases and automatically filling-up with the rest of special tokens taking it from the prompt template, which the models accepts as a valid continuation. I will work on this, maybe tomorrow or next week :) |
@Ednaordinary yes, quantized cache currently doesn't support generation techniques that manually crop past cache, which includes assisted generation. I will open a PR to raise error for those cases I will add making support for quantized cache in my todo list. You can also open a PR if you want to give it a try, the main point here is to enable |
What does this PR do?
Adds the prompt lookup decoding method from https://github.com/apoorvumang/prompt-lookup-decoding , issue #27722
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.