From feefcf256e78a5f8d60c3a940f2be3b5c3ca335d Mon Sep 17 00:00:00 2001 From: Cauldrath Date: Thu, 18 Apr 2024 23:15:36 -0400 Subject: [PATCH 1/2] Display name of error latent file When trying to load stored latents, if an error occurs, this change will tell you what file failed to load Currently it will just tell you that something failed without telling you which file --- library/train_util.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 15c23f3cc..58527fa00 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2123,18 +2123,21 @@ def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool): if not os.path.exists(npz_path): return False - npz = np.load(npz_path) - if "latents" not in npz or "original_size" not in npz or "crop_ltrb" not in npz: # old ver? - return False - if npz["latents"].shape[1:3] != expected_latents_size: - return False - - if flip_aug: - if "latents_flipped" not in npz: + try: + npz = np.load(npz_path) + if "latents" not in npz or "original_size" not in npz or "crop_ltrb" not in npz: # old ver? return False - if npz["latents_flipped"].shape[1:3] != expected_latents_size: + if npz["latents"].shape[1:3] != expected_latents_size: return False + if flip_aug: + if "latents_flipped" not in npz: + return False + if npz["latents_flipped"].shape[1:3] != expected_latents_size: + return False + except: + raise RuntimeError(f"Error loading file: {npz_path}") + return True From 040e26ff1d8f855f52cdfb62781e06284c5e9e34 Mon Sep 17 00:00:00 2001 From: Cauldrath Date: Sun, 21 Apr 2024 13:46:31 -0400 Subject: [PATCH 2/2] Regenerate failed file If a latent file fails to load, print out the path and the error, then return false to regenerate it --- library/train_util.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 58527fa00..4168a41fb 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2135,8 +2135,10 @@ def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool): return False if npz["latents_flipped"].shape[1:3] != expected_latents_size: return False - except: - raise RuntimeError(f"Error loading file: {npz_path}") + except Exception as e: + print(npz_path) + print(e) + return False return True