Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support batching for UsefulSensors Moonshine #35922

Merged
merged 15 commits into from
Jan 30, 2025
Merged

Conversation

njeffrie
Copy link
Contributor

@njeffrie njeffrie commented Jan 28, 2025

Add attention masking to the moonshine model.

Tested on Open ASR Leaderboard with batch_size=256.

  • WER (tiny): 12.53
  • WER (base): 9.89
  • RTFx (tiny): 2062
  • RTFx (base): 1634

Unblocks this OpenASR Leaderboard PR
@eustlb

Tested against Open ASR Leaderboard with batch size 256.
Perform attention mask downsampling inside of moonshine forward call.
- Correctly pipe encoder attention mask into decoder
- Add correct scaling factor if one is not already provided.
- Fix formatting with ruff
Copy link
Contributor

@eustlb eustlb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great initiative, thanks! 🤗 I had it in mind when integrating but batch inference was working pretty well with 0-padding when I benchmarked on Fleurs (yet non optimal).
Few changes to call proper functions of the codebase + little docstring fixies but otherwise LGTM!
There's just the padding to multiple of 8 that I have not seen in Transformers yet (to the best of my knowledge) and that I would not merge without proper benchmark signal.

Comment on lines 364 to 375

# Pad head size dimension to next multiple of 8. Q K and V always have equal head sizes.
pad_amount = 8 * ((query_states.shape[-1] + 7) // 8) - query_states.shape[-1]
if pad_amount > 0:
# Ensure scaling is correct even with padding.
if self.scaling is None:
self.scaling = 1.0 / math.sqrt(query_states.shape[-1])

query_states = torch.nn.functional.pad(query_states, (0, pad_amount))
key_states = torch.nn.functional.pad(key_states, (0, pad_amount))
value_states = torch.nn.functional.pad(value_states, (0, pad_amount))

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you justify the expected speedups a bit further here? Have you run benchmarks?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When I added an attention mask, I found I was only able to use batch sizes up to 32, otherwise I'd run out of memory. After doing some memory profiling, I found the culprit was the torch sdpa backend implementation - the memory efficient implementation with attention masking only supports multiple of eight head sizes, so we were falling back to the torch c++ implementation.

Overall, this change allows batch size 32 -> 256 and a corresponding ~4x increase in RTFx on Open ASR Leaderboard.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice take (here is the doc for the curious)!

I would rather have gone with another HF repo with an updated architecture (in config.json) and updated weights with 0.0s where necessary to avoid impacting dependencies, yet this would require modifying the modeling code anyway to handle correct scaling.

Let's add a config parameter (with explanation mentioned in the docstring) pad_head_dim_to_multiple_of that defaults to None (no effect) and that you would set to 8 in the model config.json (in their respective HF repos)

cc @ArthurZucker

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added the config parameter and I've got PRs ready for our two usefulsensors moonshine huggingface repos (1, 2) to update this parameter once we land this change.

Copy link
Contributor

@eustlb eustlb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tiny nit but LGTM, thanks! Will very likely require a subsequent PR to update expected logits for the CI runners, I'll take care of it

@eustlb eustlb merged commit 693328f into huggingface:main Jan 30, 2025
8 checks passed
bursteratom pushed a commit to bursteratom/transformers that referenced this pull request Feb 5, 2025
* Add support for attention masking in moonshine.

Tested against Open ASR Leaderboard with batch size 256.

* Update comments and ensure attention masks are passed everywhere.

Perform attention mask downsampling inside of moonshine forward call.

* Hide padding behind conditional. Fix encoder/decoder masking.

- Correctly pipe encoder attention mask into decoder
- Add correct scaling factor if one is not already provided.
- Fix formatting with ruff

* Add auto generated modeling_moonshine file.

* Update formatting in generated model file.

* Address review comments.

* Fix typo.

* Add `pad_head_dim_to_multiple_of` to moonshine config.

* Correct args order for MooonshineConfig.

* Update configuration moonshine too.

* Update src/transformers/models/moonshine/modular_moonshine.py

* Update src/transformers/models/moonshine/configuration_moonshine.py

---------

Co-authored-by: eustlb <94853470+eustlb@users.noreply.github.com>
elvircrn pushed a commit to elvircrn/transformers that referenced this pull request Feb 13, 2025
* Add support for attention masking in moonshine.

Tested against Open ASR Leaderboard with batch size 256.

* Update comments and ensure attention masks are passed everywhere.

Perform attention mask downsampling inside of moonshine forward call.

* Hide padding behind conditional. Fix encoder/decoder masking.

- Correctly pipe encoder attention mask into decoder
- Add correct scaling factor if one is not already provided.
- Fix formatting with ruff

* Add auto generated modeling_moonshine file.

* Update formatting in generated model file.

* Address review comments.

* Fix typo.

* Add `pad_head_dim_to_multiple_of` to moonshine config.

* Correct args order for MooonshineConfig.

* Update configuration moonshine too.

* Update src/transformers/models/moonshine/modular_moonshine.py

* Update src/transformers/models/moonshine/configuration_moonshine.py

---------

Co-authored-by: eustlb <94853470+eustlb@users.noreply.github.com>
sbucaille pushed a commit to sbucaille/transformers that referenced this pull request Feb 16, 2025
* Add support for attention masking in moonshine.

Tested against Open ASR Leaderboard with batch size 256.

* Update comments and ensure attention masks are passed everywhere.

Perform attention mask downsampling inside of moonshine forward call.

* Hide padding behind conditional. Fix encoder/decoder masking.

- Correctly pipe encoder attention mask into decoder
- Add correct scaling factor if one is not already provided.
- Fix formatting with ruff

* Add auto generated modeling_moonshine file.

* Update formatting in generated model file.

* Address review comments.

* Fix typo.

* Add `pad_head_dim_to_multiple_of` to moonshine config.

* Correct args order for MooonshineConfig.

* Update configuration moonshine too.

* Update src/transformers/models/moonshine/modular_moonshine.py

* Update src/transformers/models/moonshine/configuration_moonshine.py

---------

Co-authored-by: eustlb <94853470+eustlb@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants