Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

v 0.0.5 #42

Merged
merged 9 commits into from
Dec 15, 2022
Binary file added contents/lora_diff_lrs.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added contents/lora_diff_lrs_0.6.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
35 changes: 25 additions & 10 deletions lora_diffusion/cli_lora_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,34 +26,48 @@ def add(
] = "lpl",
with_text_lora: bool = False,
):
print("Lora Add, mode " + mode)
if mode == "lpl":
assert output_path.endswith(".pt"), "Only .pt files are supported"

for _path_1, _path_2 in (
[(path_1, path_2)] + [(_text_lora_path(path_1), _text_lora_path(path_2))]
for _path_1, _path_2, opt in [(path_1, path_2, "unet")] + (
[(_text_lora_path(path_1), _text_lora_path(path_2), "text_encoder")]
if with_text_lora
else []
):
print("Loading", _path_1, _path_2)
out_list = []
if opt == "text_encoder":
if not os.path.exists(_path_1):
print(f"No text encoder found in {_path_1}, skipping...")
continue
if not os.path.exists(_path_2):
print(f"No text encoder found in {_path_1}, skipping...")
continue

l1 = torch.load(_path_1)
l2 = torch.load(_path_2)

l1pairs = zip(l1[::2], l1[1::2])
l2pairs = zip(l2[::2], l2[1::2])

for (x1, y1), (x2, y2) in zip(l1pairs, l2pairs):
# print("Merging", x1.shape, y1.shape, x2.shape, y2.shape)
x1.data = alpha * x1.data + (1 - alpha) * x2.data
y1.data = alpha * y1.data + (1 - alpha) * y2.data

out_list.append(x1)
out_list.append(y1)

torch.save(out_list, output_path)
if with_text_lora:
torch.save(
out_list,
_text_lora_path(output_path),
)
if opt == "unet":

print("Saving merged UNET to", output_path)
torch.save(out_list, output_path)

elif opt == "text_encoder":
print("Saving merged text encoder to", _text_lora_path(output_path))
torch.save(
out_list,
_text_lora_path(output_path),
)

elif mode == "upl":

Expand Down Expand Up @@ -96,6 +110,7 @@ def add(
shutil.rmtree(_tmp_output)

else:
print("Unknown mode", mode)
raise ValueError(f"Unknown mode {mode}")


Expand Down
10 changes: 6 additions & 4 deletions lora_diffusion/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ class LoraInjectedLinear(nn.Module):
def __init__(self, in_features, out_features, bias=False, r=4):
super().__init__()

if r >= min(in_features, out_features):
if r > min(in_features, out_features):
raise ValueError(
f"LoRA rank {r} must be less than {min(in_features, out_features)}"
f"LoRA rank {r} must be less or equal than {min(in_features, out_features)}"
)

self.linear = nn.Linear(in_features, out_features, bias)
Expand Down Expand Up @@ -138,7 +138,7 @@ def weight_apply_lora(


def monkeypatch_lora(
model, loras, target_replace_module=["CrossAttention", "Attention"]
model, loras, target_replace_module=["CrossAttention", "Attention"], r: int = 4
):
for _module in model.modules():
if _module.__class__.__name__ in target_replace_module:
Expand All @@ -151,6 +151,7 @@ def monkeypatch_lora(
_child_module.in_features,
_child_module.out_features,
_child_module.bias is not None,
r=r,
)
_tmp.linear.weight = weight

Expand All @@ -174,7 +175,7 @@ def monkeypatch_lora(


def monkeypatch_replace_lora(
model, loras, target_replace_module=["CrossAttention", "Attention"]
model, loras, target_replace_module=["CrossAttention", "Attention"], r: int = 4
):
for _module in model.modules():
if _module.__class__.__name__ in target_replace_module:
Expand All @@ -187,6 +188,7 @@ def monkeypatch_replace_lora(
_child_module.linear.in_features,
_child_module.linear.out_features,
_child_module.linear.bias is not None,
r=r,
)
_tmp.linear.weight = weight

Expand Down
853 changes: 853 additions & 0 deletions scripts/lora_lr_effects.ipynb

Large diffs are not rendered by default.

58 changes: 33 additions & 25 deletions scripts/run_img2img.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
setup(
name="lora_diffusion",
py_modules=["lora_diffusion"],
version="0.0.4",
version="0.0.5",
description="Low Rank Adaptation for Diffusion Models. Works with Stable Diffusion out-of-the-box.",
author="Simo Ryu",
packages=find_packages(),
Expand Down