Skip to content

Commit

Permalink
simplify and update alpha mask to work with various cases
Browse files Browse the repository at this point in the history
  • Loading branch information
kohya-ss committed May 19, 2024
1 parent f2dd43e commit da6fea3
Show file tree
Hide file tree
Showing 10 changed files with 139 additions and 104 deletions.
33 changes: 27 additions & 6 deletions finetune/prepare_buckets_latents.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,18 @@

import torch
from library.device_utils import init_ipex, get_preferred_device

init_ipex()

from torchvision import transforms

import library.model_util as model_util
import library.train_util as train_util
from library.utils import setup_logging

setup_logging()
import logging

logger = logging.getLogger(__name__)

DEVICE = get_preferred_device()
Expand Down Expand Up @@ -89,7 +92,9 @@ def main(args):

# bucketのサイズを計算する
max_reso = tuple([int(t) for t in args.max_resolution.split(",")])
assert len(max_reso) == 2, f"illegal resolution (not 'width,height') / 画像サイズに誤りがあります。'幅,高さ'で指定してください: {args.max_resolution}"
assert (
len(max_reso) == 2
), f"illegal resolution (not 'width,height') / 画像サイズに誤りがあります。'幅,高さ'で指定してください: {args.max_resolution}"

bucket_manager = train_util.BucketManager(
args.bucket_no_upscale, max_reso, args.min_bucket_reso, args.max_bucket_reso, args.bucket_reso_steps
Expand All @@ -107,7 +112,7 @@ def main(args):
def process_batch(is_last):
for bucket in bucket_manager.buckets:
if (is_last and len(bucket) > 0) or len(bucket) >= args.batch_size:
train_util.cache_batch_latents(vae, True, bucket, args.flip_aug, False)
train_util.cache_batch_latents(vae, True, bucket, args.flip_aug, args.alpha_mask, False)
bucket.clear()

# 読み込みの高速化のためにDataLoaderを使うオプション
Expand Down Expand Up @@ -208,7 +213,9 @@ def setup_parser() -> argparse.ArgumentParser:
parser.add_argument("in_json", type=str, help="metadata file to input / 読み込むメタデータファイル")
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
parser.add_argument("model_name_or_path", type=str, help="model name or path to encode latents / latentを取得するためのモデル")
parser.add_argument("--v2", action="store_true", help="not used (for backward compatibility) / 使用されません(互換性のため残してあります)")
parser.add_argument(
"--v2", action="store_true", help="not used (for backward compatibility) / 使用されません(互換性のため残してあります)"
)
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
parser.add_argument(
"--max_data_loader_n_workers",
Expand All @@ -231,18 +238,32 @@ def setup_parser() -> argparse.ArgumentParser:
help="steps of resolution for buckets, divisible by 8 is recommended / bucketの解像度の単位、8で割り切れる値を推奨します",
)
parser.add_argument(
"--bucket_no_upscale", action="store_true", help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します"
"--bucket_no_upscale",
action="store_true",
help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します",
)
parser.add_argument(
"--mixed_precision", type=str, default="no", choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度"
"--mixed_precision",
type=str,
default="no",
choices=["no", "fp16", "bf16"],
help="use mixed precision / 混合精度を使う場合、その精度",
)
parser.add_argument(
"--full_path",
action="store_true",
help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)",
)
parser.add_argument(
"--flip_aug", action="store_true", help="flip augmentation, save latents for flipped images / 左右反転した画像もlatentを取得、保存する"
"--flip_aug",
action="store_true",
help="flip augmentation, save latents for flipped images / 左右反転した画像もlatentを取得、保存する",
)
parser.add_argument(
"--alpha_mask",
type=str,
default="",
help="save alpha mask for images for loss calculation / 損失計算用に画像のアルファマスクを保存する",
)
parser.add_argument(
"--skip_existing",
Expand Down
2 changes: 2 additions & 0 deletions library/config_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,11 +214,13 @@ def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence]
DB_SUBSET_DISTINCT_SCHEMA = {
Required("image_dir"): str,
"is_reg": bool,
"alpha_mask": bool,
}
# FT means FineTuning
FT_SUBSET_DISTINCT_SCHEMA = {
Required("metadata_file"): str,
"image_dir": str,
"alpha_mask": bool,
}
CN_SUBSET_ASCENDABLE_SCHEMA = {
"caption_extension": str,
Expand Down
15 changes: 10 additions & 5 deletions library/custom_train_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,14 +479,19 @@ def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale):
return noise


def apply_masked_loss(loss, mask_image):
# mask image is -1 to 1. we need to convert it to 0 to 1
# mask_image = batch["conditioning_images"].to(dtype=loss.dtype)[:, 0].unsqueeze(1) # use R channel
mask_image = mask_image.to(dtype=loss.dtype)
def apply_masked_loss(loss, batch):
if "conditioning_images" in batch:
# conditioning image is -1 to 1. we need to convert it to 0 to 1
mask_image = batch["conditioning_images"].to(dtype=loss.dtype)[:, 0].unsqueeze(1) # use R channel
mask_image = mask_image / 2 + 0.5
elif "alpha_masks" in batch and batch["alpha_masks"] is not None:
# alpha mask is 0 to 1
mask_image = batch["alpha_masks"].to(dtype=loss.dtype)
else:
return loss

# resize to the same size as the loss
mask_image = torch.nn.functional.interpolate(mask_image, size=loss.shape[2:], mode="area")
mask_image = mask_image / 2 + 0.5
loss = loss * mask_image
return loss

Expand Down
Loading

0 comments on commit da6fea3

Please sign in to comment.