From ae97c8bfd18e4b51bdeae0a72753c8e9ceeff29d Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Sun, 24 Mar 2024 14:40:18 +0800 Subject: [PATCH 1/2] [Experimental] Add cache mechanism for dataset groups to avoid long waiting time for initilization (#1178) * support meta cached dataset * add cache meta scripts * random ip_noise_gamma strength * random noise_offset strength * use correct settings for parser * cache path/caption/size only * revert mess up commit * revert mess up commit * Update requirements.txt * Add arguments for meta cache. * remove pickle implementation * Return sizes when enable cache --------- Co-authored-by: Kohya S <52813779+kohya-ss@users.noreply.github.com> --- cache_dataset_meta.py | 103 +++++++++++++++++++++++++++++++++++++++++ library/config_util.py | 4 ++ library/train_util.py | 83 ++++++++++++++++++++++++--------- requirements.txt | 2 + train_network.py | 3 +- 5 files changed, 173 insertions(+), 22 deletions(-) create mode 100644 cache_dataset_meta.py diff --git a/cache_dataset_meta.py b/cache_dataset_meta.py new file mode 100644 index 000000000..7e7d96d12 --- /dev/null +++ b/cache_dataset_meta.py @@ -0,0 +1,103 @@ +import argparse +import random + +from accelerate.utils import set_seed + +import library.train_util as train_util +import library.config_util as config_util +from library.config_util import ( + ConfigSanitizer, + BlueprintGenerator, +) +import library.custom_train_functions as custom_train_functions +from library.utils import setup_logging, add_logging_arguments + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +def make_dataset(args): + train_util.prepare_dataset_args(args, True) + setup_logging(args, reset=True) + + use_dreambooth_method = args.in_json is None + use_user_config = args.dataset_config is not None + + if args.seed is None: + args.seed = random.randint(0, 2**32) + set_seed(args.seed) + + # データセットを準備する + if args.dataset_class is None: + blueprint_generator = BlueprintGenerator( + ConfigSanitizer(True, True, False, True) + ) + if use_user_config: + logger.info(f"Loading dataset config from {args.dataset_config}") + user_config = config_util.load_user_config(args.dataset_config) + ignored = ["train_data_dir", "reg_data_dir", "in_json"] + if any(getattr(args, attr) is not None for attr in ignored): + logger.warning( + "ignoring the following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( + ", ".join(ignored) + ) + ) + else: + if use_dreambooth_method: + logger.info("Using DreamBooth method.") + user_config = { + "datasets": [ + { + "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs( + args.train_data_dir, args.reg_data_dir + ) + } + ] + } + else: + logger.info("Training with captions.") + user_config = { + "datasets": [ + { + "subsets": [ + { + "image_dir": args.train_data_dir, + "metadata_file": args.in_json, + } + ] + } + ] + } + + blueprint = blueprint_generator.generate(user_config, args, tokenizer=None) + train_dataset_group = config_util.generate_dataset_group_by_blueprint( + blueprint.dataset_group + ) + else: + # use arbitrary dataset class + train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer=None) + return train_dataset_group + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + add_logging_arguments(parser) + train_util.add_dataset_arguments(parser, True, True, True) + train_util.add_training_arguments(parser, True) + config_util.add_config_arguments(parser) + custom_train_functions.add_custom_train_arguments(parser) + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args, unknown = parser.parse_known_args() + args = train_util.read_config_from_file(args, parser) + if args.max_token_length is None: + args.max_token_length = 75 + args.cache_meta = True + + dataset_group = make_dataset(args) diff --git a/library/config_util.py b/library/config_util.py index eb652ecf3..58ffa5f4d 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -111,6 +111,8 @@ class DreamBoothDatasetParams(BaseDatasetParams): bucket_reso_steps: int = 64 bucket_no_upscale: bool = False prior_loss_weight: float = 1.0 + cache_meta: bool = False + use_cached_meta: bool = False @dataclass @@ -228,6 +230,8 @@ def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence] "min_bucket_reso": int, "resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int), "network_multiplier": float, + "cache_meta": bool, + "use_cached_meta": bool, } # options handled by argparse but not handled by user config diff --git a/library/train_util.py b/library/train_util.py index 99aeea90d..58c0cc14b 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -63,6 +63,7 @@ from huggingface_hub import hf_hub_download import numpy as np from PIL import Image +import imagesize import cv2 import safetensors.torch from library.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipeline @@ -1080,8 +1081,7 @@ def cache_text_encoder_outputs( ) def get_image_size(self, image_path): - image = Image.open(image_path) - return image.size + return imagesize.get(image_path) def load_image_with_face_info(self, subset: BaseSubset, image_path: str): img = load_image(image_path) @@ -1425,6 +1425,8 @@ def __init__( bucket_no_upscale: bool, prior_loss_weight: float, debug_dataset: bool, + cache_meta: bool, + use_cached_meta: bool, ) -> None: super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset) @@ -1484,26 +1486,43 @@ def load_dreambooth_dir(subset: DreamBoothSubset): logger.warning(f"not directory: {subset.image_dir}") return [], [] - img_paths = glob_images(subset.image_dir, "*") + sizes = None + if use_cached_meta: + logger.info(f"using cached metadata: {subset.image_dir}/dataset.txt") + # [img_path, caption, resolution] + with open(f"{subset.image_dir}/dataset.txt", "r", encoding="utf-8") as f: + metas = f.readlines() + metas = [x.strip().split("<|##|>") for x in metas] + sizes = [tuple(int(res) for res in x[2].split(" ")) for x in metas] + + if use_cached_meta: + img_paths = [x[0] for x in metas] + else: + img_paths = glob_images(subset.image_dir, "*") + sizes = [None]*len(img_paths) logger.info(f"found directory {subset.image_dir} contains {len(img_paths)} image files") - # 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う - captions = [] - missing_captions = [] - for img_path in img_paths: - cap_for_img = read_caption(img_path, subset.caption_extension, subset.enable_wildcard) - if cap_for_img is None and subset.class_tokens is None: - logger.warning( - f"neither caption file nor class tokens are found. use empty caption for {img_path} / キャプションファイルもclass tokenも見つかりませんでした。空のキャプションを使用します: {img_path}" - ) - captions.append("") - missing_captions.append(img_path) - else: - if cap_for_img is None: - captions.append(subset.class_tokens) + if use_cached_meta: + captions = [x[1] for x in metas] + missing_captions = [x[0] for x in metas if x[1] == ""] + else: + # 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う + captions = [] + missing_captions = [] + for img_path in img_paths: + cap_for_img = read_caption(img_path, subset.caption_extension, subset.enable_wildcard) + if cap_for_img is None and subset.class_tokens is None: + logger.warning( + f"neither caption file nor class tokens are found. use empty caption for {img_path} / キャプションファイルもclass tokenも見つかりませんでした。空のキャプションを使用します: {img_path}" + ) + captions.append("") missing_captions.append(img_path) else: - captions.append(cap_for_img) + if cap_for_img is None: + captions.append(subset.class_tokens) + missing_captions.append(img_path) + else: + captions.append(cap_for_img) self.set_tag_frequency(os.path.basename(subset.image_dir), captions) # タグ頻度を記録 @@ -1520,7 +1539,21 @@ def load_dreambooth_dir(subset: DreamBoothSubset): logger.warning(missing_caption + f"... and {remaining_missing_captions} more") break logger.warning(missing_caption) - return img_paths, captions + + if cache_meta: + logger.info(f"cache metadata for {subset.image_dir}") + if sizes is None or sizes[0] is None: + sizes = [self.get_image_size(img_path) for img_path in img_paths] + # [img_path, caption, resolution] + data = [ + (img_path, caption, " ".join(str(x) for x in size)) + for img_path, caption, size in zip(img_paths, captions, sizes) + ] + with open(f"{subset.image_dir}/dataset.txt", "w", encoding="utf-8") as f: + f.write("\n".join(["<|##|>".join(x) for x in data])) + logger.info(f"cache metadata done for {subset.image_dir}") + + return img_paths, captions, sizes logger.info("prepare images.") num_train_images = 0 @@ -1539,7 +1572,7 @@ def load_dreambooth_dir(subset: DreamBoothSubset): ) continue - img_paths, captions = load_dreambooth_dir(subset) + img_paths, captions, sizes = load_dreambooth_dir(subset) if len(img_paths) < 1: logger.warning( f"ignore subset with image_dir='{subset.image_dir}': no images found / 画像が見つからないためサブセットを無視します" @@ -1551,8 +1584,10 @@ def load_dreambooth_dir(subset: DreamBoothSubset): else: num_train_images += subset.num_repeats * len(img_paths) - for img_path, caption in zip(img_paths, captions): + for img_path, caption, size in zip(img_paths, captions, sizes): info = ImageInfo(img_path, subset.num_repeats, caption, subset.is_reg, img_path) + if size is not None: + info.image_size = size if subset.is_reg: reg_infos.append((info, subset)) else: @@ -3355,6 +3390,12 @@ def add_dataset_arguments( parser: argparse.ArgumentParser, support_dreambooth: bool, support_caption: bool, support_caption_dropout: bool ): # dataset common + parser.add_argument( + "--cache_meta", action="store_true" + ) + parser.add_argument( + "--use_cached_meta", action="store_true" + ) parser.add_argument( "--train_data_dir", type=str, default=None, help="directory for train images / 学習画像データのディレクトリ" ) diff --git a/requirements.txt b/requirements.txt index 805f0501d..c7aeb6895 100644 --- a/requirements.txt +++ b/requirements.txt @@ -15,6 +15,8 @@ easygui==0.98.3 toml==0.10.2 voluptuous==0.13.1 huggingface-hub==0.20.1 +# for Image utils +imagesize==1.4.1 # for BLIP captioning # requests==2.28.2 # timm==0.6.12 diff --git a/train_network.py b/train_network.py index 9e573d9f6..b42daba71 100644 --- a/train_network.py +++ b/train_network.py @@ -6,6 +6,7 @@ import random import time import json +import pickle from multiprocessing import Value import toml @@ -23,7 +24,7 @@ import library.train_util as train_util from library.train_util import ( - DreamBoothDataset, + DreamBoothDataset, DatasetGroup ) import library.config_util as config_util from library.config_util import ( From 025347214d761d63c5475fec83e11856f3cdbe9d Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 24 Mar 2024 18:09:32 +0900 Subject: [PATCH 2/2] refactor metadata caching for DreamBooth dataset --- cache_dataset_meta.py | 103 --------------------------------------- docs/config_README-en.md | 4 ++ docs/config_README-ja.md | 4 ++ library/config_util.py | 39 +++++++++------ library/train_util.py | 86 ++++++++++++++++++-------------- train_network.py | 8 +-- 6 files changed, 85 insertions(+), 159 deletions(-) delete mode 100644 cache_dataset_meta.py diff --git a/cache_dataset_meta.py b/cache_dataset_meta.py deleted file mode 100644 index 7e7d96d12..000000000 --- a/cache_dataset_meta.py +++ /dev/null @@ -1,103 +0,0 @@ -import argparse -import random - -from accelerate.utils import set_seed - -import library.train_util as train_util -import library.config_util as config_util -from library.config_util import ( - ConfigSanitizer, - BlueprintGenerator, -) -import library.custom_train_functions as custom_train_functions -from library.utils import setup_logging, add_logging_arguments - -setup_logging() -import logging - -logger = logging.getLogger(__name__) - - -def make_dataset(args): - train_util.prepare_dataset_args(args, True) - setup_logging(args, reset=True) - - use_dreambooth_method = args.in_json is None - use_user_config = args.dataset_config is not None - - if args.seed is None: - args.seed = random.randint(0, 2**32) - set_seed(args.seed) - - # データセットを準備する - if args.dataset_class is None: - blueprint_generator = BlueprintGenerator( - ConfigSanitizer(True, True, False, True) - ) - if use_user_config: - logger.info(f"Loading dataset config from {args.dataset_config}") - user_config = config_util.load_user_config(args.dataset_config) - ignored = ["train_data_dir", "reg_data_dir", "in_json"] - if any(getattr(args, attr) is not None for attr in ignored): - logger.warning( - "ignoring the following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( - ", ".join(ignored) - ) - ) - else: - if use_dreambooth_method: - logger.info("Using DreamBooth method.") - user_config = { - "datasets": [ - { - "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs( - args.train_data_dir, args.reg_data_dir - ) - } - ] - } - else: - logger.info("Training with captions.") - user_config = { - "datasets": [ - { - "subsets": [ - { - "image_dir": args.train_data_dir, - "metadata_file": args.in_json, - } - ] - } - ] - } - - blueprint = blueprint_generator.generate(user_config, args, tokenizer=None) - train_dataset_group = config_util.generate_dataset_group_by_blueprint( - blueprint.dataset_group - ) - else: - # use arbitrary dataset class - train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer=None) - return train_dataset_group - - -def setup_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser() - add_logging_arguments(parser) - train_util.add_dataset_arguments(parser, True, True, True) - train_util.add_training_arguments(parser, True) - config_util.add_config_arguments(parser) - custom_train_functions.add_custom_train_arguments(parser) - return parser - - -if __name__ == "__main__": - parser = setup_parser() - - args, unknown = parser.parse_known_args() - args = train_util.read_config_from_file(args, parser) - if args.max_token_length is None: - args.max_token_length = 75 - args.cache_meta = True - - dataset_group = make_dataset(args) diff --git a/docs/config_README-en.md b/docs/config_README-en.md index e99fde216..83bea329b 100644 --- a/docs/config_README-en.md +++ b/docs/config_README-en.md @@ -177,6 +177,7 @@ Options related to the configuration of DreamBooth subsets. | `image_dir` | `'C:\hoge'` | - | - | o (required) | | `caption_extension` | `".txt"` | o | o | o | | `class_tokens` | `"sks girl"` | - | - | o | +| `cache_info` | `false` | o | o | o | | `is_reg` | `false` | - | - | o | Firstly, note that for `image_dir`, the path to the image files must be specified as being directly in the directory. Unlike the previous DreamBooth method, where images had to be placed in subdirectories, this is not compatible with that specification. Also, even if you name the folder something like "5_cat", the number of repeats of the image and the class name will not be reflected. If you want to set these individually, you will need to explicitly specify them using `num_repeats` and `class_tokens`. @@ -187,6 +188,9 @@ Firstly, note that for `image_dir`, the path to the image files must be specifie * `class_tokens` * Sets the class tokens. * Only used during training when a corresponding caption file does not exist. The determination of whether or not to use it is made on a per-image basis. If `class_tokens` is not specified and a caption file is not found, an error will occur. +* `cache_info` + * Specifies whether to cache the image size and caption. If not specified, it is set to `false`. The cache is saved in `metadata_cache.json` in `image_dir`. + * Caching speeds up the loading of the dataset after the first time. It is effective when dealing with thousands of images or more. * `is_reg` * Specifies whether the subset images are for normalization. If not specified, it is set to `false`, meaning that the images are not for normalization. diff --git a/docs/config_README-ja.md b/docs/config_README-ja.md index b57ae86a7..cc74c341b 100644 --- a/docs/config_README-ja.md +++ b/docs/config_README-ja.md @@ -173,6 +173,7 @@ DreamBooth 方式のサブセットの設定に関わるオプションです。 | `image_dir` | `‘C:\hoge’` | - | - | o(必須) | | `caption_extension` | `".txt"` | o | o | o | | `class_tokens` | `“sks girl”` | - | - | o | +| `cache_info` | `false` | o | o | o | | `is_reg` | `false` | - | - | o | まず注意点として、 `image_dir` には画像ファイルが直下に置かれているパスを指定する必要があります。従来の DreamBooth の手法ではサブディレクトリに画像を置く必要がありましたが、そちらとは仕様に互換性がありません。また、`5_cat` のようなフォルダ名にしても、画像の繰り返し回数とクラス名は反映されません。これらを個別に設定したい場合、`num_repeats` と `class_tokens` で明示的に指定する必要があることに注意してください。 @@ -183,6 +184,9 @@ DreamBooth 方式のサブセットの設定に関わるオプションです。 * `class_tokens` * クラストークンを設定します。 * 画像に対応する caption ファイルが存在しない場合にのみ学習時に利用されます。利用するかどうかの判定は画像ごとに行います。`class_tokens` を指定しなかった場合に caption ファイルも見つからなかった場合にはエラーになります。 +* `cache_info` + * 画像サイズ、キャプションをキャッシュするかどうかを指定します。指定しなかった場合は `false` になります。キャッシュは `image_dir` に `metadata_cache.json` というファイル名で保存されます。 + * キャッシュを行うと、二回目以降のデータセット読み込みが高速化されます。数千枚以上の画像を扱う場合には有効です。 * `is_reg` * サブセットの画像が正規化用かどうかを指定します。指定しなかった場合は `false` として、つまり正規化画像ではないとして扱います。 diff --git a/library/config_util.py b/library/config_util.py index 58ffa5f4d..e52b7fc02 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -41,12 +41,17 @@ DatasetGroup, ) from .utils import setup_logging + setup_logging() import logging + logger = logging.getLogger(__name__) + def add_config_arguments(parser: argparse.ArgumentParser): - parser.add_argument("--dataset_config", type=Path, default=None, help="config file for detail settings / 詳細な設定用の設定ファイル") + parser.add_argument( + "--dataset_config", type=Path, default=None, help="config file for detail settings / 詳細な設定用の設定ファイル" + ) # TODO: inherit Params class in Subset, Dataset @@ -80,6 +85,7 @@ class DreamBoothSubsetParams(BaseSubsetParams): is_reg: bool = False class_tokens: Optional[str] = None caption_extension: str = ".caption" + cache_info: bool = False @dataclass @@ -91,6 +97,7 @@ class FineTuningSubsetParams(BaseSubsetParams): class ControlNetSubsetParams(BaseSubsetParams): conditioning_data_dir: str = None caption_extension: str = ".caption" + cache_info: bool = False @dataclass @@ -111,8 +118,6 @@ class DreamBoothDatasetParams(BaseDatasetParams): bucket_reso_steps: int = 64 bucket_no_upscale: bool = False prior_loss_weight: float = 1.0 - cache_meta: bool = False - use_cached_meta: bool = False @dataclass @@ -202,6 +207,7 @@ def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence] DB_SUBSET_ASCENDABLE_SCHEMA = { "caption_extension": str, "class_tokens": str, + "cache_info": bool, } DB_SUBSET_DISTINCT_SCHEMA = { Required("image_dir"): str, @@ -214,6 +220,7 @@ def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence] } CN_SUBSET_ASCENDABLE_SCHEMA = { "caption_extension": str, + "cache_info": bool, } CN_SUBSET_DISTINCT_SCHEMA = { Required("image_dir"): str, @@ -230,8 +237,6 @@ def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence] "min_bucket_reso": int, "resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int), "network_multiplier": float, - "cache_meta": bool, - "use_cached_meta": bool, } # options handled by argparse but not handled by user config @@ -366,7 +371,9 @@ def sanitize_argparse_namespace(self, argparse_namespace: argparse.Namespace) -> return self.argparse_config_validator(argparse_namespace) except MultipleInvalid: # XXX: this should be a bug - logger.error("Invalid cmdline parsed arguments. This should be a bug. / コマンドラインのパース結果が正しくないようです。プログラムのバグの可能性が高いです。") + logger.error( + "Invalid cmdline parsed arguments. This should be a bug. / コマンドラインのパース結果が正しくないようです。プログラムのバグの可能性が高いです。" + ) raise # NOTE: value would be overwritten by latter dict if there is already the same key @@ -551,11 +558,11 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu " ", ) - logger.info(f'{info}') + logger.info(f"{info}") # make buckets first because it determines the length of dataset # and set the same seed for all datasets - seed = random.randint(0, 2**31) # actual seed is seed + epoch_no + seed = random.randint(0, 2**31) # actual seed is seed + epoch_no for i, dataset in enumerate(datasets): logger.info(f"[Dataset {i}]") dataset.make_buckets() @@ -642,13 +649,17 @@ def load_user_config(file: str) -> dict: with open(file, "r") as f: config = json.load(f) except Exception: - logger.error(f"Error on parsing JSON config file. Please check the format. / JSON 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}") + logger.error( + f"Error on parsing JSON config file. Please check the format. / JSON 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}" + ) raise elif file.name.lower().endswith(".toml"): try: config = toml.load(file) except Exception: - logger.error(f"Error on parsing TOML config file. Please check the format. / TOML 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}") + logger.error( + f"Error on parsing TOML config file. Please check the format. / TOML 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}" + ) raise else: raise ValueError(f"not supported config file format / 対応していない設定ファイルの形式です: {file}") @@ -675,13 +686,13 @@ def load_user_config(file: str) -> dict: train_util.prepare_dataset_args(argparse_namespace, config_args.support_finetuning) logger.info("[argparse_namespace]") - logger.info(f'{vars(argparse_namespace)}') + logger.info(f"{vars(argparse_namespace)}") user_config = load_user_config(config_args.dataset_config) logger.info("") logger.info("[user_config]") - logger.info(f'{user_config}') + logger.info(f"{user_config}") sanitizer = ConfigSanitizer( config_args.support_dreambooth, config_args.support_finetuning, config_args.support_controlnet, config_args.support_dropout @@ -690,10 +701,10 @@ def load_user_config(file: str) -> dict: logger.info("") logger.info("[sanitized_user_config]") - logger.info(f'{sanitized_user_config}') + logger.info(f"{sanitized_user_config}") blueprint = BlueprintGenerator(sanitizer).generate(user_config, argparse_namespace) logger.info("") logger.info("[blueprint]") - logger.info(f'{blueprint}') + logger.info(f"{blueprint}") diff --git a/library/train_util.py b/library/train_util.py index 58c0cc14b..743a1147b 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -410,6 +410,7 @@ def __init__( is_reg: bool, class_tokens: Optional[str], caption_extension: str, + cache_info: bool, num_repeats, shuffle_caption, caption_separator: str, @@ -458,6 +459,7 @@ def __init__( self.caption_extension = caption_extension if self.caption_extension and not self.caption_extension.startswith("."): self.caption_extension = "." + self.caption_extension + self.cache_info = cache_info def __eq__(self, other) -> bool: if not isinstance(other, DreamBoothSubset): @@ -527,6 +529,7 @@ def __init__( image_dir: str, conditioning_data_dir: str, caption_extension: str, + cache_info: bool, num_repeats, shuffle_caption, caption_separator, @@ -574,6 +577,7 @@ def __init__( self.caption_extension = caption_extension if self.caption_extension and not self.caption_extension.startswith("."): self.caption_extension = "." + self.caption_extension + self.cache_info = cache_info def __eq__(self, other) -> bool: if not isinstance(other, ControlNetSubset): @@ -1410,6 +1414,8 @@ def get_item_for_caching(self, bucket, bucket_batch_size, image_index): class DreamBoothDataset(BaseDataset): + IMAGE_INFO_CACHE_FILE = "metadata_cache.json" + def __init__( self, subsets: Sequence[DreamBoothSubset], @@ -1425,8 +1431,6 @@ def __init__( bucket_no_upscale: bool, prior_loss_weight: float, debug_dataset: bool, - cache_meta: bool, - use_cached_meta: bool, ) -> None: super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset) @@ -1486,25 +1490,36 @@ def load_dreambooth_dir(subset: DreamBoothSubset): logger.warning(f"not directory: {subset.image_dir}") return [], [] - sizes = None - if use_cached_meta: - logger.info(f"using cached metadata: {subset.image_dir}/dataset.txt") - # [img_path, caption, resolution] - with open(f"{subset.image_dir}/dataset.txt", "r", encoding="utf-8") as f: - metas = f.readlines() - metas = [x.strip().split("<|##|>") for x in metas] - sizes = [tuple(int(res) for res in x[2].split(" ")) for x in metas] - - if use_cached_meta: - img_paths = [x[0] for x in metas] + info_cache_file = os.path.join(subset.image_dir, self.IMAGE_INFO_CACHE_FILE) + use_cached_info_for_subset = subset.cache_info + if use_cached_info_for_subset: + logger.info( + f"using cached image info for this subset / このサブセットで、キャッシュされた画像情報を使います: {info_cache_file}" + ) + if not os.path.isfile(info_cache_file): + logger.warning( + f"image info file not found. You can ignore this warning if this is the first time to use this subset" + + " / キャッシュファイルが見つかりませんでした。初回実行時はこの警告を無視してください: {metadata_file}" + ) + use_cached_info_for_subset = False + + if use_cached_info_for_subset: + # json: {`img_path`:{"caption": "caption...", "resolution": [width, height]}, ...} + with open(info_cache_file, "r", encoding="utf-8") as f: + metas = json.load(f) + img_paths = list(metas.keys()) + sizes = [meta["resolution"] for meta in metas.values()] + + # we may need to check image size and existence of image files, but it takes time, so user should check it before training else: img_paths = glob_images(subset.image_dir, "*") - sizes = [None]*len(img_paths) + sizes = [None] * len(img_paths) + logger.info(f"found directory {subset.image_dir} contains {len(img_paths)} image files") - if use_cached_meta: - captions = [x[1] for x in metas] - missing_captions = [x[0] for x in metas if x[1] == ""] + if use_cached_info_for_subset: + captions = [meta["caption"] for meta in metas.values()] + missing_captions = [img_path for img_path, caption in zip(img_paths, captions) if caption is None or caption == ""] else: # 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う captions = [] @@ -1540,19 +1555,17 @@ def load_dreambooth_dir(subset: DreamBoothSubset): break logger.warning(missing_caption) - if cache_meta: - logger.info(f"cache metadata for {subset.image_dir}") - if sizes is None or sizes[0] is None: - sizes = [self.get_image_size(img_path) for img_path in img_paths] - # [img_path, caption, resolution] - data = [ - (img_path, caption, " ".join(str(x) for x in size)) - for img_path, caption, size in zip(img_paths, captions, sizes) - ] - with open(f"{subset.image_dir}/dataset.txt", "w", encoding="utf-8") as f: - f.write("\n".join(["<|##|>".join(x) for x in data])) - logger.info(f"cache metadata done for {subset.image_dir}") - + if not use_cached_info_for_subset and subset.cache_info: + logger.info(f"cache image info for / 画像情報をキャッシュします : {info_cache_file}") + sizes = [self.get_image_size(img_path) for img_path in tqdm(img_paths, desc="get image size")] + matas = {} + for img_path, caption, size in zip(img_paths, captions, sizes): + matas[img_path] = {"caption": caption, "resolution": list(size)} + with open(info_cache_file, "w", encoding="utf-8") as f: + json.dump(matas, f, ensure_ascii=False, indent=2) + logger.info(f"cache image info done for / 画像情報を出力しました : {info_cache_file}") + + # if sizes are not set, image size will be read in make_buckets return img_paths, captions, sizes logger.info("prepare images.") @@ -1873,7 +1886,8 @@ def __init__( subset.image_dir, False, None, - subset.caption_extension, + subset.caption_extension, + subset.cache_info, subset.num_repeats, subset.shuffle_caption, subset.caption_separator, @@ -3391,13 +3405,13 @@ def add_dataset_arguments( ): # dataset common parser.add_argument( - "--cache_meta", action="store_true" - ) - parser.add_argument( - "--use_cached_meta", action="store_true" + "--train_data_dir", type=str, default=None, help="directory for train images / 学習画像データのディレクトリ" ) parser.add_argument( - "--train_data_dir", type=str, default=None, help="directory for train images / 学習画像データのディレクトリ" + "--cache_info", + action="store_true", + help="cache meta information (caption and image size) for faster dataset loading. only available for DreamBooth" + + " / メタ情報(キャプションとサイズ)をキャッシュしてデータセット読み込みを高速化する。DreamBooth方式のみ有効", ) parser.add_argument( "--shuffle_caption", action="store_true", help="shuffle separated caption / 区切られたcaptionの各要素をshuffleする" diff --git a/train_network.py b/train_network.py index b42daba71..7ae9283cb 100644 --- a/train_network.py +++ b/train_network.py @@ -6,7 +6,6 @@ import random import time import json -import pickle from multiprocessing import Value import toml @@ -14,18 +13,15 @@ import torch from library.device_utils import init_ipex, clean_memory_on_device -init_ipex() -from torch.nn.parallel import DistributedDataParallel as DDP +init_ipex() from accelerate.utils import set_seed from diffusers import DDPMScheduler from library import model_util import library.train_util as train_util -from library.train_util import ( - DreamBoothDataset, DatasetGroup -) +from library.train_util import DreamBoothDataset import library.config_util as config_util from library.config_util import ( ConfigSanitizer,