diff --git a/TTS/api.py b/TTS/api.py index 7abc188e74..2277ff270c 100644 --- a/TTS/api.py +++ b/TTS/api.py @@ -283,6 +283,7 @@ def tts( style_text=None, reference_speaker_name=None, split_sentences=split_sentences, + speed=speed, **kwargs, ) return wav @@ -330,13 +331,13 @@ def tts_to_file( Additional arguments for the model. """ self._check_arguments(speaker=speaker, language=language, speaker_wav=speaker_wav, **kwargs) - wav = self.tts( text=text, speaker=speaker, language=language, speaker_wav=speaker_wav, split_sentences=split_sentences, + speed=speed, **kwargs, ) self.synthesizer.save_wav(wav=wav, path=file_path, pipe_out=pipe_out) diff --git a/TTS/tts/models/base_tts.py b/TTS/tts/models/base_tts.py index 7871cc38c3..371db3c482 100644 --- a/TTS/tts/models/base_tts.py +++ b/TTS/tts/models/base_tts.py @@ -4,6 +4,7 @@ import torch import torch.distributed as dist +import torch.nn.functional as F from coqpit import Coqpit from torch import nn from torch.utils.data import DataLoader @@ -76,6 +77,33 @@ def _set_model_args(self, config: Coqpit): else: raise ValueError("config must be either a *Config or *Args") + def adjust_speech_rate(self, gpt_latents, length_scale): + if abs(length_scale - 1.0) < 1e-6: + return gpt_latents + + B, L, D = gpt_latents.shape + target_length = int(L * length_scale) + + assert target_length > 0, f"Invalid target length: {target_length}" + + try: + resized = F.interpolate( + gpt_latents.transpose(1, 2), + size=target_length, + mode="linear", + align_corners=True + ).transpose(1, 2) + + if torch.isnan(resized).any(): + print("Warning: NaN values detected on adjust speech rate") + return gpt_latents + + return resized + + except RuntimeError as e: + print(f"Interpolation failed: {e}") + return gpt_latents + def init_multispeaker(self, config: Coqpit, data: List = None): """Initialize a speaker embedding layer if needen and define expected embedding channel size for defining `in_channels` size of the connected layers. diff --git a/TTS/tts/models/xtts.py b/TTS/tts/models/xtts.py index 8e9d6bd382..a109c6e78b 100644 --- a/TTS/tts/models/xtts.py +++ b/TTS/tts/models/xtts.py @@ -379,7 +379,7 @@ def get_conditioning_latents( return gpt_cond_latents, speaker_embedding - def synthesize(self, text, config, speaker_wav, language, speaker_id=None, **kwargs): + def synthesize(self, text, config, speaker_wav, language, speaker_id=None, speed: float = 1.0, **kwargs): """Synthesize speech with the given input text. Args: @@ -409,14 +409,14 @@ def synthesize(self, text, config, speaker_wav, language, speaker_id=None, **kwa settings.update(kwargs) # allow overriding of preset settings with kwargs if speaker_id is not None: gpt_cond_latent, speaker_embedding = self.speaker_manager.speakers[speaker_id].values() - return self.inference(text, language, gpt_cond_latent, speaker_embedding, **settings) + return self.inference(text, language, gpt_cond_latent, speaker_embedding, speed=speed, **settings) settings.update({ "gpt_cond_len": config.gpt_cond_len, "gpt_cond_chunk_len": config.gpt_cond_chunk_len, "max_ref_len": config.max_ref_len, "sound_norm_refs": config.sound_norm_refs, }) - return self.full_inference(text, speaker_wav, language, **settings) + return self.full_inference(text, speaker_wav, language, speed=speed, **settings) @torch.inference_mode() def full_inference( @@ -436,6 +436,7 @@ def full_inference( gpt_cond_chunk_len=6, max_ref_len=10, sound_norm_refs=False, + speed: float = 1.0, **hf_generate_kwargs, ): """ @@ -496,6 +497,7 @@ def full_inference( top_k=top_k, top_p=top_p, do_sample=do_sample, + speed=speed, **hf_generate_kwargs, ) @@ -569,10 +571,7 @@ def inference( ) if length_scale != 1.0: - gpt_latents = F.interpolate( - gpt_latents.transpose(1, 2), scale_factor=length_scale, mode="linear" - ).transpose(1, 2) - + gpt_latents = self.adjust_speech_rate(gpt_latents, length_scale) gpt_latents_list.append(gpt_latents.cpu()) wavs.append(self.hifigan_decoder(gpt_latents, g=speaker_embedding).cpu().squeeze())