-
Notifications
You must be signed in to change notification settings - Fork 27.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Mistral&Mixtral]Add sliding window param to sdpa after torch 2.2.0 #29220
Conversation
I modified mixtral model as well, because of they have same sliding window structure here. qwen2 failed with these tests below: |
…com/https://github.com/ehuaa/transformers into add_sliding_window_for_sdpa
Hey @fxmarty @ArthurZucker , |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks alright but let's not add the unrelated change
I have deleted the unrelated changes and please review them again. The failed test are because of |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A small nit but LGTM otherwise.
Let's just declare _is_torch_version_greater_or_equal_than_2_2_0
in the src/transformers/utils/import_utils.py
file
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Almost good to go
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, let's fix the conflicts and we should be able to merge.
EDIT: let's add a small test in the integration tests!
I have resolved the conflict, and the test you mentioned above i think we can use the test you wrote before, https://github.com/huggingface/transformers/blob/main/tests/models/mistral/test_modeling_mistral.py#L534 |
When i tried to add a new test to check if sdpa work with sliding_window, the test below is failed. I just modified your original test_model_7b_logits with a longer input with size of 4097 to check sliding_window feature, it failes as below: |
@slow
|
…e#29264) * Add compatibility with mps device * fix * typo and style
Co-authored-by: Joao Gante <joao@huggingface.co>
* [i18n-zh] Translate fsdp.md into Chinese Signed-off-by: windsonsea <haifeng.yao@daocloud.io> * apply suggestions from Fan-Lin --------- Signed-off-by: windsonsea <haifeng.yao@daocloud.io>
* [Whisper Tok] Update integration test * make style
* Fix yolos processing * Add back slow marker - protects for pycocotools in slow * Slow decorator goes above copied from header
enable subfolder
Thanks for this PR. |
* Fix deprecated arg issue * Trainer check too * Check for dict or dataclass * Simplify, make config always AcceleratorConfig * Upstream to Trainer
Probably yes. That why it's a bit critical. Mistral only had |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The only thing missing for me is a test to make sure that sdpa
with sliding_window
gives the same results as flash_attention_2
with sliding_window
@ArthurZucker , Yes, so i test mistral-7b logits of sdpa with sliding_window and eager with sliding_window, and the results are not the same. |
That's less important than sdpa vs flash! |
Ok, I'll upload a new slow test of sdpa vs flash later. |
…com/https://github.com/ehuaa/transformers into add_sliding_window_for_sdpa
After I git push, there're something weird to the changed files above so i closed this pr and try to open a new pull request. @ArthurZucker |
alright! It think this fix is important, let's try to have this ready! Ping me on the next PR and link it with this one as well to get the full conv! |
#29407 @ArthurZucker Please review this pr you mentioned above. |
What does this PR do?
Add sliding window param as described in #28980, and the solution is check if torch version greater than or equal as torch 2.2.0 (release version) , and add sliding window to attention_mask if it's true.
Tests of sdpa and sliding window attention with torch 2.2.0 has been passed with
RUN_SLOW=yes python -m pytest -n auto --dist=loadfile -s -v ./tests/models/mistral/test_modeling_mistral.py
in my local environment.
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@ArthurZucker @fxmarty