diff --git a/library/config_util.py b/library/config_util.py index ab90fb63b..47868f3ba 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -53,6 +53,7 @@ class BaseSubsetParams: shuffle_caption: bool = False caption_separator: str = ',', keep_tokens: int = 0 + keep_tokens_separator: str = None, color_aug: bool = False flip_aug: bool = False face_crop_aug_range: Optional[Tuple[float, float]] = None @@ -160,6 +161,7 @@ def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence] "random_crop": bool, "shuffle_caption": bool, "keep_tokens": int, + "keep_tokens_separator": str, "token_warmup_min": int, "token_warmup_step": Any(float,int), "caption_prefix": str, @@ -461,6 +463,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu num_repeats: {subset.num_repeats} shuffle_caption: {subset.shuffle_caption} keep_tokens: {subset.keep_tokens} + keep_tokens_separator: {subset.keep_tokens_separator} caption_dropout_rate: {subset.caption_dropout_rate} caption_dropout_every_n_epoches: {subset.caption_dropout_every_n_epochs} caption_tag_dropout_rate: {subset.caption_tag_dropout_rate} diff --git a/library/train_util.py b/library/train_util.py index d2eb7cb2d..5adc2310e 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -351,6 +351,7 @@ def __init__( shuffle_caption: bool, caption_separator: str, keep_tokens: int, + keep_tokens_separator: str, color_aug: bool, flip_aug: bool, face_crop_aug_range: Optional[Tuple[float, float]], @@ -368,6 +369,7 @@ def __init__( self.shuffle_caption = shuffle_caption self.caption_separator = caption_separator self.keep_tokens = keep_tokens + self.keep_tokens_separator = keep_tokens_separator self.color_aug = color_aug self.flip_aug = flip_aug self.face_crop_aug_range = face_crop_aug_range @@ -395,6 +397,7 @@ def __init__( shuffle_caption, caption_separator: str, keep_tokens, + keep_tokens_separator, color_aug, flip_aug, face_crop_aug_range, @@ -415,6 +418,7 @@ def __init__( shuffle_caption, caption_separator, keep_tokens, + keep_tokens_separator, color_aug, flip_aug, face_crop_aug_range, @@ -449,6 +453,7 @@ def __init__( shuffle_caption, caption_separator, keep_tokens, + keep_tokens_separator, color_aug, flip_aug, face_crop_aug_range, @@ -469,6 +474,7 @@ def __init__( shuffle_caption, caption_separator, keep_tokens, + keep_tokens_separator, color_aug, flip_aug, face_crop_aug_range, @@ -500,6 +506,7 @@ def __init__( shuffle_caption, caption_separator, keep_tokens, + keep_tokens_separator, color_aug, flip_aug, face_crop_aug_range, @@ -520,6 +527,7 @@ def __init__( shuffle_caption, caption_separator, keep_tokens, + keep_tokens_separator, color_aug, flip_aug, face_crop_aug_range, @@ -654,15 +662,29 @@ def process_caption(self, subset: BaseSubset, caption): caption = "" else: if subset.shuffle_caption or subset.token_warmup_step > 0 or subset.caption_tag_dropout_rate > 0: - tokens = [t.strip() for t in caption.strip().split(subset.caption_separator)] + fixed_tokens = [] + flex_tokens = [] + if hasattr(subset, 'keep_tokens_separator') and subset.keep_tokens_separator in caption: + fixed_part, flex_part = caption.split(subset.keep_tokens_separator, 1) + fixed_tokens = [t.strip() for t in fixed_part.split(subset.caption_separator) if t.strip()] + flex_tokens = [t.strip() for t in flex_part.split(subset.caption_separator) if t.strip()] + else: + tokens = [t.strip() for t in caption.strip().split(subset.caption_separator)] + flex_tokens = tokens[:] + if subset.keep_tokens > 0: + fixed_tokens = flex_tokens[:subset.keep_tokens] + flex_tokens = tokens[subset.keep_tokens:] + + if subset.token_warmup_step < 1: # 初回に上書きする subset.token_warmup_step = math.floor(subset.token_warmup_step * self.max_train_steps) if subset.token_warmup_step and self.current_step < subset.token_warmup_step: tokens_len = ( - math.floor((self.current_step) * ((len(tokens) - subset.token_warmup_min) / (subset.token_warmup_step))) + math.floor((self.current_step) * ((len(flex_tokens) - subset.token_warmup_min) / (subset.token_warmup_step))) + subset.token_warmup_min ) - tokens = tokens[:tokens_len] + flex_tokens = flex_tokens[:tokens_len] + def dropout_tags(tokens): if subset.caption_tag_dropout_rate <= 0: @@ -673,12 +695,6 @@ def dropout_tags(tokens): l.append(token) return l - fixed_tokens = [] - flex_tokens = tokens[:] - if subset.keep_tokens > 0: - fixed_tokens = flex_tokens[: subset.keep_tokens] - flex_tokens = tokens[subset.keep_tokens :] - if subset.shuffle_caption: random.shuffle(flex_tokens) @@ -1724,6 +1740,7 @@ def __init__( subset.shuffle_caption, subset.caption_separator, subset.keep_tokens, + subset.keep_tokens_separator, subset.color_aug, subset.flip_aug, subset.face_crop_aug_range, @@ -3131,6 +3148,12 @@ def add_dataset_arguments( default=0, help="keep heading N tokens when shuffling caption tokens (token means comma separated strings) / captionのシャッフル時に、先頭からこの個数のトークンをシャッフルしないで残す(トークンはカンマ区切りの各部分を意味する)", ) + parser.add_argument( + "--keep_tokens_separator", + type=str, + default="", + help="A custom separator to divide the caption into fixed and flexible parts. Tokens before this separator will not be shuffled. If not specified, '--keep_tokens' will be used to determine the fixed number of tokens.", + ) parser.add_argument( "--caption_prefix", type=str,