Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add audio length sampler balancer #1561

Merged
merged 2 commits into from
May 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions TTS/tts/configs/shared_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,14 @@ class BaseTTSConfig(BaseTrainingConfig):

language_weighted_sampler_alpha (float):
Number that control the influence of the language sampler weights. Defaults to ```1.0```.

use_length_weighted_sampler (bool):
Enable / Disable the batch balancer by audio length. If enabled the dataset will be divided
into 10 buckets considering the min and max audio of the dataset. The sampler weights will be
computed forcing to have the same quantity of data for each bucket in each training batch. Defaults to ```False```.

length_weighted_sampler_alpha (float):
Number that control the influence of the length sampler weights. Defaults to ```1.0```.
"""

audio: BaseAudioConfig = field(default_factory=BaseAudioConfig)
Expand Down Expand Up @@ -279,3 +287,5 @@ class BaseTTSConfig(BaseTrainingConfig):
speaker_weighted_sampler_alpha: float = 1.0
use_language_weighted_sampler: bool = False
language_weighted_sampler_alpha: float = 1.0
use_length_weighted_sampler: bool = False
length_weighted_sampler_alpha: float = 1.0
9 changes: 9 additions & 0 deletions TTS/tts/models/base_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from TTS.model import BaseTrainerModel
from TTS.tts.datasets.dataset import TTSDataset
from TTS.tts.utils.data import get_length_balancer_weights
from TTS.tts.utils.languages import LanguageManager, get_language_balancer_weights
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_balancer_weights, get_speaker_manager
from TTS.tts.utils.synthesis import synthesis
Expand Down Expand Up @@ -250,6 +251,14 @@ def get_sampler(self, config: Coqpit, dataset: TTSDataset, num_gpus=1):
else:
weights = get_speaker_balancer_weights(data_items) * alpha

if getattr(config, "use_length_weighted_sampler", False):
alpha = getattr(config, "length_weighted_sampler_alpha", 1.0)
print(" > Using Length weighted sampler with alpha:", alpha)
erogol marked this conversation as resolved.
Show resolved Hide resolved
if weights is not None:
weights += get_length_balancer_weights(data_items) * alpha
else:
weights = get_length_balancer_weights(data_items) * alpha

if weights is not None:
sampler = WeightedRandomSampler(weights, len(weights))
else:
Expand Down
26 changes: 26 additions & 0 deletions TTS/tts/utils/data.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import bisect

import numpy as np
import torch


def _pad_data(x, length):
Expand Down Expand Up @@ -51,3 +54,26 @@ def prepare_stop_target(inputs, out_steps):

def pad_per_step(inputs, pad_len):
return np.pad(inputs, [[0, 0], [0, 0], [0, pad_len]], mode="constant", constant_values=0.0)


def get_length_balancer_weights(items: list, num_buckets=10):
# get all durations
audio_lengths = np.array([item["audio_length"] for item in items])
# create the $num_buckets buckets classes based in the dataset max and min length
max_length = int(max(audio_lengths))
min_length = int(min(audio_lengths))
step = int((max_length - min_length) / num_buckets) + 1
buckets_classes = [i + step for i in range(min_length, (max_length - step) + num_buckets + 1, step)]
# add each sample in their respective length bucket
buckets_names = np.array(
[buckets_classes[bisect.bisect_left(buckets_classes, item["audio_length"])] for item in items]
)
# count and compute the weights_bucket for each sample
unique_buckets_names = np.unique(buckets_names).tolist()
bucket_ids = [unique_buckets_names.index(l) for l in buckets_names]
bucket_count = np.array([len(np.where(buckets_names == l)[0]) for l in unique_buckets_names])
weight_bucket = 1.0 / bucket_count
dataset_samples_weight = np.array([weight_bucket[l] for l in bucket_ids])
# normalize
dataset_samples_weight = dataset_samples_weight / np.linalg.norm(dataset_samples_weight)
return torch.from_numpy(dataset_samples_weight).float()
27 changes: 27 additions & 0 deletions tests/data_tests/test_samplers.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import functools
import random
import unittest

import torch

from TTS.config.shared_configs import BaseDatasetConfig
from TTS.encoder.utils.samplers import PerfectBatchSampler
from TTS.tts.datasets import load_tts_samples
from TTS.tts.utils.data import get_length_balancer_weights
from TTS.tts.utils.languages import get_language_balancer_weights
from TTS.tts.utils.speakers import get_speaker_balancer_weights

Expand Down Expand Up @@ -136,3 +138,28 @@ def test_perfect_sampler_shuffle(self): # pylint: disable=no-self-use
else:
spk2 += 1
assert spk1 == spk2, "PerfectBatchSampler is supposed to be perfectly balanced"

def test_length_weighted_random_sampler(self): # pylint: disable=no-self-use
for _ in range(1000):
# gerenate a lenght unbalanced dataset with random max/min audio lenght
min_audio = random.randrange(1, 22050)
max_audio = random.randrange(44100, 220500)
for idx, item in enumerate(train_samples):
# increase the diversity of durations
random_increase = random.randrange(100, 1000)
if idx < 5:
item["audio_length"] = min_audio + random_increase
else:
item["audio_length"] = max_audio + random_increase

weighted_sampler = torch.utils.data.sampler.WeightedRandomSampler(
get_length_balancer_weights(train_samples, num_buckets=2), len(train_samples)
)
ids = functools.reduce(lambda a, b: a + b, [list(weighted_sampler) for i in range(100)])
len1, len2 = 0, 0
for index in ids:
if train_samples[index]["audio_length"] < max_audio:
len1 += 1
else:
len2 += 1
assert is_balanced(len1, len2), "Length Weighted sampler is supposed to be balanced"