diff --git a/TTS/tts/models/forward_tts.py b/TTS/tts/models/forward_tts.py index b6e9ac8a14..95d79d2cbc 100644 --- a/TTS/tts/models/forward_tts.py +++ b/TTS/tts/models/forward_tts.py @@ -398,7 +398,7 @@ def _forward_encoder( """ if hasattr(self, "emb_g"): g = g.type(torch.LongTensor) - g = self.emb_g(g) # [B, C, 1] + g = self.emb_g(g.to("cuda") if torch.cuda.is_available() else g.to("cpu")) # [B, C, 1] if g is not None: g = g.unsqueeze(-1) # [B, T, C]