Skip to content

Commit

Permalink
Nit
Browse files Browse the repository at this point in the history
  • Loading branch information
thomasw21 committed Jul 28, 2022
1 parent 2677a28 commit ddbe33e
Showing 1 changed file with 13 additions and 7 deletions.
20 changes: 13 additions & 7 deletions src/transformers/models/bloom/modeling_bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ def forward(
else:
present = None

# # [batch_size*num_heads, head_dim, q_length] x [batch_size*num_heads, head_dim, k_length] -> [batch_size*num_heads, q_length, k_length]
# # [batch_size*num_heads, head_dim, q_length] x [batch_size*num_heads, head_dim, kv_length] -> [batch_size*num_heads, q_length, kv_length]
matmul_result = torch.baddbmm(
input=alibi,
batch1=query_layer,
Expand All @@ -307,24 +307,24 @@ def forward(
alpha=self.inv_norm_factor,
)

# change view to [batch_size, num_heads, q_length, k_length]
# change view to [batch_size, num_heads, q_length, kv_length]
attention_scores = matmul_result.view(batch_size, self.num_heads, q_length, kv_length)

# we cast attention scores to fp32, compute scaled softmax and cast back into initial dtype - [batch_size, num_heads, q_length, k_length] input_dtype = attention_scores.dtype
# we cast attention scores to fp32, compute scaled softmax and cast back into initial dtype - [batch_size, num_heads, q_length, kv_length] input_dtype = attention_scores.dtype
input_dtype = attention_scores.dtype
attention_scores = attention_scores.float()
attn_weights = torch.masked_fill(
attention_scores * self.layer_number, attention_mask, torch.finfo(torch.float32).min
)
attention_probs = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(input_dtype)

# [batch_size, num_heads, q_length, k_length]
# [batch_size, num_heads, q_length, kv_length]
attention_probs = self.attention_dropout(attention_probs)

if head_mask is not None:
attention_probs = attention_probs * head_mask

# change view [batch_size x num_heads, q_length, k_length]
# change view [batch_size x num_heads, q_length, kv_length]
attention_probs_reshaped = attention_probs.view(batch_size * self.num_heads, q_length, kv_length)

# matmul: [batch_size * num_heads, q_length, head_dim]
Expand Down Expand Up @@ -890,8 +890,14 @@ def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) ->
# value: layer_past[1] [batch_size * num_heads, seq_length, head_dim]
return tuple(
(
layer_past[0].view(batch_size, num_heads, head_dim, seq_length).index_select(0, beam_idx.to(layer_past[0].device)).view(batch_size_times_num_heads, head_dim, seq_length),
layer_past[1].view(batch_size, num_heads, seq_length, head_dim).index_select(0, beam_idx.to(layer_past[1].device)).view(batch_size_times_num_heads, seq_length, head_dim)
layer_past[0]
.view(batch_size, num_heads, head_dim, seq_length)
.index_select(0, beam_idx.to(layer_past[0].device))
.view(batch_size_times_num_heads, head_dim, seq_length),
layer_past[1]
.view(batch_size, num_heads, seq_length, head_dim)
.index_select(0, beam_idx.to(layer_past[1].device))
.view(batch_size_times_num_heads, seq_length, head_dim),
)
for layer_past in past
)
Expand Down

0 comments on commit ddbe33e

Please sign in to comment.