diff --git a/library/train_util.py b/library/train_util.py index 20f8055dc..6cf285903 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -409,6 +409,7 @@ def __init__( self.alpha_mask = alpha_mask + class DreamBoothSubset(BaseSubset): def __init__( self, @@ -417,13 +418,47 @@ def __init__( class_tokens: Optional[str], caption_extension: str, cache_info: bool, - **kwargs, + num_repeats, + shuffle_caption, + caption_separator: str, + keep_tokens, + keep_tokens_separator, + secondary_separator, + enable_wildcard, + color_aug, + flip_aug, + face_crop_aug_range, + random_crop, + caption_dropout_rate, + caption_dropout_every_n_epochs, + caption_tag_dropout_rate, + caption_prefix, + caption_suffix, + token_warmup_min, + token_warmup_step, ) -> None: assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です" super().__init__( image_dir, - **kwargs, + num_repeats, + shuffle_caption, + caption_separator, + keep_tokens, + keep_tokens_separator, + secondary_separator, + enable_wildcard, + color_aug, + flip_aug, + face_crop_aug_range, + random_crop, + caption_dropout_rate, + caption_dropout_every_n_epochs, + caption_tag_dropout_rate, + caption_prefix, + caption_suffix, + token_warmup_min, + token_warmup_step, ) self.is_reg = is_reg @@ -444,13 +479,47 @@ def __init__( self, image_dir, metadata_file: str, - **kwargs, + num_repeats, + shuffle_caption, + caption_separator, + keep_tokens, + keep_tokens_separator, + secondary_separator, + enable_wildcard, + color_aug, + flip_aug, + face_crop_aug_range, + random_crop, + caption_dropout_rate, + caption_dropout_every_n_epochs, + caption_tag_dropout_rate, + caption_prefix, + caption_suffix, + token_warmup_min, + token_warmup_step, ) -> None: assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です" super().__init__( image_dir, - **kwargs, + num_repeats, + shuffle_caption, + caption_separator, + keep_tokens, + keep_tokens_separator, + secondary_separator, + enable_wildcard, + color_aug, + flip_aug, + face_crop_aug_range, + random_crop, + caption_dropout_rate, + caption_dropout_every_n_epochs, + caption_tag_dropout_rate, + caption_prefix, + caption_suffix, + token_warmup_min, + token_warmup_step, ) self.metadata_file = metadata_file @@ -468,13 +537,47 @@ def __init__( conditioning_data_dir: str, caption_extension: str, cache_info: bool, - **kwargs, + num_repeats, + shuffle_caption, + caption_separator, + keep_tokens, + keep_tokens_separator, + secondary_separator, + enable_wildcard, + color_aug, + flip_aug, + face_crop_aug_range, + random_crop, + caption_dropout_rate, + caption_dropout_every_n_epochs, + caption_tag_dropout_rate, + caption_prefix, + caption_suffix, + token_warmup_min, + token_warmup_step, ) -> None: assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です" super().__init__( image_dir, - **kwargs, + num_repeats, + shuffle_caption, + caption_separator, + keep_tokens, + keep_tokens_separator, + secondary_separator, + enable_wildcard, + color_aug, + flip_aug, + face_crop_aug_range, + random_crop, + caption_dropout_rate, + caption_dropout_every_n_epochs, + caption_tag_dropout_rate, + caption_prefix, + caption_suffix, + token_warmup_min, + token_warmup_step, ) self.conditioning_data_dir = conditioning_data_dir @@ -1100,10 +1203,12 @@ def __getitem__(self, index): else: latents = image_info.latents_flipped alpha_mask = image_info.alpha_mask_flipped - + image = None elif image_info.latents_npz is not None: # FineTuningDatasetまたはcache_latents_to_disk=Trueの場合 - latents, original_size, crop_ltrb, flipped_latents, alpha_mask, flipped_alpha_mask = load_latents_from_disk(image_info.latents_npz) + latents, original_size, crop_ltrb, flipped_latents, alpha_mask, flipped_alpha_mask = load_latents_from_disk( + image_info.latents_npz + ) if flipped: latents = flipped_latents alpha_mask = flipped_alpha_mask @@ -1116,7 +1221,9 @@ def __getitem__(self, index): image = None else: # 画像を読み込み、必要ならcropする - img, face_cx, face_cy, face_w, face_h = self.load_image_with_face_info(subset, image_info.absolute_path, subset.alpha_mask) + img, face_cx, face_cy, face_w, face_h = self.load_image_with_face_info( + subset, image_info.absolute_path, subset.alpha_mask + ) im_h, im_w = img.shape[0:2] if self.enable_bucket: @@ -1157,7 +1264,7 @@ def __getitem__(self, index): if img.shape[2] == 4: alpha_mask = img[:, :, 3] # [W,H] else: - alpha_mask = np.full((im_w, im_h), 255, dtype=np.uint8) # [W,H] + alpha_mask = np.full((im_w, im_h), 255, dtype=np.uint8) # [W,H] alpha_mask = transforms.ToTensor()(alpha_mask) else: alpha_mask = None @@ -2070,7 +2177,14 @@ def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool): # 戻り値は、latents_tensor, (original_size width, original_size height), (crop left, crop top) def load_latents_from_disk( npz_path, -) -> Tuple[Optional[torch.Tensor], Optional[List[int]], Optional[List[int]], Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: +) -> Tuple[ + Optional[torch.Tensor], + Optional[List[int]], + Optional[List[int]], + Optional[torch.Tensor], + Optional[torch.Tensor], + Optional[torch.Tensor], +]: npz = np.load(npz_path) if "latents" not in npz: raise ValueError(f"error: npz is old format. please re-generate {npz_path}") @@ -2084,7 +2198,9 @@ def load_latents_from_disk( return latents, original_size, crop_ltrb, flipped_latents, alpha_mask, flipped_alpha_mask -def save_latents_to_disk(npz_path, latents_tensor, original_size, crop_ltrb, flipped_latents_tensor=None, alpha_mask=None, flipped_alpha_mask=None): +def save_latents_to_disk( + npz_path, latents_tensor, original_size, crop_ltrb, flipped_latents_tensor=None, alpha_mask=None, flipped_alpha_mask=None +): kwargs = {} if flipped_latents_tensor is not None: kwargs["latents_flipped"] = flipped_latents_tensor.float().cpu().numpy() @@ -2344,10 +2460,10 @@ def cache_batch_latents( image, original_size, crop_ltrb = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size) if info.use_alpha_mask: if image.shape[2] == 4: - alpha_mask = image[:, :, 3] # [W,H] + alpha_mask = image[:, :, 3] # [W,H] image = image[:, :, :3] else: - alpha_mask = np.full_like(image[:, :, 0], 255, dtype=np.uint8) # [W,H] + alpha_mask = np.full_like(image[:, :, 0], 255, dtype=np.uint8) # [W,H] alpha_masks.append(transforms.ToTensor()(alpha_mask)) image = IMAGE_TRANSFORMS(image) images.append(image) @@ -2377,13 +2493,23 @@ def cache_batch_latents( flipped_latents = [None] * len(latents) flipped_alpha_masks = [None] * len(image_infos) - for info, latent, flipped_latent, alpha_mask, flipped_alpha_mask in zip(image_infos, latents, flipped_latents, alpha_masks, flipped_alpha_masks): + for info, latent, flipped_latent, alpha_mask, flipped_alpha_mask in zip( + image_infos, latents, flipped_latents, alpha_masks, flipped_alpha_masks + ): # check NaN if torch.isnan(latents).any() or (flipped_latent is not None and torch.isnan(flipped_latent).any()): raise RuntimeError(f"NaN detected in latents: {info.absolute_path}") if cache_to_disk: - save_latents_to_disk(info.latents_npz, latent, info.latents_original_size, info.latents_crop_ltrb, flipped_latent, alpha_mask, flipped_alpha_mask) + save_latents_to_disk( + info.latents_npz, + latent, + info.latents_original_size, + info.latents_crop_ltrb, + flipped_latent, + alpha_mask, + flipped_alpha_mask, + ) else: info.latents = latent if flip_aug: