Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TTS] Append pretrained FastPitch & SpectrogamEnhancer pair to available models #7012

Merged
merged 2 commits into from
Jul 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions nemo/collections/tts/models/fastpitch.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,6 +778,20 @@ def list_available_models(cls) -> 'List[PretrainedModelInfo]':
)
list_of_models.append(model)

# en, multi speaker, LibriTTS, 16000 Hz
# stft 25ms 10ms matching ASR params
# for use during Enhlish ASR training/adaptation
model = PretrainedModelInfo(
pretrained_model_name="tts_en_fastpitch_for_asr_finetuning",
location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/tts_en_fastpitch_spectrogram_enhancer_for_asr_finetuning/versions/1.20.0/files/tts_en_fastpitch_for_asr_finetuning.nemo",
description="This model is trained on LibriSpeech, train-960 subset."
" STFT parameters follow those commonly used in ASR: 25 ms window, 10 ms hop."
" This model is supposed to be used with its companion SpetrogramEnhancer for "
" ASR fine-tuning. Usage for regular TTS tasks is not advised.",
class_=cls,
)
list_of_models.append(model)

return list_of_models

# Methods for model exportability
Expand Down
20 changes: 18 additions & 2 deletions nemo/collections/tts/models/spectrogram_enhancer.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
HingeLoss,
)
from nemo.collections.tts.parts.utils.helpers import mask_sequence_tensor, to_device_recursive
from nemo.core import Exportable, ModelPT, typecheck
from nemo.core import Exportable, ModelPT, PretrainedModelInfo, typecheck
from nemo.core.neural_types import LengthsType, MelSpectrogramType, NeuralType
from nemo.core.neural_types.elements import BoolType
from nemo.utils import logging
Expand Down Expand Up @@ -277,7 +277,23 @@ def setup_validation_data(self, val_data_config):

@classmethod
def list_available_models(cls):
return []
list_of_models = []

# en, multi speaker, LibriTTS, 16000 Hz
# stft 25ms 10ms matching ASR params
# for use during Enhlish ASR training/adaptation
model = PretrainedModelInfo(
pretrained_model_name="tts_en_spectrogram_enhancer_for_asr_finetuning",
location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/tts_en_fastpitch_spectrogram_enhancer_for_asr_finetuning/versions/1.20.0/files/tts_en_spectrogram_enhancer_for_asr_finetuning.nemo",
description="This model is trained to add details to synthetic spectrograms."
" It was trained on pairs of real-synthesized spectrograms generated by FastPitch."
" STFT parameters follow ASR with 25 ms window and 10 ms hop."
" It is supposed to be used in conjunction with that model for ASR training/adaptation.",
class_=cls,
)
list_of_models.append(model)

return list_of_models

def log_illustration(self, target_spectrograms, input_spectrograms, enhanced_spectrograms, lengths):
if self.global_rank != 0:
Expand Down