diff --git a/TTS/tts/models/tortoise.py b/TTS/tts/models/tortoise.py index 5c81638262..888177252e 100644 --- a/TTS/tts/models/tortoise.py +++ b/TTS/tts/models/tortoise.py @@ -1,5 +1,6 @@ import os import random +import re from contextlib import contextmanager from dataclasses import dataclass from time import time @@ -871,7 +872,16 @@ def load_checkpoint( vocoder_checkpoint_path = vocoder_checkpoint_path or os.path.join(checkpoint_dir, "vocoder.pth") if os.path.exists(ar_path): - self.autoregressive.load_state_dict(torch.load(ar_path), strict=strict) + keys_to_ignore = self.autoregressive.gpt._keys_to_ignore_on_load_missing # pylint: disable=protected-access + # remove keys from the checkpoint that are not in the model + checkpoint = torch.load(ar_path, map_location=torch.device("cpu")) + for key in list(checkpoint.keys()): + for pat in keys_to_ignore: + if re.search(pat, key) is not None: + del checkpoint[key] + break + + self.autoregressive.load_state_dict(checkpoint, strict=strict) if os.path.exists(diff_path): self.diffusion.load_state_dict(torch.load(diff_path), strict=strict)