diff --git a/nemo/collections/tts/torch/data.py b/nemo/collections/tts/torch/data.py index 248fd355fbbb..a929c0b5a8e2 100644 --- a/nemo/collections/tts/torch/data.py +++ b/nemo/collections/tts/torch/data.py @@ -52,6 +52,16 @@ from nemo.utils import logging +EPSILON = 1e-9 +WINDOW_FN_SUPPORTED = { + 'hann': torch.hann_window, + 'hamming': torch.hamming_window, + 'blackman': torch.blackman_window, + 'bartlett': torch.bartlett_window, + 'none': None, +} + + class TTSDataset(Dataset): def __init__( self, @@ -230,13 +240,13 @@ def __init__( dtype=torch.float, ).unsqueeze(0) - window_fn = { - 'hann': torch.hann_window, - 'hamming': torch.hamming_window, - 'blackman': torch.blackman_window, - 'bartlett': torch.bartlett_window, - 'none': None, - }.get(self.window, None) + try: + window_fn = WINDOW_FN_SUPPORTED[self.window] + except KeyError: + raise NotImplementedError( + f"Current implementation doesn't support {self.window} window. " + f"Please choose one from {list(WINDOW_FN_SUPPORTED.keys())}." + ) self.stft = lambda x: torch.stft( input=x, @@ -380,7 +390,7 @@ def get_spec(self, audio): spec = self.stft(audio) if spec.dtype in [torch.cfloat, torch.cdouble]: spec = torch.view_as_real(spec) - spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-9) + spec = torch.sqrt(spec.pow(2).sum(-1) + EPSILON) return spec def get_log_mel(self, audio): @@ -436,7 +446,6 @@ def __getitem__(self, index): # Load alignment prior matrix if needed align_prior_matrix = None if AlignPriorMatrix in self.sup_data_types_set: - align_prior_matrix = None if self.use_beta_binomial_interpolator: mel_len = self.get_log_mel(audio).shape[2] align_prior_matrix = torch.from_numpy(self.beta_binomial_interpolator(mel_len, text_length.item())) @@ -472,7 +481,7 @@ def __getitem__(self, index): if self.pitch_mean is not None and self.pitch_std is not None and self.pitch_norm: pitch -= self.pitch_mean - pitch[pitch == -self.pitch_mean] = 0.0 # Zero out values that were perviously zero + pitch[pitch == -self.pitch_mean] = 0.0 # Zero out values that were previously zero pitch /= self.pitch_std pitch_length = torch.tensor(len(pitch)).long() diff --git a/tests/collections/tts/test_torch_tts.py b/tests/collections/tts/test_torch_tts.py index e4e3270eadce..62876f1d4d59 100644 --- a/tests/collections/tts/test_torch_tts.py +++ b/tests/collections/tts/test_torch_tts.py @@ -71,3 +71,27 @@ def test_raise_exception_on_not_supported_sup_data_types(self, test_data_dir): g2p=EnglishG2p(), ), ) + + @pytest.mark.run_only_on('CPU') + @pytest.mark.unit + @pytest.mark.torch_tts + def test_raise_exception_on_not_supported_window(self, test_data_dir): + manifest_path = os.path.join(test_data_dir, 'tts/mini_ljspeech/manifest.json') + sup_path = os.path.join(test_data_dir, 'tts/mini_ljspeech/sup') + with pytest.raises(NotImplementedError): + dataset = TTSDataset( + manifest_filepath=manifest_path, + sample_rate=22050, + sup_data_types=["pitch"], + sup_data_path=sup_path, + window="not_supported_window", + text_tokenizer=EnglishPhonemesTokenizer( + punct=True, + stresses=True, + chars=True, + space=' ', + apostrophe=True, + pad_with_space=True, + g2p=EnglishG2p(), + ), + )