Skip to content

Commit

Permalink
restored flatten_parameters
Browse files Browse the repository at this point in the history
Signed-off-by: Boris Fomitchev <bfomitchev@nvidia.com>
  • Loading branch information
borisfom committed Nov 10, 2022
1 parent da91270 commit d48393c
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions nemo/collections/tts/modules/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,15 +123,15 @@ def lstm_tensor(self, context: Tensor, lens: Tensor, enforce_sorted: bool = Fals
seq = nn.utils.rnn.pack_padded_sequence(
context, lens.long().cpu(), batch_first=True, enforce_sorted=enforce_sorted
)
# if not torch.jit.is_scripting():
# self.bilstm.flatten_parameters()
if not torch.jit.is_scripting():
self.bilstm.flatten_parameters()
ret, _ = self.bilstm(seq)
return nn.utils.rnn.pad_packed_sequence(ret, batch_first=True)

@torch.jit.export
def lstm_sequence(self, seq: PackedSequence) -> Tuple[Tensor, Tensor]:
# if not torch.jit.is_scripting():
# self.bilstm.flatten_parameters()
if not torch.jit.is_scripting():
self.bilstm.flatten_parameters()
ret, _ = self.bilstm(seq)
return nn.utils.rnn.pad_packed_sequence(ret, batch_first=True)

Expand Down

0 comments on commit d48393c

Please sign in to comment.