-
Notifications
You must be signed in to change notification settings - Fork 28.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Generate: New Cache
abstraction and Attention Sinks support
#26681
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left some first nits, thanks for kick-starting this effort!
@patrickvonplaten @gante
Regarding #26681 (comment) I would have to experiment. There might be options, but they'll probably be slower than necessary. |
Sounds great! |
Addressed the various comments. Beyond that, I also made
This is all under the assumption that the PR is heading in the right direction 😄 As a heads up, the sink cache does not work yet. I still need to do some experiments to see if I can store rotated keys and back-rotate + forward-rotate them when the cached keys are requested, rather than storing non-rotated keys. That is what I'll work on next. That leaves me with an additional question: each architecture requires slightly different key rotations. I'd like to implement this to be sufficiently adaptable, e.g. allowing architecture-specific functionality in
|
Hey @tomaarsen 👋
|
Upon some shower thoughts, I've come across an alternative plan for the model-specific cache modification problem -- easier to implement and that would result in more readable code. Instead of the base |
Agree here! Think it would be better if we only have a few selected cache classes that work for all models. The functions => Thus, I think we can:
|
I just noticed that Joao's message has been edited, and I missed Patrick's message, so my response doesn't make much sense anymore - I deleted it. I also prefer the |
@gante @tomaarsen This is really a good abstraction of kv_cache to enable window context and be compatible to the legacy kv_cache. But for the memory footprint, the 'torch.cat' are still needed when update the cache and reorder_cache using 'index_select' is also there with beam search. |
b2ba5ab
to
34e56c1
Compare
I removed the commits regarding changing the order of the caches based on @gante's recommendation. He is working on more changes for this PR here: tomaarsen#1 |
@liangan1 both are issues in the near-future roadmap, with this cache abstraction being a requirement! 🙌 First we will work on pre-allocated buffers (a subclass of |
Wow, cool. Can you share more details about your roadmap? e.g., the pre-allocate buffers. In fact, we also have implemented a indirect access kv_cache in Intel-extension-for-pytorch optimization for LLM which can be used for both greedy and beam and we are pleasure to contribute code if I can know your design about the pre-allocated buffer. |
f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " | ||
"to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " | ||
"when creating this class." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this warning needs some TLC - the current "will to errors" doesn't parse. Thank you!
did you mean "will lead to errors" perhaps? i.e. missing "lead"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're right! I'll take care of it. Thanks @stas00
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tomaarsen FYI adding this fix to the next cache PR (adding support to encoder-decoder models), no need to open a PR :)
@stas00 thanks for flagging! 🤗
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: | ||
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tomaarsen When does this case happen?
@tomaarsen Great work! Looks like the current implementation cannot handle case input_length > window_length at least for some models (mistral)
Supporse inputs is of torch.Size([1, 15759]). It will raise
|
Not sure it was tested fro mistral! Can you open a separate issue with a full reproducer? 🤗 |
@fayejf You might want to look at this paper, too - handling extremely long inputs in a streaming fashion: Leave No Context Behind: Efficient Infinite Context Transformers with Infini-attention.
|
@ArthurZucker It seems that even with Llama2, passing in SinkCache to I'm using transformers 4.39.3 and the Llama2 model was loaded using the following code:
The SinkCache was passed to generate as in @fayejf's script. I am not sure if this is the correct way to use SinkCache:
The code caused
|
@MoonRide303 Thanks for sharing! very interesting paper! |
@ArthurZucker I was with 4.38.2 and haven't got a chance to test it with latest release. But looks like @explanare have seen similar issues. I also tried with llama model and I've seen legacy cache error as well. I'm not sure if I'm using SinkCache correctly. I didn't find any doc for it unfortuantely. :( |
I got the same error with Llama2 using 4.38.2 following code here. Also tried
|
cc @gante as well as we broke it |
Will might be fixed by #30476. Only tuples should be going that path |
@fayejf @ys-zong @ArthurZucker
Apologies for this temporary issue, we had to break a few eggs (sink cache) to make an omelet (torch.compile support) :D Edit: merged, let me know if you run into issues! |
Hi @tomaarsen I just came across this line https://github.com/huggingface/transformers/blob/main/src/transformers/cache_utils.py#L979 in your PR and wonder why the condition for cache growing is |
Closes #26553
Hello!
What does this PR do?
I had a few hours on Saturday to work up a draft version of the updated KV caching mechanism as discussed in #26553. Ideally, this should allow Attention Sinks (https://github.com/tomaarsen/attention_sinks) / StreamingLLM (https://arxiv.org/abs/2309.17453) to be easily implemented in a third-party or in transformers directly.
The implementation doesn't work well yet, as the VRAM usage quickly shoots up after generating even just 8 tokens. This is probably some bug that I haven't had time for yet. There's a few other comments that I have on specific sections of code, so I'll write some comments below.
Goal for this draft
The intention for this draft is to continue discussion about whether this is moving in the right direction, and to determine the scope (e.g. do we want to include this updated
Cache
for all architectures that use KV caching?).Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@patrickvonplaten
@gante
@LysandreJik
@Guangxuan-Xiao