From d48393c1c55b0f88c09ca03c29e95e32a0bbb3d1 Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Thu, 10 Nov 2022 13:56:47 -0800 Subject: [PATCH] restored flatten_parameters Signed-off-by: Boris Fomitchev --- nemo/collections/tts/modules/common.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/nemo/collections/tts/modules/common.py b/nemo/collections/tts/modules/common.py index 091623327ace..0613a5343dc0 100644 --- a/nemo/collections/tts/modules/common.py +++ b/nemo/collections/tts/modules/common.py @@ -123,15 +123,15 @@ def lstm_tensor(self, context: Tensor, lens: Tensor, enforce_sorted: bool = Fals seq = nn.utils.rnn.pack_padded_sequence( context, lens.long().cpu(), batch_first=True, enforce_sorted=enforce_sorted ) - # if not torch.jit.is_scripting(): - # self.bilstm.flatten_parameters() + if not torch.jit.is_scripting(): + self.bilstm.flatten_parameters() ret, _ = self.bilstm(seq) return nn.utils.rnn.pad_packed_sequence(ret, batch_first=True) @torch.jit.export def lstm_sequence(self, seq: PackedSequence) -> Tuple[Tensor, Tensor]: - # if not torch.jit.is_scripting(): - # self.bilstm.flatten_parameters() + if not torch.jit.is_scripting(): + self.bilstm.flatten_parameters() ret, _ = self.bilstm(seq) return nn.utils.rnn.pad_packed_sequence(ret, batch_first=True)