Skip to content

Commit

Permalink
[core] bugfix for capturing NotImplementedError of non-supported supp…
Browse files Browse the repository at this point in the history
…lementary data types. (#4297)

Signed-off-by: Xuesong Yang <xuesongyxs@gmail.com>

Co-authored-by: Somshubra Majumdar <titu1994@gmail.com>
  • Loading branch information
XuesongYang and titu1994 authored Jun 2, 2022
1 parent 0667415 commit 187e9a9
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 7 deletions.
17 changes: 10 additions & 7 deletions nemo/collections/tts/torch/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
from nemo.collections.tts.torch.tts_data_types import (
DATA_STR2DATA_CLASS,
MAIN_DATA_TYPES,
VALID_SUPPLEMENTARY_DATA_TYPES,
AlignPriorMatrix,
Durations,
Energy,
Expand Down Expand Up @@ -253,15 +252,19 @@ def __init__(
Path(sup_data_path).mkdir(parents=True, exist_ok=True)
self.sup_data_path = sup_data_path

self.sup_data_types = (
[DATA_STR2DATA_CLASS[d_as_str] for d_as_str in sup_data_types] if sup_data_types is not None else []
)
self.sup_data_types = []
if sup_data_types is not None:
for d_as_str in sup_data_types:
try:
sup_data_type = DATA_STR2DATA_CLASS[d_as_str]
except KeyError:
raise NotImplementedError(f"Current implementation doesn't support {d_as_str} type.")

self.sup_data_types.append(sup_data_type)

self.sup_data_types_set = set(self.sup_data_types)

for data_type in self.sup_data_types:
if data_type not in VALID_SUPPLEMENTARY_DATA_TYPES:
raise NotImplementedError(f"Current implementation doesn't support {data_type} type.")

getattr(self, f"add_{data_type.name}")(**kwargs)

@staticmethod
Expand Down
23 changes: 23 additions & 0 deletions tests/collections/tts/test_torch_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,26 @@ def test_dataset(self, test_data_dir):

dataloader = torch.utils.data.DataLoader(dataset, 2, collate_fn=dataset._collate_fn)
data, _, _, _, _, _ = next(iter(dataloader))

@pytest.mark.run_only_on('CPU')
@pytest.mark.unit
@pytest.mark.torch_tts
def test_raise_exception_on_not_supported_sup_data_types(self, test_data_dir):
manifest_path = os.path.join(test_data_dir, 'tts/mini_ljspeech/manifest.json')
sup_path = os.path.join(test_data_dir, 'tts/mini_ljspeech/sup')
with pytest.raises(NotImplementedError):
dataset = TTSDataset(
manifest_filepath=manifest_path,
sample_rate=22050,
sup_data_types=["not_supported_sup_data_type"],
sup_data_path=sup_path,
text_tokenizer=EnglishPhonemesTokenizer(
punct=True,
stresses=True,
chars=True,
space=' ',
apostrophe=True,
pad_with_space=True,
g2p=EnglishG2p(),
),
)

0 comments on commit 187e9a9

Please sign in to comment.