Skip to content

Commit

Permalink
[TTS] bug fix - sample rate was being ignored in vocoder dataset (#4518)
Browse files Browse the repository at this point in the history
* bug fix - sample rate was being ignored in vocoder dataset when not loading mel
* handled n segments for a different sampling rate than original sampling rate
* Added case for n_segments 0, warning for n_segments greater than file length

Signed-off-by: Paarth Neekhara <paarth.n@gmail.com>
Co-authored-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>
Co-authored-by: Jocelyn <jocelynh@nvidia.com>
  • Loading branch information
3 people authored Oct 13, 2022
1 parent 441e97f commit 4463a9f
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 4 deletions.
26 changes: 22 additions & 4 deletions nemo/collections/asr/parts/preprocessing/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
# SOFTWARE.
# This file contains code artifacts adapted from https://github.com/ryanleary/patter

import math
import os
import random

Expand Down Expand Up @@ -261,23 +262,40 @@ def segment_from_file(
Note that audio_file can be either the file path, or a file-like object.
"""
is_segmented = False
try:
with sf.SoundFile(audio_file, 'r') as f:
sample_rate = f.samplerate
if 0 < n_segments < len(f):
max_audio_start = len(f) - n_segments
if target_sr is not None:
n_segments_at_original_sr = math.ceil(n_segments * sample_rate / target_sr)
else:
n_segments_at_original_sr = n_segments

if 0 < n_segments_at_original_sr < len(f):
max_audio_start = len(f) - n_segments_at_original_sr
audio_start = random.randint(0, max_audio_start)
f.seek(audio_start)
samples = f.read(n_segments, dtype='float32')
samples = f.read(n_segments_at_original_sr, dtype='float32')
is_segmented = True
elif n_segments_at_original_sr >= len(f):
logging.warning(
f"Number of segments is greater than the length of the audio file {audio_file}. This may lead to shape mismatch errors."
)
samples = f.read(dtype='float32')
else:
samples = f.read(dtype='float32')
except RuntimeError as e:
logging.error(f"Loading {audio_file} via SoundFile raised RuntimeError: `{e}`.")

return cls(
features = cls(
samples, sample_rate, target_sr=target_sr, trim=trim, orig_sr=orig_sr, channel_selector=channel_selector
)

if is_segmented:
features._samples = features._samples[:n_segments]

return features

@property
def samples(self):
return self._samples.copy()
Expand Down
1 change: 1 addition & 0 deletions nemo/collections/tts/torch/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -951,6 +951,7 @@ def __getitem__(self, index):
if not self.load_precomputed_mel:
features = AudioSegment.segment_from_file(
sample["audio_filepath"],
target_sr=self.sample_rate,
n_segments=self.n_segments if self.n_segments is not None else -1,
trim=self.trim,
)
Expand Down

0 comments on commit 4463a9f

Please sign in to comment.