diff --git a/src/transformers/models/moonshine/configuration_moonshine.py b/src/transformers/models/moonshine/configuration_moonshine.py index cabbe9179ba8..0ea6f149e430 100644 --- a/src/transformers/models/moonshine/configuration_moonshine.py +++ b/src/transformers/models/moonshine/configuration_moonshine.py @@ -64,6 +64,9 @@ class MoonshineConfig(PretrainedConfig): by meanpooling all the original heads within that group. For more details checkout [this paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `decoder_num_attention_heads`. + pad_head_dim_to_multiple_of (`int`, *optional*): + Pad head dimension in encoder and decoder to the next multiple of this value. Necessary for using certain + optimized attention implementations. encoder_hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): The non-linear activation function (function or string) in the encoder. decoder_hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): @@ -164,6 +167,7 @@ def __init__( decoder_num_attention_heads=8, encoder_num_key_value_heads=None, decoder_num_key_value_heads=None, + pad_head_dim_to_multiple_of=None, encoder_hidden_act="gelu", decoder_hidden_act="silu", max_position_embeddings=512, @@ -196,6 +200,8 @@ def __init__( decoder_num_key_value_heads = decoder_num_attention_heads self.decoder_num_key_value_heads = decoder_num_key_value_heads + self.pad_head_dim_to_multiple_of = pad_head_dim_to_multiple_of + self.encoder_hidden_act = encoder_hidden_act self.decoder_hidden_act = decoder_hidden_act self.max_position_embeddings = max_position_embeddings diff --git a/src/transformers/models/moonshine/modeling_moonshine.py b/src/transformers/models/moonshine/modeling_moonshine.py index 03d2b8d00d05..96285262514b 100644 --- a/src/transformers/models/moonshine/modeling_moonshine.py +++ b/src/transformers/models/moonshine/modeling_moonshine.py @@ -18,6 +18,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import math from typing import Callable, Optional, Tuple, Union import numpy as np @@ -27,7 +28,11 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache from ...generation import GenerationMixin -from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_attn_mask_utils import ( + AttentionMaskConverter, + _prepare_4d_attention_mask, + _prepare_4d_attention_mask_for_sdpa, +) from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( BaseModelOutput, @@ -270,6 +275,23 @@ def forward( attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] is_causal = True if self.is_causal and attention_mask is None and q_len > 1 else False + + # Pad head size dimension to next specified multiple. Q K and V always have equal head sizes. + head_dim_padding = 0 + if self.config.pad_head_dim_to_multiple_of is not None: + head_dim = query_states.shape[-1] + target_multiple = self.config.pad_head_dim_to_multiple_of + target_head_dim = target_multiple * ((head_dim + target_multiple - 1) // target_multiple) + head_dim_padding = target_head_dim - head_dim + if head_dim_padding > 0: + # Ensure scaling is correct even with padding. + if self.scaling is None: + self.scaling = 1.0 / math.sqrt(query_states.shape[-1]) + + query_states = torch.nn.functional.pad(query_states, (0, head_dim_padding)) + key_states = torch.nn.functional.pad(key_states, (0, head_dim_padding)) + value_states = torch.nn.functional.pad(value_states, (0, head_dim_padding)) + attn_output, attn_weights = attention_interface( self, query_states, @@ -282,6 +304,10 @@ def forward( **kwargs, ) + # Remove head size padding. + if head_dim_padding > 0: + attn_output = attn_output[:, :, :, :-head_dim_padding] + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() attn_output = self.o_proj(attn_output) return attn_output, attn_weights @@ -603,9 +629,11 @@ def forward( `numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into `input_values`, the [`AutoFeatureExtractor`] should be used for padding and conversion into a tensor of type `torch.FloatTensor`. - attention_mask (`torch.Tensor`)`, *optional*): - Moonshine does not support masking of the `input_values`, this argument is preserved for compatibility, - but it is not used. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding indices in `input_values`. Mask values selected in `[0, 1]`: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + [What are attention masks?](../glossary#attention-mask) output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. @@ -632,6 +660,22 @@ def forward( hidden_states = nn.functional.gelu(self.conv3(hidden_states)) hidden_states = hidden_states.permute(0, 2, 1) + # attention mask downsampling + if attention_mask is not None: + mask_len = self._get_feat_extract_output_lengths(attention_mask.shape[-1]) + downsample_stride = 64 * 3 * 2 # conv strides + attention_mask = attention_mask[..., ::downsample_stride][..., :mask_len] + if self.config._attn_implementation == "flash_attention_2": + attention_mask = attention_mask if (attention_mask == 0.0).any() else None + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + elif self.config._attn_implementation == "sdpa" and not output_attentions: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, hidden_states.dtype) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype) + position_ids = torch.arange(0, hidden_states.shape[1], device=hidden_states.device).unsqueeze(0) # create position embeddings to be shared across the decoder layers @@ -649,7 +693,7 @@ def forward( layer_outputs = self._gradient_checkpointing_func( encoder_layer.__call__, hidden_states, - None, + attention_mask, position_ids, None, output_attentions, @@ -660,6 +704,7 @@ def forward( else: layer_outputs = encoder_layer( hidden_states, + attention_mask=attention_mask, position_ids=position_ids, output_attentions=output_attentions, position_embeddings=position_embeddings, @@ -810,6 +855,7 @@ def forward( return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, **flash_attn_kwargs: Unpack[FlashAttentionKwargs], ) -> Union[Tuple, BaseModelOutputWithPast]: """ @@ -817,6 +863,11 @@ def forward( encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*): Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. + encoder_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding indices in `encoder_hidden_states`. Mask values selected in `[0, 1]`: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + [What are attention masks?](../glossary#attention-mask) """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -865,6 +916,26 @@ def forward( all_self_attns = () if output_attentions else None all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None + # attention mask downsampling + if encoder_attention_mask is not None: + mask_len = encoder_hidden_states.shape[-2] + downsample_stride = 64 * 3 * 2 # conv strides + encoder_attention_mask = encoder_attention_mask[..., ::downsample_stride][..., :mask_len] + if self.config._attn_implementation == "flash_attention_2": + encoder_attention_mask = encoder_attention_mask if (encoder_attention_mask == 0.0).any() else None + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + elif self.config._attn_implementation == "sdpa" and not output_attentions: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( + encoder_attention_mask, hidden_states.dtype, hidden_states.shape[-2] + ) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, hidden_states.dtype, hidden_states.shape[-2] + ) + for decoder_layer in self.layers: if output_hidden_states: all_hidden_states += (hidden_states,) @@ -886,6 +957,7 @@ def forward( layer_outputs = decoder_layer( hidden_states, attention_mask=causal_mask, + encoder_attention_mask=encoder_attention_mask, encoder_hidden_states=encoder_hidden_states, position_ids=position_ids, past_key_value=past_key_values, @@ -1168,9 +1240,11 @@ def compute_num_masked_span(input_length): `numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into `input_values`, the [`AutoFeatureExtractor`] should be used for padding and conversion into a tensor of type `torch.FloatTensor`. - attention_mask (`torch.Tensor`)`, *optional*): - Moonshine does not support masking of the `input_values`, this argument is preserved for compatibility, - but it is not used. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding indices in `input_values`. Mask values selected in `[0, 1]`: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + [What are attention masks?](../glossary#attention-mask) decoder_input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide it. @@ -1371,6 +1445,7 @@ def forward( if encoder_outputs is None: encoder_outputs = self.encoder( input_values, + attention_mask=attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, @@ -1387,6 +1462,7 @@ def forward( decoder_outputs = self.decoder( input_ids=decoder_input_ids, attention_mask=decoder_attention_mask, + encoder_attention_mask=attention_mask, encoder_hidden_states=encoder_outputs[0], past_key_values=past_key_values, inputs_embeds=decoder_inputs_embeds, @@ -1517,6 +1593,7 @@ def forward( outputs = self.model( input_values, + attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, encoder_outputs=encoder_outputs, decoder_attention_mask=decoder_attention_mask, diff --git a/src/transformers/models/moonshine/modular_moonshine.py b/src/transformers/models/moonshine/modular_moonshine.py index 787395835232..a78b153725d5 100644 --- a/src/transformers/models/moonshine/modular_moonshine.py +++ b/src/transformers/models/moonshine/modular_moonshine.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import math from typing import Callable, Optional, Tuple, Union import torch @@ -21,6 +22,10 @@ from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...configuration_utils import PretrainedConfig from ...generation import GenerationMixin +from ...modeling_attn_mask_utils import ( + _prepare_4d_attention_mask, + _prepare_4d_attention_mask_for_sdpa, +) from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( BaseModelOutput, @@ -91,6 +96,9 @@ class MoonshineConfig(PretrainedConfig): by meanpooling all the original heads within that group. For more details checkout [this paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `decoder_num_attention_heads`. + pad_head_dim_to_multiple_of (`int`, *optional*): + Pad head dimension in encoder and decoder to the next multiple of this value. Necessary for using certain + optimized attention implementations. encoder_hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): The non-linear activation function (function or string) in the encoder. decoder_hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): @@ -191,6 +199,7 @@ def __init__( decoder_num_attention_heads=8, encoder_num_key_value_heads=None, decoder_num_key_value_heads=None, + pad_head_dim_to_multiple_of=None, encoder_hidden_act="gelu", decoder_hidden_act="silu", max_position_embeddings=512, @@ -223,6 +232,8 @@ def __init__( decoder_num_key_value_heads = decoder_num_attention_heads self.decoder_num_key_value_heads = decoder_num_key_value_heads + self.pad_head_dim_to_multiple_of = pad_head_dim_to_multiple_of + self.encoder_hidden_act = encoder_hidden_act self.decoder_hidden_act = decoder_hidden_act self.max_position_embeddings = max_position_embeddings @@ -360,6 +371,23 @@ def forward( attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] is_causal = True if self.is_causal and attention_mask is None and q_len > 1 else False + + # Pad head size dimension to next specified multiple. Q K and V always have equal head sizes. + head_dim_padding = 0 + if self.config.pad_head_dim_to_multiple_of is not None: + head_dim = query_states.shape[-1] + target_multiple = self.config.pad_head_dim_to_multiple_of + target_head_dim = target_multiple * ((head_dim + target_multiple - 1) // target_multiple) + head_dim_padding = target_head_dim - head_dim + if head_dim_padding > 0: + # Ensure scaling is correct even with padding. + if self.scaling is None: + self.scaling = 1.0 / math.sqrt(query_states.shape[-1]) + + query_states = torch.nn.functional.pad(query_states, (0, head_dim_padding)) + key_states = torch.nn.functional.pad(key_states, (0, head_dim_padding)) + value_states = torch.nn.functional.pad(value_states, (0, head_dim_padding)) + attn_output, attn_weights = attention_interface( self, query_states, @@ -372,6 +400,10 @@ def forward( **kwargs, ) + # Remove head size padding. + if head_dim_padding > 0: + attn_output = attn_output[:, :, :, :-head_dim_padding] + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() attn_output = self.o_proj(attn_output) return attn_output, attn_weights @@ -593,9 +625,11 @@ def forward( `numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into `input_values`, the [`AutoFeatureExtractor`] should be used for padding and conversion into a tensor of type `torch.FloatTensor`. - attention_mask (`torch.Tensor`)`, *optional*): - Moonshine does not support masking of the `input_values`, this argument is preserved for compatibility, - but it is not used. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding indices in `input_values`. Mask values selected in `[0, 1]`: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + [What are attention masks?](../glossary#attention-mask) output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. @@ -622,6 +656,22 @@ def forward( hidden_states = nn.functional.gelu(self.conv3(hidden_states)) hidden_states = hidden_states.permute(0, 2, 1) + # attention mask downsampling + if attention_mask is not None: + mask_len = self._get_feat_extract_output_lengths(attention_mask.shape[-1]) + downsample_stride = 64 * 3 * 2 # conv strides + attention_mask = attention_mask[..., ::downsample_stride][..., :mask_len] + if self.config._attn_implementation == "flash_attention_2": + attention_mask = attention_mask if (attention_mask == 0.0).any() else None + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + elif self.config._attn_implementation == "sdpa" and not output_attentions: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, hidden_states.dtype) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype) + position_ids = torch.arange(0, hidden_states.shape[1], device=hidden_states.device).unsqueeze(0) # create position embeddings to be shared across the decoder layers @@ -639,7 +689,7 @@ def forward( layer_outputs = self._gradient_checkpointing_func( encoder_layer.__call__, hidden_states, - None, + attention_mask, position_ids, None, output_attentions, @@ -650,6 +700,7 @@ def forward( else: layer_outputs = encoder_layer( hidden_states, + attention_mask=attention_mask, position_ids=position_ids, output_attentions=output_attentions, position_embeddings=position_embeddings, @@ -698,6 +749,7 @@ def forward( return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, **flash_attn_kwargs: Unpack[FlashAttentionKwargs], ) -> Union[Tuple, BaseModelOutputWithPast]: """ @@ -705,6 +757,11 @@ def forward( encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*): Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. + encoder_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding indices in `encoder_hidden_states`. Mask values selected in `[0, 1]`: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + [What are attention masks?](../glossary#attention-mask) """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -753,6 +810,26 @@ def forward( all_self_attns = () if output_attentions else None all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None + # attention mask downsampling + if encoder_attention_mask is not None: + mask_len = encoder_hidden_states.shape[-2] + downsample_stride = 64 * 3 * 2 # conv strides + encoder_attention_mask = encoder_attention_mask[..., ::downsample_stride][..., :mask_len] + if self.config._attn_implementation == "flash_attention_2": + encoder_attention_mask = encoder_attention_mask if (encoder_attention_mask == 0.0).any() else None + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + elif self.config._attn_implementation == "sdpa" and not output_attentions: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( + encoder_attention_mask, hidden_states.dtype, hidden_states.shape[-2] + ) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, hidden_states.dtype, hidden_states.shape[-2] + ) + for decoder_layer in self.layers: if output_hidden_states: all_hidden_states += (hidden_states,) @@ -774,6 +851,7 @@ def forward( layer_outputs = decoder_layer( hidden_states, attention_mask=causal_mask, + encoder_attention_mask=encoder_attention_mask, encoder_hidden_states=encoder_hidden_states, position_ids=position_ids, past_key_value=past_key_values, @@ -816,9 +894,11 @@ def forward( `numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into `input_values`, the [`AutoFeatureExtractor`] should be used for padding and conversion into a tensor of type `torch.FloatTensor`. - attention_mask (`torch.Tensor`)`, *optional*): - Moonshine does not support masking of the `input_values`, this argument is preserved for compatibility, - but it is not used. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding indices in `input_values`. Mask values selected in `[0, 1]`: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + [What are attention masks?](../glossary#attention-mask) decoder_input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide it. @@ -945,6 +1025,7 @@ def forward( if encoder_outputs is None: encoder_outputs = self.encoder( input_values, + attention_mask=attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, @@ -961,6 +1042,7 @@ def forward( decoder_outputs = self.decoder( input_ids=decoder_input_ids, attention_mask=decoder_attention_mask, + encoder_attention_mask=attention_mask, encoder_hidden_states=encoder_outputs[0], past_key_values=past_key_values, inputs_embeds=decoder_inputs_embeds, @@ -1075,6 +1157,7 @@ def forward( outputs = self.model( input_values, + attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, encoder_outputs=encoder_outputs, decoder_attention_mask=decoder_attention_mask,