Skip to content

Commit

Permalink
revert kwargs to explicit declaration
Browse files Browse the repository at this point in the history
  • Loading branch information
kohya-ss committed May 19, 2024
1 parent db67529 commit f2dd43e
Showing 1 changed file with 142 additions and 16 deletions.
158 changes: 142 additions & 16 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,7 @@ def __init__(

self.alpha_mask = alpha_mask


class DreamBoothSubset(BaseSubset):
def __init__(
self,
Expand All @@ -417,13 +418,47 @@ def __init__(
class_tokens: Optional[str],
caption_extension: str,
cache_info: bool,
**kwargs,
num_repeats,
shuffle_caption,
caption_separator: str,
keep_tokens,
keep_tokens_separator,
secondary_separator,
enable_wildcard,
color_aug,
flip_aug,
face_crop_aug_range,
random_crop,
caption_dropout_rate,
caption_dropout_every_n_epochs,
caption_tag_dropout_rate,
caption_prefix,
caption_suffix,
token_warmup_min,
token_warmup_step,
) -> None:
assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です"

super().__init__(
image_dir,
**kwargs,
num_repeats,
shuffle_caption,
caption_separator,
keep_tokens,
keep_tokens_separator,
secondary_separator,
enable_wildcard,
color_aug,
flip_aug,
face_crop_aug_range,
random_crop,
caption_dropout_rate,
caption_dropout_every_n_epochs,
caption_tag_dropout_rate,
caption_prefix,
caption_suffix,
token_warmup_min,
token_warmup_step,
)

self.is_reg = is_reg
Expand All @@ -444,13 +479,47 @@ def __init__(
self,
image_dir,
metadata_file: str,
**kwargs,
num_repeats,
shuffle_caption,
caption_separator,
keep_tokens,
keep_tokens_separator,
secondary_separator,
enable_wildcard,
color_aug,
flip_aug,
face_crop_aug_range,
random_crop,
caption_dropout_rate,
caption_dropout_every_n_epochs,
caption_tag_dropout_rate,
caption_prefix,
caption_suffix,
token_warmup_min,
token_warmup_step,
) -> None:
assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です"

super().__init__(
image_dir,
**kwargs,
num_repeats,
shuffle_caption,
caption_separator,
keep_tokens,
keep_tokens_separator,
secondary_separator,
enable_wildcard,
color_aug,
flip_aug,
face_crop_aug_range,
random_crop,
caption_dropout_rate,
caption_dropout_every_n_epochs,
caption_tag_dropout_rate,
caption_prefix,
caption_suffix,
token_warmup_min,
token_warmup_step,
)

self.metadata_file = metadata_file
Expand All @@ -468,13 +537,47 @@ def __init__(
conditioning_data_dir: str,
caption_extension: str,
cache_info: bool,
**kwargs,
num_repeats,
shuffle_caption,
caption_separator,
keep_tokens,
keep_tokens_separator,
secondary_separator,
enable_wildcard,
color_aug,
flip_aug,
face_crop_aug_range,
random_crop,
caption_dropout_rate,
caption_dropout_every_n_epochs,
caption_tag_dropout_rate,
caption_prefix,
caption_suffix,
token_warmup_min,
token_warmup_step,
) -> None:
assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です"

super().__init__(
image_dir,
**kwargs,
num_repeats,
shuffle_caption,
caption_separator,
keep_tokens,
keep_tokens_separator,
secondary_separator,
enable_wildcard,
color_aug,
flip_aug,
face_crop_aug_range,
random_crop,
caption_dropout_rate,
caption_dropout_every_n_epochs,
caption_tag_dropout_rate,
caption_prefix,
caption_suffix,
token_warmup_min,
token_warmup_step,
)

self.conditioning_data_dir = conditioning_data_dir
Expand Down Expand Up @@ -1100,10 +1203,12 @@ def __getitem__(self, index):
else:
latents = image_info.latents_flipped
alpha_mask = image_info.alpha_mask_flipped

image = None
elif image_info.latents_npz is not None: # FineTuningDatasetまたはcache_latents_to_disk=Trueの場合
latents, original_size, crop_ltrb, flipped_latents, alpha_mask, flipped_alpha_mask = load_latents_from_disk(image_info.latents_npz)
latents, original_size, crop_ltrb, flipped_latents, alpha_mask, flipped_alpha_mask = load_latents_from_disk(
image_info.latents_npz
)
if flipped:
latents = flipped_latents
alpha_mask = flipped_alpha_mask
Expand All @@ -1116,7 +1221,9 @@ def __getitem__(self, index):
image = None
else:
# 画像を読み込み、必要ならcropする
img, face_cx, face_cy, face_w, face_h = self.load_image_with_face_info(subset, image_info.absolute_path, subset.alpha_mask)
img, face_cx, face_cy, face_w, face_h = self.load_image_with_face_info(
subset, image_info.absolute_path, subset.alpha_mask
)
im_h, im_w = img.shape[0:2]

if self.enable_bucket:
Expand Down Expand Up @@ -1157,7 +1264,7 @@ def __getitem__(self, index):
if img.shape[2] == 4:
alpha_mask = img[:, :, 3] # [W,H]
else:
alpha_mask = np.full((im_w, im_h), 255, dtype=np.uint8) # [W,H]
alpha_mask = np.full((im_w, im_h), 255, dtype=np.uint8) # [W,H]
alpha_mask = transforms.ToTensor()(alpha_mask)
else:
alpha_mask = None
Expand Down Expand Up @@ -2070,7 +2177,14 @@ def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool):
# 戻り値は、latents_tensor, (original_size width, original_size height), (crop left, crop top)
def load_latents_from_disk(
npz_path,
) -> Tuple[Optional[torch.Tensor], Optional[List[int]], Optional[List[int]], Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
) -> Tuple[
Optional[torch.Tensor],
Optional[List[int]],
Optional[List[int]],
Optional[torch.Tensor],
Optional[torch.Tensor],
Optional[torch.Tensor],
]:
npz = np.load(npz_path)
if "latents" not in npz:
raise ValueError(f"error: npz is old format. please re-generate {npz_path}")
Expand All @@ -2084,7 +2198,9 @@ def load_latents_from_disk(
return latents, original_size, crop_ltrb, flipped_latents, alpha_mask, flipped_alpha_mask


def save_latents_to_disk(npz_path, latents_tensor, original_size, crop_ltrb, flipped_latents_tensor=None, alpha_mask=None, flipped_alpha_mask=None):
def save_latents_to_disk(
npz_path, latents_tensor, original_size, crop_ltrb, flipped_latents_tensor=None, alpha_mask=None, flipped_alpha_mask=None
):
kwargs = {}
if flipped_latents_tensor is not None:
kwargs["latents_flipped"] = flipped_latents_tensor.float().cpu().numpy()
Expand Down Expand Up @@ -2344,10 +2460,10 @@ def cache_batch_latents(
image, original_size, crop_ltrb = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size)
if info.use_alpha_mask:
if image.shape[2] == 4:
alpha_mask = image[:, :, 3] # [W,H]
alpha_mask = image[:, :, 3] # [W,H]
image = image[:, :, :3]
else:
alpha_mask = np.full_like(image[:, :, 0], 255, dtype=np.uint8) # [W,H]
alpha_mask = np.full_like(image[:, :, 0], 255, dtype=np.uint8) # [W,H]
alpha_masks.append(transforms.ToTensor()(alpha_mask))
image = IMAGE_TRANSFORMS(image)
images.append(image)
Expand Down Expand Up @@ -2377,13 +2493,23 @@ def cache_batch_latents(
flipped_latents = [None] * len(latents)
flipped_alpha_masks = [None] * len(image_infos)

for info, latent, flipped_latent, alpha_mask, flipped_alpha_mask in zip(image_infos, latents, flipped_latents, alpha_masks, flipped_alpha_masks):
for info, latent, flipped_latent, alpha_mask, flipped_alpha_mask in zip(
image_infos, latents, flipped_latents, alpha_masks, flipped_alpha_masks
):
# check NaN
if torch.isnan(latents).any() or (flipped_latent is not None and torch.isnan(flipped_latent).any()):
raise RuntimeError(f"NaN detected in latents: {info.absolute_path}")

if cache_to_disk:
save_latents_to_disk(info.latents_npz, latent, info.latents_original_size, info.latents_crop_ltrb, flipped_latent, alpha_mask, flipped_alpha_mask)
save_latents_to_disk(
info.latents_npz,
latent,
info.latents_original_size,
info.latents_crop_ltrb,
flipped_latent,
alpha_mask,
flipped_alpha_mask,
)
else:
info.latents = latent
if flip_aug:
Expand Down

0 comments on commit f2dd43e

Please sign in to comment.