Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add device support in TTS and Synthesizer #2855

Merged
merged 9 commits into from
Aug 14, 2023
8 changes: 7 additions & 1 deletion TTS/api.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
import tempfile
import warnings
from pathlib import Path
from typing import Union

import numpy as np
from torch import nn

from TTS.cs_api import CS_API
from TTS.utils.audio.numpy_transforms import save_wav
from TTS.utils.manage import ModelManager
from TTS.utils.synthesizer import Synthesizer


class TTS:
class TTS(nn.Module):
"""TODO: Add voice conversion and Capacitron support."""

def __init__(
Expand Down Expand Up @@ -62,6 +64,7 @@ def __init__(
Defaults to "XTTS".
gpu (bool, optional): Enable/disable GPU. Some models might be too slow on CPU. Defaults to False.
"""
super().__init__()
self.manager = ModelManager(models_file=self.get_models_file_path(), progress_bar=progress_bar, verbose=False)

self.synthesizer = None
Expand All @@ -70,6 +73,9 @@ def __init__(
self.cs_api_model = cs_api_model
self.model_name = None

if gpu:
warnings.warn("`gpu` will be deprecated. Please use `tts.to(device)` instead.")
Copy link
Contributor Author

@jaketae jaketae Aug 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added warning. We could add specific dates or versions to better inform users about future plans, but I left it this way because I didn't have enough context on the future releases roadmap.


if model_name is not None:
if "tts_models" in model_name or "coqui_studio" in model_name:
self.load_tts_model_by_name(model_name, gpu)
Expand Down
66 changes: 39 additions & 27 deletions TTS/tts/utils/synthesis.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,21 @@
from torch import nn


def numpy_to_torch(np_array, dtype, cuda=False):
def numpy_to_torch(np_array, dtype, cuda=False, device="cpu"):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added new device argument to functions called in Synthesizer. To retain backwards compatibility, we keep the cuda argument for now; we should probably clean them up in the future and provide a single way of configuring device/enabling CUDA.

if cuda:
device = "cuda"
if np_array is None:
return None
tensor = torch.as_tensor(np_array, dtype=dtype)
if cuda:
return tensor.cuda()
tensor = torch.as_tensor(np_array, dtype=dtype, device=device)
return tensor


def compute_style_mel(style_wav, ap, cuda=False):
style_mel = torch.FloatTensor(ap.melspectrogram(ap.load_wav(style_wav, sr=ap.sample_rate))).unsqueeze(0)
def compute_style_mel(style_wav, ap, cuda=False, device="cpu"):
if cuda:
return style_mel.cuda()
device = "cuda"
style_mel = torch.FloatTensor(
ap.melspectrogram(ap.load_wav(style_wav, sr=ap.sample_rate)), device=device,
).unsqueeze(0)
return style_mel


Expand Down Expand Up @@ -73,22 +75,22 @@ def inv_spectrogram(postnet_output, ap, CONFIG):
return wav


def id_to_torch(aux_id, cuda=False):
def id_to_torch(aux_id, cuda=False, device="cpu"):
if cuda:
device = "cuda"
if aux_id is not None:
aux_id = np.asarray(aux_id)
aux_id = torch.from_numpy(aux_id)
if cuda:
return aux_id.cuda()
aux_id = torch.from_numpy(aux_id).to(device)
return aux_id


def embedding_to_torch(d_vector, cuda=False):
def embedding_to_torch(d_vector, cuda=False, device="cpu"):
if cuda:
device = "cuda"
if d_vector is not None:
d_vector = np.asarray(d_vector)
d_vector = torch.from_numpy(d_vector).type(torch.FloatTensor)
d_vector = d_vector.squeeze().unsqueeze(0)
if cuda:
return d_vector.cuda()
d_vector = d_vector.squeeze().unsqueeze(0).to(device)
return d_vector


Expand Down Expand Up @@ -162,17 +164,22 @@ def synthesis(
language_id (int):
Language ID passed to the language embedding layer in multi-langual model. Defaults to None.
"""
# device
device = next(model.parameters()).device
if use_cuda:
device = "cuda"

# GST or Capacitron processing
# TODO: need to handle the case of setting both gst and capacitron to true somewhere
style_mel = None
if CONFIG.has("gst") and CONFIG.gst and style_wav is not None:
if isinstance(style_wav, dict):
style_mel = style_wav
else:
style_mel = compute_style_mel(style_wav, model.ap, cuda=use_cuda)
style_mel = compute_style_mel(style_wav, model.ap, device=device)

if CONFIG.has("capacitron_vae") and CONFIG.use_capacitron_vae and style_wav is not None:
style_mel = compute_style_mel(style_wav, model.ap, cuda=use_cuda)
style_mel = compute_style_mel(style_wav, model.ap, device=device)
style_mel = style_mel.transpose(1, 2) # [1, time, depth]

language_name = None
Expand All @@ -188,26 +195,26 @@ def synthesis(
)
# pass tensors to backend
if speaker_id is not None:
speaker_id = id_to_torch(speaker_id, cuda=use_cuda)
speaker_id = id_to_torch(speaker_id, device=device)

if d_vector is not None:
d_vector = embedding_to_torch(d_vector, cuda=use_cuda)
d_vector = embedding_to_torch(d_vector, device=device)

if language_id is not None:
language_id = id_to_torch(language_id, cuda=use_cuda)
language_id = id_to_torch(language_id, device=device)

if not isinstance(style_mel, dict):
# GST or Capacitron style mel
style_mel = numpy_to_torch(style_mel, torch.float, cuda=use_cuda)
style_mel = numpy_to_torch(style_mel, torch.float, device=device)
if style_text is not None:
style_text = np.asarray(
model.tokenizer.text_to_ids(style_text, language=language_id),
dtype=np.int32,
)
style_text = numpy_to_torch(style_text, torch.long, cuda=use_cuda)
style_text = numpy_to_torch(style_text, torch.long, device=device)
style_text = style_text.unsqueeze(0)

text_inputs = numpy_to_torch(text_inputs, torch.long, cuda=use_cuda)
text_inputs = numpy_to_torch(text_inputs, torch.long, device=device)
text_inputs = text_inputs.unsqueeze(0)
# synthesize voice
outputs = run_model_torch(
Expand Down Expand Up @@ -290,22 +297,27 @@ def transfer_voice(
do_trim_silence (bool):
trim silence after synthesis. Defaults to False.
"""
# device
device = next(model.parameters()).device
if use_cuda:
device = "cuda"

# pass tensors to backend
if speaker_id is not None:
speaker_id = id_to_torch(speaker_id, cuda=use_cuda)
speaker_id = id_to_torch(speaker_id, device=device)

if d_vector is not None:
d_vector = embedding_to_torch(d_vector, cuda=use_cuda)
d_vector = embedding_to_torch(d_vector, device=device)

if reference_d_vector is not None:
reference_d_vector = embedding_to_torch(reference_d_vector, cuda=use_cuda)
reference_d_vector = embedding_to_torch(reference_d_vector, device=device)

# load reference_wav audio
reference_wav = embedding_to_torch(
model.ap.load_wav(
reference_wav, sr=model.args.encoder_sample_rate if model.args.encoder_sample_rate else model.ap.sample_rate
),
cuda=use_cuda,
device=device,
)

if hasattr(model, "module"):
Expand Down
19 changes: 12 additions & 7 deletions TTS/utils/synthesizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import numpy as np
import pysbd
import torch
from torch import nn

from TTS.config import load_config
from TTS.tts.configs.vits_config import VitsConfig
Expand All @@ -21,7 +22,7 @@
from TTS.vocoder.utils.generic_utils import interpolate_vocoder_input


class Synthesizer(object):
class Synthesizer(nn.Module):
def __init__(
self,
tts_checkpoint: str = "",
Expand Down Expand Up @@ -60,6 +61,7 @@ def __init__(
vc_config (str, optional): path to the voice conversion config file. Defaults to `""`,
use_cuda (bool, optional): enable/disable cuda. Defaults to False.
"""
super().__init__()
self.tts_checkpoint = tts_checkpoint
self.tts_config_path = tts_config_path
self.tts_speakers_file = tts_speakers_file
Expand Down Expand Up @@ -356,7 +358,12 @@ def tts(
if speaker_wav is not None and self.tts_model.speaker_manager is not None:
speaker_embedding = self.tts_model.speaker_manager.compute_embedding_from_clip(speaker_wav)

vocoder_device = "cpu"
use_gl = self.vocoder_model is None
if not use_gl:
vocoder_device = next(self.vocoder_model.parameters()).device
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In some obscure use cases, the user could have placed the feature frontend and the vocoder on different devices.

>>> tts.synthesizer.tts_model = tts.synthesizer.tts_model.to("cuda:0")
>>> tts.synthesizer.vocoder_model = tts.synthesizer.vocoder_model.to("cuda:1")

We check the device of the vocoder, if it exists.

if self.use_cuda:
vocoder_device = "cuda"

if not reference_wav: # not voice conversion
for sen in sens:
Expand Down Expand Up @@ -388,7 +395,6 @@ def tts(
mel_postnet_spec = outputs["outputs"]["model_outputs"][0].detach().cpu().numpy()
# denormalize tts output based on tts audio config
mel_postnet_spec = self.tts_model.ap.denormalize(mel_postnet_spec.T).T
device_type = "cuda" if self.use_cuda else "cpu"
# renormalize spectrogram based on vocoder config
vocoder_input = self.vocoder_ap.normalize(mel_postnet_spec.T)
# compute scale factor for possible sample rate mismatch
Expand All @@ -403,8 +409,8 @@ def tts(
vocoder_input = torch.tensor(vocoder_input).unsqueeze(0) # pylint: disable=not-callable
# run vocoder model
# [1, T, C]
waveform = self.vocoder_model.inference(vocoder_input.to(device_type))
if self.use_cuda and not use_gl:
waveform = self.vocoder_model.inference(vocoder_input.to(vocoder_device))
if torch.is_tensor(waveform) and waveform.device != torch.device("cpu") and not use_gl:
waveform = waveform.cpu()
if not use_gl:
waveform = waveform.numpy()
Expand Down Expand Up @@ -453,7 +459,6 @@ def tts(
mel_postnet_spec = outputs[0].detach().cpu().numpy()
# denormalize tts output based on tts audio config
mel_postnet_spec = self.tts_model.ap.denormalize(mel_postnet_spec.T).T
device_type = "cuda" if self.use_cuda else "cpu"
# renormalize spectrogram based on vocoder config
vocoder_input = self.vocoder_ap.normalize(mel_postnet_spec.T)
# compute scale factor for possible sample rate mismatch
Expand All @@ -468,8 +473,8 @@ def tts(
vocoder_input = torch.tensor(vocoder_input).unsqueeze(0) # pylint: disable=not-callable
# run vocoder model
# [1, T, C]
waveform = self.vocoder_model.inference(vocoder_input.to(device_type))
if self.use_cuda:
waveform = self.vocoder_model.inference(vocoder_input.to(vocoder_device))
if torch.is_tensor(waveform) and waveform.device != torch.device("cpu"):
waveform = waveform.cpu()
if not use_gl:
waveform = waveform.numpy()
Expand Down