Skip to content

Commit

Permalink
[TTS] remove redundant lines and declare global variables and capture (
Browse files Browse the repository at this point in the history
…#4320)

exception of non-supported windows.

Signed-off-by: Xuesong Yang <xuesongyxs@gmail.com>
  • Loading branch information
XuesongYang authored Jun 3, 2022
1 parent a26a891 commit 968ee12
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 10 deletions.
29 changes: 19 additions & 10 deletions nemo/collections/tts/torch/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,16 @@
from nemo.utils import logging


EPSILON = 1e-9
WINDOW_FN_SUPPORTED = {
'hann': torch.hann_window,
'hamming': torch.hamming_window,
'blackman': torch.blackman_window,
'bartlett': torch.bartlett_window,
'none': None,
}


class TTSDataset(Dataset):
def __init__(
self,
Expand Down Expand Up @@ -230,13 +240,13 @@ def __init__(
dtype=torch.float,
).unsqueeze(0)

window_fn = {
'hann': torch.hann_window,
'hamming': torch.hamming_window,
'blackman': torch.blackman_window,
'bartlett': torch.bartlett_window,
'none': None,
}.get(self.window, None)
try:
window_fn = WINDOW_FN_SUPPORTED[self.window]
except KeyError:
raise NotImplementedError(
f"Current implementation doesn't support {self.window} window. "
f"Please choose one from {list(WINDOW_FN_SUPPORTED.keys())}."
)

self.stft = lambda x: torch.stft(
input=x,
Expand Down Expand Up @@ -380,7 +390,7 @@ def get_spec(self, audio):
spec = self.stft(audio)
if spec.dtype in [torch.cfloat, torch.cdouble]:
spec = torch.view_as_real(spec)
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-9)
spec = torch.sqrt(spec.pow(2).sum(-1) + EPSILON)
return spec

def get_log_mel(self, audio):
Expand Down Expand Up @@ -436,7 +446,6 @@ def __getitem__(self, index):
# Load alignment prior matrix if needed
align_prior_matrix = None
if AlignPriorMatrix in self.sup_data_types_set:
align_prior_matrix = None
if self.use_beta_binomial_interpolator:
mel_len = self.get_log_mel(audio).shape[2]
align_prior_matrix = torch.from_numpy(self.beta_binomial_interpolator(mel_len, text_length.item()))
Expand Down Expand Up @@ -472,7 +481,7 @@ def __getitem__(self, index):

if self.pitch_mean is not None and self.pitch_std is not None and self.pitch_norm:
pitch -= self.pitch_mean
pitch[pitch == -self.pitch_mean] = 0.0 # Zero out values that were perviously zero
pitch[pitch == -self.pitch_mean] = 0.0 # Zero out values that were previously zero
pitch /= self.pitch_std

pitch_length = torch.tensor(len(pitch)).long()
Expand Down
24 changes: 24 additions & 0 deletions tests/collections/tts/test_torch_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,27 @@ def test_raise_exception_on_not_supported_sup_data_types(self, test_data_dir):
g2p=EnglishG2p(),
),
)

@pytest.mark.run_only_on('CPU')
@pytest.mark.unit
@pytest.mark.torch_tts
def test_raise_exception_on_not_supported_window(self, test_data_dir):
manifest_path = os.path.join(test_data_dir, 'tts/mini_ljspeech/manifest.json')
sup_path = os.path.join(test_data_dir, 'tts/mini_ljspeech/sup')
with pytest.raises(NotImplementedError):
dataset = TTSDataset(
manifest_filepath=manifest_path,
sample_rate=22050,
sup_data_types=["pitch"],
sup_data_path=sup_path,
window="not_supported_window",
text_tokenizer=EnglishPhonemesTokenizer(
punct=True,
stresses=True,
chars=True,
space=' ',
apostrophe=True,
pad_with_space=True,
g2p=EnglishG2p(),
),
)

0 comments on commit 968ee12

Please sign in to comment.