Skip to content

Commit

Permalink
Merge pull request #975 from Linaqruf/dev
Browse files Browse the repository at this point in the history
Add keep_tokens_separator as alternative for keep_tokens
  • Loading branch information
kohya-ss authored Dec 12, 2023
2 parents 3b6825d + 1bdd83a commit 034a49c
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 9 deletions.
3 changes: 3 additions & 0 deletions library/config_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class BaseSubsetParams:
shuffle_caption: bool = False
caption_separator: str = ',',
keep_tokens: int = 0
keep_tokens_separator: str = None,
color_aug: bool = False
flip_aug: bool = False
face_crop_aug_range: Optional[Tuple[float, float]] = None
Expand Down Expand Up @@ -160,6 +161,7 @@ def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence]
"random_crop": bool,
"shuffle_caption": bool,
"keep_tokens": int,
"keep_tokens_separator": str,
"token_warmup_min": int,
"token_warmup_step": Any(float,int),
"caption_prefix": str,
Expand Down Expand Up @@ -461,6 +463,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
num_repeats: {subset.num_repeats}
shuffle_caption: {subset.shuffle_caption}
keep_tokens: {subset.keep_tokens}
keep_tokens_separator: {subset.keep_tokens_separator}
caption_dropout_rate: {subset.caption_dropout_rate}
caption_dropout_every_n_epoches: {subset.caption_dropout_every_n_epochs}
caption_tag_dropout_rate: {subset.caption_tag_dropout_rate}
Expand Down
41 changes: 32 additions & 9 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,7 @@ def __init__(
shuffle_caption: bool,
caption_separator: str,
keep_tokens: int,
keep_tokens_separator: str,
color_aug: bool,
flip_aug: bool,
face_crop_aug_range: Optional[Tuple[float, float]],
Expand All @@ -368,6 +369,7 @@ def __init__(
self.shuffle_caption = shuffle_caption
self.caption_separator = caption_separator
self.keep_tokens = keep_tokens
self.keep_tokens_separator = keep_tokens_separator
self.color_aug = color_aug
self.flip_aug = flip_aug
self.face_crop_aug_range = face_crop_aug_range
Expand Down Expand Up @@ -395,6 +397,7 @@ def __init__(
shuffle_caption,
caption_separator: str,
keep_tokens,
keep_tokens_separator,
color_aug,
flip_aug,
face_crop_aug_range,
Expand All @@ -415,6 +418,7 @@ def __init__(
shuffle_caption,
caption_separator,
keep_tokens,
keep_tokens_separator,
color_aug,
flip_aug,
face_crop_aug_range,
Expand Down Expand Up @@ -449,6 +453,7 @@ def __init__(
shuffle_caption,
caption_separator,
keep_tokens,
keep_tokens_separator,
color_aug,
flip_aug,
face_crop_aug_range,
Expand All @@ -469,6 +474,7 @@ def __init__(
shuffle_caption,
caption_separator,
keep_tokens,
keep_tokens_separator,
color_aug,
flip_aug,
face_crop_aug_range,
Expand Down Expand Up @@ -500,6 +506,7 @@ def __init__(
shuffle_caption,
caption_separator,
keep_tokens,
keep_tokens_separator,
color_aug,
flip_aug,
face_crop_aug_range,
Expand All @@ -520,6 +527,7 @@ def __init__(
shuffle_caption,
caption_separator,
keep_tokens,
keep_tokens_separator,
color_aug,
flip_aug,
face_crop_aug_range,
Expand Down Expand Up @@ -654,15 +662,29 @@ def process_caption(self, subset: BaseSubset, caption):
caption = ""
else:
if subset.shuffle_caption or subset.token_warmup_step > 0 or subset.caption_tag_dropout_rate > 0:
tokens = [t.strip() for t in caption.strip().split(subset.caption_separator)]
fixed_tokens = []
flex_tokens = []
if hasattr(subset, 'keep_tokens_separator') and subset.keep_tokens_separator in caption:
fixed_part, flex_part = caption.split(subset.keep_tokens_separator, 1)
fixed_tokens = [t.strip() for t in fixed_part.split(subset.caption_separator) if t.strip()]
flex_tokens = [t.strip() for t in flex_part.split(subset.caption_separator) if t.strip()]
else:
tokens = [t.strip() for t in caption.strip().split(subset.caption_separator)]
flex_tokens = tokens[:]
if subset.keep_tokens > 0:
fixed_tokens = flex_tokens[:subset.keep_tokens]
flex_tokens = tokens[subset.keep_tokens:]


if subset.token_warmup_step < 1: # 初回に上書きする
subset.token_warmup_step = math.floor(subset.token_warmup_step * self.max_train_steps)
if subset.token_warmup_step and self.current_step < subset.token_warmup_step:
tokens_len = (
math.floor((self.current_step) * ((len(tokens) - subset.token_warmup_min) / (subset.token_warmup_step)))
math.floor((self.current_step) * ((len(flex_tokens) - subset.token_warmup_min) / (subset.token_warmup_step)))
+ subset.token_warmup_min
)
tokens = tokens[:tokens_len]
flex_tokens = flex_tokens[:tokens_len]


def dropout_tags(tokens):
if subset.caption_tag_dropout_rate <= 0:
Expand All @@ -673,12 +695,6 @@ def dropout_tags(tokens):
l.append(token)
return l

fixed_tokens = []
flex_tokens = tokens[:]
if subset.keep_tokens > 0:
fixed_tokens = flex_tokens[: subset.keep_tokens]
flex_tokens = tokens[subset.keep_tokens :]

if subset.shuffle_caption:
random.shuffle(flex_tokens)

Expand Down Expand Up @@ -1724,6 +1740,7 @@ def __init__(
subset.shuffle_caption,
subset.caption_separator,
subset.keep_tokens,
subset.keep_tokens_separator,
subset.color_aug,
subset.flip_aug,
subset.face_crop_aug_range,
Expand Down Expand Up @@ -3131,6 +3148,12 @@ def add_dataset_arguments(
default=0,
help="keep heading N tokens when shuffling caption tokens (token means comma separated strings) / captionのシャッフル時に、先頭からこの個数のトークンをシャッフルしないで残す(トークンはカンマ区切りの各部分を意味する)",
)
parser.add_argument(
"--keep_tokens_separator",
type=str,
default="",
help="A custom separator to divide the caption into fixed and flexible parts. Tokens before this separator will not be shuffled. If not specified, '--keep_tokens' will be used to determine the fixed number of tokens.",
)
parser.add_argument(
"--caption_prefix",
type=str,
Expand Down

0 comments on commit 034a49c

Please sign in to comment.