-
Notifications
You must be signed in to change notification settings - Fork 28.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support batching for UsefulSensors Moonshine #35922
Conversation
Tested against Open ASR Leaderboard with batch size 256.
Perform attention mask downsampling inside of moonshine forward call.
- Correctly pipe encoder attention mask into decoder - Add correct scaling factor if one is not already provided. - Fix formatting with ruff
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great initiative, thanks! 🤗 I had it in mind when integrating but batch inference was working pretty well with 0-padding when I benchmarked on Fleurs (yet non optimal).
Few changes to call proper functions of the codebase + little docstring fixies but otherwise LGTM!
There's just the padding to multiple of 8 that I have not seen in Transformers yet (to the best of my knowledge) and that I would not merge without proper benchmark signal.
|
||
# Pad head size dimension to next multiple of 8. Q K and V always have equal head sizes. | ||
pad_amount = 8 * ((query_states.shape[-1] + 7) // 8) - query_states.shape[-1] | ||
if pad_amount > 0: | ||
# Ensure scaling is correct even with padding. | ||
if self.scaling is None: | ||
self.scaling = 1.0 / math.sqrt(query_states.shape[-1]) | ||
|
||
query_states = torch.nn.functional.pad(query_states, (0, pad_amount)) | ||
key_states = torch.nn.functional.pad(key_states, (0, pad_amount)) | ||
value_states = torch.nn.functional.pad(value_states, (0, pad_amount)) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you justify the expected speedups a bit further here? Have you run benchmarks?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When I added an attention mask, I found I was only able to use batch sizes up to 32, otherwise I'd run out of memory. After doing some memory profiling, I found the culprit was the torch sdpa backend implementation - the memory efficient implementation with attention masking only supports multiple of eight head sizes, so we were falling back to the torch c++ implementation.
Overall, this change allows batch size 32 -> 256 and a corresponding ~4x increase in RTFx on Open ASR Leaderboard.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice take (here is the doc for the curious)!
I would rather have gone with another HF repo with an updated architecture (in config.json
) and updated weights with 0.0s where necessary to avoid impacting dependencies, yet this would require modifying the modeling code anyway to handle correct scaling.
Let's add a config parameter (with explanation mentioned in the docstring) pad_head_dim_to_multiple_of
that defaults to None
(no effect) and that you would set to 8 in the model config.json
(in their respective HF repos)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tiny nit but LGTM, thanks! Will very likely require a subsequent PR to update expected logits for the CI runners, I'll take care of it
* Add support for attention masking in moonshine. Tested against Open ASR Leaderboard with batch size 256. * Update comments and ensure attention masks are passed everywhere. Perform attention mask downsampling inside of moonshine forward call. * Hide padding behind conditional. Fix encoder/decoder masking. - Correctly pipe encoder attention mask into decoder - Add correct scaling factor if one is not already provided. - Fix formatting with ruff * Add auto generated modeling_moonshine file. * Update formatting in generated model file. * Address review comments. * Fix typo. * Add `pad_head_dim_to_multiple_of` to moonshine config. * Correct args order for MooonshineConfig. * Update configuration moonshine too. * Update src/transformers/models/moonshine/modular_moonshine.py * Update src/transformers/models/moonshine/configuration_moonshine.py --------- Co-authored-by: eustlb <94853470+eustlb@users.noreply.github.com>
* Add support for attention masking in moonshine. Tested against Open ASR Leaderboard with batch size 256. * Update comments and ensure attention masks are passed everywhere. Perform attention mask downsampling inside of moonshine forward call. * Hide padding behind conditional. Fix encoder/decoder masking. - Correctly pipe encoder attention mask into decoder - Add correct scaling factor if one is not already provided. - Fix formatting with ruff * Add auto generated modeling_moonshine file. * Update formatting in generated model file. * Address review comments. * Fix typo. * Add `pad_head_dim_to_multiple_of` to moonshine config. * Correct args order for MooonshineConfig. * Update configuration moonshine too. * Update src/transformers/models/moonshine/modular_moonshine.py * Update src/transformers/models/moonshine/configuration_moonshine.py --------- Co-authored-by: eustlb <94853470+eustlb@users.noreply.github.com>
* Add support for attention masking in moonshine. Tested against Open ASR Leaderboard with batch size 256. * Update comments and ensure attention masks are passed everywhere. Perform attention mask downsampling inside of moonshine forward call. * Hide padding behind conditional. Fix encoder/decoder masking. - Correctly pipe encoder attention mask into decoder - Add correct scaling factor if one is not already provided. - Fix formatting with ruff * Add auto generated modeling_moonshine file. * Update formatting in generated model file. * Address review comments. * Fix typo. * Add `pad_head_dim_to_multiple_of` to moonshine config. * Correct args order for MooonshineConfig. * Update configuration moonshine too. * Update src/transformers/models/moonshine/modular_moonshine.py * Update src/transformers/models/moonshine/configuration_moonshine.py --------- Co-authored-by: eustlb <94853470+eustlb@users.noreply.github.com>
Add attention masking to the moonshine model.
Tested on Open ASR Leaderboard with batch_size=256.
Unblocks this OpenASR Leaderboard PR
@eustlb