-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
F.scaled_dot_product_attention support #26572
F.scaled_dot_product_attention support #26572
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for starting over the preliminary support, I really like the design of it!
I see there are a bunch of deduplicated code from the vanilla attention, my question would be why not add everything inside the forward pass of the vanilla attention? The only problem is that it might make the attention code quite hard to read.
On the other hand, in case we go for a standalone LlamaSDPAAttention
module, it might make the modeling file of all models more bloated (FA-2, SDPA, ..). @ydshieh suggested offline that we could offload those modules in a new file to make the modeling file cleaner and nearly untouched.
I would personally advocate to add the SDPA support directly inside xxxAttention
as the changes relative to it is only ~20 LoC, it would be surprising for users to see that the xxxAttention
modules has suddenly changed to xxxSDPAAttention
by just upgrading transformers with no other intervention.
I would like to hear opinions from others @LysandreJik @ArthurZucker @patrickvonplaten on this matter, and I will be happy to help you extend this PR on other archs and adding relevant tests
As #26792 was merged will get back to it this week, targeting next to next transformers release. |
7bb4857
to
dd646c1
Compare
…xmarty/transformers into torch-sdpa-preliminary-support
It is ready. Here is a summary of the relevant CI.
|
if use_flash_attention_2: | ||
config = cls._check_and_enable_flash_attn_2(config, torch_dtype) | ||
config = copy.deepcopy(config) # We do not want to modify the config inplace in _from_config. | ||
config._attn_implementation = kwargs.pop("attn_implementation", None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This overrides existing _attn_implementation
value inside config. And sets it to None when attn_implementation
is not passed in kwargs ...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is intended, users should not use _attn_implementation
. Is there a case where you have no choice but to use it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @fxmarty thanks for reply. For context it is the same background as this PR #28823 where I tried to unblock export in our benchmark pipeline.
I guess we misunderstood the error message, and tried to pass attn_implementation="eager"
to config constructor instead of from_config
call.
Regarding your comment though I'm not sure if that is the right behavior. attn_implementation
is indeed documented in PretrainedConfig
, and it is not respected if called in this way.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@BowenBao you are right, it is an issue in the documentation. This should not be exposed in the config.
from transformers import AutoModelForCausalLM, AutoConfig, LlamaForCausalLM
cfg = AutoConfig.from_pretrained("fxmarty/tiny-llama-fast-tokenizer")
cfg._attn_implementation = "eager"
model = LlamaForCausalLM(cfg)
works. It is true that there is no API exposed to the user for initializing with XxxForCausalLM(cfg)
and selecting the attention implementation, apart from using this private attribute.
Any of:
model = AutoModel.from_config(cfg, attn_implementation="eager")
model = LlamaModel.from_pretrained("xxx", attn_implementation="eager")
work.
As per title, this PR proposes to support natively
torch.nn.functional.scaled_dot_product_attention
in transformers. I propose to enable SDPA by default iftorch>=2.1.1
(released 15 Nov. 2023), for the reasons written in the PR. The support could then be extended using https://github.com/huggingface/optimum/blob/main/optimum/bettertransformer/models/attention.py.The introduced
_unmask_unattended
is a workaround for pytorch/pytorch#110213.It behaves as follow:
If attention_mask is
and expanded_mask is (e.g. here left-padding case)
then the modified expanded_mask will be
Modifying as such the attention mask is fine given that we modify it only for pad tokens on the
-2
dimension. Softmax is computed on the-1
dimension, and thus there is no change for the relevant non-padding tokens.