Skip to content

Commit

Permalink
feat: add adjust_speech_rate function to modify speech speed with mor…
Browse files Browse the repository at this point in the history
…e durable latents. also missed tts speed implementations added.
  • Loading branch information
isikhi committed Dec 28, 2024
1 parent dbf1a08 commit 26128be
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 8 deletions.
3 changes: 2 additions & 1 deletion TTS/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,7 @@ def tts(
style_text=None,
reference_speaker_name=None,
split_sentences=split_sentences,
speed=speed,
**kwargs,
)
return wav
Expand Down Expand Up @@ -330,13 +331,13 @@ def tts_to_file(
Additional arguments for the model.
"""
self._check_arguments(speaker=speaker, language=language, speaker_wav=speaker_wav, **kwargs)

wav = self.tts(
text=text,
speaker=speaker,
language=language,
speaker_wav=speaker_wav,
split_sentences=split_sentences,
speed=speed,
**kwargs,
)
self.synthesizer.save_wav(wav=wav, path=file_path, pipe_out=pipe_out)
Expand Down
28 changes: 28 additions & 0 deletions TTS/tts/models/base_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import torch
import torch.distributed as dist
import torch.nn.functional as F
from coqpit import Coqpit
from torch import nn
from torch.utils.data import DataLoader
Expand Down Expand Up @@ -76,6 +77,33 @@ def _set_model_args(self, config: Coqpit):
else:
raise ValueError("config must be either a *Config or *Args")

def adjust_speech_rate(self, gpt_latents, length_scale):
if abs(length_scale - 1.0) < 1e-6:
return gpt_latents

B, L, D = gpt_latents.shape
target_length = int(L * length_scale)

assert target_length > 0, f"Invalid target length: {target_length}"

try:
resized = F.interpolate(
gpt_latents.transpose(1, 2),
size=target_length,
mode="linear",
align_corners=True
).transpose(1, 2)

if torch.isnan(resized).any():
print("Warning: NaN values detected on adjust speech rate")
return gpt_latents

return resized

except RuntimeError as e:
print(f"Interpolation failed: {e}")
return gpt_latents

def init_multispeaker(self, config: Coqpit, data: List = None):
"""Initialize a speaker embedding layer if needen and define expected embedding channel size for defining
`in_channels` size of the connected layers.
Expand Down
13 changes: 6 additions & 7 deletions TTS/tts/models/xtts.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ def get_conditioning_latents(

return gpt_cond_latents, speaker_embedding

def synthesize(self, text, config, speaker_wav, language, speaker_id=None, **kwargs):
def synthesize(self, text, config, speaker_wav, language, speaker_id=None, speed: float = 1.0, **kwargs):
"""Synthesize speech with the given input text.
Args:
Expand Down Expand Up @@ -409,14 +409,14 @@ def synthesize(self, text, config, speaker_wav, language, speaker_id=None, **kwa
settings.update(kwargs) # allow overriding of preset settings with kwargs
if speaker_id is not None:
gpt_cond_latent, speaker_embedding = self.speaker_manager.speakers[speaker_id].values()
return self.inference(text, language, gpt_cond_latent, speaker_embedding, **settings)
return self.inference(text, language, gpt_cond_latent, speaker_embedding, speed=speed, **settings)
settings.update({
"gpt_cond_len": config.gpt_cond_len,
"gpt_cond_chunk_len": config.gpt_cond_chunk_len,
"max_ref_len": config.max_ref_len,
"sound_norm_refs": config.sound_norm_refs,
})
return self.full_inference(text, speaker_wav, language, **settings)
return self.full_inference(text, speaker_wav, language, speed=speed, **settings)

@torch.inference_mode()
def full_inference(
Expand All @@ -436,6 +436,7 @@ def full_inference(
gpt_cond_chunk_len=6,
max_ref_len=10,
sound_norm_refs=False,
speed: float = 1.0,
**hf_generate_kwargs,
):
"""
Expand Down Expand Up @@ -496,6 +497,7 @@ def full_inference(
top_k=top_k,
top_p=top_p,
do_sample=do_sample,
speed=speed,
**hf_generate_kwargs,
)

Expand Down Expand Up @@ -569,10 +571,7 @@ def inference(
)

if length_scale != 1.0:
gpt_latents = F.interpolate(
gpt_latents.transpose(1, 2), scale_factor=length_scale, mode="linear"
).transpose(1, 2)

gpt_latents = self.adjust_speech_rate(gpt_latents, length_scale)
gpt_latents_list.append(gpt_latents.cpu())
wavs.append(self.hifigan_decoder(gpt_latents, g=speaker_embedding).cpu().squeeze())

Expand Down

0 comments on commit 26128be

Please sign in to comment.