Skip to content

Commit

Permalink
Support batching for UsefulSensors Moonshine (#35922)
Browse files Browse the repository at this point in the history
* Add support for attention masking in moonshine.

Tested against Open ASR Leaderboard with batch size 256.

* Update comments and ensure attention masks are passed everywhere.

Perform attention mask downsampling inside of moonshine forward call.

* Hide padding behind conditional. Fix encoder/decoder masking.

- Correctly pipe encoder attention mask into decoder
- Add correct scaling factor if one is not already provided.
- Fix formatting with ruff

* Add auto generated modeling_moonshine file.

* Update formatting in generated model file.

* Address review comments.

* Fix typo.

* Add `pad_head_dim_to_multiple_of` to moonshine config.

* Correct args order for MooonshineConfig.

* Update configuration moonshine too.

* Update src/transformers/models/moonshine/modular_moonshine.py

* Update src/transformers/models/moonshine/configuration_moonshine.py

---------

Co-authored-by: eustlb <94853470+eustlb@users.noreply.github.com>
  • Loading branch information
njeffrie and eustlb authored Jan 30, 2025
1 parent 5757681 commit 693328f
Show file tree
Hide file tree
Showing 3 changed files with 181 additions and 15 deletions.
6 changes: 6 additions & 0 deletions src/transformers/models/moonshine/configuration_moonshine.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ class MoonshineConfig(PretrainedConfig):
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
`decoder_num_attention_heads`.
pad_head_dim_to_multiple_of (`int`, *optional*):
Pad head dimension in encoder and decoder to the next multiple of this value. Necessary for using certain
optimized attention implementations.
encoder_hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
The non-linear activation function (function or string) in the encoder.
decoder_hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
Expand Down Expand Up @@ -164,6 +167,7 @@ def __init__(
decoder_num_attention_heads=8,
encoder_num_key_value_heads=None,
decoder_num_key_value_heads=None,
pad_head_dim_to_multiple_of=None,
encoder_hidden_act="gelu",
decoder_hidden_act="silu",
max_position_embeddings=512,
Expand Down Expand Up @@ -196,6 +200,8 @@ def __init__(
decoder_num_key_value_heads = decoder_num_attention_heads
self.decoder_num_key_value_heads = decoder_num_key_value_heads

self.pad_head_dim_to_multiple_of = pad_head_dim_to_multiple_of

self.encoder_hidden_act = encoder_hidden_act
self.decoder_hidden_act = decoder_hidden_act
self.max_position_embeddings = max_position_embeddings
Expand Down
93 changes: 85 additions & 8 deletions src/transformers/models/moonshine/modeling_moonshine.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import math
from typing import Callable, Optional, Tuple, Union

import numpy as np
Expand All @@ -27,7 +28,11 @@
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_attn_mask_utils import (
AttentionMaskConverter,
_prepare_4d_attention_mask,
_prepare_4d_attention_mask_for_sdpa,
)
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import (
BaseModelOutput,
Expand Down Expand Up @@ -270,6 +275,23 @@ def forward(
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]

is_causal = True if self.is_causal and attention_mask is None and q_len > 1 else False

# Pad head size dimension to next specified multiple. Q K and V always have equal head sizes.
head_dim_padding = 0
if self.config.pad_head_dim_to_multiple_of is not None:
head_dim = query_states.shape[-1]
target_multiple = self.config.pad_head_dim_to_multiple_of
target_head_dim = target_multiple * ((head_dim + target_multiple - 1) // target_multiple)
head_dim_padding = target_head_dim - head_dim
if head_dim_padding > 0:
# Ensure scaling is correct even with padding.
if self.scaling is None:
self.scaling = 1.0 / math.sqrt(query_states.shape[-1])

query_states = torch.nn.functional.pad(query_states, (0, head_dim_padding))
key_states = torch.nn.functional.pad(key_states, (0, head_dim_padding))
value_states = torch.nn.functional.pad(value_states, (0, head_dim_padding))

attn_output, attn_weights = attention_interface(
self,
query_states,
Expand All @@ -282,6 +304,10 @@ def forward(
**kwargs,
)

# Remove head size padding.
if head_dim_padding > 0:
attn_output = attn_output[:, :, :, :-head_dim_padding]

attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights
Expand Down Expand Up @@ -603,9 +629,11 @@ def forward(
`numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into
`input_values`, the [`AutoFeatureExtractor`] should be used for padding
and conversion into a tensor of type `torch.FloatTensor`.
attention_mask (`torch.Tensor`)`, *optional*):
Moonshine does not support masking of the `input_values`, this argument is preserved for compatibility,
but it is not used.
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding indices in `input_values`. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
Expand All @@ -632,6 +660,22 @@ def forward(
hidden_states = nn.functional.gelu(self.conv3(hidden_states))
hidden_states = hidden_states.permute(0, 2, 1)

# attention mask downsampling
if attention_mask is not None:
mask_len = self._get_feat_extract_output_lengths(attention_mask.shape[-1])
downsample_stride = 64 * 3 * 2 # conv strides
attention_mask = attention_mask[..., ::downsample_stride][..., :mask_len]
if self.config._attn_implementation == "flash_attention_2":
attention_mask = attention_mask if (attention_mask == 0.0).any() else None

# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
elif self.config._attn_implementation == "sdpa" and not output_attentions:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, hidden_states.dtype)
else:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)

position_ids = torch.arange(0, hidden_states.shape[1], device=hidden_states.device).unsqueeze(0)

# create position embeddings to be shared across the decoder layers
Expand All @@ -649,7 +693,7 @@ def forward(
layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
None,
attention_mask,
position_ids,
None,
output_attentions,
Expand All @@ -660,6 +704,7 @@ def forward(
else:
layer_outputs = encoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
output_attentions=output_attentions,
position_embeddings=position_embeddings,
Expand Down Expand Up @@ -810,13 +855,19 @@ def forward(
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> Union[Tuple, BaseModelOutputWithPast]:
"""
Args:
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
of the decoder.
encoder_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding indices in `encoder_hidden_states`. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
Expand Down Expand Up @@ -865,6 +916,26 @@ def forward(
all_self_attns = () if output_attentions else None
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None

# attention mask downsampling
if encoder_attention_mask is not None:
mask_len = encoder_hidden_states.shape[-2]
downsample_stride = 64 * 3 * 2 # conv strides
encoder_attention_mask = encoder_attention_mask[..., ::downsample_stride][..., :mask_len]
if self.config._attn_implementation == "flash_attention_2":
encoder_attention_mask = encoder_attention_mask if (encoder_attention_mask == 0.0).any() else None

# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
elif self.config._attn_implementation == "sdpa" and not output_attentions:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
encoder_attention_mask, hidden_states.dtype, hidden_states.shape[-2]
)
else:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
encoder_attention_mask = _prepare_4d_attention_mask(
encoder_attention_mask, hidden_states.dtype, hidden_states.shape[-2]
)

for decoder_layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
Expand All @@ -886,6 +957,7 @@ def forward(
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
encoder_attention_mask=encoder_attention_mask,
encoder_hidden_states=encoder_hidden_states,
position_ids=position_ids,
past_key_value=past_key_values,
Expand Down Expand Up @@ -1168,9 +1240,11 @@ def compute_num_masked_span(input_length):
`numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into
`input_values`, the [`AutoFeatureExtractor`] should be used for padding
and conversion into a tensor of type `torch.FloatTensor`.
attention_mask (`torch.Tensor`)`, *optional*):
Moonshine does not support masking of the `input_values`, this argument is preserved for compatibility,
but it is not used.
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding indices in `input_values`. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
decoder_input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
Expand Down Expand Up @@ -1371,6 +1445,7 @@ def forward(
if encoder_outputs is None:
encoder_outputs = self.encoder(
input_values,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
Expand All @@ -1387,6 +1462,7 @@ def forward(
decoder_outputs = self.decoder(
input_ids=decoder_input_ids,
attention_mask=decoder_attention_mask,
encoder_attention_mask=attention_mask,
encoder_hidden_states=encoder_outputs[0],
past_key_values=past_key_values,
inputs_embeds=decoder_inputs_embeds,
Expand Down Expand Up @@ -1517,6 +1593,7 @@ def forward(

outputs = self.model(
input_values,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
encoder_outputs=encoder_outputs,
decoder_attention_mask=decoder_attention_mask,
Expand Down
Loading

0 comments on commit 693328f

Please sign in to comment.