Skip to content

Commit

Permalink
Fix errors and typos.
Browse files Browse the repository at this point in the history
Signed-off-by: Daniel Galvez <dgalvez@nvidia.com>
  • Loading branch information
galv committed May 29, 2024
1 parent bf017ac commit 0f5385d
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -657,7 +657,7 @@ def _compute_out_global_to_all(
global_attn_scores = global_attn_scores.view(batch_size * self.h, max_num_global_attn_indices, seq_len)

# compute global attn probs
global_attn_probs_float = nn.functional.softmax(global_attn_scores, dim=-1, dtype=torch.float32)
global_attn_probs_float = nn.functional.softmax(global_attn_scores, dim=-1)

global_attn_probs = self.dropout(global_attn_probs_float)

Expand Down
2 changes: 0 additions & 2 deletions nemo/collections/asr/parts/utils/transcribe_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,8 @@
import re
from dataclasses import dataclass
from pathlib import Path
from tempfile import NamedTemporaryFile
from typing import List, Optional, Tuple, Union

import soundfile as sf
import torch
from omegaconf import DictConfig
from tqdm.auto import tqdm
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def test_relmha_adapter_init(self, n_head, proj_dim):
relpos_enc = adapter_modules.RelPositionalEncodingAdapter(d_model=50)

pad_mask, att_mask = get_mask(lengths)
relpos_enc.extend_pe(lengths.max(), device='cpu')
relpos_enc.extend_pe(lengths.max(), device='cpu', dtype=torch.float32)

with torch.no_grad():
assert adapter.linear_out.weight.sum() == 0
Expand All @@ -171,7 +171,7 @@ def test_abspos_encoding_init(self):

relpos_enc = adapter_modules.PositionalEncodingAdapter(d_model=50)

relpos_enc.extend_pe(lengths.max(), device='cpu')
relpos_enc.extend_pe(lengths.max(), device='cpu', dtype=torch.float32)

with torch.no_grad():
out, pos_emb = relpos_enc(x)
Expand All @@ -187,7 +187,7 @@ def test_relpos_encoding_init(self):

relpos_enc = adapter_modules.RelPositionalEncodingAdapter(d_model=50)

relpos_enc.extend_pe(lengths.max(), device='cpu')
relpos_enc.extend_pe(lengths.max(), device='cpu', dtype=torch.float32)

with torch.no_grad():
out, pos_emb = relpos_enc(x)
Expand Down

0 comments on commit 0f5385d

Please sign in to comment.