Skip to content

Commit

Permalink
Learning rate switching & fix indent (#57)
Browse files Browse the repository at this point in the history
* Learning rate switching & fix indent

Make learning rates switch from training textual inversion to unet/text encoder after unfreeze_lora_step.
I think this is how it was explained in the paper linked(?)

Either way, it might be useful to add another parameter to activate unet/text encoder training at a certain step instead of at unfreeze_lora_step.
This would let the user have more control.

Also fix indenting to make save_steps and logging work properly.

* Fix indent

fix accelerator.wait_for_everyone() indent according to original dreambooth training
  • Loading branch information
hdon96 authored Dec 19, 2022
1 parent d0c4cc5 commit 986626f
Showing 1 changed file with 77 additions and 72 deletions.
149 changes: 77 additions & 72 deletions train_lora_w_ti.py
Original file line number Diff line number Diff line change
Expand Up @@ -899,10 +899,15 @@ def collate_fn(examples):

# optimizer = accelerator.prepare(optimizer)
for step, batch in enumerate(train_dataloader):


# freeze unet and text encoder during ti training
if global_step < args.unfreeze_lora_step:
optimizer.param_groups[0]["lr"] = 0.0
optimizer.param_groups[1]["lr"] = 0.0
else: # begin learning with unet and text encoder
optimizer.param_groups[0]["lr"] = args.learning_rate
optimizer.param_groups[1]["lr"] = args.learning_rate_text
optimizer.param_groups[2]["lr"] = 0.0 # stop learning ti

# Convert images to latent space
latents = vae.encode(
Expand Down Expand Up @@ -990,86 +995,86 @@ def collate_fn(examples):

global_step += 1

# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
if args.save_steps and global_step - last_save >= args.save_steps:
if accelerator.is_main_process:
# newer versions of accelerate allow the 'keep_fp32_wrapper' arg. without passing
# it, the models will be unwrapped, and when they are then used for further training,
# we will crash. pass this, but only to newer versions of accelerate. fixes
# https://github.com/huggingface/diffusers/issues/1566
accepts_keep_fp32_wrapper = "keep_fp32_wrapper" in set(
inspect.signature(accelerator.unwrap_model).parameters.keys()
)
extra_args = (
{"keep_fp32_wrapper": True} if accepts_keep_fp32_wrapper else {}
)
pipeline = StableDiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path,
unet=accelerator.unwrap_model(unet, **extra_args),
text_encoder=accelerator.unwrap_model(
text_encoder, **extra_args
),
revision=args.revision,
)

filename_unet = (
f"{args.output_dir}/lora_weight_e{epoch}_s{global_step}.pt"
)
filename_text_encoder = f"{args.output_dir}/lora_weight_e{epoch}_s{global_step}.text_encoder.pt"
print(f"save weights {filename_unet}, {filename_text_encoder}")
save_lora_weight(pipeline.unet, filename_unet)

save_lora_weight(
pipeline.text_encoder,
filename_text_encoder,
target_replace_module=["CLIPAttention"],
)

for _up, _down in extract_lora_ups_down(pipeline.unet):
print("First Unet Layer's Up Weight is now : ", _up.weight.data)
print(
"First Unet Layer's Down Weight is now : ",
_down.weight.data,
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
if args.save_steps and global_step - last_save >= args.save_steps:
if accelerator.is_main_process:
# newer versions of accelerate allow the 'keep_fp32_wrapper' arg. without passing
# it, the models will be unwrapped, and when they are then used for further training,
# we will crash. pass this, but only to newer versions of accelerate. fixes
# https://github.com/huggingface/diffusers/issues/1566
accepts_keep_fp32_wrapper = "keep_fp32_wrapper" in set(
inspect.signature(accelerator.unwrap_model).parameters.keys()
)
break

for _up, _down in extract_lora_ups_down(
pipeline.text_encoder,
target_replace_module=["CLIPAttention"],
):
print(
"First Text Encoder Layer's Up Weight is now : ",
_up.weight.data,
extra_args = (
{"keep_fp32_wrapper": True} if accepts_keep_fp32_wrapper else {}
)
print(
"First Text Encoder Layer's Down Weight is now : ",
_down.weight.data,
pipeline = StableDiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path,
unet=accelerator.unwrap_model(unet, **extra_args),
text_encoder=accelerator.unwrap_model(
text_encoder, **extra_args
),
revision=args.revision,
)
break

filename_ti = (
f"{args.output_dir}/lora_weight_e{epoch}_s{global_step}.ti.pt"
)
filename_unet = (
f"{args.output_dir}/lora_weight_e{epoch}_s{global_step}.pt"
)
filename_text_encoder = f"{args.output_dir}/lora_weight_e{epoch}_s{global_step}.text_encoder.pt"
print(f"save weights {filename_unet}, {filename_text_encoder}")
save_lora_weight(pipeline.unet, filename_unet)

save_lora_weight(
pipeline.text_encoder,
filename_text_encoder,
target_replace_module=["CLIPAttention"],
)

save_progress(
pipeline.text_encoder,
placeholder_token_id,
accelerator,
args,
filename_ti,
)
for _up, _down in extract_lora_ups_down(pipeline.unet):
print("First Unet Layer's Up Weight is now : ", _up.weight.data)
print(
"First Unet Layer's Down Weight is now : ",
_down.weight.data,
)
break

for _up, _down in extract_lora_ups_down(
pipeline.text_encoder,
target_replace_module=["CLIPAttention"],
):
print(
"First Text Encoder Layer's Up Weight is now : ",
_up.weight.data,
)
print(
"First Text Encoder Layer's Down Weight is now : ",
_down.weight.data,
)
break

filename_ti = (
f"{args.output_dir}/lora_weight_e{epoch}_s{global_step}.ti.pt"
)

save_progress(
pipeline.text_encoder,
placeholder_token_id,
accelerator,
args,
filename_ti,
)

last_save = global_step
last_save = global_step

logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
accelerator.log(logs, step=global_step)
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
accelerator.log(logs, step=global_step)

if global_step >= args.max_train_steps:
break
if global_step >= args.max_train_steps:
break

accelerator.wait_for_everyone()
accelerator.wait_for_everyone()

# Create the pipeline using using the trained modules and save it.
if accelerator.is_main_process:
Expand Down

0 comments on commit 986626f

Please sign in to comment.