diff --git a/src/transformers/models/bloom/modeling_bloom.py b/src/transformers/models/bloom/modeling_bloom.py index 02f682d01118..8fabde20b927 100644 --- a/src/transformers/models/bloom/modeling_bloom.py +++ b/src/transformers/models/bloom/modeling_bloom.py @@ -298,7 +298,7 @@ def forward( else: present = None - # # [batch_size*num_heads, head_dim, q_length] x [batch_size*num_heads, head_dim, k_length] -> [batch_size*num_heads, q_length, k_length] + # # [batch_size*num_heads, head_dim, q_length] x [batch_size*num_heads, head_dim, kv_length] -> [batch_size*num_heads, q_length, kv_length] matmul_result = torch.baddbmm( input=alibi, batch1=query_layer, @@ -307,10 +307,10 @@ def forward( alpha=self.inv_norm_factor, ) - # change view to [batch_size, num_heads, q_length, k_length] + # change view to [batch_size, num_heads, q_length, kv_length] attention_scores = matmul_result.view(batch_size, self.num_heads, q_length, kv_length) - # we cast attention scores to fp32, compute scaled softmax and cast back into initial dtype - [batch_size, num_heads, q_length, k_length] input_dtype = attention_scores.dtype + # we cast attention scores to fp32, compute scaled softmax and cast back into initial dtype - [batch_size, num_heads, q_length, kv_length] input_dtype = attention_scores.dtype input_dtype = attention_scores.dtype attention_scores = attention_scores.float() attn_weights = torch.masked_fill( @@ -318,13 +318,13 @@ def forward( ) attention_probs = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(input_dtype) - # [batch_size, num_heads, q_length, k_length] + # [batch_size, num_heads, q_length, kv_length] attention_probs = self.attention_dropout(attention_probs) if head_mask is not None: attention_probs = attention_probs * head_mask - # change view [batch_size x num_heads, q_length, k_length] + # change view [batch_size x num_heads, q_length, kv_length] attention_probs_reshaped = attention_probs.view(batch_size * self.num_heads, q_length, kv_length) # matmul: [batch_size * num_heads, q_length, head_dim] @@ -890,8 +890,14 @@ def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> # value: layer_past[1] [batch_size * num_heads, seq_length, head_dim] return tuple( ( - layer_past[0].view(batch_size, num_heads, head_dim, seq_length).index_select(0, beam_idx.to(layer_past[0].device)).view(batch_size_times_num_heads, head_dim, seq_length), - layer_past[1].view(batch_size, num_heads, seq_length, head_dim).index_select(0, beam_idx.to(layer_past[1].device)).view(batch_size_times_num_heads, seq_length, head_dim) + layer_past[0] + .view(batch_size, num_heads, head_dim, seq_length) + .index_select(0, beam_idx.to(layer_past[0].device)) + .view(batch_size_times_num_heads, head_dim, seq_length), + layer_past[1] + .view(batch_size, num_heads, seq_length, head_dim) + .index_select(0, beam_idx.to(layer_past[1].device)) + .view(batch_size_times_num_heads, seq_length, head_dim), ) for layer_past in past )