Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support batching for UsefulSensors Moonshine #35922

Merged
merged 15 commits into from
Jan 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/transformers/models/moonshine/configuration_moonshine.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ class MoonshineConfig(PretrainedConfig):
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
`decoder_num_attention_heads`.
pad_head_dim_to_multiple_of (`int`, *optional*):
Pad head dimension in encoder and decoder to the next multiple of this value. Necessary for using certain
optimized attention implementations.
encoder_hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
The non-linear activation function (function or string) in the encoder.
decoder_hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
Expand Down Expand Up @@ -164,6 +167,7 @@ def __init__(
decoder_num_attention_heads=8,
encoder_num_key_value_heads=None,
decoder_num_key_value_heads=None,
pad_head_dim_to_multiple_of=None,
encoder_hidden_act="gelu",
decoder_hidden_act="silu",
max_position_embeddings=512,
Expand Down Expand Up @@ -196,6 +200,8 @@ def __init__(
decoder_num_key_value_heads = decoder_num_attention_heads
self.decoder_num_key_value_heads = decoder_num_key_value_heads

self.pad_head_dim_to_multiple_of = pad_head_dim_to_multiple_of

self.encoder_hidden_act = encoder_hidden_act
self.decoder_hidden_act = decoder_hidden_act
self.max_position_embeddings = max_position_embeddings
Expand Down
93 changes: 85 additions & 8 deletions src/transformers/models/moonshine/modeling_moonshine.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import math
from typing import Callable, Optional, Tuple, Union

import numpy as np
Expand All @@ -27,7 +28,11 @@
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_attn_mask_utils import (
AttentionMaskConverter,
_prepare_4d_attention_mask,
_prepare_4d_attention_mask_for_sdpa,
)
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import (
BaseModelOutput,
Expand Down Expand Up @@ -270,6 +275,23 @@ def forward(
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]

is_causal = True if self.is_causal and attention_mask is None and q_len > 1 else False

# Pad head size dimension to next specified multiple. Q K and V always have equal head sizes.
head_dim_padding = 0
if self.config.pad_head_dim_to_multiple_of is not None:
head_dim = query_states.shape[-1]
target_multiple = self.config.pad_head_dim_to_multiple_of
target_head_dim = target_multiple * ((head_dim + target_multiple - 1) // target_multiple)
head_dim_padding = target_head_dim - head_dim
if head_dim_padding > 0:
# Ensure scaling is correct even with padding.
if self.scaling is None:
self.scaling = 1.0 / math.sqrt(query_states.shape[-1])

query_states = torch.nn.functional.pad(query_states, (0, head_dim_padding))
key_states = torch.nn.functional.pad(key_states, (0, head_dim_padding))
value_states = torch.nn.functional.pad(value_states, (0, head_dim_padding))

attn_output, attn_weights = attention_interface(
self,
query_states,
Expand All @@ -282,6 +304,10 @@ def forward(
**kwargs,
)

# Remove head size padding.
if head_dim_padding > 0:
attn_output = attn_output[:, :, :, :-head_dim_padding]

attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights
Expand Down Expand Up @@ -603,9 +629,11 @@ def forward(
`numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into
`input_values`, the [`AutoFeatureExtractor`] should be used for padding
and conversion into a tensor of type `torch.FloatTensor`.
attention_mask (`torch.Tensor`)`, *optional*):
Moonshine does not support masking of the `input_values`, this argument is preserved for compatibility,
but it is not used.
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding indices in `input_values`. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
Expand All @@ -632,6 +660,22 @@ def forward(
hidden_states = nn.functional.gelu(self.conv3(hidden_states))
hidden_states = hidden_states.permute(0, 2, 1)

# attention mask downsampling
if attention_mask is not None:
mask_len = self._get_feat_extract_output_lengths(attention_mask.shape[-1])
downsample_stride = 64 * 3 * 2 # conv strides
attention_mask = attention_mask[..., ::downsample_stride][..., :mask_len]
if self.config._attn_implementation == "flash_attention_2":
attention_mask = attention_mask if (attention_mask == 0.0).any() else None

# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
elif self.config._attn_implementation == "sdpa" and not output_attentions:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, hidden_states.dtype)
else:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)

position_ids = torch.arange(0, hidden_states.shape[1], device=hidden_states.device).unsqueeze(0)

# create position embeddings to be shared across the decoder layers
Expand All @@ -649,7 +693,7 @@ def forward(
layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
None,
attention_mask,
position_ids,
None,
output_attentions,
Expand All @@ -660,6 +704,7 @@ def forward(
else:
layer_outputs = encoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
output_attentions=output_attentions,
position_embeddings=position_embeddings,
Expand Down Expand Up @@ -810,13 +855,19 @@ def forward(
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> Union[Tuple, BaseModelOutputWithPast]:
"""
Args:
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
of the decoder.
encoder_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding indices in `encoder_hidden_states`. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
Expand Down Expand Up @@ -865,6 +916,26 @@ def forward(
all_self_attns = () if output_attentions else None
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None

# attention mask downsampling
if encoder_attention_mask is not None:
mask_len = encoder_hidden_states.shape[-2]
downsample_stride = 64 * 3 * 2 # conv strides
encoder_attention_mask = encoder_attention_mask[..., ::downsample_stride][..., :mask_len]
if self.config._attn_implementation == "flash_attention_2":
encoder_attention_mask = encoder_attention_mask if (encoder_attention_mask == 0.0).any() else None

# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
elif self.config._attn_implementation == "sdpa" and not output_attentions:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
encoder_attention_mask, hidden_states.dtype, hidden_states.shape[-2]
)
else:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
encoder_attention_mask = _prepare_4d_attention_mask(
encoder_attention_mask, hidden_states.dtype, hidden_states.shape[-2]
)

for decoder_layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
Expand All @@ -886,6 +957,7 @@ def forward(
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
encoder_attention_mask=encoder_attention_mask,
encoder_hidden_states=encoder_hidden_states,
position_ids=position_ids,
past_key_value=past_key_values,
Expand Down Expand Up @@ -1168,9 +1240,11 @@ def compute_num_masked_span(input_length):
`numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into
`input_values`, the [`AutoFeatureExtractor`] should be used for padding
and conversion into a tensor of type `torch.FloatTensor`.
attention_mask (`torch.Tensor`)`, *optional*):
Moonshine does not support masking of the `input_values`, this argument is preserved for compatibility,
but it is not used.
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding indices in `input_values`. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
decoder_input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
Expand Down Expand Up @@ -1371,6 +1445,7 @@ def forward(
if encoder_outputs is None:
encoder_outputs = self.encoder(
input_values,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
Expand All @@ -1387,6 +1462,7 @@ def forward(
decoder_outputs = self.decoder(
input_ids=decoder_input_ids,
attention_mask=decoder_attention_mask,
encoder_attention_mask=attention_mask,
encoder_hidden_states=encoder_outputs[0],
past_key_values=past_key_values,
inputs_embeds=decoder_inputs_embeds,
Expand Down Expand Up @@ -1517,6 +1593,7 @@ def forward(

outputs = self.model(
input_values,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
encoder_outputs=encoder_outputs,
decoder_attention_mask=decoder_attention_mask,
Expand Down
Loading