Skip to content

Commit

Permalink
Refactoring
Browse files Browse the repository at this point in the history
Signed-off-by: treacker <emshabalin@yandex.ru>
  • Loading branch information
treacker committed May 11, 2022
1 parent 9df43dd commit cfa290f
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 35 deletions.
49 changes: 19 additions & 30 deletions nemo/collections/tts/helpers/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,12 @@

from nemo.utils import logging

HAVE_WANDB = True
try:
import wandb
except ModuleNotFoundError:
HAVE_WANDB = False

try:
from pytorch_lightning.utilities import rank_zero_only
except ModuleNotFoundError:
Expand All @@ -69,12 +75,6 @@ def wrapped_fn(*args, **kwargs):
)
exit(1)

HAVE_WANDB = True
try:
import wandb
except ModuleNotFoundError:
HAVE_WANDB = False

class OperationMode(Enum):
"""Training or Inference (Evaluation) mode"""

Expand Down Expand Up @@ -328,24 +328,21 @@ def tacotron2_log_to_wandb_func(
alignments = []
specs = []
gates = []
alignments += [wandb.Image(
plot_alignment_to_numpy(alignments[0].data.cpu().numpy().T), caption=f"{tag}_alignment",
)]
alignments += [
wandb.Image(plot_alignment_to_numpy(alignments[0].data.cpu().numpy().T), caption=f"{tag}_alignment",)
]
alignments += [
wandb.Image(plot_spectrogram_to_numpy(spec_target[0].data.cpu().numpy()), caption=f"{tag}_mel_target",),
wandb.Image(plot_spectrogram_to_numpy(mel_postnet[0].data.cpu().numpy()), caption=f"{tag}_mel_predicted",)
]
gates += [
wandb.Image(
plot_spectrogram_to_numpy(spec_target[0].data.cpu().numpy()),
caption=f"{tag}_mel_target",
),
wandb.Image(
plot_spectrogram_to_numpy(mel_postnet[0].data.cpu().numpy()),
caption=f"{tag}_mel_predicted",
plot_gate_outputs_to_numpy(
gate_target[0].data.cpu().numpy(), torch.sigmoid(gate[0]).data.cpu().numpy(),
),
caption=f"{tag}_gate",
)
]

gates += [wandb.Image(
plot_gate_outputs_to_numpy(gate_target[0].data.cpu().numpy(), torch.sigmoid(gate[0]).data.cpu().numpy(),),
caption=f"{tag}_gate",
)]

swriter.log({"specs": specs, "alignments": alignments, "gates" : gates})

Expand All @@ -363,16 +360,8 @@ def tacotron2_log_to_wandb_func(
audio_true = griffin_lim(magnitude.T ** griffin_lim_power)

audios += [
wandb.Audio(
audio_true / max(np.abs(audio_true)),
caption=f"{tag}_wav_target",
sample_rate=sr,
),
wandb.Audio(
audio_pred / max(np.abs(audio_pred)),
caption=f"{tag}_wav_predicted",
sample_rate=sr,
),
wandb.Audio(audio_true / max(np.abs(audio_true)),caption=f"{tag}_wav_target",sample_rate=sr,),
wandb.Audio(audio_pred / max(np.abs(audio_pred)),caption=f"{tag}_wav_predicted",sample_rate=sr,),
]

swriter.log({"audios": audios})
Expand Down
16 changes: 11 additions & 5 deletions nemo/collections/tts/models/tacotron2.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,11 @@
from torch import nn

from nemo.collections.common.parts.preprocessing import parsers
from nemo.collections.tts.helpers.helpers import get_mask_from_lengths, tacotron2_log_to_tb_func, tacotron2_log_to_wandb_func
from nemo.collections.tts.helpers.helpers import (
get_mask_from_lengths,
tacotron2_log_to_tb_func,
tacotron2_log_to_wandb_func
)
from nemo.collections.tts.losses.tacotron2loss import Tacotron2Loss
from nemo.collections.tts.models.base import SpectrogramGenerator
from nemo.core.classes.common import PretrainedModelInfo, typecheck
Expand Down Expand Up @@ -61,12 +65,11 @@ class Tacotron2Model(SpectrogramGenerator):
"""Tacotron 2 Model that is used to generate mel spectrograms from text"""

def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None):

# Convert to Hydra 1.0 compatible DictConfig
cfg = model_utils.convert_model_config_to_dict_config(cfg)
cfg = model_utils.maybe_update_config_version(cfg)


# setup normalizer
# setup normalizer
self.normalizer = None
self.text_normalizer_call = None
self.text_normalizer_call_kwargs = {}
Expand Down Expand Up @@ -117,6 +120,7 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None):
def parser(self):
if self._parser is not None:
return self._parser

ds_class_name = self._cfg.train_ds.dataset._target_.split(".")[-1]
if ds_class_name == "TTSDataset":
self._parser = None
Expand All @@ -134,6 +138,7 @@ def parser(self):
self.parser = self.vocab.encode
else:
raise ValueError("Wanted to setup parser, but model does not have necessary paramaters")

return self._parser

def parse(self, text: str, normalize=True) -> torch.Tensor:
Expand Down Expand Up @@ -348,7 +353,8 @@ def __setup_dataloader_from_config(self, cfg, shuffle_should_be: bool = True, na
logging.error(f"The {name} dataloader for {self} has shuffle set to True!!!")

dataset = instantiate(
cfg.dataset, text_normalizer=self.normalizer,
cfg.dataset,
text_normalizer=self.normalizer,
text_normalizer_call_kwargs=self.text_normalizer_call_kwargs,
text_tokenizer=self.tokenizer,
)
Expand Down

0 comments on commit cfa290f

Please sign in to comment.