Skip to content

Commit

Permalink
Add audio length sampler balancer
Browse files Browse the repository at this point in the history
  • Loading branch information
Edresson committed May 7, 2022
1 parent 5048aeb commit 074ee35
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 0 deletions.
10 changes: 10 additions & 0 deletions TTS/tts/configs/shared_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,14 @@ class BaseTTSConfig(BaseTrainingConfig):
language_weighted_sampler_alpha (float):
Number that control the influence of the language sampler weights. Defaults to ```1.0```.
use_length_weighted_sampler (bool):
Enable / Disable the batch balancer by audio length. If enabled the dataset will be divided
into 10 buckets considering the min and max audio of the dataset. The sampler weights will be
computed forcing to have the same quantity of data for each bucket in each training batch. Defaults to ```False```.
length_weighted_sampler_alpha (float):
Number that control the influence of the length sampler weights. Defaults to ```1.0```.
"""

audio: BaseAudioConfig = field(default_factory=BaseAudioConfig)
Expand Down Expand Up @@ -279,3 +287,5 @@ class BaseTTSConfig(BaseTrainingConfig):
speaker_weighted_sampler_alpha: float = 1.0
use_language_weighted_sampler: bool = False
language_weighted_sampler_alpha: float = 1.0
use_length_weighted_sampler: bool = False
length_weighted_sampler_alpha: float = 1.0
9 changes: 9 additions & 0 deletions TTS/tts/models/base_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from TTS.model import BaseTrainerModel
from TTS.tts.datasets.dataset import TTSDataset
from TTS.tts.utils.data import get_length_balancer_weights
from TTS.tts.utils.languages import LanguageManager, get_language_balancer_weights
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_balancer_weights, get_speaker_manager
from TTS.tts.utils.synthesis import synthesis
Expand Down Expand Up @@ -250,6 +251,14 @@ def get_sampler(self, config: Coqpit, dataset: TTSDataset, num_gpus=1):
else:
weights = get_speaker_balancer_weights(data_items) * alpha

if getattr(config, "use_length_weighted_sampler", False):
alpha = getattr(config, "length_weighted_sampler_alpha", 1.0)
print(" > Using Length weighted sampler with alpha:", alpha)
if weights is not None:
weights += get_length_balancer_weights(data_items) * alpha
else:
weights = get_length_balancer_weights(data_items) * alpha

if weights is not None:
sampler = WeightedRandomSampler(weights, len(weights))
else:
Expand Down
26 changes: 26 additions & 0 deletions TTS/tts/utils/data.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import bisect

import numpy as np
import torch


def _pad_data(x, length):
Expand Down Expand Up @@ -51,3 +54,26 @@ def prepare_stop_target(inputs, out_steps):

def pad_per_step(inputs, pad_len):
return np.pad(inputs, [[0, 0], [0, 0], [0, pad_len]], mode="constant", constant_values=0.0)


def get_length_balancer_weights(items: list, num_buckets=10):
# get all durations
audio_lengths = np.array([item["audio_length"] for item in items])
# create the $num_buckets buckets classes based in the dataset max and min length
max_length = int(max(audio_lengths))
min_length = int(min(audio_lengths))
step = int((max_length - min_length) / num_buckets)
buckets_classes = [i + step for i in range(min_length, max_length, step)]
# add each sample in their respective length bucket
buckets_names = np.array(
[buckets_classes[bisect.bisect_left(buckets_classes, item["audio_length"])] for item in items]
)
# count and compute the weights_bucket for each sample
unique_buckets_names = np.unique(buckets_names).tolist()
bucket_ids = [unique_buckets_names.index(l) for l in buckets_names]
bucket_count = np.array([len(np.where(buckets_names == l)[0]) for l in unique_buckets_names])
weight_bucket = 1.0 / bucket_count
dataset_samples_weight = np.array([weight_bucket[l] for l in bucket_ids])
# normalize
dataset_samples_weight = dataset_samples_weight / np.linalg.norm(dataset_samples_weight)
return torch.from_numpy(dataset_samples_weight).float()

0 comments on commit 074ee35

Please sign in to comment.