diff --git a/nemo/collections/asr/parts/submodules/multi_head_attention.py b/nemo/collections/asr/parts/submodules/multi_head_attention.py index e35974fe3e06f..bbcae67524ef9 100644 --- a/nemo/collections/asr/parts/submodules/multi_head_attention.py +++ b/nemo/collections/asr/parts/submodules/multi_head_attention.py @@ -657,7 +657,7 @@ def _compute_out_global_to_all( global_attn_scores = global_attn_scores.view(batch_size * self.h, max_num_global_attn_indices, seq_len) # compute global attn probs - global_attn_probs_float = nn.functional.softmax(global_attn_scores, dim=-1, dtype=torch.float32) + global_attn_probs_float = nn.functional.softmax(global_attn_scores, dim=-1) global_attn_probs = self.dropout(global_attn_probs_float) diff --git a/nemo/collections/asr/parts/utils/transcribe_utils.py b/nemo/collections/asr/parts/utils/transcribe_utils.py index da89077cb7313..c270e5c3a0f7b 100644 --- a/nemo/collections/asr/parts/utils/transcribe_utils.py +++ b/nemo/collections/asr/parts/utils/transcribe_utils.py @@ -17,10 +17,8 @@ import re from dataclasses import dataclass from pathlib import Path -from tempfile import NamedTemporaryFile from typing import List, Optional, Tuple, Union -import soundfile as sf import torch from omegaconf import DictConfig from tqdm.auto import tqdm diff --git a/tests/collections/asr/mixins/adapters/test_asr_adapter_modules.py b/tests/collections/asr/mixins/adapters/test_asr_adapter_modules.py index 2637e33ebd2aa..c4ee4b97a2a6a 100644 --- a/tests/collections/asr/mixins/adapters/test_asr_adapter_modules.py +++ b/tests/collections/asr/mixins/adapters/test_asr_adapter_modules.py @@ -150,7 +150,7 @@ def test_relmha_adapter_init(self, n_head, proj_dim): relpos_enc = adapter_modules.RelPositionalEncodingAdapter(d_model=50) pad_mask, att_mask = get_mask(lengths) - relpos_enc.extend_pe(lengths.max(), device='cpu') + relpos_enc.extend_pe(lengths.max(), device='cpu', dtype=torch.float32) with torch.no_grad(): assert adapter.linear_out.weight.sum() == 0 @@ -171,7 +171,7 @@ def test_abspos_encoding_init(self): relpos_enc = adapter_modules.PositionalEncodingAdapter(d_model=50) - relpos_enc.extend_pe(lengths.max(), device='cpu') + relpos_enc.extend_pe(lengths.max(), device='cpu', dtype=torch.float32) with torch.no_grad(): out, pos_emb = relpos_enc(x) @@ -187,7 +187,7 @@ def test_relpos_encoding_init(self): relpos_enc = adapter_modules.RelPositionalEncodingAdapter(d_model=50) - relpos_enc.extend_pe(lengths.max(), device='cpu') + relpos_enc.extend_pe(lengths.max(), device='cpu', dtype=torch.float32) with torch.no_grad(): out, pos_emb = relpos_enc(x)