Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fast Conformer global token fix #7085

Merged
merged 16 commits into from
Jul 21, 2023
66 changes: 42 additions & 24 deletions nemo/collections/asr/parts/submodules/multi_head_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,12 +377,6 @@ def forward(self, query, key, value, pad_mask, pos_emb, cache=None):

scores += d_mask

attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0)
attn = self.dropout(attn)
# (batch, head, time, 2w + 1)

out = self.sliding_chunks_matmul_pv(attn, v, w).reshape(n_batch, -1, self.h * self.d_k)

if self.global_tokens > 0:

# create q, k, v for global attn
Expand Down Expand Up @@ -426,21 +420,34 @@ def forward(self, query, key, value, pad_mask, pos_emb, cache=None):
is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero,
).transpose(1, 2)

global_key_attn = torch.softmax(global_key_attn, dim=-1).masked_fill(mask, 0.0)
global_key_attn = self.dropout(global_key_attn)
# concat to local_attn_probs
# (batch, time, head, max_num_global_attn_indices + 2*w)
scores = torch.cat((global_key_attn, scores), dim=-1)

# compute outputs for global attention from all tokens to global
# (batch, time, head x head_dim)
out_all_to_global = self._compute_out_all_to_global(
value=global_v,
attn_probs=global_key_attn,
# free memory
del global_key_attn

attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0)
p_attn = self.dropout(attn)
# (batch, head, time, 2w + 1)

if self.global_tokens > 0:
# compute sum of global and local attn
out = self._compute_attn_output_with_global_indices(
value=v,
attn_probs=p_attn,
max_num_global_attn_indices=max_num_global_attn_indices,
is_index_global_attn_nonzero=is_index_global_attn_nonzero,
is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero,
w=w,
)
else:
# compute local attn only
out = self.sliding_chunks_matmul_pv(p_attn, v, w)

out = out.reshape(n_batch, -1, self.h * self.d_k)[:, :T]

# compute outputs for global attention from global tokens to all
# (batch, max_num_global_attn_indices, head x head_dim)
if self.global_tokens > 0:
out_global_to_all = self._compute_out_global_to_all(
query=global_q,
key=global_k,
Expand All @@ -452,11 +459,11 @@ def forward(self, query, key, value, pad_mask, pos_emb, cache=None):
is_index_masked=mask,
)

out += out_all_to_global
# overwrite values with global attention
out[is_index_global_attn_nonzero] = out_global_to_all

out[is_index_global_attn_nonzero] += out_global_to_all
ret = self.linear_out(out)

ret = self.linear_out(out.reshape(n_batch, -1, self.h * self.d_k)[:, :T])
if cache is None:
return ret
else:
Expand Down Expand Up @@ -544,24 +551,25 @@ def _compute_global_key_attn(

return attn_probs_from_global_key

def _compute_out_all_to_global(
def _compute_attn_output_with_global_indices(
self,
value: torch.Tensor,
attn_probs: torch.Tensor,
max_num_global_attn_indices: int,
is_index_global_attn_nonzero: tuple,
is_local_index_global_attn_nonzero: tuple,
w: int,
) -> torch.Tensor:
"""
Compute the attention output of all tokens attending to global.
Compute the attention output with global indices.

Args:
value (torch.Tensor): (batch, head, time, head_dim) The value vectors for global attention.
attn_probs (torch.Tensor): (batch, time, head, 2w) The attention probabilities.
max_num_global_attn_indices (int): Maximum number of global attention indices in the batch.
is_index_global_attn_nonzero (tuple): Indices of global attention (non-zero elements).
is_local_index_global_attn_nonzero (tuple): Non-padding values within global attention indices.

w (int): Local context size
Returns:
torch.Tensor: (batch, time, head x head_dim) The attention output of all tokens attending to global.
"""
Expand All @@ -573,12 +581,22 @@ def _compute_out_all_to_global(
value_vectors_only_global = value.new_zeros(batch_size, max_num_global_attn_indices, self.h, self.d_k)
value_vectors_only_global[is_local_index_global_attn_nonzero] = value[is_index_global_attn_nonzero]

# cut local attn probs to global only
attn_probs_only_global = attn_probs.narrow(-1, 0, max_num_global_attn_indices)
# compute attn output only global
out_all_to_global = torch.matmul(attn_probs, value_vectors_only_global.transpose(1, 2)).transpose(1, 2)
attn_output_only_global = torch.matmul(
attn_probs_only_global.clone(), value_vectors_only_global.transpose(1, 2).clone()
).transpose(1, 2)

# reshape attn probs
attn_probs_without_global = attn_probs.narrow(
-1, max_num_global_attn_indices, attn_probs.size(-1) - max_num_global_attn_indices
).contiguous()

out_all_to_global = out_all_to_global.reshape(batch_size, time, -1)
# compute attn output with global
attn_output_without_global = self.sliding_chunks_matmul_pv(attn_probs_without_global, value.transpose(1, 2), w)

return out_all_to_global
return attn_output_only_global + attn_output_without_global

def _compute_out_global_to_all(
self,
Expand Down