forked from r9y9/deepvoice3_pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
/
synthesis.py
167 lines (138 loc) · 6.11 KB
/
synthesis.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
# coding: utf-8
"""
Synthesis waveform from trained model.
usage: synthesis.py [options] <checkpoint> <text_list_file> <dst_dir>
options:
--hparams=<parmas> Hyper parameters [default: ].
--checkpoint-seq2seq=<path> Load seq2seq model from checkpoint path.
--checkpoint-postnet=<path> Load postnet model from checkpoint path.
--file-name-suffix=<s> File name suffix [default: ].
--max-decoder-steps=<N> Max decoder steps [default: 500].
--replace_pronunciation_prob=<N> Prob [default: 0.0].
--speaker_id=<id> Speaker ID (for multi-speaker model).
--output-html Output html for blog post.
-h, --help Show help message.
"""
from docopt import docopt
import sys
import os
from os.path import dirname, join, basename, splitext
import audio
import torch
from torch.autograd import Variable
import numpy as np
import nltk
# The deepvoice3 model
from deepvoice3_pytorch import frontend
from hparams import hparams
from tqdm import tqdm
use_cuda = torch.cuda.is_available()
_frontend = None # to be set later
def tts(model, text, p=0, speaker_id=None, fast=False):
"""Convert text to speech waveform given a deepvoice3 model.
Args:
text (str) : Input text to be synthesized
p (float) : Replace word to pronounciation if p > 0. Default is 0.
"""
if use_cuda:
model = model.cuda()
model.eval()
if fast:
model.make_generation_fast_()
sequence = np.array(_frontend.text_to_sequence(text, p=p))
sequence = Variable(torch.from_numpy(sequence)).unsqueeze(0)
text_positions = torch.arange(1, sequence.size(-1) + 1).unsqueeze(0).long()
text_positions = Variable(text_positions)
speaker_ids = None if speaker_id is None else Variable(torch.LongTensor([speaker_id]))
if use_cuda:
sequence = sequence.cuda()
text_positions = text_positions.cuda()
speaker_ids = None if speaker_ids is None else speaker_ids.cuda()
# Greedy decoding
mel_outputs, linear_outputs, alignments, done = model(
sequence, text_positions=text_positions, speaker_ids=speaker_ids)
linear_output = linear_outputs[0].cpu().data.numpy()
spectrogram = audio._denormalize(linear_output)
alignment = alignments[0].cpu().data.numpy()
mel = mel_outputs[0].cpu().data.numpy()
mel = audio._denormalize(mel)
# Predicted audio signal
waveform = audio.inv_spectrogram(linear_output.T)
return waveform, alignment, spectrogram, mel
if __name__ == "__main__":
args = docopt(__doc__)
print("Command line args:\n", args)
checkpoint_path = args["<checkpoint>"]
text_list_file_path = args["<text_list_file>"]
dst_dir = args["<dst_dir>"]
checkpoint_seq2seq_path = args["--checkpoint-seq2seq"]
checkpoint_postnet_path = args["--checkpoint-postnet"]
max_decoder_steps = int(args["--max-decoder-steps"])
file_name_suffix = args["--file-name-suffix"]
replace_pronunciation_prob = float(args["--replace_pronunciation_prob"])
output_html = args["--output-html"]
speaker_id = args["--speaker_id"]
if speaker_id is not None:
speaker_id = int(speaker_id)
# Override hyper parameters
hparams.parse(args["--hparams"])
assert hparams.name == "deepvoice3"
# Presets
if hparams.preset is not None and hparams.preset != "":
preset = hparams.presets[hparams.preset]
import json
hparams.parse_json(json.dumps(preset))
print("Override hyper parameters with preset \"{}\": {}".format(
hparams.preset, json.dumps(preset, indent=4)))
_frontend = getattr(frontend, hparams.frontend)
import train
train._frontend = _frontend
from train import plot_alignment, build_model
# Model
model = build_model()
# Load checkpoints separately
if checkpoint_postnet_path is not None and checkpoint_seq2seq_path is not None:
checkpoint = torch.load(checkpoint_seq2seq_path)
model.seq2seq.load_state_dict(checkpoint["state_dict"])
checkpoint = torch.load(checkpoint_postnet_path)
model.postnet.load_state_dict(checkpoint["state_dict"])
checkpoint_name = splitext(basename(checkpoint_seq2seq_path))[0]
else:
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint["state_dict"])
checkpoint_name = splitext(basename(checkpoint_path))[0]
model.seq2seq.decoder.max_decoder_steps = max_decoder_steps
os.makedirs(dst_dir, exist_ok=True)
with open(text_list_file_path, "rb") as f:
lines = f.readlines()
for idx, line in enumerate(lines):
text = line.decode("utf-8")[:-1]
words = nltk.word_tokenize(text)
waveform, alignment, _, _ = tts(
model, text, p=replace_pronunciation_prob, speaker_id=speaker_id, fast=True)
dst_wav_path = join(dst_dir, "{}_{}{}.wav".format(
idx, checkpoint_name, file_name_suffix))
dst_alignment_path = join(
dst_dir, "{}_{}{}_alignment.png".format(idx, checkpoint_name,
file_name_suffix))
plot_alignment(alignment.T, dst_alignment_path,
info="{}, {}".format(hparams.builder, basename(checkpoint_path)))
audio.save_wav(waveform, dst_wav_path)
from os.path import basename, splitext
name = splitext(basename(text_list_file_path))[0]
if output_html:
print("""
{}
({} chars, {} words)
<audio controls="controls" >
<source src="/audio/{}/{}/{}" autoplay/>
Your browser does not support the audio element.
</audio>
<div align="center"><img src="/audio/{}/{}/{}" /></div>
""".format(text, len(text), len(words),
hparams.builder, name, basename(dst_wav_path),
hparams.builder, name, basename(dst_alignment_path)))
else:
print(idx, ": {}\n ({} chars, {} words)".format(text, len(text), len(words)))
print("Finished! Check out {} for generated audio samples.".format(dst_dir))
sys.exit(0)