diff --git a/nemo/collections/asr/parts/submodules/multi_head_attention.py b/nemo/collections/asr/parts/submodules/multi_head_attention.py index a0253524419e..6a866a617f35 100644 --- a/nemo/collections/asr/parts/submodules/multi_head_attention.py +++ b/nemo/collections/asr/parts/submodules/multi_head_attention.py @@ -377,12 +377,6 @@ def forward(self, query, key, value, pad_mask, pos_emb, cache=None): scores += d_mask - attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0) - attn = self.dropout(attn) - # (batch, head, time, 2w + 1) - - out = self.sliding_chunks_matmul_pv(attn, v, w).reshape(n_batch, -1, self.h * self.d_k) - if self.global_tokens > 0: # create q, k, v for global attn @@ -426,21 +420,34 @@ def forward(self, query, key, value, pad_mask, pos_emb, cache=None): is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero, ).transpose(1, 2) - global_key_attn = torch.softmax(global_key_attn, dim=-1).masked_fill(mask, 0.0) - global_key_attn = self.dropout(global_key_attn) + # concat to local_attn_probs + # (batch, time, head, max_num_global_attn_indices + 2*w) + scores = torch.cat((global_key_attn, scores), dim=-1) - # compute outputs for global attention from all tokens to global - # (batch, time, head x head_dim) - out_all_to_global = self._compute_out_all_to_global( - value=global_v, - attn_probs=global_key_attn, + # free memory + del global_key_attn + + attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0) + p_attn = self.dropout(attn) + # (batch, head, time, 2w + 1) + + if self.global_tokens > 0: + # compute sum of global and local attn + out = self._compute_attn_output_with_global_indices( + value=v, + attn_probs=p_attn, max_num_global_attn_indices=max_num_global_attn_indices, is_index_global_attn_nonzero=is_index_global_attn_nonzero, is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero, + w=w, ) + else: + # compute local attn only + out = self.sliding_chunks_matmul_pv(p_attn, v, w) + + out = out.reshape(n_batch, -1, self.h * self.d_k)[:, :T] - # compute outputs for global attention from global tokens to all - # (batch, max_num_global_attn_indices, head x head_dim) + if self.global_tokens > 0: out_global_to_all = self._compute_out_global_to_all( query=global_q, key=global_k, @@ -452,11 +459,11 @@ def forward(self, query, key, value, pad_mask, pos_emb, cache=None): is_index_masked=mask, ) - out += out_all_to_global + # overwrite values with global attention + out[is_index_global_attn_nonzero] = out_global_to_all - out[is_index_global_attn_nonzero] += out_global_to_all + ret = self.linear_out(out) - ret = self.linear_out(out.reshape(n_batch, -1, self.h * self.d_k)[:, :T]) if cache is None: return ret else: @@ -544,16 +551,17 @@ def _compute_global_key_attn( return attn_probs_from_global_key - def _compute_out_all_to_global( + def _compute_attn_output_with_global_indices( self, value: torch.Tensor, attn_probs: torch.Tensor, max_num_global_attn_indices: int, is_index_global_attn_nonzero: tuple, is_local_index_global_attn_nonzero: tuple, + w: int, ) -> torch.Tensor: """ - Compute the attention output of all tokens attending to global. + Compute the attention output with global indices. Args: value (torch.Tensor): (batch, head, time, head_dim) The value vectors for global attention. @@ -561,7 +569,7 @@ def _compute_out_all_to_global( max_num_global_attn_indices (int): Maximum number of global attention indices in the batch. is_index_global_attn_nonzero (tuple): Indices of global attention (non-zero elements). is_local_index_global_attn_nonzero (tuple): Non-padding values within global attention indices. - + w (int): Local context size Returns: torch.Tensor: (batch, time, head x head_dim) The attention output of all tokens attending to global. """ @@ -573,12 +581,22 @@ def _compute_out_all_to_global( value_vectors_only_global = value.new_zeros(batch_size, max_num_global_attn_indices, self.h, self.d_k) value_vectors_only_global[is_local_index_global_attn_nonzero] = value[is_index_global_attn_nonzero] + # cut local attn probs to global only + attn_probs_only_global = attn_probs.narrow(-1, 0, max_num_global_attn_indices) # compute attn output only global - out_all_to_global = torch.matmul(attn_probs, value_vectors_only_global.transpose(1, 2)).transpose(1, 2) + attn_output_only_global = torch.matmul( + attn_probs_only_global.clone(), value_vectors_only_global.transpose(1, 2).clone() + ).transpose(1, 2) + + # reshape attn probs + attn_probs_without_global = attn_probs.narrow( + -1, max_num_global_attn_indices, attn_probs.size(-1) - max_num_global_attn_indices + ).contiguous() - out_all_to_global = out_all_to_global.reshape(batch_size, time, -1) + # compute attn output with global + attn_output_without_global = self.sliding_chunks_matmul_pv(attn_probs_without_global, value.transpose(1, 2), w) - return out_all_to_global + return attn_output_only_global + attn_output_without_global def _compute_out_global_to_all( self,