-
Notifications
You must be signed in to change notification settings - Fork 4.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add device support in TTS and Synthesizer #2855
Changes from all commits
711459c
af26ffd
739be29
7fbe8cb
5912b56
0bcc016
26c7a14
41de849
ef554f9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,19 +5,21 @@ | |
from torch import nn | ||
|
||
|
||
def numpy_to_torch(np_array, dtype, cuda=False): | ||
def numpy_to_torch(np_array, dtype, cuda=False, device="cpu"): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added new |
||
if cuda: | ||
device = "cuda" | ||
if np_array is None: | ||
return None | ||
tensor = torch.as_tensor(np_array, dtype=dtype) | ||
if cuda: | ||
return tensor.cuda() | ||
tensor = torch.as_tensor(np_array, dtype=dtype, device=device) | ||
return tensor | ||
|
||
|
||
def compute_style_mel(style_wav, ap, cuda=False): | ||
style_mel = torch.FloatTensor(ap.melspectrogram(ap.load_wav(style_wav, sr=ap.sample_rate))).unsqueeze(0) | ||
def compute_style_mel(style_wav, ap, cuda=False, device="cpu"): | ||
if cuda: | ||
return style_mel.cuda() | ||
device = "cuda" | ||
style_mel = torch.FloatTensor( | ||
ap.melspectrogram(ap.load_wav(style_wav, sr=ap.sample_rate)), device=device, | ||
).unsqueeze(0) | ||
return style_mel | ||
|
||
|
||
|
@@ -73,22 +75,22 @@ def inv_spectrogram(postnet_output, ap, CONFIG): | |
return wav | ||
|
||
|
||
def id_to_torch(aux_id, cuda=False): | ||
def id_to_torch(aux_id, cuda=False, device="cpu"): | ||
if cuda: | ||
device = "cuda" | ||
if aux_id is not None: | ||
aux_id = np.asarray(aux_id) | ||
aux_id = torch.from_numpy(aux_id) | ||
if cuda: | ||
return aux_id.cuda() | ||
aux_id = torch.from_numpy(aux_id).to(device) | ||
return aux_id | ||
|
||
|
||
def embedding_to_torch(d_vector, cuda=False): | ||
def embedding_to_torch(d_vector, cuda=False, device="cpu"): | ||
if cuda: | ||
device = "cuda" | ||
if d_vector is not None: | ||
d_vector = np.asarray(d_vector) | ||
d_vector = torch.from_numpy(d_vector).type(torch.FloatTensor) | ||
d_vector = d_vector.squeeze().unsqueeze(0) | ||
if cuda: | ||
return d_vector.cuda() | ||
d_vector = d_vector.squeeze().unsqueeze(0).to(device) | ||
return d_vector | ||
|
||
|
||
|
@@ -162,17 +164,22 @@ def synthesis( | |
language_id (int): | ||
Language ID passed to the language embedding layer in multi-langual model. Defaults to None. | ||
""" | ||
# device | ||
device = next(model.parameters()).device | ||
if use_cuda: | ||
device = "cuda" | ||
|
||
# GST or Capacitron processing | ||
# TODO: need to handle the case of setting both gst and capacitron to true somewhere | ||
style_mel = None | ||
if CONFIG.has("gst") and CONFIG.gst and style_wav is not None: | ||
if isinstance(style_wav, dict): | ||
style_mel = style_wav | ||
else: | ||
style_mel = compute_style_mel(style_wav, model.ap, cuda=use_cuda) | ||
style_mel = compute_style_mel(style_wav, model.ap, device=device) | ||
|
||
if CONFIG.has("capacitron_vae") and CONFIG.use_capacitron_vae and style_wav is not None: | ||
style_mel = compute_style_mel(style_wav, model.ap, cuda=use_cuda) | ||
style_mel = compute_style_mel(style_wav, model.ap, device=device) | ||
style_mel = style_mel.transpose(1, 2) # [1, time, depth] | ||
|
||
language_name = None | ||
|
@@ -188,26 +195,26 @@ def synthesis( | |
) | ||
# pass tensors to backend | ||
if speaker_id is not None: | ||
speaker_id = id_to_torch(speaker_id, cuda=use_cuda) | ||
speaker_id = id_to_torch(speaker_id, device=device) | ||
|
||
if d_vector is not None: | ||
d_vector = embedding_to_torch(d_vector, cuda=use_cuda) | ||
d_vector = embedding_to_torch(d_vector, device=device) | ||
|
||
if language_id is not None: | ||
language_id = id_to_torch(language_id, cuda=use_cuda) | ||
language_id = id_to_torch(language_id, device=device) | ||
|
||
if not isinstance(style_mel, dict): | ||
# GST or Capacitron style mel | ||
style_mel = numpy_to_torch(style_mel, torch.float, cuda=use_cuda) | ||
style_mel = numpy_to_torch(style_mel, torch.float, device=device) | ||
if style_text is not None: | ||
style_text = np.asarray( | ||
model.tokenizer.text_to_ids(style_text, language=language_id), | ||
dtype=np.int32, | ||
) | ||
style_text = numpy_to_torch(style_text, torch.long, cuda=use_cuda) | ||
style_text = numpy_to_torch(style_text, torch.long, device=device) | ||
style_text = style_text.unsqueeze(0) | ||
|
||
text_inputs = numpy_to_torch(text_inputs, torch.long, cuda=use_cuda) | ||
text_inputs = numpy_to_torch(text_inputs, torch.long, device=device) | ||
text_inputs = text_inputs.unsqueeze(0) | ||
# synthesize voice | ||
outputs = run_model_torch( | ||
|
@@ -290,22 +297,27 @@ def transfer_voice( | |
do_trim_silence (bool): | ||
trim silence after synthesis. Defaults to False. | ||
""" | ||
# device | ||
device = next(model.parameters()).device | ||
if use_cuda: | ||
device = "cuda" | ||
|
||
# pass tensors to backend | ||
if speaker_id is not None: | ||
speaker_id = id_to_torch(speaker_id, cuda=use_cuda) | ||
speaker_id = id_to_torch(speaker_id, device=device) | ||
|
||
if d_vector is not None: | ||
d_vector = embedding_to_torch(d_vector, cuda=use_cuda) | ||
d_vector = embedding_to_torch(d_vector, device=device) | ||
|
||
if reference_d_vector is not None: | ||
reference_d_vector = embedding_to_torch(reference_d_vector, cuda=use_cuda) | ||
reference_d_vector = embedding_to_torch(reference_d_vector, device=device) | ||
|
||
# load reference_wav audio | ||
reference_wav = embedding_to_torch( | ||
model.ap.load_wav( | ||
reference_wav, sr=model.args.encoder_sample_rate if model.args.encoder_sample_rate else model.ap.sample_rate | ||
), | ||
cuda=use_cuda, | ||
device=device, | ||
) | ||
|
||
if hasattr(model, "module"): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,6 +5,7 @@ | |
import numpy as np | ||
import pysbd | ||
import torch | ||
from torch import nn | ||
|
||
from TTS.config import load_config | ||
from TTS.tts.configs.vits_config import VitsConfig | ||
|
@@ -21,7 +22,7 @@ | |
from TTS.vocoder.utils.generic_utils import interpolate_vocoder_input | ||
|
||
|
||
class Synthesizer(object): | ||
class Synthesizer(nn.Module): | ||
def __init__( | ||
self, | ||
tts_checkpoint: str = "", | ||
|
@@ -60,6 +61,7 @@ def __init__( | |
vc_config (str, optional): path to the voice conversion config file. Defaults to `""`, | ||
use_cuda (bool, optional): enable/disable cuda. Defaults to False. | ||
""" | ||
super().__init__() | ||
self.tts_checkpoint = tts_checkpoint | ||
self.tts_config_path = tts_config_path | ||
self.tts_speakers_file = tts_speakers_file | ||
|
@@ -356,7 +358,12 @@ def tts( | |
if speaker_wav is not None and self.tts_model.speaker_manager is not None: | ||
speaker_embedding = self.tts_model.speaker_manager.compute_embedding_from_clip(speaker_wav) | ||
|
||
vocoder_device = "cpu" | ||
use_gl = self.vocoder_model is None | ||
if not use_gl: | ||
vocoder_device = next(self.vocoder_model.parameters()).device | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In some obscure use cases, the user could have placed the feature frontend and the vocoder on different devices. >>> tts.synthesizer.tts_model = tts.synthesizer.tts_model.to("cuda:0")
>>> tts.synthesizer.vocoder_model = tts.synthesizer.vocoder_model.to("cuda:1") We check the device of the vocoder, if it exists. |
||
if self.use_cuda: | ||
vocoder_device = "cuda" | ||
|
||
if not reference_wav: # not voice conversion | ||
for sen in sens: | ||
|
@@ -388,7 +395,6 @@ def tts( | |
mel_postnet_spec = outputs["outputs"]["model_outputs"][0].detach().cpu().numpy() | ||
# denormalize tts output based on tts audio config | ||
mel_postnet_spec = self.tts_model.ap.denormalize(mel_postnet_spec.T).T | ||
device_type = "cuda" if self.use_cuda else "cpu" | ||
# renormalize spectrogram based on vocoder config | ||
vocoder_input = self.vocoder_ap.normalize(mel_postnet_spec.T) | ||
# compute scale factor for possible sample rate mismatch | ||
|
@@ -403,8 +409,8 @@ def tts( | |
vocoder_input = torch.tensor(vocoder_input).unsqueeze(0) # pylint: disable=not-callable | ||
# run vocoder model | ||
# [1, T, C] | ||
waveform = self.vocoder_model.inference(vocoder_input.to(device_type)) | ||
if self.use_cuda and not use_gl: | ||
waveform = self.vocoder_model.inference(vocoder_input.to(vocoder_device)) | ||
if torch.is_tensor(waveform) and waveform.device != torch.device("cpu") and not use_gl: | ||
waveform = waveform.cpu() | ||
if not use_gl: | ||
waveform = waveform.numpy() | ||
|
@@ -453,7 +459,6 @@ def tts( | |
mel_postnet_spec = outputs[0].detach().cpu().numpy() | ||
# denormalize tts output based on tts audio config | ||
mel_postnet_spec = self.tts_model.ap.denormalize(mel_postnet_spec.T).T | ||
device_type = "cuda" if self.use_cuda else "cpu" | ||
# renormalize spectrogram based on vocoder config | ||
vocoder_input = self.vocoder_ap.normalize(mel_postnet_spec.T) | ||
# compute scale factor for possible sample rate mismatch | ||
|
@@ -468,8 +473,8 @@ def tts( | |
vocoder_input = torch.tensor(vocoder_input).unsqueeze(0) # pylint: disable=not-callable | ||
# run vocoder model | ||
# [1, T, C] | ||
waveform = self.vocoder_model.inference(vocoder_input.to(device_type)) | ||
if self.use_cuda: | ||
waveform = self.vocoder_model.inference(vocoder_input.to(vocoder_device)) | ||
if torch.is_tensor(waveform) and waveform.device != torch.device("cpu"): | ||
waveform = waveform.cpu() | ||
if not use_gl: | ||
waveform = waveform.numpy() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added warning. We could add specific dates or versions to better inform users about future plans, but I left it this way because I didn't have enough context on the future releases roadmap.