Skip to content

Commit

Permalink
update fastpitch to add export controls (NVIDIA#4509)
Browse files Browse the repository at this point in the history
* update fastpitch to add export controls

Signed-off-by: Jason <jasoli@nvidia.com>

* final touchups

Signed-off-by: Jason <jasoli@nvidia.com>

* more final touchups

Signed-off-by: Jason <jasoli@nvidia.com>
Signed-off-by: David Mosallanezhad <dmosallanezh@nvidia.com>
  • Loading branch information
blisc authored and Davood-M committed Aug 9, 2022
1 parent d3ffc92 commit 44bdb5f
Showing 1 changed file with 85 additions and 16 deletions.
101 changes: 85 additions & 16 deletions nemo/collections/tts/models/fastpitch.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
cfg.n_mel_channels,
)
self._input_types = self._output_types = None
self.export_config = {"enable_volume": False, "enable_ragged_batches": False}

def _get_default_text_tokenizer_conf(self):
text_tokenizer: TextTokenizerConfig = TextTokenizerConfig()
Expand Down Expand Up @@ -515,22 +516,26 @@ def list_available_models(cls) -> 'List[PretrainedModelInfo]':
def _prepare_for_export(self, **kwargs):
super()._prepare_for_export(**kwargs)

tensor_shape = ('T') if self.export_config["enable_ragged_batches"] else ('B', 'T')

# Define input_types and output_types as required by export()
self._input_types = {
"text": NeuralType(('B', 'T_text'), TokenIndex()),
"pitch": NeuralType(('B', 'T_text'), RegressionValuesType()),
"pace": NeuralType(('B', 'T_text'), optional=True),
"volume": NeuralType(('B', 'T_text')),
"speaker": NeuralType(('B'), Index()),
"text": NeuralType(tensor_shape, TokenIndex()),
"pitch": NeuralType(tensor_shape, RegressionValuesType()),
"pace": NeuralType(tensor_shape),
"speaker": NeuralType(('B'), Index(), optional=True),
"volume": NeuralType(tensor_shape, optional=True),
"batch_lengths": NeuralType(('B'), optional=True),
}
self._output_types = {
"spect": NeuralType(('B', 'D', 'T_spec'), MelSpectrogramType()),
"spect": NeuralType(('B', 'D', 'T'), MelSpectrogramType()),
"num_frames": NeuralType(('B'), TokenDurationType()),
"durs_predicted": NeuralType(('B', 'T_text'), TokenDurationType()),
"log_durs_predicted": NeuralType(('B', 'T_text'), TokenLogDurationType()),
"pitch_predicted": NeuralType(('B', 'T_text'), RegressionValuesType()),
"volume_aligned": NeuralType(('B', 'T_spec'), RegressionValuesType()),
"durs_predicted": NeuralType(('B', 'T'), TokenDurationType()),
"log_durs_predicted": NeuralType(('B', 'T'), TokenLogDurationType()),
"pitch_predicted": NeuralType(('B', 'T'), RegressionValuesType()),
}
if self.export_config["enable_volume"]:
self._output_types["volume_aligned"] = NeuralType(('B', 'T'), RegressionValuesType())

def _export_teardown(self):
self._input_types = self._output_types = None
Expand All @@ -541,6 +546,10 @@ def disabled_deployment_input_names(self):
disabled_inputs = set()
if self.fastpitch.speaker_emb is None:
disabled_inputs.add("speaker")
if not self.export_config["enable_ragged_batches"]:
disabled_inputs.add("batch_lengths")
if not self.export_config["enable_volume"]:
disabled_inputs.add("volume")
return disabled_inputs

@property
Expand All @@ -558,15 +567,35 @@ def input_example(self, max_batch=1, max_dim=44):
A tuple of input examples.
"""
par = next(self.fastpitch.parameters())
sz = (max_batch, max_dim)
sz = (max_batch * max_dim) if self.export_config["enable_ragged_batches"] else (max_batch, max_dim)
inp = torch.randint(
0, self.fastpitch.encoder.word_emb.num_embeddings, sz, device=par.device, dtype=torch.int64
)
pitch = torch.randn(sz, device=par.device, dtype=torch.float32) * 0.5
pace = torch.clamp((torch.randn(sz, device=par.device, dtype=torch.float32) + 1) * 0.1, min=0.01)
volume = torch.clamp((torch.randn(sz, device=par.device, dtype=torch.float32) + 1) * 0.1, min=0.01)

inputs = {'text': inp, 'pitch': pitch, 'pace': pace, 'volume': volume}
pace = torch.clamp(torch.randn(sz, device=par.device, dtype=torch.float32) * 0.1 + 1, min=0.01)

inputs = {'text': inp, 'pitch': pitch, 'pace': pace}

if self.export_config["enable_volume"]:
volume = torch.clamp(torch.randn(sz, device=par.device, dtype=torch.float32) * 0.1 + 1, min=0.01)
inputs['volume'] = volume
if self.export_config["enable_ragged_batches"]:
batch_lengths = torch.zeros((max_batch + 1), device=par.device, dtype=torch.int32)
left_over_size = sz
batch_lengths[0] = 0
for i in range(1, max_batch):
length = torch.randint(1, left_over_size - (max_batch - i), (1,), device=par.device)
batch_lengths[i] = length + batch_lengths[i - 1]
left_over_size -= length.detach().cpu().numpy()[0]
batch_lengths[-1] = left_over_size + batch_lengths[-2]

sum = 0
index = 1
while index < len(batch_lengths):
sum += batch_lengths[index] - batch_lengths[index - 1]
index += 1
assert sum == sz, f"sum: {sum}, sz: {sz}, lengths:{batch_lengths}"
inputs['batch_lengths'] = batch_lengths

if self.fastpitch.speaker_emb is not None:
inputs['speaker'] = torch.randint(
Expand All @@ -575,5 +604,45 @@ def input_example(self, max_batch=1, max_dim=44):

return (inputs,)

def forward_for_export(self, text, pitch, pace, volume, speaker=None):
def forward_for_export(self, text, pitch, pace, volume=None, batch_lengths=None, speaker=None):
if self.export_config["enable_ragged_batches"]:
text, pitch, pace, volume_tensor = create_batch(
text, pitch, pace, volume, batch_lengths, padding_idx=self.fastpitch.encoder.padding_idx
)
if volume is not None:
volume = volume_tensor
return self.fastpitch.infer(text=text, pitch=pitch, pace=pace, volume=volume, speaker=speaker)


@torch.jit.script
def create_batch(
text: torch.Tensor,
pitch: torch.Tensor,
pace: torch.Tensor,
batch_lengths: torch.Tensor,
padding_idx: int = -1,
volume: Optional[torch.Tensor] = None,
):
batch_lengths = batch_lengths.to(torch.int64)
max_len = torch.max(batch_lengths[1:] - batch_lengths[:-1])

index = 1
texts = torch.zeros(batch_lengths.shape[0] - 1, max_len, dtype=torch.int64, device=text.device) + padding_idx
pitches = torch.zeros(batch_lengths.shape[0] - 1, max_len, dtype=torch.float32, device=text.device)
paces = torch.zeros(batch_lengths.shape[0] - 1, max_len, dtype=torch.float32, device=text.device) + 1.0
volumes = torch.zeros(batch_lengths.shape[0] - 1, max_len, dtype=torch.float32, device=text.device) + 1.0

while index < batch_lengths.shape[0]:
seq_start = batch_lengths[index - 1]
seq_end = batch_lengths[index]
cur_seq_len = seq_end - seq_start

texts[index - 1, :cur_seq_len] = text[seq_start:seq_end]
pitches[index - 1, :cur_seq_len] = pitch[seq_start:seq_end]
paces[index - 1, :cur_seq_len] = pace[seq_start:seq_end]
if volume is not None:
volumes[index - 1, :cur_seq_len] = volume[seq_start:seq_end]

index += 1

return texts, pitches, paces, volumes

0 comments on commit 44bdb5f

Please sign in to comment.