diff --git a/nemo/collections/tts/data/__init__.py b/nemo/collections/tts/data/__init__.py new file mode 100644 index 000000000000..a1cf281f0908 --- /dev/null +++ b/nemo/collections/tts/data/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/collections/tts/data/audio_trimming.py b/nemo/collections/tts/data/audio_trimming.py new file mode 100644 index 000000000000..2cd831cc0724 --- /dev/null +++ b/nemo/collections/tts/data/audio_trimming.py @@ -0,0 +1,310 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod +from typing import Tuple + +import librosa +import numpy as np +import torch + +from nemo.collections.asr.models import EncDecClassificationModel +from nemo.collections.tts.data.data_utils import normalize_volume +from nemo.utils import logging + + +class AudioTrimmer(ABC): + """Interface for silence trimming implementations + """ + + @abstractmethod + def trim_audio(self, audio: np.array, sample_rate: int, audio_id: str) -> Tuple[np.array, int, int]: + """Trim starting and trailing silence from the input audio. + Args: + audio: Numpy array containing audio samples. Float [-1.0, 1.0] format. + sample_rate: Sample rate of input audio. + audio_id: String identifier (eg. file name) used for logging. + + Returns numpy array with trimmed audio, and integer sample indices representing the start and end + of speech within the original audio array. + """ + raise NotImplementedError + + +class EnergyAudioTrimmer(AudioTrimmer): + def __init__( + self, + db_threshold: int = 50, + ref_amplitude: float = 1.0, + speech_frame_threshold: int = 1, + trim_win_length: int = 2048, + trim_hop_length: int = 512, + pad_seconds: float = 0.1, + volume_norm: bool = True, + ) -> None: + """Energy/power based silence trimming using Librosa backend. + Args: + db_threshold: Audio frames at least db_threshold decibels below ref_amplitude will be + considered silence. + ref_amplitude: Amplitude threshold for classifying speech versus silence. + speech_frame_threshold: Start and end of speech will be detected where there are at least + speech_frame_threshold consecutive audio frames classified as speech. Setting this value higher + is more robust to false-positives (silence detected as speech), but setting it too high may result + in very short speech segments being cut out from the audio. + trim_win_length: Length of audio frames to use when doing speech detection. This does not need to match + the win_length used any other part of the code or model. + trim_hop_length: Stride of audio frames to use when doing speech detection. This does not need to match + the hop_length used any other part of the code or model. + pad_seconds: Audio duration in seconds to keep before and after each speech segment. + Set this to at least 0.1 to avoid cutting off any speech audio, with larger values + being safer but increasing the average silence duration left afterwards. + volume_norm: Whether to normalize the volume of audio before doing speech detection. + """ + assert db_threshold >= 0 + assert ref_amplitude >= 0 + assert speech_frame_threshold > 0 + assert trim_win_length > 0 + assert trim_hop_length > 0 + + self.db_threshold = db_threshold + self.ref_amplitude = ref_amplitude + self.speech_frame_threshold = speech_frame_threshold + self.trim_win_length = trim_win_length + self.trim_hop_length = trim_hop_length + self.pad_seconds = pad_seconds + self.volume_norm = volume_norm + + def trim_audio(self, audio: np.array, sample_rate: int, audio_id: str = "") -> Tuple[np.array, int, int]: + if self.volume_norm: + # Normalize volume so we have a fixed scale relative to the reference amplitude + audio = normalize_volume(audio=audio, volume_level=1.0) + + speech_frames = librosa.effects._signal_to_frame_nonsilent( + audio, + ref=self.ref_amplitude, + frame_length=self.trim_win_length, + hop_length=self.trim_hop_length, + top_db=self.db_threshold, + ) + + start_frame, end_frame = get_start_and_end_of_speech_frames( + is_speech=speech_frames, speech_frame_threshold=self.speech_frame_threshold, audio_id=audio_id, + ) + + start_sample = librosa.core.frames_to_samples(start_frame, hop_length=self.trim_hop_length) + end_sample = librosa.core.frames_to_samples(end_frame, hop_length=self.trim_hop_length) + + start_sample, end_sample = pad_sample_indices( + start_sample=start_sample, + end_sample=end_sample, + max_sample=audio.shape[0], + sample_rate=sample_rate, + pad_seconds=self.pad_seconds, + ) + + trimmed_audio = audio[start_sample:end_sample] + + return trimmed_audio, start_sample, end_sample + + +class VadAudioTrimmer(AudioTrimmer): + def __init__( + self, + model_name: str = "vad_multilingual_marblenet", + vad_sample_rate: int = 16000, + vad_threshold: float = 0.5, + device: str = "cpu", + speech_frame_threshold: int = 1, + trim_win_length: int = 4096, + trim_hop_length: int = 1024, + pad_seconds: float = 0.1, + volume_norm: bool = True, + ) -> None: + """Voice activity detection (VAD) based silence trimming. + + Args: + model_name: NeMo VAD model to load. Valid configurations can be found with + EncDecClassificationModel.list_available_models() + vad_sample_rate: Sample rate used for pretrained VAD model. + vad_threshold: Softmax probability [0, 1] of VAD output, above which audio frames will be classified + as speech. + device: Device "cpu" or "cuda" to use for running the VAD model. + trim_win_length: Length of audio frames to use when doing speech detection. This does not need to match + the win_length used any other part of the code or model. + trim_hop_length: Stride of audio frames to use when doing speech detection. This does not need to match + the hop_length used any other part of the code or model. + pad_seconds: Audio duration in seconds to keep before and after each speech segment. + Set this to at least 0.1 to avoid cutting off any speech audio, with larger values + being safer but increasing the average silence duration left afterwards. + volume_norm: Whether to normalize the volume of audio before doing speech detection. + """ + assert vad_sample_rate > 0 + assert vad_threshold >= 0 + assert speech_frame_threshold > 0 + assert trim_win_length > 0 + assert trim_hop_length > 0 + + self.device = device + self.vad_model = EncDecClassificationModel.from_pretrained(model_name=model_name).eval().to(self.device) + self.vad_sample_rate = vad_sample_rate + self.vad_threshold = vad_threshold + + self.speech_frame_threshold = speech_frame_threshold + self.trim_win_length = trim_win_length + self.trim_hop_length = trim_hop_length + # Window shift neeeded in order to center frames + self.trim_shift = self.trim_win_length // 2 + + self.pad_seconds = pad_seconds + self.volume_norm = volume_norm + + def _detect_speech(self, audio: np.array) -> np.array: + # [num_frames, win_length] + audio_frames = librosa.util.frame( + audio, frame_length=self.trim_win_length, hop_length=self.trim_hop_length + ).transpose() + audio_frame_lengths = audio_frames.shape[0] * [self.trim_win_length] + + # [num_frames, win_length] + audio_signal = torch.tensor(audio_frames, dtype=torch.float32, device=self.device) + # [1] + audio_signal_len = torch.tensor(audio_frame_lengths, dtype=torch.int32, device=self.device) + # VAD outputs 2 values for each audio frame with logits indicating the likelihood that + # each frame is non-speech or speech, respectively. + # [num_frames, 2] + log_probs = self.vad_model(input_signal=audio_signal, input_signal_length=audio_signal_len) + probs = torch.softmax(log_probs, dim=-1) + probs = probs.detach().cpu().numpy() + # [num_frames] + speech_probs = probs[:, 1] + speech_frames = speech_probs >= self.vad_threshold + + return speech_frames + + def _scale_sample_indices(self, start_sample: int, end_sample: int, sample_rate: int) -> Tuple[int, int]: + sample_rate_ratio = sample_rate / self.vad_sample_rate + start_sample = int(sample_rate_ratio * start_sample) + end_sample = int(sample_rate_ratio * end_sample) + return start_sample, end_sample + + def trim_audio(self, audio: np.array, sample_rate: int, audio_id: str = "") -> Tuple[np.array, int, int]: + if sample_rate == self.vad_sample_rate: + vad_audio = audio + else: + # Resample audio to match sample rate of VAD model + vad_audio = librosa.resample(audio, orig_sr=sample_rate, target_sr=self.vad_sample_rate) + + if self.volume_norm: + # Normalize volume so we have a fixed scale relative to the reference amplitude + vad_audio = normalize_volume(audio=vad_audio, volume_level=1.0) + + speech_frames = self._detect_speech(audio=vad_audio) + + start_frame, end_frame = get_start_and_end_of_speech_frames( + is_speech=speech_frames, speech_frame_threshold=self.speech_frame_threshold, audio_id=audio_id, + ) + + if start_frame == 0: + start_sample = 0 + else: + start_sample = librosa.core.frames_to_samples(start_frame, hop_length=self.trim_hop_length) + start_sample += self.trim_shift + + # Avoid trimming off the end because VAD model is not trained to classify partial end frames. + if end_frame == speech_frames.shape[0]: + end_sample = vad_audio.shape[0] + else: + end_sample = librosa.core.frames_to_samples(end_frame, hop_length=self.trim_hop_length) + end_sample += self.trim_shift + + if sample_rate != self.vad_sample_rate: + # Convert sample indices back to input sample rate + start_sample, end_sample = self._scale_sample_indices( + start_sample=start_sample, end_sample=end_sample, sample_rate=sample_rate + ) + + start_sample, end_sample = pad_sample_indices( + start_sample=start_sample, + end_sample=end_sample, + max_sample=audio.shape[0], + sample_rate=sample_rate, + pad_seconds=self.pad_seconds, + ) + + trimmed_audio = audio[start_sample:end_sample] + + return trimmed_audio, start_sample, end_sample + + +def get_start_and_end_of_speech_frames( + is_speech: np.array, speech_frame_threshold: int, audio_id: str = "" +) -> Tuple[int, int]: + """Finds the speech frames corresponding to the start and end of speech for an utterance. + Args: + is_speech: [num_frames] boolean array with true entries labeling speech frames. + speech_frame_threshold: The number of consecutive speech frames required to classify the speech boundaries. + audio_id: String identifier (eg. file name) used for logging. + + Returns integers representing the frame indices of the start (inclusive) and end (exclusive) of speech. + """ + num_frames = is_speech.shape[0] + + # Iterate forwards over the utterance until we find the first speech_frame_threshold consecutive speech frames. + start_frame = None + for i in range(0, num_frames - speech_frame_threshold + 1): + high_i = i + speech_frame_threshold + if all(is_speech[i:high_i]): + start_frame = i + break + + # Iterate backwards over the utterance until we find the last speech_frame_threshold consecutive speech frames. + end_frame = None + for i in range(num_frames, speech_frame_threshold - 1, -1): + low_i = i - speech_frame_threshold + if all(is_speech[low_i:i]): + end_frame = i + break + + if start_frame is None: + logging.warning(f"Could not find start of speech for '{audio_id}'") + start_frame = 0 + + if end_frame is None: + logging.warning(f"Could not find end of speech for '{audio_id}'") + end_frame = num_frames + + return start_frame, end_frame + + +def pad_sample_indices( + start_sample: int, end_sample: int, max_sample: int, sample_rate: int, pad_seconds: float +) -> Tuple[int, int]: + """Shift the input sample indices by pad_seconds in front and back within [0, max_sample] + Args: + start_sample: Start sample index + end_sample: End sample index + max_sample: Maximum sample index + sample_rate: Sample rate of audio + pad_seconds: Amount to pad/shift the indices by. + + Returns the sample indices after padding by the input amount. + """ + pad_samples = int(pad_seconds * sample_rate) + start_sample = start_sample - pad_samples + end_sample = end_sample + pad_samples + + start_sample = max(0, start_sample) + end_sample = min(max_sample, end_sample) + + return start_sample, end_sample diff --git a/nemo/collections/tts/data/data_utils.py b/nemo/collections/tts/data/data_utils.py new file mode 100644 index 000000000000..d002e089e312 --- /dev/null +++ b/nemo/collections/tts/data/data_utils.py @@ -0,0 +1,48 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from pathlib import Path +from typing import List + +import numpy as np + + +def read_manifest(manifest_path: Path) -> List[dict]: + """Read manifest file at the given path and convert it to a list of dictionary entries. + """ + with open(manifest_path, "r", encoding="utf-8") as manifest_f: + entries = [json.loads(line) for line in manifest_f] + return entries + + +def write_manifest(manifest_path: Path, entries: List[dict]) -> None: + """Convert input entries to JSON format and write them as a manifest at the given path. + """ + output_lines = [f"{json.dumps(entry, ensure_ascii=False)}\n" for entry in entries] + with open(manifest_path, "w", encoding="utf-8") as output_f: + output_f.writelines(output_lines) + + +def normalize_volume(audio: np.array, volume_level: float) -> np.array: + """Apply peak normalization to the input audio. + """ + if not (0.0 <= volume_level <= 1.0): + raise ValueError(f"Volume must be in range [0.0, 1.0], received {volume_level}") + + max_sample = np.max(np.abs(audio)) + if max_sample == 0: + return audio + + return volume_level * (audio / np.max(np.abs(audio))) diff --git a/scripts/dataset_processing/tts/audio_processing/config/preprocessing.yaml b/scripts/dataset_processing/tts/audio_processing/config/preprocessing.yaml new file mode 100644 index 000000000000..0392023d67b4 --- /dev/null +++ b/scripts/dataset_processing/tts/audio_processing/config/preprocessing.yaml @@ -0,0 +1,19 @@ +name: "preprocessing" + +data_base_dir: ??? + +defaults: + - trim: energy + +config: + _target_: scripts.dataset_processing.tts.audio_processing.preprocess_audio.AudioPreprocessingConfig + input_manifest: ${data_base_dir}/manifest.json + output_manifest: ${data_base_dir}/manifest_processed.json + output_dir: ${data_base_dir}/audio_processed + num_workers: -1 + max_entries: 0 + output_sample_rate: 0 + volume_level: 0.95 + min_duration: 0.5 + max_duration: 10.0 + filter_file: ${data_base_dir}/filtered_utts.json \ No newline at end of file diff --git a/scripts/dataset_processing/tts/audio_processing/config/trim/energy.yaml b/scripts/dataset_processing/tts/audio_processing/config/trim/energy.yaml new file mode 100644 index 000000000000..9ae633dd2037 --- /dev/null +++ b/scripts/dataset_processing/tts/audio_processing/config/trim/energy.yaml @@ -0,0 +1,7 @@ +_target_: nemo.collections.tts.data.audio_trimming.EnergyAudioTrimmer + +db_threshold: 50.0 +speech_frame_threshold: 3 +trim_win_length: 4096 +trim_hop_length: 1024 +pad_seconds: 0.2 \ No newline at end of file diff --git a/scripts/dataset_processing/tts/audio_processing/config/trim/vad.yaml b/scripts/dataset_processing/tts/audio_processing/config/trim/vad.yaml new file mode 100644 index 000000000000..3f91fd26044c --- /dev/null +++ b/scripts/dataset_processing/tts/audio_processing/config/trim/vad.yaml @@ -0,0 +1,10 @@ +_target_: nemo.collections.tts.data.audio_trimming.VadAudioTrimmer + +model_name: "vad_multilingual_marblenet" +vad_sample_rate: 16000 +vad_threshold: 0.5 +device: "cpu" +speech_frame_threshold: 3 +trim_win_length: 4096 +trim_hop_length: 1024 +pad_seconds: 0.2 \ No newline at end of file diff --git a/scripts/dataset_processing/tts/audio_processing/preprocess_audio.py b/scripts/dataset_processing/tts/audio_processing/preprocess_audio.py new file mode 100644 index 000000000000..128d311e04c0 --- /dev/null +++ b/scripts/dataset_processing/tts/audio_processing/preprocess_audio.py @@ -0,0 +1,181 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script is used to preprocess audio before TTS model training. + +It can be configured to do several processing steps such as silence trimming, volume normalization, +and duration filtering. + +These can be done separately through multiple executions of the script, or all at once to avoid saving +too many copies of the same audio. + +Most of these can also be done by the TTS data loader at training time, but doing them ahead of time +lets us implement more complex processing, validate the corectness of the output, and save on compute time. + +$ HYDRA_FULL_ERROR=1 python /scripts/dataset_processing/tts/audio_processing/preprocess_audio.py \ + --config-path=/scripts/dataset_processing/tts/audio_processing/config \ + --config-name=preprocessing.yaml \ + data_base_dir="/home/data" \ + config.num_workers=1 +""" + +import os +from dataclasses import dataclass +from pathlib import Path +from typing import Tuple + +import librosa +import soundfile as sf +from hydra.utils import instantiate +from joblib import Parallel, delayed +from tqdm import tqdm + +from nemo.collections.tts.data.audio_trimming import AudioTrimmer +from nemo.collections.tts.data.data_utils import normalize_volume, read_manifest, write_manifest +from nemo.collections.tts.torch.helpers import get_base_dir +from nemo.core.config import hydra_runner +from nemo.utils import logging + + +@dataclass +class AudioPreprocessingConfig: + # Input training manifest. + input_manifest: Path + # New training manifest after processing audio. + output_manifest: Path + # Directory to save processed audio to. + output_dir: Path + # Number of threads to use. -1 will use all available CPUs. + num_workers: int = -1 + # If provided, maximum number of entries in the manifest to process. + max_entries: int = 0 + # If provided, rate to resample the audio to. + output_sample_rate: int = 0 + # If provided, peak volume to normalize audio to. + volume_level: float = 0.0 + # If provided, filter out utterances shorter than min_duration. + min_duration: float = 0.0 + # If provided, filter out utterances longer than min_duration. + max_duration: float = float("inf") + # If provided, output filter_file will contain list of utterances filtered out. + filter_file: Path = None + + +def _process_entry( + entry: dict, + base_dir: Path, + output_dir: Path, + audio_trimmer: AudioTrimmer, + output_sample_rate: int, + volume_level: float, +) -> Tuple[dict, float, float]: + audio_filepath = Path(entry["audio_filepath"]) + rel_audio_path = audio_filepath.relative_to(base_dir) + input_path = os.path.join(base_dir, rel_audio_path) + output_path = os.path.join(output_dir, rel_audio_path) + + audio, sample_rate = librosa.load(input_path, sr=None) + + if audio_trimmer is not None: + audio_id = str(audio_filepath) + audio, start_i, end_i = audio_trimmer.trim_audio(audio=audio, sample_rate=sample_rate, audio_id=audio_id) + + if output_sample_rate is not None: + audio = librosa.resample(y=audio, orig_sr=sample_rate, target_sr=output_sample_rate) + sample_rate = output_sample_rate + + if volume_level: + audio = normalize_volume(audio, volume_level=volume_level) + + sf.write(file=output_path, data=audio, samplerate=sample_rate) + + original_duration = librosa.get_duration(filename=str(audio_filepath)) + output_duration = librosa.get_duration(filename=str(output_path)) + + entry["audio_filepath"] = output_path + entry["duration"] = output_duration + + return entry, original_duration, output_duration + + +@hydra_runner(config_path='config', config_name='preprocessing') +def main(cfg): + config = instantiate(cfg.config) + logging.info(f"Running audio preprocessing with config: {config}") + + input_manifest_path = Path(config.input_manifest) + output_manifest_path = Path(config.output_manifest) + output_dir = Path(config.output_dir) + num_workers = config.num_workers + max_entries = config.max_entries + output_sample_rate = config.output_sample_rate + volume_level = config.volume_level + min_duration = config.min_duration + max_duration = config.max_duration + filter_file = Path(config.filter_file) + + if cfg.trim: + audio_trimmer = instantiate(cfg.trim) + else: + audio_trimmer = None + + output_dir.mkdir(exist_ok=True, parents=True) + + entries = read_manifest(input_manifest_path) + if max_entries: + entries = entries[:max_entries] + + audio_paths = [entry["audio_filepath"] for entry in entries] + base_dir = get_base_dir(audio_paths) + + # 'threading' backend is required when parallelizing torch models. + job_outputs = Parallel(n_jobs=num_workers, backend='threading')( + delayed(_process_entry)( + entry=entry, + base_dir=base_dir, + output_dir=output_dir, + audio_trimmer=audio_trimmer, + output_sample_rate=output_sample_rate, + volume_level=volume_level, + ) + for entry in tqdm(entries) + ) + + output_entries = [] + filtered_entries = [] + original_durations = 0.0 + output_durations = 0.0 + for output_entry, original_duration, output_duration in job_outputs: + + if not min_duration <= output_duration <= max_duration: + if output_duration != original_duration: + output_entry["original_duration"] = original_duration + filtered_entries.append(output_entry) + continue + + original_durations += original_duration + output_durations += output_duration + output_entries.append(output_entry) + + write_manifest(manifest_path=output_manifest_path, entries=output_entries) + if filter_file: + write_manifest(manifest_path=filter_file, entries=filtered_entries) + + logging.info(f"Duration of original audio: {original_durations / 3600} hours") + logging.info(f"Duration of processed audio: {output_durations / 3600} hours") + + +if __name__ == "__main__": + main() diff --git a/tests/collections/tts/data/test_audio_trimming.py b/tests/collections/tts/data/test_audio_trimming.py new file mode 100644 index 000000000000..8ef1b79534c2 --- /dev/null +++ b/tests/collections/tts/data/test_audio_trimming.py @@ -0,0 +1,65 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import pytest + +from nemo.collections.tts.data.audio_trimming import get_start_and_end_of_speech_frames, pad_sample_indices + + +class TestAudioTrimming: + @pytest.mark.run_only_on('CPU') + @pytest.mark.unit + def test_get_start_and_end_of_speech_frames_frames(self): + # First speech frame is index 2 (inclusive) and last one is index 8 (exclusive). + is_speech = np.array([True, False, True, True, False, True, True, True, False, True, False]) + speech_frame_threshold = 2 + + start_frame, end_frame = get_start_and_end_of_speech_frames( + is_speech=is_speech, speech_frame_threshold=speech_frame_threshold + ) + + assert start_frame == 2 + assert end_frame == 8 + + @pytest.mark.run_only_on('CPU') + @pytest.mark.unit + def test_get_start_and_end_of_speech_frames_not_frames_found(self): + is_speech = np.array([False, True, True, False]) + speech_frame_threshold = 3 + + start_frame, end_frame = get_start_and_end_of_speech_frames( + is_speech=is_speech, speech_frame_threshold=speech_frame_threshold, audio_id="test" + ) + + assert start_frame == 0 + assert end_frame == 4 + + @pytest.mark.run_only_on('CPU') + @pytest.mark.unit + def test_pad_sample_indices(self): + start_sample, end_sample = pad_sample_indices( + start_sample=1000, end_sample=2000, max_sample=5000, sample_rate=100, pad_seconds=3 + ) + assert start_sample == 700 + assert end_sample == 2300 + + @pytest.mark.run_only_on('CPU') + @pytest.mark.unit + def test_pad_sample_indices_boundaries(self): + start_sample, end_sample = pad_sample_indices( + start_sample=100, end_sample=1000, max_sample=1150, sample_rate=100, pad_seconds=2 + ) + assert start_sample == 0 + assert end_sample == 1150 diff --git a/tests/collections/tts/data/test_data_utils.py b/tests/collections/tts/data/test_data_utils.py new file mode 100644 index 000000000000..ff86fc0e5c0a --- /dev/null +++ b/tests/collections/tts/data/test_data_utils.py @@ -0,0 +1,76 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import pytest + +from nemo.collections.tts.data.data_utils import normalize_volume + + +class TestDataUtils: + @pytest.mark.run_only_on('CPU') + @pytest.mark.unit + def test_normalize_volume(self): + input_audio = np.array([0.0, 0.1, 0.3, 0.5]) + expected_output = np.array([0.0, 0.18, 0.54, 0.9]) + + output_audio = normalize_volume(audio=input_audio, volume_level=0.9) + + np.testing.assert_array_almost_equal(output_audio, expected_output) + + @pytest.mark.run_only_on('CPU') + @pytest.mark.unit + def test_normalize_volume_negative_peak(self): + input_audio = np.array([0.0, 0.1, -0.3, -1.0, 0.5]) + expected_output = np.array([0.0, 0.05, -0.15, -0.5, 0.25]) + + output_audio = normalize_volume(audio=input_audio, volume_level=0.5) + + np.testing.assert_array_almost_equal(output_audio, expected_output) + + @pytest.mark.run_only_on('CPU') + @pytest.mark.unit + def test_normalize_volume_zero(self): + input_audio = np.array([0.0, 0.1, 0.3, 0.5]) + expected_output = np.array([0.0, 0.0, 0.0, 0.0]) + + output_audio = normalize_volume(audio=input_audio, volume_level=0.0) + + np.testing.assert_array_almost_equal(output_audio, expected_output) + + @pytest.mark.run_only_on('CPU') + @pytest.mark.unit + def test_normalize_volume_max(self): + input_audio = np.array([0.0, 0.1, 0.3, 0.5]) + expected_output = np.array([0.0, 0.2, 0.6, 1.0]) + + output_audio = normalize_volume(audio=input_audio, volume_level=1.0) + + np.testing.assert_array_almost_equal(output_audio, expected_output) + + @pytest.mark.run_only_on('CPU') + @pytest.mark.unit + def test_normalize_volume_zeros(self): + input_audio = np.array([0.0, 0.0, 0.0]) + + output_audio = normalize_volume(audio=input_audio, volume_level=0.5) + + np.testing.assert_array_almost_equal(input_audio, input_audio) + + @pytest.mark.run_only_on('CPU') + @pytest.mark.unit + def test_normalize_volume_out_of_range(self): + input_audio = np.array([0.0, 0.1, 0.3, 0.5]) + with pytest.raises(ValueError, match="Volume must be in range"): + normalize_volume(audio=input_audio, volume_level=2.0)