From e8cfd4ba1d4734c4dd37c9b5fdc0633378879d9b Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 26 May 2024 22:01:37 +0900 Subject: [PATCH] fix to work cond mask and alpha mask --- library/config_util.py | 3 ++- library/custom_train_functions.py | 4 +++- library/train_util.py | 12 ++++++++++++ 3 files changed, 17 insertions(+), 2 deletions(-) diff --git a/library/config_util.py b/library/config_util.py index 964270dbb..10b2457f3 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -78,7 +78,6 @@ class BaseSubsetParams: caption_tag_dropout_rate: float = 0.0 token_warmup_min: int = 1 token_warmup_step: float = 0 - alpha_mask: bool = False @dataclass @@ -87,11 +86,13 @@ class DreamBoothSubsetParams(BaseSubsetParams): class_tokens: Optional[str] = None caption_extension: str = ".caption" cache_info: bool = False + alpha_mask: bool = False @dataclass class FineTuningSubsetParams(BaseSubsetParams): metadata_file: Optional[str] = None + alpha_mask: bool = False @dataclass diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index af5813a1d..2a513dc5b 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -484,9 +484,11 @@ def apply_masked_loss(loss, batch): # conditioning image is -1 to 1. we need to convert it to 0 to 1 mask_image = batch["conditioning_images"].to(dtype=loss.dtype)[:, 0].unsqueeze(1) # use R channel mask_image = mask_image / 2 + 0.5 + # print(f"conditioning_image: {mask_image.shape}") elif "alpha_masks" in batch and batch["alpha_masks"] is not None: # alpha mask is 0 to 1 - mask_image = batch["alpha_masks"].to(dtype=loss.dtype) + mask_image = batch["alpha_masks"].to(dtype=loss.dtype).unsqueeze(1) # add channel dimension + # print(f"mask_image: {mask_image.shape}, {mask_image.mean()}") else: return loss diff --git a/library/train_util.py b/library/train_util.py index e7a50f04d..1f9f3c5df 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -561,6 +561,7 @@ def __init__( super().__init__( image_dir, + False, # alpha_mask num_repeats, shuffle_caption, caption_separator, @@ -1947,6 +1948,7 @@ def __init__( None, subset.caption_extension, subset.cache_info, + False, subset.num_repeats, subset.shuffle_caption, subset.caption_separator, @@ -2196,6 +2198,9 @@ def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool, alph return False if npz["alpha_mask"].shape[0:2] != reso: # HxW return False + else: + if "alpha_mask" in npz: + return False except Exception as e: logger.error(f"Error loading file: {npz_path}") raise e @@ -2296,6 +2301,13 @@ def debug_dataset(train_dataset, show_input_ids=False): if os.name == "nt": cv2.imshow("cond_img", cond_img) + if "alpha_masks" in example and example["alpha_masks"] is not None: + alpha_mask = example["alpha_masks"][j] + logger.info(f"alpha mask size: {alpha_mask.size()}") + alpha_mask = (alpha_mask[0].numpy() * 255.0).astype(np.uint8) + if os.name == "nt": + cv2.imshow("alpha_mask", alpha_mask) + if os.name == "nt": # only windows cv2.imshow("img", im) k = cv2.waitKey()