From f2c727fc8cadf0971c24fdb42c8684032e7e6f80 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 26 Feb 2024 23:19:58 +0900 Subject: [PATCH 1/6] add minimal impl for masked loss --- library/config_util.py | 38 +++++++++++++++++++++++++------------- library/train_util.py | 3 +++ train_network.py | 18 +++++++++++++++++- 3 files changed, 45 insertions(+), 14 deletions(-) diff --git a/library/config_util.py b/library/config_util.py index eb652ecf3..edc6a5385 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -41,12 +41,17 @@ DatasetGroup, ) from .utils import setup_logging + setup_logging() import logging + logger = logging.getLogger(__name__) + def add_config_arguments(parser: argparse.ArgumentParser): - parser.add_argument("--dataset_config", type=Path, default=None, help="config file for detail settings / 詳細な設定用の設定ファイル") + parser.add_argument( + "--dataset_config", type=Path, default=None, help="config file for detail settings / 詳細な設定用の設定ファイル" + ) # TODO: inherit Params class in Subset, Dataset @@ -248,9 +253,10 @@ def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence] } def __init__(self, support_dreambooth: bool, support_finetuning: bool, support_controlnet: bool, support_dropout: bool) -> None: - assert ( - support_dreambooth or support_finetuning or support_controlnet - ), "Neither DreamBooth mode nor fine tuning mode specified. Please specify one mode or more. / DreamBooth モードか fine tuning モードのどちらも指定されていません。1つ以上指定してください。" + assert support_dreambooth or support_finetuning or support_controlnet, ( + "Neither DreamBooth mode nor fine tuning mode nor controlnet mode specified. Please specify one mode or more." + + " / DreamBooth モードか fine tuning モードか controlnet モードのどれも指定されていません。1つ以上指定してください。" + ) self.db_subset_schema = self.__merge_dict( self.SUBSET_ASCENDABLE_SCHEMA, @@ -362,7 +368,9 @@ def sanitize_argparse_namespace(self, argparse_namespace: argparse.Namespace) -> return self.argparse_config_validator(argparse_namespace) except MultipleInvalid: # XXX: this should be a bug - logger.error("Invalid cmdline parsed arguments. This should be a bug. / コマンドラインのパース結果が正しくないようです。プログラムのバグの可能性が高いです。") + logger.error( + "Invalid cmdline parsed arguments. This should be a bug. / コマンドラインのパース結果が正しくないようです。プログラムのバグの可能性が高いです。" + ) raise # NOTE: value would be overwritten by latter dict if there is already the same key @@ -547,11 +555,11 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu " ", ) - logger.info(f'{info}') + logger.info(f"{info}") # make buckets first because it determines the length of dataset # and set the same seed for all datasets - seed = random.randint(0, 2**31) # actual seed is seed + epoch_no + seed = random.randint(0, 2**31) # actual seed is seed + epoch_no for i, dataset in enumerate(datasets): logger.info(f"[Dataset {i}]") dataset.make_buckets() @@ -638,13 +646,17 @@ def load_user_config(file: str) -> dict: with open(file, "r") as f: config = json.load(f) except Exception: - logger.error(f"Error on parsing JSON config file. Please check the format. / JSON 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}") + logger.error( + f"Error on parsing JSON config file. Please check the format. / JSON 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}" + ) raise elif file.name.lower().endswith(".toml"): try: config = toml.load(file) except Exception: - logger.error(f"Error on parsing TOML config file. Please check the format. / TOML 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}") + logger.error( + f"Error on parsing TOML config file. Please check the format. / TOML 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}" + ) raise else: raise ValueError(f"not supported config file format / 対応していない設定ファイルの形式です: {file}") @@ -671,13 +683,13 @@ def load_user_config(file: str) -> dict: train_util.prepare_dataset_args(argparse_namespace, config_args.support_finetuning) logger.info("[argparse_namespace]") - logger.info(f'{vars(argparse_namespace)}') + logger.info(f"{vars(argparse_namespace)}") user_config = load_user_config(config_args.dataset_config) logger.info("") logger.info("[user_config]") - logger.info(f'{user_config}') + logger.info(f"{user_config}") sanitizer = ConfigSanitizer( config_args.support_dreambooth, config_args.support_finetuning, config_args.support_controlnet, config_args.support_dropout @@ -686,10 +698,10 @@ def load_user_config(file: str) -> dict: logger.info("") logger.info("[sanitized_user_config]") - logger.info(f'{sanitized_user_config}') + logger.info(f"{sanitized_user_config}") blueprint = BlueprintGenerator(sanitizer).generate(user_config, argparse_namespace) logger.info("") logger.info("[blueprint]") - logger.info(f'{blueprint}') + logger.info(f"{blueprint}") diff --git a/library/train_util.py b/library/train_util.py index b71e4edc6..7fe5bc56e 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1810,6 +1810,9 @@ def __init__( db_subsets = [] for subset in subsets: + assert ( + not subset.random_crop + ), "random_crop is not supported in ControlNetDataset / random_cropはControlNetDatasetではサポートされていません" db_subset = DreamBoothSubset( subset.image_dir, False, diff --git a/train_network.py b/train_network.py index e5b26d8a2..e3ce7bd36 100644 --- a/train_network.py +++ b/train_network.py @@ -13,6 +13,7 @@ import torch from library.device_utils import init_ipex, clean_memory_on_device + init_ipex() from torch.nn.parallel import DistributedDataParallel as DDP @@ -157,7 +158,7 @@ def train(self, args): # データセットを準備する if args.dataset_class is None: - blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True)) + blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, True)) if use_user_config: logger.info(f"Loading dataset config from {args.dataset_config}") user_config = config_util.load_user_config(args.dataset_config) @@ -834,6 +835,16 @@ def remove_model(old_ckpt_name): target = noise loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + + if args.masked_loss: + # mask image is -1 to 1. we need to convert it to 0 to 1 + mask_image = batch["conditioning_images"].to(dtype=weight_dtype)[:, 0].unsqueeze(1) # use R channel + + # resize to the same size as the loss + mask_image = torch.nn.functional.interpolate(mask_image, size=loss.shape[2:], mode="area") + mask_image = mask_image / 2 + 0.5 + loss = loss * mask_image + loss = loss.mean([1, 2, 3]) loss_weights = batch["loss_weights"] # 各sampleごとのweight @@ -1050,6 +1061,11 @@ def setup_parser() -> argparse.ArgumentParser: action="store_true", help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う", ) + parser.add_argument( + "--masked_loss", + action="store_true", + help="apply mask for caclulating loss. conditioning_data_dir is required for dataset. / 損失計算時にマスクを適用する。datasetにはconditioning_data_dirが必要", + ) return parser From 175193623b39027ffcfe0c0ae250dbce564ed6ef Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 26 Feb 2024 23:29:41 +0900 Subject: [PATCH 2/6] update readme --- README.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/README.md b/README.md index e1b6a26c3..9cc79cc09 100644 --- a/README.md +++ b/README.md @@ -249,6 +249,13 @@ ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [docum ## Change History +### Masked loss + +`train_network.py` and `sdxl_train_network.py` now support the masked loss. `--masked_loss` option is added. + +ControlNet dataset is used to specify the mask. The mask images should be the RGB images. The pixel value 255 in R channel is treated as the mask (the loss is calculated only for the pixels with the mask), and 0 is treated as the non-mask. See details for the dataset specification in the [LLLite documentation](./docs/train_lllite_README.md#preparing-the-dataset). + + ### Working in progress - `train_network.py` and `sdxl_train_network.py` are modified to record some dataset settings in the metadata of the trained model (`caption_prefix`, `caption_suffix`, `keep_tokens_separator`, `secondary_separator`, `enable_wildcard`). From 4a5546d40e6de5789be78dd16373d2b820b8754e Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 26 Feb 2024 23:39:56 +0900 Subject: [PATCH 3/6] fix typo --- train_network.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train_network.py b/train_network.py index e3ce7bd36..f5617986c 100644 --- a/train_network.py +++ b/train_network.py @@ -1064,7 +1064,7 @@ def setup_parser() -> argparse.ArgumentParser: parser.add_argument( "--masked_loss", action="store_true", - help="apply mask for caclulating loss. conditioning_data_dir is required for dataset. / 損失計算時にマスクを適用する。datasetにはconditioning_data_dirが必要", + help="apply mask for calculating loss. conditioning_data_dir is required for dataset. / 損失計算時にマスクを適用する。datasetにはconditioning_data_dirが必要", ) return parser From a9b64ffba8efbb0991a094e38b1f5d5c56680caf Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 27 Feb 2024 21:43:55 +0900 Subject: [PATCH 4/6] support masked loss in sdxl_train ref #589 --- README.md | 4 +++- sdxl_train.py | 20 +++++++++++++++++++- 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 9cc79cc09..354983c38 100644 --- a/README.md +++ b/README.md @@ -251,7 +251,9 @@ ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [docum ### Masked loss -`train_network.py` and `sdxl_train_network.py` now support the masked loss. `--masked_loss` option is added. +`train_network.py`, `sdxl_train_network.py` and `sdxl_train.py` now support the masked loss. `--masked_loss` option is added. + +NOTE: `train_network.py` and `sdxl_train.py` are not tested yet. ControlNet dataset is used to specify the mask. The mask images should be the RGB images. The pixel value 255 in R channel is treated as the mask (the loss is calculated only for the pixels with the mask), and 0 is treated as the non-mask. See details for the dataset specification in the [LLLite documentation](./docs/train_lllite_README.md#preparing-the-dataset). diff --git a/sdxl_train.py b/sdxl_train.py index e0df263d6..448a160f6 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -11,6 +11,7 @@ import torch from library.device_utils import init_ipex, clean_memory_on_device + init_ipex() from accelerate.utils import set_seed @@ -124,7 +125,7 @@ def train(args): # データセットを準備する if args.dataset_class is None: - blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True)) + blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, True)) if args.dataset_config is not None: logger.info(f"Load dataset config from {args.dataset_config}") user_config = config_util.load_user_config(args.dataset_config) @@ -579,6 +580,16 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): ): # do not mean over batch dimension for snr weight or scale v-pred loss loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + + if args.masked_loss: + # mask image is -1 to 1. we need to convert it to 0 to 1 + mask_image = batch["conditioning_images"].to(dtype=weight_dtype)[:, 0].unsqueeze(1) # use R channel + + # resize to the same size as the loss + mask_image = torch.nn.functional.interpolate(mask_image, size=loss.shape[2:], mode="area") + mask_image = mask_image / 2 + 0.5 + loss = loss * mask_image + loss = loss.mean([1, 2, 3]) if args.min_snr_gamma: @@ -780,6 +791,13 @@ def setup_parser() -> argparse.ArgumentParser: + f"U-Netの各ブロックの学習率、カンマ区切り、{UNET_NUM_BLOCKS_FOR_BLOCK_LR}個の値", ) + # TODO common masked_loss argument + parser.add_argument( + "--masked_loss", + action="store_true", + help="apply mask for calculating loss. conditioning_data_dir is required for dataset. / 損失計算時にマスクを適用する。datasetにはconditioning_data_dirが必要", + ) + return parser From 7081a0cf0f1ca1a543edf7cab10c4c7d497348ca Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 17 Mar 2024 18:09:15 +0900 Subject: [PATCH 5/6] extension of src image could be different than target image --- library/train_util.py | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 7fe5bc56e..0f8cf9eea 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1863,7 +1863,7 @@ def __init__( # assert all conditioning data exists missing_imgs = [] - cond_imgs_with_img = set() + cond_imgs_with_pair = set() for image_key, info in self.dreambooth_dataset_delegate.image_data.items(): db_subset = self.dreambooth_dataset_delegate.image_to_subset[image_key] subset = None @@ -1877,23 +1877,29 @@ def __init__( logger.warning(f"not directory: {subset.conditioning_data_dir}") continue - img_basename = os.path.basename(info.absolute_path) - ctrl_img_path = os.path.join(subset.conditioning_data_dir, img_basename) - if not os.path.exists(ctrl_img_path): + img_basename = os.path.splitext(os.path.basename(info.absolute_path))[0] + ctrl_img_path = glob_images(subset.conditioning_data_dir, img_basename) + if len(ctrl_img_path) < 1: missing_imgs.append(img_basename) + continue + ctrl_img_path = ctrl_img_path[0] + ctrl_img_path = os.path.abspath(ctrl_img_path) # normalize path info.cond_img_path = ctrl_img_path - cond_imgs_with_img.add(ctrl_img_path) + cond_imgs_with_pair.add(os.path.splitext(ctrl_img_path)[0]) # remove extension because Windows is case insensitive extra_imgs = [] for subset in subsets: conditioning_img_paths = glob_images(subset.conditioning_data_dir, "*") - extra_imgs.extend( - [cond_img_path for cond_img_path in conditioning_img_paths if cond_img_path not in cond_imgs_with_img] - ) + conditioning_img_paths = [os.path.abspath(p) for p in conditioning_img_paths] # normalize path + extra_imgs.extend([p for p in conditioning_img_paths if os.path.splitext(p)[0] not in cond_imgs_with_pair]) - assert len(missing_imgs) == 0, f"missing conditioning data for {len(missing_imgs)} images: {missing_imgs}" - assert len(extra_imgs) == 0, f"extra conditioning data for {len(extra_imgs)} images: {extra_imgs}" + assert ( + len(missing_imgs) == 0 + ), f"missing conditioning data for {len(missing_imgs)} images / 制御用画像が見つかりませんでした: {missing_imgs}" + assert ( + len(extra_imgs) == 0 + ), f"extra conditioning data for {len(extra_imgs)} images / 余分な制御用画像があります: {extra_imgs}" self.conditioning_image_transforms = IMAGE_TRANSFORMS From 3419c3de0d0ff8cba1d74444ece23608614f3c5b Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 17 Mar 2024 19:30:20 +0900 Subject: [PATCH 6/6] common masked loss func, apply to all training script --- docs/train_lllite_README-ja.md | 8 ++++++-- docs/train_lllite_README.md | 4 +++- library/config_util.py | 5 ++++- library/custom_train_functions.py | 24 ++++++++++++++++++++---- library/train_util.py | 16 ++++++++++++++++ sdxl_train.py | 21 ++++----------------- train_db.py | 7 ++++++- train_network.py | 17 +++-------------- train_textual_inversion.py | 7 ++++++- train_textual_inversion_XTI.py | 7 ++++++- 10 files changed, 74 insertions(+), 42 deletions(-) diff --git a/docs/train_lllite_README-ja.md b/docs/train_lllite_README-ja.md index dbdc1fea2..1f6a78d5c 100644 --- a/docs/train_lllite_README-ja.md +++ b/docs/train_lllite_README-ja.md @@ -21,9 +21,13 @@ ComfyUIのカスタムノードを用意しています。: https://github.com/k ## モデルの学習 ### データセットの準備 -通常のdatasetに加え、`conditioning_data_dir` で指定したディレクトリにconditioning imageを格納してください。conditioning imageは学習用画像と同じbasenameを持つ必要があります。また、conditioning imageは学習用画像と同じサイズに自動的にリサイズされます。conditioning imageにはキャプションファイルは不要です。 +DreamBooth 方式の dataset で、`conditioning_data_dir` で指定したディレクトリにconditioning imageを格納してください。 -たとえば DreamBooth 方式でキャプションファイルを用いる場合の設定ファイルは以下のようになります。 +(finetuning 方式の dataset はサポートしていません。) + +conditioning imageは学習用画像と同じbasenameを持つ必要があります。また、conditioning imageは学習用画像と同じサイズに自動的にリサイズされます。conditioning imageにはキャプションファイルは不要です。 + +たとえば、キャプションにフォルダ名ではなくキャプションファイルを用いる場合の設定ファイルは以下のようになります。 ```toml [[datasets.subsets]] diff --git a/docs/train_lllite_README.md b/docs/train_lllite_README.md index 04dc12da2..a05f87f5f 100644 --- a/docs/train_lllite_README.md +++ b/docs/train_lllite_README.md @@ -26,7 +26,9 @@ Due to the limitations of the inference environment, only CrossAttention (attn1 ### Preparing the dataset -In addition to the normal dataset, please store the conditioning image in the directory specified by `conditioning_data_dir`. The conditioning image must have the same basename as the training image. The conditioning image will be automatically resized to the same size as the training image. The conditioning image does not require a caption file. +In addition to the normal DreamBooth method dataset, please store the conditioning image in the directory specified by `conditioning_data_dir`. The conditioning image must have the same basename as the training image. The conditioning image will be automatically resized to the same size as the training image. The conditioning image does not require a caption file. + +(We do not support the finetuning method dataset.) ```toml [[datasets.subsets]] diff --git a/library/config_util.py b/library/config_util.py index edc6a5385..26daeb472 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -323,7 +323,10 @@ def validate_flex_dataset(dataset_config: dict): self.dataset_schema = validate_flex_dataset elif support_dreambooth: - self.dataset_schema = self.db_dataset_schema + if support_controlnet: + self.dataset_schema = self.cn_dataset_schema + else: + self.dataset_schema = self.db_dataset_schema elif support_finetuning: self.dataset_schema = self.ft_dataset_schema elif support_controlnet: diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index a56474622..406e0e36e 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -3,11 +3,14 @@ import random import re from typing import List, Optional, Union -from .utils import setup_logging +from .utils import setup_logging + setup_logging() -import logging +import logging + logger = logging.getLogger(__name__) + def prepare_scheduler_for_custom_training(noise_scheduler, device): if hasattr(noise_scheduler, "all_snr"): return @@ -64,7 +67,7 @@ def apply_snr_weight(loss, timesteps, noise_scheduler, gamma, v_prediction=False snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) min_snr_gamma = torch.minimum(snr, torch.full_like(snr, gamma)) if v_prediction: - snr_weight = torch.div(min_snr_gamma, snr+1).float().to(loss.device) + snr_weight = torch.div(min_snr_gamma, snr + 1).float().to(loss.device) else: snr_weight = torch.div(min_snr_gamma, snr).float().to(loss.device) loss = loss * snr_weight @@ -92,13 +95,15 @@ def add_v_prediction_like_loss(loss, timesteps, noise_scheduler, v_pred_like_los loss = loss + loss / scale * v_pred_like_loss return loss + def apply_debiased_estimation(loss, timesteps, noise_scheduler): snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000 - weight = 1/torch.sqrt(snr_t) + weight = 1 / torch.sqrt(snr_t) loss = weight * loss return loss + # TODO train_utilと分散しているのでどちらかに寄せる @@ -474,6 +479,17 @@ def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale): return noise +def apply_masked_loss(loss, batch): + # mask image is -1 to 1. we need to convert it to 0 to 1 + mask_image = batch["conditioning_images"].to(dtype=loss.dtype)[:, 0].unsqueeze(1) # use R channel + + # resize to the same size as the loss + mask_image = torch.nn.functional.interpolate(mask_image, size=loss.shape[2:], mode="area") + mask_image = mask_image / 2 + 0.5 + loss = loss * mask_image + return loss + + """ ########################################## # Perlin Noise diff --git a/library/train_util.py b/library/train_util.py index 0f8cf9eea..1d9f8bf82 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3028,6 +3028,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: "--full_bf16", action="store_true", help="bf16 training including gradients / 勾配も含めてbf16で学習する" ) # TODO move to SDXL training, because it is not supported by SD1/2 parser.add_argument("--fp8_base", action="store_true", help="use fp8 for base model / base modelにfp8を使う") + parser.add_argument( "--ddp_timeout", type=int, @@ -3090,6 +3091,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: default=None, help="specify WandB API key to log in before starting training (optional). / WandB APIキーを指定して学習開始前にログインする(オプション)", ) + parser.add_argument( "--noise_offset", type=float, @@ -3252,6 +3254,20 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: ) +def add_masked_loss_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--conditioning_data_dir", + type=str, + default=None, + help="conditioning data directory / 条件付けデータのディレクトリ", + ) + parser.add_argument( + "--masked_loss", + action="store_true", + help="apply mask for calculating loss. conditioning_data_dir is required for dataset. / 損失計算時にマスクを適用する。datasetにはconditioning_data_dirが必要", + ) + + def verify_training_args(args: argparse.Namespace): r""" Verify training arguments. Also reflect highvram option to global variable diff --git a/sdxl_train.py b/sdxl_train.py index 448a160f6..f8aa46081 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -40,6 +40,7 @@ scale_v_prediction_loss_like_noise_prediction, add_v_prediction_like_loss, apply_debiased_estimation, + apply_masked_loss, ) from library.sdxl_original_unet import SdxlUNet2DConditionModel @@ -577,19 +578,12 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): or args.scale_v_pred_loss_like_noise_pred or args.v_pred_like_loss or args.debiased_estimation_loss + or args.masked_loss ): # do not mean over batch dimension for snr weight or scale v-pred loss loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") - if args.masked_loss: - # mask image is -1 to 1. we need to convert it to 0 to 1 - mask_image = batch["conditioning_images"].to(dtype=weight_dtype)[:, 0].unsqueeze(1) # use R channel - - # resize to the same size as the loss - mask_image = torch.nn.functional.interpolate(mask_image, size=loss.shape[2:], mode="area") - mask_image = mask_image / 2 + 0.5 - loss = loss * mask_image - + loss = apply_masked_loss(loss, batch) loss = loss.mean([1, 2, 3]) if args.min_snr_gamma: @@ -755,6 +749,7 @@ def setup_parser() -> argparse.ArgumentParser: train_util.add_sd_models_arguments(parser) train_util.add_dataset_arguments(parser, True, True, True) train_util.add_training_arguments(parser, False) + train_util.add_masked_loss_arguments(parser) train_util.add_sd_saving_arguments(parser) train_util.add_optimizer_arguments(parser) config_util.add_config_arguments(parser) @@ -790,14 +785,6 @@ def setup_parser() -> argparse.ArgumentParser: help=f"learning rates for each block of U-Net, comma-separated, {UNET_NUM_BLOCKS_FOR_BLOCK_LR} values / " + f"U-Netの各ブロックの学習率、カンマ区切り、{UNET_NUM_BLOCKS_FOR_BLOCK_LR}個の値", ) - - # TODO common masked_loss argument - parser.add_argument( - "--masked_loss", - action="store_true", - help="apply mask for calculating loss. conditioning_data_dir is required for dataset. / 損失計算時にマスクを適用する。datasetにはconditioning_data_dirが必要", - ) - return parser diff --git a/train_db.py b/train_db.py index 8d36097a5..213df1516 100644 --- a/train_db.py +++ b/train_db.py @@ -12,6 +12,7 @@ import torch from library.device_utils import init_ipex, clean_memory_on_device + init_ipex() from accelerate.utils import set_seed @@ -32,6 +33,7 @@ apply_noise_offset, scale_v_prediction_loss_like_noise_prediction, apply_debiased_estimation, + apply_masked_loss, ) from library.utils import setup_logging, add_logging_arguments @@ -57,7 +59,7 @@ def train(args): # データセットを準備する if args.dataset_class is None: - blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, False, False, True)) + blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, False, args.masked_loss, True)) if args.dataset_config is not None: logger.info(f"Load dataset config from {args.dataset_config}") user_config = config_util.load_user_config(args.dataset_config) @@ -339,6 +341,8 @@ def train(args): target = noise loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + if args.masked_loss: + loss = apply_masked_loss(loss, batch) loss = loss.mean([1, 2, 3]) loss_weights = batch["loss_weights"] # 各sampleごとのweight @@ -464,6 +468,7 @@ def setup_parser() -> argparse.ArgumentParser: train_util.add_sd_models_arguments(parser) train_util.add_dataset_arguments(parser, True, False, True) train_util.add_training_arguments(parser, True) + train_util.add_masked_loss_arguments(parser) train_util.add_sd_saving_arguments(parser) train_util.add_optimizer_arguments(parser) config_util.add_config_arguments(parser) diff --git a/train_network.py b/train_network.py index f5617986c..05522070b 100644 --- a/train_network.py +++ b/train_network.py @@ -40,6 +40,7 @@ scale_v_prediction_loss_like_noise_prediction, add_v_prediction_like_loss, apply_debiased_estimation, + apply_masked_loss, ) from library.utils import setup_logging, add_logging_arguments @@ -835,16 +836,8 @@ def remove_model(old_ckpt_name): target = noise loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") - if args.masked_loss: - # mask image is -1 to 1. we need to convert it to 0 to 1 - mask_image = batch["conditioning_images"].to(dtype=weight_dtype)[:, 0].unsqueeze(1) # use R channel - - # resize to the same size as the loss - mask_image = torch.nn.functional.interpolate(mask_image, size=loss.shape[2:], mode="area") - mask_image = mask_image / 2 + 0.5 - loss = loss * mask_image - + loss = apply_masked_loss(loss, batch) loss = loss.mean([1, 2, 3]) loss_weights = batch["loss_weights"] # 各sampleごとのweight @@ -968,6 +961,7 @@ def setup_parser() -> argparse.ArgumentParser: train_util.add_sd_models_arguments(parser) train_util.add_dataset_arguments(parser, True, True, True) train_util.add_training_arguments(parser, True) + train_util.add_masked_loss_arguments(parser) train_util.add_optimizer_arguments(parser) config_util.add_config_arguments(parser) custom_train_functions.add_custom_train_arguments(parser) @@ -1061,11 +1055,6 @@ def setup_parser() -> argparse.ArgumentParser: action="store_true", help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う", ) - parser.add_argument( - "--masked_loss", - action="store_true", - help="apply mask for calculating loss. conditioning_data_dir is required for dataset. / 損失計算時にマスクを適用する。datasetにはconditioning_data_dirが必要", - ) return parser diff --git a/train_textual_inversion.py b/train_textual_inversion.py index df1d8485a..7697b9672 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -8,6 +8,7 @@ import torch from library.device_utils import init_ipex, clean_memory_on_device + init_ipex() from accelerate.utils import set_seed @@ -29,6 +30,7 @@ scale_v_prediction_loss_like_noise_prediction, add_v_prediction_like_loss, apply_debiased_estimation, + apply_masked_loss, ) from library.utils import setup_logging, add_logging_arguments @@ -268,7 +270,7 @@ def train(self, args): # データセットを準備する if args.dataset_class is None: - blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, False)) + blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, False)) if args.dataset_config is not None: accelerator.print(f"Load dataset config from {args.dataset_config}") user_config = config_util.load_user_config(args.dataset_config) @@ -586,6 +588,8 @@ def remove_model(old_ckpt_name): target = noise loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + if args.masked_loss: + loss = apply_masked_loss(loss, batch) loss = loss.mean([1, 2, 3]) loss_weights = batch["loss_weights"] # 各sampleごとのweight @@ -749,6 +753,7 @@ def setup_parser() -> argparse.ArgumentParser: train_util.add_sd_models_arguments(parser) train_util.add_dataset_arguments(parser, True, True, False) train_util.add_training_arguments(parser, True) + train_util.add_masked_loss_arguments(parser) train_util.add_optimizer_arguments(parser) config_util.add_config_arguments(parser) custom_train_functions.add_custom_train_arguments(parser, False) diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index 695fad2a8..72b79da46 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -9,6 +9,7 @@ import torch from library.device_utils import init_ipex, clean_memory_on_device + init_ipex() from accelerate.utils import set_seed @@ -31,6 +32,7 @@ apply_noise_offset, scale_v_prediction_loss_like_noise_prediction, apply_debiased_estimation, + apply_masked_loss, ) import library.original_unet as original_unet from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI @@ -200,7 +202,7 @@ def train(args): logger.info(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}") # データセットを準備する - blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, False)) + blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, False)) if args.dataset_config is not None: logger.info(f"Load dataset config from {args.dataset_config}") user_config = config_util.load_user_config(args.dataset_config) @@ -471,6 +473,8 @@ def remove_model(old_ckpt_name): target = noise loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + if args.masked_loss: + loss = apply_masked_loss(loss, batch) loss = loss.mean([1, 2, 3]) loss_weights = batch["loss_weights"] # 各sampleごとのweight @@ -662,6 +666,7 @@ def setup_parser() -> argparse.ArgumentParser: train_util.add_sd_models_arguments(parser) train_util.add_dataset_arguments(parser, True, True, False) train_util.add_training_arguments(parser, True) + train_util.add_masked_loss_arguments(parser) train_util.add_optimizer_arguments(parser) config_util.add_config_arguments(parser) custom_train_functions.add_custom_train_arguments(parser, False)