Skip to content

Commit

Permalink
Radtts 1.13 (#5451)
Browse files Browse the repository at this point in the history
* [TTS] Fixing RADTTS training - removing view buffer and fixing accuracy issue (#5358)
* [TTS] add CI test for RADTTS training recipe.

Signed-off-by: Boris Fomitchev <bfomitchev@nvidia.com>
Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>
Co-authored-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>
Co-authored-by: Oleksii Kuchaiev <okuchaiev@users.noreply.github.com>
  • Loading branch information
3 people authored and ericharper committed Dec 7, 2022
1 parent 0c9a919 commit 3a9616f
Show file tree
Hide file tree
Showing 7 changed files with 56 additions and 49 deletions.
45 changes: 31 additions & 14 deletions nemo/collections/tts/modules/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,22 +122,30 @@ def lstm_tensor(self, context: Tensor, lens: Tensor, enforce_sorted: bool = Fals
seq = nn.utils.rnn.pack_padded_sequence(
context, lens.long().cpu(), batch_first=True, enforce_sorted=enforce_sorted
)
return self.lstm_sequence(seq)
if not (torch.jit.is_scripting() or torch.jit.is_tracing()):
self.bilstm.flatten_parameters()
if hasattr(self.bilstm, 'forward'):
ret, _ = self.bilstm.forward(seq)
else:
ret, _ = self.bilstm.forward_1(seq)
return nn.utils.rnn.pad_packed_sequence(ret, batch_first=True)

def lstm_sequence(self, seq: PackedSequence) -> Tuple[Tensor, Tensor]:
if not (torch.jit.is_scripting() or torch.jit.is_tracing()):
self.bilstm.flatten_parameters()
ret, _ = self.bilstm(seq)
if hasattr(self.bilstm, 'forward'):
ret, _ = self.bilstm.forward(seq)
elif hasattr(self.bilstm, 'forward_1'):
ret, _ = self.bilstm.forward_1(seq)
return nn.utils.rnn.pad_packed_sequence(ret, batch_first=True)

def forward(self, context: Tensor, lens: Tensor) -> Tensor:
@torch.jit.export
def sort_and_lstm_tensor(self, context: Tensor, lens: Tensor) -> Tensor:
context, lens_sorted, unsort_ids = sort_tensor(context, lens)
dtype = context.dtype
# this is only needed for Torchscript to run in Triton
# (https://github.com/pytorch/pytorch/issues/89241)
with torch.cuda.amp.autocast(enabled=False):
ret = self.lstm_tensor(context.to(dtype=torch.float32), lens_sorted, enforce_sorted=True)
return ret[0].to(dtype=dtype)[unsort_ids]
seq = nn.utils.rnn.pack_padded_sequence(
context, lens_sorted.long().cpu(), batch_first=True, enforce_sorted=True
)
return self.lstm_sequence(seq)[0][unsort_ids]


class ConvLSTMLinear(nn.Module):
Expand All @@ -152,8 +160,7 @@ def __init__(
use_partial_padding=False,
norm_fn=None,
):
super(ConvLSTMLinear, self).__init__()
self.bilstm = BiLSTM(n_channels, int(n_channels // 2), 1)
super(ConvLSTMLinear, self).__init__(n_channels, int(n_channels // 2), 1)
self.convolutions = nn.ModuleList()

if n_layers > 0:
Expand Down Expand Up @@ -184,14 +191,24 @@ def __init__(
if out_dim is not None:
self.dense = nn.Linear(n_channels, out_dim)

def forward(self, context: Tensor, lens: Tensor) -> Tensor:
def masked_conv_to_sequence(self, context: Tensor, lens: Tensor, enforce_sorted: bool = False) -> PackedSequence:
mask = get_mask_from_lengths_and_val(lens, context)
mask = mask.to(dtype=context.dtype).unsqueeze(1)
for conv in self.convolutions:
context = self.dropout(F.relu(conv(context, mask)))

context = context.transpose(1, 2)
# Apply Bidirectional LSTM
context = self.bilstm(context, lens)
seq = torch.nn.utils.rnn.pack_padded_sequence(
context, lens.long().cpu(), batch_first=True, enforce_sorted=enforce_sorted
)
return seq

def forward(self, context: Tensor, lens: Tensor) -> Tensor:
context, lens, unsort_ids = sort_tensor(context, lens)
seq = self.masked_conv_to_sequence(context, lens, enforce_sorted=True)
context, _ = self.lstm_sequence(seq)
context = context[unsort_ids]

if self.dense is not None:
context = self.dense(context).permute(0, 2, 1)
return context
Expand Down
8 changes: 5 additions & 3 deletions nemo/collections/tts/modules/radtts.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,9 @@ def preprocess_context(self, context, speaker_vecs, out_lens, f0, energy_avg):
context_w_spkvec = torch.cat((context_w_spkvec, energy_avg), 1)

unfolded_out_lens = out_lens // self.n_group_size
context_lstm_padded_output = self.context_lstm(context_w_spkvec.transpose(1, 2), unfolded_out_lens)
context_lstm_padded_output = self.context_lstm.sort_and_lstm_tensor(
context_w_spkvec.transpose(1, 2), unfolded_out_lens
)
context_w_spkvec = context_lstm_padded_output.transpose(1, 2)

if not self.context_lstm_w_f0_and_energy:
Expand Down Expand Up @@ -770,8 +772,8 @@ def input_example(self, max_batch=1, max_dim=256):
"""
par = next(self.parameters())
sz = (max_batch, max_dim)
inp = torch.randint(16, 32, sz, device=par.device, dtype=torch.int64)
lens = torch.randint(max_dim // 4, max_dim // 2, (max_batch,), device=par.device, dtype=torch.int)
inp = torch.randint(0, 16, sz, device=par.device, dtype=torch.int64)
lens = torch.randint(16, max_dim, (max_batch,), device=par.device, dtype=torch.int)
speaker = torch.randint(0, 1, (max_batch,), device=par.device, dtype=torch.int64)
inputs = {
'text': inp,
Expand Down
2 changes: 1 addition & 1 deletion nemo/core/classes/exportable.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def _export(
# Set module mode
with torch.onnx.select_model_mode_for_export(
self, training
), torch.inference_mode(), torch.no_grad(), torch.jit.optimized_execution(True), _jit_is_scripting():
), torch.inference_mode(), torch.no_grad(), torch.jit.optimized_execution(True):

if input_example is None:
input_example = self.input_module.input_example()
Expand Down
2 changes: 1 addition & 1 deletion nemo/utils/cast_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,6 @@ def __init__(self, mod):
self.mod = mod

def forward(self, x):
with avoid_float16_autocast_context():
with torch.cuda.amp.autocast(enabled=False):
ret = self.mod.forward(x.to(torch.float32)).to(x.dtype)
return ret
40 changes: 15 additions & 25 deletions nemo/utils/export_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import os
from contextlib import nullcontext
from enum import Enum
from typing import Callable, Dict, Optional, Type
from typing import Callable, Dict, List, Optional, Type

import onnx
import torch
Expand Down Expand Up @@ -158,12 +158,8 @@ def verify_torchscript(model, output, input_examples, input_names, check_toleran
for input_example in input_examples:
input_list, input_dict = parse_input_example(input_example)
output_example = model.forward(*input_list, **input_dict)
# We disable autocast here to make sure exported TS will run under Triton or other C++ env
with torch.cuda.amp.autocast(enabled=False):
ts_model = torch.jit.load(output)
all_good = all_good and run_ts_and_compare(
ts_model, input_list, input_dict, output_example, check_tolerance
)

all_good = all_good and run_ts_and_compare(ts_model, input_list, input_dict, output_example, check_tolerance)
status = "SUCCESS" if all_good else "FAIL"
logging.info(f"Torchscript generated at {output} verified with torchscript forward : " + status)
return all_good
Expand Down Expand Up @@ -205,15 +201,8 @@ def run_ts_and_compare(ts_model, ts_input_list, ts_input_dict, output_example, c

if torch.is_tensor(expected):
tout = out.to('cpu')
logging.debug(f"Checking output {i}, shape: {expected.shape}:\n")
this_good = True
try:
if not torch.allclose(tout, expected.cpu(), rtol=check_tolerance, atol=check_tolerance):
this_good = False
except Exception: # there may ne size mismatch and it may be OK
this_good = False
if not this_good:
logging.info(f"Results mismatch! PyTorch(expected):\n{expected}\nTorchScript:\n{tout}")
logging.debug(f"Checking output {i}, shape: {expected.shape}:\n{expected}\n{tout}")
if not torch.allclose(tout, expected.cpu(), rtol=check_tolerance, atol=check_tolerance):
all_good = False
return all_good

Expand All @@ -227,14 +216,9 @@ def run_ort_and_compare(sess, ort_input, output_example, check_tolerance=0.01):

if torch.is_tensor(expected):
tout = torch.from_numpy(out)
logging.debug(f"Checking output {i}, shape: {expected.shape}:\n")
this_good = True
try:
if not torch.allclose(tout, expected.cpu(), rtol=check_tolerance, atol=100 * check_tolerance):
this_good = False
except Exception: # there may ne size mismatch and it may be OK
this_good = False
if not this_good:
logging.debug(f"Checking output {i}, shape: {expected.shape}:\n{expected}\n{tout}")
if not torch.allclose(tout, expected.cpu(), rtol=check_tolerance, atol=100 * check_tolerance):
all_good = False
logging.info(f"onnxruntime results mismatch! PyTorch(expected):\n{expected}\nONNXruntime:\n{tout}")
all_good = False
return all_good
Expand Down Expand Up @@ -433,7 +417,8 @@ def replace_modules(


def script_module(m: nn.Module):
return torch.jit.script(m)
m1 = torch.jit.script(m)
return m1


default_replacements = {
Expand All @@ -443,6 +428,11 @@ def script_module(m: nn.Module):
"MatchedScaleMaskSoftmax": wrap_module(nn.Softmax, ExportableMatchedScaleMaskSoftmax),
}

script_replacements = {
"BiLSTM": script_module,
"ConvLSTMLinear": script_module,
}


def replace_for_export(model: nn.Module) -> nn.Module:
"""
Expand Down
3 changes: 2 additions & 1 deletion scripts/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,10 +143,11 @@ def nemo_export(argv):
if check_trace and len(in_args) > 0:
input_example = model.input_module.input_example(**in_args)
check_trace = [input_example]
for key, arg in in_args:
for key, arg in in_args.items():
in_args[key] = (arg + 1) // 2
input_example2 = model.input_module.input_example(**in_args)
check_trace.append(input_example2)
logging.info(f"Using additional check args: {in_args}")

_, descriptions = model.export(
out,
Expand Down
5 changes: 1 addition & 4 deletions tests/collections/tts/test_tts_exportables.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import tempfile

import pytest
import torch
from omegaconf import OmegaConf

from nemo.collections.tts.models import FastPitchModel, HifiGanModel, RadTTSModel
Expand Down Expand Up @@ -74,12 +73,10 @@ def test_HifiGanModel_export_to_onnx(self, hifigan_model):
filename = os.path.join(tmpdir, 'hfg.pt')
model.export(output=filename, verbose=True, check_trace=True)

@pytest.mark.pleasefixme
@pytest.mark.run_only_on('GPU')
@pytest.mark.unit
def test_RadTTSModel_export_to_torchscript(self, radtts_model):
model = radtts_model.cuda()
with tempfile.TemporaryDirectory() as tmpdir:
filename = os.path.join(tmpdir, 'rad.ts')
with torch.cuda.amp.autocast(enabled=True):
model.export(output=filename, verbose=True, check_trace=True)
model.export(output=filename, verbose=True, check_trace=True)

0 comments on commit 3a9616f

Please sign in to comment.