Skip to content

Commit

Permalink
Fix Tortoise load (#2697)
Browse files Browse the repository at this point in the history
* Handle missing gpt weights

* Make style

* Fix lint
  • Loading branch information
erogol authored Jun 21, 2023
1 parent d658194 commit 4cf8652
Showing 1 changed file with 11 additions and 1 deletion.
12 changes: 11 additions & 1 deletion TTS/tts/models/tortoise.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import random
import re
from contextlib import contextmanager
from dataclasses import dataclass
from time import time
Expand Down Expand Up @@ -871,7 +872,16 @@ def load_checkpoint(
vocoder_checkpoint_path = vocoder_checkpoint_path or os.path.join(checkpoint_dir, "vocoder.pth")

if os.path.exists(ar_path):
self.autoregressive.load_state_dict(torch.load(ar_path), strict=strict)
keys_to_ignore = self.autoregressive.gpt._keys_to_ignore_on_load_missing # pylint: disable=protected-access
# remove keys from the checkpoint that are not in the model
checkpoint = torch.load(ar_path, map_location=torch.device("cpu"))
for key in list(checkpoint.keys()):
for pat in keys_to_ignore:
if re.search(pat, key) is not None:
del checkpoint[key]
break

self.autoregressive.load_state_dict(checkpoint, strict=strict)

if os.path.exists(diff_path):
self.diffusion.load_state_dict(torch.load(diff_path), strict=strict)
Expand Down

0 comments on commit 4cf8652

Please sign in to comment.