Skip to content

Commit

Permalink
Allow negative learning rate
Browse files Browse the repository at this point in the history
This can be used to train away from a group of images you don't want
As this moves the model away from a point instead of towards it, the change in the model is unbounded
So, don't set it too low. -4e-7 seemed to work well.
  • Loading branch information
Cauldrath committed Apr 19, 2024
1 parent 71e2c91 commit fc37437
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions sdxl_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
# 学習を準備する:モデルを適切な状態にする
if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()
train_unet = args.learning_rate > 0
train_unet = args.learning_rate != 0
train_text_encoder1 = False
train_text_encoder2 = False

Expand All @@ -284,8 +284,8 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
text_encoder2.gradient_checkpointing_enable()
lr_te1 = args.learning_rate_te1 if args.learning_rate_te1 is not None else args.learning_rate # 0 means not train
lr_te2 = args.learning_rate_te2 if args.learning_rate_te2 is not None else args.learning_rate # 0 means not train
train_text_encoder1 = lr_te1 > 0
train_text_encoder2 = lr_te2 > 0
train_text_encoder1 = lr_te1 != 0
train_text_encoder2 = lr_te2 != 0

# caching one text encoder output is not supported
if not train_text_encoder1:
Expand Down

0 comments on commit fc37437

Please sign in to comment.