diff --git a/TTS/bin/tune_wavegrad.py b/TTS/bin/tune_wavegrad.py index a31d6c4548..09582cea7c 100644 --- a/TTS/bin/tune_wavegrad.py +++ b/TTS/bin/tune_wavegrad.py @@ -1,4 +1,4 @@ -"""Search a good noise schedule for WaveGrad for a given number of inferece iterations""" +"""Search a good noise schedule for WaveGrad for a given number of inference iterations""" import argparse from itertools import product as cartesian_product @@ -7,94 +7,97 @@ from torch.utils.data import DataLoader from tqdm import tqdm +from TTS.config import load_config from TTS.utils.audio import AudioProcessor -from TTS.utils.io import load_config from TTS.vocoder.datasets.preprocess import load_wav_data from TTS.vocoder.datasets.wavegrad_dataset import WaveGradDataset -from TTS.vocoder.utils.generic_utils import setup_generator +from TTS.vocoder.models import setup_model -parser = argparse.ArgumentParser() -parser.add_argument("--model_path", type=str, help="Path to model checkpoint.") -parser.add_argument("--config_path", type=str, help="Path to model config file.") -parser.add_argument("--data_path", type=str, help="Path to data directory.") -parser.add_argument("--output_path", type=str, help="path for output file including file name and extension.") -parser.add_argument( - "--num_iter", type=int, help="Number of model inference iterations that you like to optimize noise schedule for." -) -parser.add_argument("--use_cuda", type=bool, help="enable/disable CUDA.") -parser.add_argument("--num_samples", type=int, default=1, help="Number of datasamples used for inference.") -parser.add_argument( - "--search_depth", - type=int, - default=3, - help="Search granularity. Increasing this increases the run-time exponentially.", -) +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model_path", type=str, help="Path to model checkpoint.") + parser.add_argument("--config_path", type=str, help="Path to model config file.") + parser.add_argument("--data_path", type=str, help="Path to data directory.") + parser.add_argument("--output_path", type=str, help="path for output file including file name and extension.") + parser.add_argument( + "--num_iter", + type=int, + help="Number of model inference iterations that you like to optimize noise schedule for.", + ) + parser.add_argument("--use_cuda", action="store_true", help="enable CUDA.") + parser.add_argument("--num_samples", type=int, default=1, help="Number of datasamples used for inference.") + parser.add_argument( + "--search_depth", + type=int, + default=3, + help="Search granularity. Increasing this increases the run-time exponentially.", + ) -# load config -args = parser.parse_args() -config = load_config(args.config_path) + # load config + args = parser.parse_args() + config = load_config(args.config_path) -# setup audio processor -ap = AudioProcessor(**config.audio) + # setup audio processor + ap = AudioProcessor(**config.audio) -# load dataset -_, train_data = load_wav_data(args.data_path, 0) -train_data = train_data[: args.num_samples] -dataset = WaveGradDataset( - ap=ap, - items=train_data, - seq_len=-1, - hop_len=ap.hop_length, - pad_short=config.pad_short, - conv_pad=config.conv_pad, - is_training=True, - return_segments=False, - use_noise_augment=False, - use_cache=False, - verbose=True, -) -loader = DataLoader( - dataset, - batch_size=1, - shuffle=False, - collate_fn=dataset.collate_full_clips, - drop_last=False, - num_workers=config.num_loader_workers, - pin_memory=False, -) + # load dataset + _, train_data = load_wav_data(args.data_path, 0) + train_data = train_data[: args.num_samples] + dataset = WaveGradDataset( + ap=ap, + items=train_data, + seq_len=-1, + hop_len=ap.hop_length, + pad_short=config.pad_short, + conv_pad=config.conv_pad, + is_training=True, + return_segments=False, + use_noise_augment=False, + use_cache=False, + verbose=True, + ) + loader = DataLoader( + dataset, + batch_size=1, + shuffle=False, + collate_fn=dataset.collate_full_clips, + drop_last=False, + num_workers=config.num_loader_workers, + pin_memory=False, + ) -# setup the model -model = setup_generator(config) -if args.use_cuda: - model.cuda() + # setup the model + model = setup_model(config) + if args.use_cuda: + model.cuda() -# setup optimization parameters -base_values = sorted(10 * np.random.uniform(size=args.search_depth)) -print(base_values) -exponents = 10 ** np.linspace(-6, -1, num=args.num_iter) -best_error = float("inf") -best_schedule = None -total_search_iter = len(base_values) ** args.num_iter -for base in tqdm(cartesian_product(base_values, repeat=args.num_iter), total=total_search_iter): - beta = exponents * base - model.compute_noise_level(beta) - for data in loader: - mel, audio = data - y_hat = model.inference(mel.cuda() if args.use_cuda else mel) + # setup optimization parameters + base_values = sorted(10 * np.random.uniform(size=args.search_depth)) + print(f" > base values: {base_values}") + exponents = 10 ** np.linspace(-6, -1, num=args.num_iter) + best_error = float("inf") + best_schedule = None # pylint: disable=C0103 + total_search_iter = len(base_values) ** args.num_iter + for base in tqdm(cartesian_product(base_values, repeat=args.num_iter), total=total_search_iter): + beta = exponents * base + model.compute_noise_level(beta) + for data in loader: + mel, audio = data + y_hat = model.inference(mel.cuda() if args.use_cuda else mel) - if args.use_cuda: - y_hat = y_hat.cpu() - y_hat = y_hat.numpy() + if args.use_cuda: + y_hat = y_hat.cpu() + y_hat = y_hat.numpy() - mel_hat = [] - for i in range(y_hat.shape[0]): - m = ap.melspectrogram(y_hat[i, 0])[:, :-1] - mel_hat.append(torch.from_numpy(m)) + mel_hat = [] + for i in range(y_hat.shape[0]): + m = ap.melspectrogram(y_hat[i, 0])[:, :-1] + mel_hat.append(torch.from_numpy(m)) - mel_hat = torch.stack(mel_hat) - mse = torch.sum((mel - mel_hat) ** 2).mean() - if mse.item() < best_error: - best_error = mse.item() - best_schedule = {"beta": beta} - print(f" > Found a better schedule. - MSE: {mse.item()}") - np.save(args.output_path, best_schedule) + mel_hat = torch.stack(mel_hat) + mse = torch.sum((mel - mel_hat) ** 2).mean() + if mse.item() < best_error: + best_error = mse.item() + best_schedule = {"beta": beta} + print(f" > Found a better schedule. - MSE: {mse.item()}") + np.save(args.output_path, best_schedule) diff --git a/TTS/config/__init__.py b/TTS/config/__init__.py index 6b0778c5a7..067c32d97d 100644 --- a/TTS/config/__init__.py +++ b/TTS/config/__init__.py @@ -62,7 +62,7 @@ def _process_model_name(config_dict: Dict) -> str: return model_name -def load_config(config_path: str) -> None: +def load_config(config_path: str) -> Coqpit: """Import `json` or `yaml` files as TTS configs. First, load the input file as a `dict` and check the model name to find the corresponding Config class. Then initialize the Config. diff --git a/TTS/vocoder/datasets/wavegrad_dataset.py b/TTS/vocoder/datasets/wavegrad_dataset.py index 05e0fae887..d941eab33e 100644 --- a/TTS/vocoder/datasets/wavegrad_dataset.py +++ b/TTS/vocoder/datasets/wavegrad_dataset.py @@ -149,4 +149,4 @@ def collate_full_clips(batch): mels[idx, :, : mel.shape[1]] = mel audios[idx, : audio.shape[0]] = audio - return audios, mels + return mels, audios