Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable resuming #52

Merged
merged 2 commits into from
Dec 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added contents/lora_diff_lrs.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added contents/lora_diff_lrs_0.6.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
35 changes: 25 additions & 10 deletions lora_diffusion/cli_lora_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,34 +26,48 @@ def add(
] = "lpl",
with_text_lora: bool = False,
):
print("Lora Add, mode " + mode)
if mode == "lpl":
assert output_path.endswith(".pt"), "Only .pt files are supported"

for _path_1, _path_2 in (
[(path_1, path_2)] + [(_text_lora_path(path_1), _text_lora_path(path_2))]
for _path_1, _path_2, opt in [(path_1, path_2, "unet")] + (
[(_text_lora_path(path_1), _text_lora_path(path_2), "text_encoder")]
if with_text_lora
else []
):
print("Loading", _path_1, _path_2)
out_list = []
if opt == "text_encoder":
if not os.path.exists(_path_1):
print(f"No text encoder found in {_path_1}, skipping...")
continue
if not os.path.exists(_path_2):
print(f"No text encoder found in {_path_1}, skipping...")
continue

l1 = torch.load(_path_1)
l2 = torch.load(_path_2)

l1pairs = zip(l1[::2], l1[1::2])
l2pairs = zip(l2[::2], l2[1::2])

for (x1, y1), (x2, y2) in zip(l1pairs, l2pairs):
# print("Merging", x1.shape, y1.shape, x2.shape, y2.shape)
x1.data = alpha * x1.data + (1 - alpha) * x2.data
y1.data = alpha * y1.data + (1 - alpha) * y2.data

out_list.append(x1)
out_list.append(y1)

torch.save(out_list, output_path)
if with_text_lora:
torch.save(
out_list,
_text_lora_path(output_path),
)
if opt == "unet":

print("Saving merged UNET to", output_path)
torch.save(out_list, output_path)

elif opt == "text_encoder":
print("Saving merged text encoder to", _text_lora_path(output_path))
torch.save(
out_list,
_text_lora_path(output_path),
)

elif mode == "upl":

Expand Down Expand Up @@ -96,6 +110,7 @@ def add(
shutil.rmtree(_tmp_output)

else:
print("Unknown mode", mode)
raise ValueError(f"Unknown mode {mode}")


Expand Down
21 changes: 15 additions & 6 deletions lora_diffusion/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ class LoraInjectedLinear(nn.Module):
def __init__(self, in_features, out_features, bias=False, r=4):
super().__init__()

if r >= min(in_features, out_features):
if r > min(in_features, out_features):
raise ValueError(
f"LoRA rank {r} must be less than {min(in_features, out_features)}"
f"LoRA rank {r} must be less or equal than {min(in_features, out_features)}"
)

self.linear = nn.Linear(in_features, out_features, bias)
Expand All @@ -34,6 +34,7 @@ def inject_trainable_lora(
model: nn.Module,
target_replace_module: List[str] = ["CrossAttention", "Attention"],
r: int = 4,
loras = None # path to lora .pt
):
"""
inject lora into model, and returns lora parameter groups.
Expand All @@ -42,6 +43,9 @@ def inject_trainable_lora(
require_grad_params = []
names = []

if loras != None:
loras = torch.load(loras)

for _module in model.modules():
if _module.__class__.__name__ in target_replace_module:

Expand All @@ -62,18 +66,21 @@ def inject_trainable_lora(

# switch the module
_module._modules[name] = _tmp

require_grad_params.append(
_module._modules[name].lora_up.parameters()
)
require_grad_params.append(
_module._modules[name].lora_down.parameters()
)

if loras != None:
_module._modules[name].lora_up.weight = loras.pop(0)
_module._modules[name].lora_down.weight = loras.pop(0)

_module._modules[name].lora_up.weight.requires_grad = True
_module._modules[name].lora_down.weight.requires_grad = True
names.append(name)

return require_grad_params, names


Expand Down Expand Up @@ -138,7 +145,7 @@ def weight_apply_lora(


def monkeypatch_lora(
model, loras, target_replace_module=["CrossAttention", "Attention"]
model, loras, target_replace_module=["CrossAttention", "Attention"], r: int = 4
):
for _module in model.modules():
if _module.__class__.__name__ in target_replace_module:
Expand All @@ -151,6 +158,7 @@ def monkeypatch_lora(
_child_module.in_features,
_child_module.out_features,
_child_module.bias is not None,
r=r,
)
_tmp.linear.weight = weight

Expand All @@ -174,7 +182,7 @@ def monkeypatch_lora(


def monkeypatch_replace_lora(
model, loras, target_replace_module=["CrossAttention", "Attention"]
model, loras, target_replace_module=["CrossAttention", "Attention"], r: int = 4
):
for _module in model.modules():
if _module.__class__.__name__ in target_replace_module:
Expand All @@ -187,6 +195,7 @@ def monkeypatch_replace_lora(
_child_module.linear.in_features,
_child_module.linear.out_features,
_child_module.linear.bias is not None,
r=r,
)
_tmp.linear.weight = weight

Expand Down
853 changes: 853 additions & 0 deletions scripts/lora_lr_effects.ipynb

Large diffs are not rendered by default.

58 changes: 33 additions & 25 deletions scripts/run_img2img.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
setup(
name="lora_diffusion",
py_modules=["lora_diffusion"],
version="0.0.4",
version="0.0.5",
description="Low Rank Adaptation for Diffusion Models. Works with Stable Diffusion out-of-the-box.",
author="Simo Ryu",
packages=find_packages(),
Expand Down
139 changes: 78 additions & 61 deletions train_lora_dreambooth.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,22 @@ def parse_args(input_args=None):
default=-1,
help="For distributed training: local_rank",
)
parser.add_argument(
"--resume_unet",
type=str,
default=None,
help=(
"File path for unet lora to resume training."
)
)
parser.add_argument(
"--resume_text_encoder",
type=str,
default=None,
help=(
"File path for text encoder lora to resume training."
)
)

if input_args is not None:
args = parser.parse_args(input_args)
Expand Down Expand Up @@ -576,7 +592,7 @@ def main(args):
revision=args.revision,
)
unet.requires_grad_(False)
unet_lora_params, _ = inject_trainable_lora(unet, r=args.lora_rank)
unet_lora_params, _ = inject_trainable_lora(unet, r=args.lora_rank, loras=args.resume_unet)

for _up, _down in extract_lora_ups_down(unet):
print("Before training: Unet First Layer lora up", _up.weight.data)
Expand All @@ -590,6 +606,7 @@ def main(args):
text_encoder_lora_params, _ = inject_trainable_lora(
text_encoder, target_replace_module=["CLIPAttention"],
r=args.lora_rank,
loras=args.resume_text_encoder,
)
for _up, _down in extract_lora_ups_down(
text_encoder, target_replace_module=["CLIPAttention"]
Expand Down Expand Up @@ -864,74 +881,74 @@ def collate_fn(examples):

global_step += 1

# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
if args.save_steps and global_step - last_save >= args.save_steps:
if accelerator.is_main_process:
# newer versions of accelerate allow the 'keep_fp32_wrapper' arg. without passing
# it, the models will be unwrapped, and when they are then used for further training,
# we will crash. pass this, but only to newer versions of accelerate. fixes
# https://github.com/huggingface/diffusers/issues/1566
accepts_keep_fp32_wrapper = "keep_fp32_wrapper" in set(
inspect.signature(accelerator.unwrap_model).parameters.keys()
)
extra_args = (
{"keep_fp32_wrapper": True} if accepts_keep_fp32_wrapper else {}
)
pipeline = StableDiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path,
unet=accelerator.unwrap_model(unet, **extra_args),
text_encoder=accelerator.unwrap_model(
text_encoder, **extra_args
),
revision=args.revision,
)

filename_unet = (
f"{args.output_dir}/lora_weight_e{epoch}_s{global_step}.pt"
)
filename_text_encoder = f"{args.output_dir}/lora_weight_e{epoch}_s{global_step}.text_encoder.pt"
print(f"save weights {filename_unet}, {filename_text_encoder}")
save_lora_weight(pipeline.unet, filename_unet)
if args.train_text_encoder:
save_lora_weight(
pipeline.text_encoder,
filename_text_encoder,
target_replace_module=["CLIPAttention"],
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
if args.save_steps and global_step - last_save >= args.save_steps:
if accelerator.is_main_process:
# newer versions of accelerate allow the 'keep_fp32_wrapper' arg. without passing
# it, the models will be unwrapped, and when they are then used for further training,
# we will crash. pass this, but only to newer versions of accelerate. fixes
# https://github.com/huggingface/diffusers/issues/1566
accepts_keep_fp32_wrapper = "keep_fp32_wrapper" in set(
inspect.signature(accelerator.unwrap_model).parameters.keys()
)
extra_args = (
{"keep_fp32_wrapper": True} if accepts_keep_fp32_wrapper else {}
)
pipeline = StableDiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path,
unet=accelerator.unwrap_model(unet, **extra_args),
text_encoder=accelerator.unwrap_model(
text_encoder, **extra_args
),
revision=args.revision,
)

for _up, _down in extract_lora_ups_down(pipeline.unet):
print("First Unet Layer's Up Weight is now : ", _up.weight.data)
print(
"First Unet Layer's Down Weight is now : ",
_down.weight.data,
filename_unet = (
f"{args.output_dir}/lora_weight_e{epoch}_s{global_step}.pt"
)
break
if args.train_text_encoder:
for _up, _down in extract_lora_ups_down(
pipeline.text_encoder,
target_replace_module=["CLIPAttention"],
):
print(
"First Text Encoder Layer's Up Weight is now : ",
_up.weight.data,
filename_text_encoder = f"{args.output_dir}/lora_weight_e{epoch}_s{global_step}.text_encoder.pt"
print(f"save weights {filename_unet}, {filename_text_encoder}")
save_lora_weight(pipeline.unet, filename_unet)
if args.train_text_encoder:
save_lora_weight(
pipeline.text_encoder,
filename_text_encoder,
target_replace_module=["CLIPAttention"],
)

for _up, _down in extract_lora_ups_down(pipeline.unet):
print("First Unet Layer's Up Weight is now : ", _up.weight.data)
print(
"First Text Encoder Layer's Down Weight is now : ",
"First Unet Layer's Down Weight is now : ",
_down.weight.data,
)
break

last_save = global_step

logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
accelerator.log(logs, step=global_step)

if global_step >= args.max_train_steps:
break

accelerator.wait_for_everyone()
if args.train_text_encoder:
for _up, _down in extract_lora_ups_down(
pipeline.text_encoder,
target_replace_module=["CLIPAttention"],
):
print(
"First Text Encoder Layer's Up Weight is now : ",
_up.weight.data,
)
print(
"First Text Encoder Layer's Down Weight is now : ",
_down.weight.data,
)
break

last_save = global_step

logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
accelerator.log(logs, step=global_step)

if global_step >= args.max_train_steps:
break

accelerator.wait_for_everyone()

# Create the pipeline using using the trained modules and save it.
if accelerator.is_main_process:
Expand Down