Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Develop #66

Merged
merged 20 commits into from
Dec 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
7e78c8d
Add parameter to control rank of decomposition (#28)
brian6091 Dec 13, 2022
6aee5f3
Merge branch 'master' of https://github.com/cloneofsimo/lora into dev…
cloneofsimo Dec 14, 2022
9f31bd0
feat : statefully monkeypatch different loras + example ipynb + readme
cloneofsimo Dec 14, 2022
fececf3
Fix lora inject, added weight self apply lora (#39)
DavidePaglieri Dec 15, 2022
65438b5
Revert "Fix lora inject, added weight self apply lora (#39)" (#40)
cloneofsimo Dec 15, 2022
4975cfa
Merge branch 'master' of https://github.com/cloneofsimo/lora into dev…
cloneofsimo Dec 15, 2022
9ca7bc8
fix : rank bug in monkeypatch
cloneofsimo Dec 15, 2022
6a3ad97
fix cli fix
cloneofsimo Dec 15, 2022
40ad282
visualizatio on effect of LR
cloneofsimo Dec 15, 2022
a386525
Fix save_steps, max_train_steps, and logging (#45)
hdon96 Dec 16, 2022
6767142
Enable resuming (#52)
hdon96 Dec 16, 2022
24af4c8
feat : low-rank pivotal tuning
cloneofsimo Dec 16, 2022
046422c
feat : pivotal tuning
cloneofsimo Dec 16, 2022
0a92e62
Merge branch 'develop' of https://github.com/cloneofsimo/lora into de…
cloneofsimo Dec 16, 2022
4abbf90
v 0.0.6
cloneofsimo Dec 16, 2022
d0c4cc5
Merge branch 'master' into develop
cloneofsimo Dec 16, 2022
986626f
Learning rate switching & fix indent (#57)
hdon96 Dec 19, 2022
bbda1e5
Re:Fix indent (#58)
hdon96 Dec 19, 2022
46d9cf6
Merge branch 'master' into develop
cloneofsimo Dec 19, 2022
e1ea114
Merge branch 'develop' of https://github.com/cloneofsimo/lora into de…
cloneofsimo Dec 19, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion train_lora_dreambooth.py
Original file line number Diff line number Diff line change
Expand Up @@ -922,7 +922,7 @@ def collate_fn(examples):
if global_step >= args.max_train_steps:
break

accelerator.wait_for_everyone()
accelerator.wait_for_everyone()

# Create the pipeline using using the trained modules and save it.
if accelerator.is_main_process:
Expand Down
150 changes: 77 additions & 73 deletions train_lora_w_ti.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,6 @@ def __getitem__(self, index):

placeholder_string = self.placeholder_token
text = random.choice(self.templates).format(placeholder_string)

example["instance_prompt_ids"] = self.tokenizer(
text,
padding="do_not_pad",
Expand Down Expand Up @@ -899,10 +898,15 @@ def collate_fn(examples):

# optimizer = accelerator.prepare(optimizer)
for step, batch in enumerate(train_dataloader):


# freeze unet and text encoder during ti training
if global_step < args.unfreeze_lora_step:
optimizer.param_groups[0]["lr"] = 0.0
optimizer.param_groups[1]["lr"] = 0.0
else: # begin learning with unet and text encoder
optimizer.param_groups[0]["lr"] = args.learning_rate
optimizer.param_groups[1]["lr"] = args.learning_rate_text
optimizer.param_groups[2]["lr"] = 0.0 # stop learning ti

# Convert images to latent space
latents = vae.encode(
Expand Down Expand Up @@ -987,86 +991,86 @@ def collate_fn(examples):

global_step += 1

# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
if args.save_steps and global_step - last_save >= args.save_steps:
if accelerator.is_main_process:
# newer versions of accelerate allow the 'keep_fp32_wrapper' arg. without passing
# it, the models will be unwrapped, and when they are then used for further training,
# we will crash. pass this, but only to newer versions of accelerate. fixes
# https://github.com/huggingface/diffusers/issues/1566
accepts_keep_fp32_wrapper = "keep_fp32_wrapper" in set(
inspect.signature(accelerator.unwrap_model).parameters.keys()
)
extra_args = (
{"keep_fp32_wrapper": True} if accepts_keep_fp32_wrapper else {}
)
pipeline = StableDiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path,
unet=accelerator.unwrap_model(unet, **extra_args),
text_encoder=accelerator.unwrap_model(
text_encoder, **extra_args
),
revision=args.revision,
)

filename_unet = (
f"{args.output_dir}/lora_weight_e{epoch}_s{global_step}.pt"
)
filename_text_encoder = f"{args.output_dir}/lora_weight_e{epoch}_s{global_step}.text_encoder.pt"
print(f"save weights {filename_unet}, {filename_text_encoder}")
save_lora_weight(pipeline.unet, filename_unet)

save_lora_weight(
pipeline.text_encoder,
filename_text_encoder,
target_replace_module=["CLIPAttention"],
)

for _up, _down in extract_lora_ups_down(pipeline.unet):
print("First Unet Layer's Up Weight is now : ", _up.weight.data)
print(
"First Unet Layer's Down Weight is now : ",
_down.weight.data,
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
if args.save_steps and global_step - last_save >= args.save_steps:
if accelerator.is_main_process:
# newer versions of accelerate allow the 'keep_fp32_wrapper' arg. without passing
# it, the models will be unwrapped, and when they are then used for further training,
# we will crash. pass this, but only to newer versions of accelerate. fixes
# https://github.com/huggingface/diffusers/issues/1566
accepts_keep_fp32_wrapper = "keep_fp32_wrapper" in set(
inspect.signature(accelerator.unwrap_model).parameters.keys()
)
break

for _up, _down in extract_lora_ups_down(
pipeline.text_encoder,
target_replace_module=["CLIPAttention"],
):
print(
"First Text Encoder Layer's Up Weight is now : ",
_up.weight.data,
extra_args = (
{"keep_fp32_wrapper": True} if accepts_keep_fp32_wrapper else {}
)
print(
"First Text Encoder Layer's Down Weight is now : ",
_down.weight.data,
pipeline = StableDiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path,
unet=accelerator.unwrap_model(unet, **extra_args),
text_encoder=accelerator.unwrap_model(
text_encoder, **extra_args
),
revision=args.revision,
)
break

filename_ti = (
f"{args.output_dir}/lora_weight_e{epoch}_s{global_step}.ti.pt"
)
filename_unet = (
f"{args.output_dir}/lora_weight_e{epoch}_s{global_step}.pt"
)
filename_text_encoder = f"{args.output_dir}/lora_weight_e{epoch}_s{global_step}.text_encoder.pt"
print(f"save weights {filename_unet}, {filename_text_encoder}")
save_lora_weight(pipeline.unet, filename_unet)

save_lora_weight(
pipeline.text_encoder,
filename_text_encoder,
target_replace_module=["CLIPAttention"],
)

save_progress(
pipeline.text_encoder,
placeholder_token_id,
accelerator,
args,
filename_ti,
)
for _up, _down in extract_lora_ups_down(pipeline.unet):
print("First Unet Layer's Up Weight is now : ", _up.weight.data)
print(
"First Unet Layer's Down Weight is now : ",
_down.weight.data,
)
break

for _up, _down in extract_lora_ups_down(
pipeline.text_encoder,
target_replace_module=["CLIPAttention"],
):
print(
"First Text Encoder Layer's Up Weight is now : ",
_up.weight.data,
)
print(
"First Text Encoder Layer's Down Weight is now : ",
_down.weight.data,
)
break

filename_ti = (
f"{args.output_dir}/lora_weight_e{epoch}_s{global_step}.ti.pt"
)

save_progress(
pipeline.text_encoder,
placeholder_token_id,
accelerator,
args,
filename_ti,
)

last_save = global_step
last_save = global_step

logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
accelerator.log(logs, step=global_step)
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
accelerator.log(logs, step=global_step)

if global_step >= args.max_train_steps:
break
if global_step >= args.max_train_steps:
break

accelerator.wait_for_everyone()
accelerator.wait_for_everyone()

# Create the pipeline using using the trained modules and save it.
if accelerator.is_main_process:
Expand Down