diff --git a/nemo/collections/tts/modules/common.py b/nemo/collections/tts/modules/common.py index 63c28f12a4a7..0765d0499bda 100644 --- a/nemo/collections/tts/modules/common.py +++ b/nemo/collections/tts/modules/common.py @@ -122,22 +122,30 @@ def lstm_tensor(self, context: Tensor, lens: Tensor, enforce_sorted: bool = Fals seq = nn.utils.rnn.pack_padded_sequence( context, lens.long().cpu(), batch_first=True, enforce_sorted=enforce_sorted ) - return self.lstm_sequence(seq) + if not (torch.jit.is_scripting() or torch.jit.is_tracing()): + self.bilstm.flatten_parameters() + if hasattr(self.bilstm, 'forward'): + ret, _ = self.bilstm.forward(seq) + else: + ret, _ = self.bilstm.forward_1(seq) + return nn.utils.rnn.pad_packed_sequence(ret, batch_first=True) def lstm_sequence(self, seq: PackedSequence) -> Tuple[Tensor, Tensor]: if not (torch.jit.is_scripting() or torch.jit.is_tracing()): self.bilstm.flatten_parameters() - ret, _ = self.bilstm(seq) + if hasattr(self.bilstm, 'forward'): + ret, _ = self.bilstm.forward(seq) + elif hasattr(self.bilstm, 'forward_1'): + ret, _ = self.bilstm.forward_1(seq) return nn.utils.rnn.pad_packed_sequence(ret, batch_first=True) - def forward(self, context: Tensor, lens: Tensor) -> Tensor: + @torch.jit.export + def sort_and_lstm_tensor(self, context: Tensor, lens: Tensor) -> Tensor: context, lens_sorted, unsort_ids = sort_tensor(context, lens) - dtype = context.dtype - # this is only needed for Torchscript to run in Triton - # (https://github.com/pytorch/pytorch/issues/89241) - with torch.cuda.amp.autocast(enabled=False): - ret = self.lstm_tensor(context.to(dtype=torch.float32), lens_sorted, enforce_sorted=True) - return ret[0].to(dtype=dtype)[unsort_ids] + seq = nn.utils.rnn.pack_padded_sequence( + context, lens_sorted.long().cpu(), batch_first=True, enforce_sorted=True + ) + return self.lstm_sequence(seq)[0][unsort_ids] class ConvLSTMLinear(nn.Module): @@ -152,8 +160,7 @@ def __init__( use_partial_padding=False, norm_fn=None, ): - super(ConvLSTMLinear, self).__init__() - self.bilstm = BiLSTM(n_channels, int(n_channels // 2), 1) + super(ConvLSTMLinear, self).__init__(n_channels, int(n_channels // 2), 1) self.convolutions = nn.ModuleList() if n_layers > 0: @@ -184,14 +191,24 @@ def __init__( if out_dim is not None: self.dense = nn.Linear(n_channels, out_dim) - def forward(self, context: Tensor, lens: Tensor) -> Tensor: + def masked_conv_to_sequence(self, context: Tensor, lens: Tensor, enforce_sorted: bool = False) -> PackedSequence: mask = get_mask_from_lengths_and_val(lens, context) mask = mask.to(dtype=context.dtype).unsqueeze(1) for conv in self.convolutions: context = self.dropout(F.relu(conv(context, mask))) + context = context.transpose(1, 2) - # Apply Bidirectional LSTM - context = self.bilstm(context, lens) + seq = torch.nn.utils.rnn.pack_padded_sequence( + context, lens.long().cpu(), batch_first=True, enforce_sorted=enforce_sorted + ) + return seq + + def forward(self, context: Tensor, lens: Tensor) -> Tensor: + context, lens, unsort_ids = sort_tensor(context, lens) + seq = self.masked_conv_to_sequence(context, lens, enforce_sorted=True) + context, _ = self.lstm_sequence(seq) + context = context[unsort_ids] + if self.dense is not None: context = self.dense(context).permute(0, 2, 1) return context diff --git a/nemo/collections/tts/modules/radtts.py b/nemo/collections/tts/modules/radtts.py index 9f360a4e5a33..dca0f0ede62c 100644 --- a/nemo/collections/tts/modules/radtts.py +++ b/nemo/collections/tts/modules/radtts.py @@ -345,7 +345,9 @@ def preprocess_context(self, context, speaker_vecs, out_lens, f0, energy_avg): context_w_spkvec = torch.cat((context_w_spkvec, energy_avg), 1) unfolded_out_lens = out_lens // self.n_group_size - context_lstm_padded_output = self.context_lstm(context_w_spkvec.transpose(1, 2), unfolded_out_lens) + context_lstm_padded_output = self.context_lstm.sort_and_lstm_tensor( + context_w_spkvec.transpose(1, 2), unfolded_out_lens + ) context_w_spkvec = context_lstm_padded_output.transpose(1, 2) if not self.context_lstm_w_f0_and_energy: @@ -770,8 +772,8 @@ def input_example(self, max_batch=1, max_dim=256): """ par = next(self.parameters()) sz = (max_batch, max_dim) - inp = torch.randint(16, 32, sz, device=par.device, dtype=torch.int64) - lens = torch.randint(max_dim // 4, max_dim // 2, (max_batch,), device=par.device, dtype=torch.int) + inp = torch.randint(0, 16, sz, device=par.device, dtype=torch.int64) + lens = torch.randint(16, max_dim, (max_batch,), device=par.device, dtype=torch.int) speaker = torch.randint(0, 1, (max_batch,), device=par.device, dtype=torch.int64) inputs = { 'text': inp, diff --git a/nemo/core/classes/exportable.py b/nemo/core/classes/exportable.py index 50266dab3dbe..b3f0b2fdd642 100644 --- a/nemo/core/classes/exportable.py +++ b/nemo/core/classes/exportable.py @@ -128,7 +128,7 @@ def _export( # Set module mode with torch.onnx.select_model_mode_for_export( self, training - ), torch.inference_mode(), torch.no_grad(), torch.jit.optimized_execution(True), _jit_is_scripting(): + ), torch.inference_mode(), torch.no_grad(), torch.jit.optimized_execution(True): if input_example is None: input_example = self.input_module.input_example() diff --git a/nemo/utils/cast_utils.py b/nemo/utils/cast_utils.py index 9eb064936ea5..f973a4719e24 100644 --- a/nemo/utils/cast_utils.py +++ b/nemo/utils/cast_utils.py @@ -70,6 +70,6 @@ def __init__(self, mod): self.mod = mod def forward(self, x): - with avoid_float16_autocast_context(): + with torch.cuda.amp.autocast(enabled=False): ret = self.mod.forward(x.to(torch.float32)).to(x.dtype) return ret diff --git a/nemo/utils/export_utils.py b/nemo/utils/export_utils.py index c7a45649daa2..0fbe2999bffe 100644 --- a/nemo/utils/export_utils.py +++ b/nemo/utils/export_utils.py @@ -15,7 +15,7 @@ import os from contextlib import nullcontext from enum import Enum -from typing import Callable, Dict, Optional, Type +from typing import Callable, Dict, List, Optional, Type import onnx import torch @@ -158,12 +158,8 @@ def verify_torchscript(model, output, input_examples, input_names, check_toleran for input_example in input_examples: input_list, input_dict = parse_input_example(input_example) output_example = model.forward(*input_list, **input_dict) - # We disable autocast here to make sure exported TS will run under Triton or other C++ env - with torch.cuda.amp.autocast(enabled=False): - ts_model = torch.jit.load(output) - all_good = all_good and run_ts_and_compare( - ts_model, input_list, input_dict, output_example, check_tolerance - ) + + all_good = all_good and run_ts_and_compare(ts_model, input_list, input_dict, output_example, check_tolerance) status = "SUCCESS" if all_good else "FAIL" logging.info(f"Torchscript generated at {output} verified with torchscript forward : " + status) return all_good @@ -205,15 +201,8 @@ def run_ts_and_compare(ts_model, ts_input_list, ts_input_dict, output_example, c if torch.is_tensor(expected): tout = out.to('cpu') - logging.debug(f"Checking output {i}, shape: {expected.shape}:\n") - this_good = True - try: - if not torch.allclose(tout, expected.cpu(), rtol=check_tolerance, atol=check_tolerance): - this_good = False - except Exception: # there may ne size mismatch and it may be OK - this_good = False - if not this_good: - logging.info(f"Results mismatch! PyTorch(expected):\n{expected}\nTorchScript:\n{tout}") + logging.debug(f"Checking output {i}, shape: {expected.shape}:\n{expected}\n{tout}") + if not torch.allclose(tout, expected.cpu(), rtol=check_tolerance, atol=check_tolerance): all_good = False return all_good @@ -227,14 +216,9 @@ def run_ort_and_compare(sess, ort_input, output_example, check_tolerance=0.01): if torch.is_tensor(expected): tout = torch.from_numpy(out) - logging.debug(f"Checking output {i}, shape: {expected.shape}:\n") - this_good = True - try: - if not torch.allclose(tout, expected.cpu(), rtol=check_tolerance, atol=100 * check_tolerance): - this_good = False - except Exception: # there may ne size mismatch and it may be OK - this_good = False - if not this_good: + logging.debug(f"Checking output {i}, shape: {expected.shape}:\n{expected}\n{tout}") + if not torch.allclose(tout, expected.cpu(), rtol=check_tolerance, atol=100 * check_tolerance): + all_good = False logging.info(f"onnxruntime results mismatch! PyTorch(expected):\n{expected}\nONNXruntime:\n{tout}") all_good = False return all_good @@ -433,7 +417,8 @@ def replace_modules( def script_module(m: nn.Module): - return torch.jit.script(m) + m1 = torch.jit.script(m) + return m1 default_replacements = { @@ -443,6 +428,11 @@ def script_module(m: nn.Module): "MatchedScaleMaskSoftmax": wrap_module(nn.Softmax, ExportableMatchedScaleMaskSoftmax), } +script_replacements = { + "BiLSTM": script_module, + "ConvLSTMLinear": script_module, +} + def replace_for_export(model: nn.Module) -> nn.Module: """ diff --git a/scripts/export.py b/scripts/export.py index b3d6317e936c..2e100e446e72 100644 --- a/scripts/export.py +++ b/scripts/export.py @@ -143,10 +143,11 @@ def nemo_export(argv): if check_trace and len(in_args) > 0: input_example = model.input_module.input_example(**in_args) check_trace = [input_example] - for key, arg in in_args: + for key, arg in in_args.items(): in_args[key] = (arg + 1) // 2 input_example2 = model.input_module.input_example(**in_args) check_trace.append(input_example2) + logging.info(f"Using additional check args: {in_args}") _, descriptions = model.export( out, diff --git a/tests/collections/tts/test_tts_exportables.py b/tests/collections/tts/test_tts_exportables.py index d7684de732e5..e3e496373271 100644 --- a/tests/collections/tts/test_tts_exportables.py +++ b/tests/collections/tts/test_tts_exportables.py @@ -15,7 +15,6 @@ import tempfile import pytest -import torch from omegaconf import OmegaConf from nemo.collections.tts.models import FastPitchModel, HifiGanModel, RadTTSModel @@ -74,12 +73,10 @@ def test_HifiGanModel_export_to_onnx(self, hifigan_model): filename = os.path.join(tmpdir, 'hfg.pt') model.export(output=filename, verbose=True, check_trace=True) - @pytest.mark.pleasefixme @pytest.mark.run_only_on('GPU') @pytest.mark.unit def test_RadTTSModel_export_to_torchscript(self, radtts_model): model = radtts_model.cuda() with tempfile.TemporaryDirectory() as tmpdir: filename = os.path.join(tmpdir, 'rad.ts') - with torch.cuda.amp.autocast(enabled=True): - model.export(output=filename, verbose=True, check_trace=True) + model.export(output=filename, verbose=True, check_trace=True)