diff --git a/train_lora_w_ti.py b/train_lora_w_ti.py index d2d8f8c..120fdcd 100644 --- a/train_lora_w_ti.py +++ b/train_lora_w_ti.py @@ -899,10 +899,15 @@ def collate_fn(examples): # optimizer = accelerator.prepare(optimizer) for step, batch in enumerate(train_dataloader): - + + # freeze unet and text encoder during ti training if global_step < args.unfreeze_lora_step: optimizer.param_groups[0]["lr"] = 0.0 optimizer.param_groups[1]["lr"] = 0.0 + else: # begin learning with unet and text encoder + optimizer.param_groups[0]["lr"] = args.learning_rate + optimizer.param_groups[1]["lr"] = args.learning_rate_text + optimizer.param_groups[2]["lr"] = 0.0 # stop learning ti # Convert images to latent space latents = vae.encode( @@ -990,86 +995,86 @@ def collate_fn(examples): global_step += 1 - # Checks if the accelerator has performed an optimization step behind the scenes - if accelerator.sync_gradients: - if args.save_steps and global_step - last_save >= args.save_steps: - if accelerator.is_main_process: - # newer versions of accelerate allow the 'keep_fp32_wrapper' arg. without passing - # it, the models will be unwrapped, and when they are then used for further training, - # we will crash. pass this, but only to newer versions of accelerate. fixes - # https://github.com/huggingface/diffusers/issues/1566 - accepts_keep_fp32_wrapper = "keep_fp32_wrapper" in set( - inspect.signature(accelerator.unwrap_model).parameters.keys() - ) - extra_args = ( - {"keep_fp32_wrapper": True} if accepts_keep_fp32_wrapper else {} - ) - pipeline = StableDiffusionPipeline.from_pretrained( - args.pretrained_model_name_or_path, - unet=accelerator.unwrap_model(unet, **extra_args), - text_encoder=accelerator.unwrap_model( - text_encoder, **extra_args - ), - revision=args.revision, - ) - - filename_unet = ( - f"{args.output_dir}/lora_weight_e{epoch}_s{global_step}.pt" - ) - filename_text_encoder = f"{args.output_dir}/lora_weight_e{epoch}_s{global_step}.text_encoder.pt" - print(f"save weights {filename_unet}, {filename_text_encoder}") - save_lora_weight(pipeline.unet, filename_unet) - - save_lora_weight( - pipeline.text_encoder, - filename_text_encoder, - target_replace_module=["CLIPAttention"], - ) - - for _up, _down in extract_lora_ups_down(pipeline.unet): - print("First Unet Layer's Up Weight is now : ", _up.weight.data) - print( - "First Unet Layer's Down Weight is now : ", - _down.weight.data, + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + if args.save_steps and global_step - last_save >= args.save_steps: + if accelerator.is_main_process: + # newer versions of accelerate allow the 'keep_fp32_wrapper' arg. without passing + # it, the models will be unwrapped, and when they are then used for further training, + # we will crash. pass this, but only to newer versions of accelerate. fixes + # https://github.com/huggingface/diffusers/issues/1566 + accepts_keep_fp32_wrapper = "keep_fp32_wrapper" in set( + inspect.signature(accelerator.unwrap_model).parameters.keys() ) - break - - for _up, _down in extract_lora_ups_down( - pipeline.text_encoder, - target_replace_module=["CLIPAttention"], - ): - print( - "First Text Encoder Layer's Up Weight is now : ", - _up.weight.data, + extra_args = ( + {"keep_fp32_wrapper": True} if accepts_keep_fp32_wrapper else {} ) - print( - "First Text Encoder Layer's Down Weight is now : ", - _down.weight.data, + pipeline = StableDiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + unet=accelerator.unwrap_model(unet, **extra_args), + text_encoder=accelerator.unwrap_model( + text_encoder, **extra_args + ), + revision=args.revision, ) - break - filename_ti = ( - f"{args.output_dir}/lora_weight_e{epoch}_s{global_step}.ti.pt" - ) + filename_unet = ( + f"{args.output_dir}/lora_weight_e{epoch}_s{global_step}.pt" + ) + filename_text_encoder = f"{args.output_dir}/lora_weight_e{epoch}_s{global_step}.text_encoder.pt" + print(f"save weights {filename_unet}, {filename_text_encoder}") + save_lora_weight(pipeline.unet, filename_unet) + + save_lora_weight( + pipeline.text_encoder, + filename_text_encoder, + target_replace_module=["CLIPAttention"], + ) - save_progress( - pipeline.text_encoder, - placeholder_token_id, - accelerator, - args, - filename_ti, - ) + for _up, _down in extract_lora_ups_down(pipeline.unet): + print("First Unet Layer's Up Weight is now : ", _up.weight.data) + print( + "First Unet Layer's Down Weight is now : ", + _down.weight.data, + ) + break + + for _up, _down in extract_lora_ups_down( + pipeline.text_encoder, + target_replace_module=["CLIPAttention"], + ): + print( + "First Text Encoder Layer's Up Weight is now : ", + _up.weight.data, + ) + print( + "First Text Encoder Layer's Down Weight is now : ", + _down.weight.data, + ) + break + + filename_ti = ( + f"{args.output_dir}/lora_weight_e{epoch}_s{global_step}.ti.pt" + ) + + save_progress( + pipeline.text_encoder, + placeholder_token_id, + accelerator, + args, + filename_ti, + ) - last_save = global_step + last_save = global_step - logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} - progress_bar.set_postfix(**logs) - accelerator.log(logs, step=global_step) + logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) - if global_step >= args.max_train_steps: - break + if global_step >= args.max_train_steps: + break - accelerator.wait_for_everyone() + accelerator.wait_for_everyone() # Create the pipeline using using the trained modules and save it. if accelerator.is_main_process: