diff --git a/nemo/collections/asr/modules/wav2vec_modules.py b/nemo/collections/asr/modules/wav2vec_modules.py index e82d5a665a92..d1f5b090d4e1 100644 --- a/nemo/collections/asr/modules/wav2vec_modules.py +++ b/nemo/collections/asr/modules/wav2vec_modules.py @@ -32,8 +32,6 @@ from nemo.core.classes.module import NeuralModule from nemo.core.neural_types import AcousticEncodedRepresentation, AudioSignal, LengthsType, NeuralType, SpectrogramType -DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") - class TransposeLast(torch.nn.Module): """ @@ -341,7 +339,7 @@ def apply_transformer(self, x, padding_mask=None): def create_padding_mask(self, length): # Broadcast to vectorize creating the padding mask max_len = max(length) - padding_mask = torch.arange(max_len, device=DEVICE) + padding_mask = torch.arange(max_len, device=length.device) # Switch to binary for transformer, 1 for valid tokens, 0 for padding padding_mask = (padding_mask.expand(len(length), max_len) < length.unsqueeze(1)).type(torch.uint8)