Skip to content

Commit

Permalink
fix to work cond mask and alpha mask
Browse files Browse the repository at this point in the history
  • Loading branch information
kohya-ss committed May 26, 2024
1 parent da6fea3 commit e8cfd4b
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 2 deletions.
3 changes: 2 additions & 1 deletion library/config_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@ class BaseSubsetParams:
caption_tag_dropout_rate: float = 0.0
token_warmup_min: int = 1
token_warmup_step: float = 0
alpha_mask: bool = False


@dataclass
Expand All @@ -87,11 +86,13 @@ class DreamBoothSubsetParams(BaseSubsetParams):
class_tokens: Optional[str] = None
caption_extension: str = ".caption"
cache_info: bool = False
alpha_mask: bool = False


@dataclass
class FineTuningSubsetParams(BaseSubsetParams):
metadata_file: Optional[str] = None
alpha_mask: bool = False


@dataclass
Expand Down
4 changes: 3 additions & 1 deletion library/custom_train_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,9 +484,11 @@ def apply_masked_loss(loss, batch):
# conditioning image is -1 to 1. we need to convert it to 0 to 1
mask_image = batch["conditioning_images"].to(dtype=loss.dtype)[:, 0].unsqueeze(1) # use R channel
mask_image = mask_image / 2 + 0.5
# print(f"conditioning_image: {mask_image.shape}")
elif "alpha_masks" in batch and batch["alpha_masks"] is not None:
# alpha mask is 0 to 1
mask_image = batch["alpha_masks"].to(dtype=loss.dtype)
mask_image = batch["alpha_masks"].to(dtype=loss.dtype).unsqueeze(1) # add channel dimension
# print(f"mask_image: {mask_image.shape}, {mask_image.mean()}")
else:
return loss

Expand Down
12 changes: 12 additions & 0 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,6 +561,7 @@ def __init__(

super().__init__(
image_dir,
False, # alpha_mask
num_repeats,
shuffle_caption,
caption_separator,
Expand Down Expand Up @@ -1947,6 +1948,7 @@ def __init__(
None,
subset.caption_extension,
subset.cache_info,
False,
subset.num_repeats,
subset.shuffle_caption,
subset.caption_separator,
Expand Down Expand Up @@ -2196,6 +2198,9 @@ def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool, alph
return False
if npz["alpha_mask"].shape[0:2] != reso: # HxW
return False
else:
if "alpha_mask" in npz:
return False
except Exception as e:
logger.error(f"Error loading file: {npz_path}")
raise e
Expand Down Expand Up @@ -2296,6 +2301,13 @@ def debug_dataset(train_dataset, show_input_ids=False):
if os.name == "nt":
cv2.imshow("cond_img", cond_img)

if "alpha_masks" in example and example["alpha_masks"] is not None:
alpha_mask = example["alpha_masks"][j]
logger.info(f"alpha mask size: {alpha_mask.size()}")
alpha_mask = (alpha_mask[0].numpy() * 255.0).astype(np.uint8)
if os.name == "nt":
cv2.imshow("alpha_mask", alpha_mask)

if os.name == "nt": # only windows
cv2.imshow("img", im)
k = cv2.waitKey()
Expand Down

0 comments on commit e8cfd4b

Please sign in to comment.