diff --git a/examples/tts/conf/fastpitch_align_44100_adapter.yaml b/examples/tts/conf/fastpitch_align_44100_adapter.yaml index b2957b057d28..3c41cf3e55e5 100644 --- a/examples/tts/conf/fastpitch_align_44100_adapter.yaml +++ b/examples/tts/conf/fastpitch_align_44100_adapter.yaml @@ -32,6 +32,9 @@ phoneme_dict_path: "scripts/tts_dataset_files/cmudict-0.7b_nv22.10" heteronyms_path: "scripts/tts_dataset_files/heteronyms-052722" model: + unfreeze_aligner: false + unfreeze_duration_predictor: false + unfreeze_pitch_predictor: false learn_alignment: true bin_loss_warmup_epochs: 100 diff --git a/examples/tts/fastpitch_finetune_adapters.py b/examples/tts/fastpitch_finetune_adapters.py index 396552b0f4fd..1361d63fb4cf 100644 --- a/examples/tts/fastpitch_finetune_adapters.py +++ b/examples/tts/fastpitch_finetune_adapters.py @@ -107,6 +107,18 @@ def main(cfg): if adapter_global_cfg is not None: add_global_adapter_cfg(model, adapter_global_cfg) + if cfg.model.get("unfreeze_aligner", False): + for name, param in model.fastpitch.aligner.named_parameters(): + param.requires_grad = True + + if cfg.model.get("unfreeze_duration_predictor", False): + for name, param in model.fastpitch.duration_predictor.named_parameters(): + param.requires_grad = True + + if cfg.model.get("unfreeze_pitch_predictor", False): + for name, param in model.fastpitch.pitch_predictor.named_parameters(): + param.requires_grad = True + # Add adapters model.add_adapter(name=adapter_name, cfg=cfg.model.adapter) assert model.is_adapter_available() diff --git a/nemo/collections/tts/losses/aligner_loss.py b/nemo/collections/tts/losses/aligner_loss.py index 1a666d750521..792125a25edb 100644 --- a/nemo/collections/tts/losses/aligner_loss.py +++ b/nemo/collections/tts/losses/aligner_loss.py @@ -22,11 +22,12 @@ class ForwardSumLoss(Loss): - def __init__(self, blank_logprob=-1): + def __init__(self, blank_logprob=-1, loss_scale=1.0): super().__init__() self.log_softmax = torch.nn.LogSoftmax(dim=-1) self.ctc_loss = torch.nn.CTCLoss(zero_infinity=True) self.blank_logprob = blank_logprob + self.loss_scale = loss_scale @property def input_types(self): @@ -67,13 +68,15 @@ def forward(self, attn_logprob, in_lens, out_lens): # Evaluate CTC loss cost = self.ctc_loss(attn_logprob, target_seqs, input_lengths=query_lens, target_lengths=key_lens) + cost *= self.loss_scale return cost class BinLoss(Loss): - def __init__(self): + def __init__(self, loss_scale=1.0): super().__init__() + self.loss_scale = loss_scale @property def input_types(self): @@ -91,4 +94,6 @@ def output_types(self): @typecheck() def forward(self, hard_attention, soft_attention): log_sum = torch.log(torch.clamp(soft_attention[hard_attention == 1], min=1e-12)).sum() - return -log_sum / hard_attention.sum() + loss = -log_sum / hard_attention.sum() + loss *= self.loss_scale + return loss diff --git a/nemo/collections/tts/models/fastpitch.py b/nemo/collections/tts/models/fastpitch.py index 3939c9453911..ee2b3dce6ea3 100644 --- a/nemo/collections/tts/models/fastpitch.py +++ b/nemo/collections/tts/models/fastpitch.py @@ -138,9 +138,10 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): self.aligner = None if self.learn_alignment: + aligner_loss_scale = cfg.aligner_loss_scale if "aligner_loss_scale" in cfg else 1.0 self.aligner = instantiate(self._cfg.alignment_module) - self.forward_sum_loss_fn = ForwardSumLoss() - self.bin_loss_fn = BinLoss() + self.forward_sum_loss_fn = ForwardSumLoss(loss_scale=aligner_loss_scale) + self.bin_loss_fn = BinLoss(loss_scale=aligner_loss_scale) self.preprocessor = instantiate(self._cfg.preprocessor) input_fft = instantiate(self._cfg.input_fft, **input_fft_kwargs) diff --git a/tutorials/tts/FastPitch_Adapter_Finetuning.ipynb b/tutorials/tts/FastPitch_Adapter_Finetuning.ipynb index fa1b1bdc90c8..13fcccec9352 100644 --- a/tutorials/tts/FastPitch_Adapter_Finetuning.ipynb +++ b/tutorials/tts/FastPitch_Adapter_Finetuning.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "ea49c0e5", + "id": "d3f30c7d", "metadata": {}, "source": [ "# FastPitch Adapter Finetuning\n", @@ -23,7 +23,7 @@ }, { "cell_type": "markdown", - "id": "37259555", + "id": "858f8989", "metadata": {}, "source": [ "# License\n", @@ -46,7 +46,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d61cbea5", + "id": "26f80dbb", "metadata": {}, "outputs": [], "source": [ @@ -73,7 +73,7 @@ { "cell_type": "code", "execution_count": null, - "id": "fef9aba9", + "id": "1e9a80f7", "metadata": {}, "outputs": [], "source": [ @@ -83,7 +83,7 @@ { "cell_type": "code", "execution_count": null, - "id": "49bc38ab", + "id": "34248bfd", "metadata": {}, "outputs": [], "source": [ @@ -95,7 +95,7 @@ { "cell_type": "code", "execution_count": null, - "id": "9459f9dc", + "id": "c7defa68", "metadata": {}, "outputs": [], "source": [ @@ -113,7 +113,7 @@ { "cell_type": "code", "execution_count": null, - "id": "eb26f54d", + "id": "ddbd82d7", "metadata": {}, "outputs": [], "source": [ @@ -131,7 +131,7 @@ { "cell_type": "code", "execution_count": null, - "id": "12b28329", + "id": "711e7af7", "metadata": {}, "outputs": [], "source": [ @@ -149,7 +149,7 @@ }, { "cell_type": "markdown", - "id": "30996769", + "id": "2348739b", "metadata": {}, "source": [ "# 1. Fine-tune FastPitch on adaptation data" @@ -157,7 +157,7 @@ }, { "cell_type": "markdown", - "id": "2f5f5945", + "id": "c082c1f7", "metadata": {}, "source": [ "## a. Data Preparation\n", @@ -167,7 +167,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8047f988", + "id": "f8814bf6", "metadata": {}, "outputs": [], "source": [ @@ -177,7 +177,7 @@ { "cell_type": "code", "execution_count": null, - "id": "b8242769", + "id": "420e0fc9", "metadata": {}, "outputs": [], "source": [ @@ -188,7 +188,7 @@ { "cell_type": "code", "execution_count": null, - "id": "79cf8539", + "id": "223778c2", "metadata": {}, "outputs": [], "source": [ @@ -198,7 +198,7 @@ }, { "cell_type": "markdown", - "id": "35c3b97b", + "id": "ba84d822", "metadata": {}, "source": [ "## b. Preprocessing" @@ -206,7 +206,7 @@ }, { "cell_type": "markdown", - "id": "ba3a7c3a", + "id": "4e704114", "metadata": {}, "source": [ "### Add absolute file path in manifest\n", @@ -216,7 +216,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8bc485b5", + "id": "b3320461", "metadata": {}, "outputs": [], "source": [ @@ -226,7 +226,7 @@ { "cell_type": "code", "execution_count": null, - "id": "f9cb8ef5", + "id": "44442b19", "metadata": {}, "outputs": [], "source": [ @@ -241,7 +241,7 @@ }, { "cell_type": "markdown", - "id": "f92054d5", + "id": "8f03570a", "metadata": {}, "source": [ "### Extract Supplementary Data\n", @@ -252,7 +252,7 @@ { "cell_type": "code", "execution_count": null, - "id": "0adc618b", + "id": "c9c59467", "metadata": {}, "outputs": [], "source": [ @@ -267,7 +267,7 @@ }, { "cell_type": "markdown", - "id": "96dd5fe1", + "id": "23453289", "metadata": {}, "source": [ "After running the above command line, you will observe a new folder NeMoTTS_sup_data/pitch and printouts of pitch statistics like below. Specify these values to the FastPitch training configurations. We will be there in the following section.\n", @@ -280,7 +280,7 @@ { "cell_type": "code", "execution_count": null, - "id": "23703c76", + "id": "91752fa7", "metadata": {}, "outputs": [], "source": [ @@ -295,7 +295,7 @@ }, { "cell_type": "markdown", - "id": "7c70e5db", + "id": "1dadd87d", "metadata": {}, "source": [ "## c. Model Setting\n", @@ -305,7 +305,7 @@ { "cell_type": "code", "execution_count": null, - "id": "439f2f82", + "id": "31a77a2b", "metadata": {}, "outputs": [], "source": [ @@ -318,7 +318,7 @@ { "cell_type": "code", "execution_count": null, - "id": "30f865cb", + "id": "3f679835", "metadata": {}, "outputs": [], "source": [ @@ -350,7 +350,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e92910b5", + "id": "c8fbf558", "metadata": {}, "outputs": [], "source": [ @@ -360,7 +360,7 @@ }, { "cell_type": "markdown", - "id": "7f03219f", + "id": "80fb14c1", "metadata": {}, "source": [ "### Precompute Speaker Embedding\n", @@ -370,7 +370,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c2a35241", + "id": "f371b868", "metadata": {}, "outputs": [], "source": [ @@ -405,7 +405,7 @@ { "cell_type": "code", "execution_count": null, - "id": "5fa1b309", + "id": "ebb9dc03", "metadata": {}, "outputs": [], "source": [ @@ -417,7 +417,7 @@ }, { "cell_type": "markdown", - "id": "3b77e95f", + "id": "30c8908f", "metadata": {}, "source": [ "## d. Training" @@ -426,7 +426,7 @@ { "cell_type": "code", "execution_count": null, - "id": "9e8c3740", + "id": "b5f7ffba", "metadata": {}, "outputs": [], "source": [ @@ -440,7 +440,7 @@ }, { "cell_type": "markdown", - "id": "19bb6d8b", + "id": "02233f03", "metadata": {}, "source": [ "### Important notes\n", @@ -451,13 +451,16 @@ "* Other optional arguments based on your preference:\n", " * batch_size\n", " * exp_manager\n", - " * trainer" + " * trainer\n", + " * model.unfreeze_aligner=true\n", + " * model.unfreeze_duration_predictor=true\n", + " * model.unfreeze_pitch_predictor=true" ] }, { "cell_type": "code", "execution_count": null, - "id": "8c8cbea2", + "id": "270b061a", "metadata": {}, "outputs": [], "source": [ @@ -495,7 +498,7 @@ { "cell_type": "code", "execution_count": null, - "id": "fe5c7b2f", + "id": "690acf5b", "metadata": {}, "outputs": [], "source": [ @@ -510,7 +513,7 @@ }, { "cell_type": "markdown", - "id": "75856d0e", + "id": "ef0d84d0", "metadata": {}, "source": [ "# 3. Fine-tune HiFiGAN on adaptation data" @@ -518,7 +521,7 @@ }, { "cell_type": "markdown", - "id": "3444698f", + "id": "91076fba", "metadata": {}, "source": [ "## a. Dataset Preparation\n", @@ -528,7 +531,7 @@ { "cell_type": "code", "execution_count": null, - "id": "bb2fd64d", + "id": "0475ec9e", "metadata": {}, "outputs": [], "source": [ @@ -554,7 +557,7 @@ { "cell_type": "code", "execution_count": null, - "id": "da69cb66", + "id": "26b8a35f", "metadata": {}, "outputs": [], "source": [ @@ -564,7 +567,7 @@ }, { "cell_type": "markdown", - "id": "fa2cbb02", + "id": "6e043e23", "metadata": {}, "source": [ "## b. Training" @@ -573,7 +576,7 @@ { "cell_type": "code", "execution_count": null, - "id": "ffdce5d5", + "id": "4ddb023b", "metadata": {}, "outputs": [], "source": [ @@ -601,7 +604,7 @@ { "cell_type": "code", "execution_count": null, - "id": "9e6376cf", + "id": "cdff9449", "metadata": {}, "outputs": [], "source": [ @@ -613,7 +616,7 @@ }, { "cell_type": "markdown", - "id": "e5076e51", + "id": "21618f41", "metadata": {}, "source": [ "# 4. Inference" @@ -622,7 +625,7 @@ { "cell_type": "code", "execution_count": null, - "id": "52358549", + "id": "b5b82404", "metadata": {}, "outputs": [], "source": [ @@ -633,7 +636,7 @@ }, { "cell_type": "markdown", - "id": "9e96ee13", + "id": "8de3ed76", "metadata": {}, "source": [ "## a. Load Model" @@ -642,7 +645,7 @@ { "cell_type": "code", "execution_count": null, - "id": "2cb5d524", + "id": "0908431a", "metadata": {}, "outputs": [], "source": [ @@ -652,7 +655,7 @@ { "cell_type": "code", "execution_count": null, - "id": "32dbd30c", + "id": "c4290675", "metadata": {}, "outputs": [], "source": [ @@ -668,7 +671,7 @@ { "cell_type": "code", "execution_count": null, - "id": "74a7ad03", + "id": "62ec6894", "metadata": {}, "outputs": [], "source": [ @@ -678,7 +681,7 @@ }, { "cell_type": "markdown", - "id": "4f882975", + "id": "beb87927", "metadata": {}, "source": [ "## b. Output Audio" @@ -687,7 +690,7 @@ { "cell_type": "code", "execution_count": null, - "id": "2178a8ef", + "id": "ec561fe1", "metadata": {}, "outputs": [], "source": [ @@ -720,7 +723,7 @@ { "cell_type": "code", "execution_count": null, - "id": "766154e3", + "id": "5426c70a", "metadata": {}, "outputs": [], "source": [ @@ -743,7 +746,7 @@ { "cell_type": "code", "execution_count": null, - "id": "dfa71ca6", + "id": "534fff32", "metadata": {}, "outputs": [], "source": [ @@ -775,7 +778,7 @@ { "cell_type": "code", "execution_count": null, - "id": "51d9d176", + "id": "7a9f185e", "metadata": {}, "outputs": [], "source": [ @@ -786,7 +789,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6180a7d2", + "id": "903c37ce", "metadata": {}, "outputs": [], "source": [ @@ -797,7 +800,7 @@ { "cell_type": "code", "execution_count": null, - "id": "5b33263b", + "id": "b8c1daac", "metadata": {}, "outputs": [], "source": []