Skip to content

Commit

Permalink
feat : svd distillation with CLI (#98)
Browse files Browse the repository at this point in the history
  • Loading branch information
cloneofsimo authored Dec 29, 2022
1 parent 35653d2 commit 7dd0467
Show file tree
Hide file tree
Showing 5 changed files with 129 additions and 10 deletions.
Binary file added analog_svd_distill.pt
Binary file not shown.
Binary file added analog_svd_distill.text_encoder.pt
Binary file not shown.
115 changes: 115 additions & 0 deletions lora_diffusion/cli_svd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import fire
from diffusers import StableDiffusionPipeline
import torch
import torch.nn as nn

from .lora import save_all, _find_modules


def _text_lora_path(path: str) -> str:
assert path.endswith(".pt"), "Only .pt files are supported"
return ".".join(path.split(".")[:-1] + ["text_encoder", "pt"])


def _ti_lora_path(path: str) -> str:
assert path.endswith(".pt"), "Only .pt files are supported"
return ".".join(path.split(".")[:-1] + ["ti", "pt"])


def extract_linear_weights(model, target_replace_module):
lins = []
for _module, name, _child_module in _find_modules(
model, target_replace_module, search_class=[nn.Linear]
):
lins.append(_child_module.weight)

return lins


def svd_distill(
target_model: str,
base_model: str,
rank: int = 4,
clamp_quantile: float = 0.99,
device: str = "cuda:0",
save_path: str = "svd_distill.pt",
):
pipe_base = StableDiffusionPipeline.from_pretrained(
base_model, torch_dtype=torch.float16
).to(device)

model_id = "wavymulder/Analog-Diffusion"
pipe_tuned = StableDiffusionPipeline.from_pretrained(
target_model, torch_dtype=torch.float16
).to(device)

ori_unet = extract_linear_weights(
pipe_base.unet, ["CrossAttention", "Attention", "GEGLU"]
)
ori_clip = extract_linear_weights(pipe_base.text_encoder, ["CLIPAttention"])

tuned_unet = extract_linear_weights(
pipe_tuned.unet, ["CrossAttention", "Attention", "GEGLU"]
)
tuned_clip = extract_linear_weights(pipe_tuned.text_encoder, ["CLIPAttention"])

diffs_unet = []
diffs_clip = []

for ori, tuned in zip(ori_unet, tuned_unet):
diffs_unet.append(tuned - ori)

for ori, tuned in zip(ori_clip, tuned_clip):
diffs_clip.append(tuned - ori)

uds_unet = []
uds_clip = []
with torch.no_grad():
for mat in diffs_unet:
mat = mat.float()

U, S, Vh = torch.linalg.svd(mat)

U = U[:, :rank]
S = S[:rank]
U = U @ torch.diag(S)

Vh = Vh[:rank, :]

dist = torch.cat([U.flatten(), Vh.flatten()])
hi_val = torch.quantile(dist, clamp_quantile)
low_val = -hi_val

U = U.clamp(low_val, hi_val)
Vh = Vh.clamp(low_val, hi_val)

uds_unet.append(U)
uds_unet.append(Vh)

for mat in diffs_clip:
mat = mat.float()

U, S, Vh = torch.linalg.svd(mat)

U = U[:, :rank]
S = S[:rank]
U = U @ torch.diag(S)

Vh = Vh[:rank, :]

dist = torch.cat([U.flatten(), Vh.flatten()])
hi_val = torch.quantile(dist, clamp_quantile)
low_val = -hi_val

U = U.clamp(low_val, hi_val)
Vh = Vh.clamp(low_val, hi_val)

uds_clip.append(U)
uds_clip.append(Vh)

torch.save(uds_unet, save_path)
torch.save(uds_clip, _text_lora_path(save_path))


def main():
fire.Fire(svd_distill)
23 changes: 13 additions & 10 deletions lora_diffusion/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,22 +654,25 @@ def save_all(
placeholder_tokens,
save_path,
save_lora=True,
save_ti=True,
target_replace_module_text=TEXT_ENCODER_DEFAULT_TARGET_REPLACE,
target_replace_module_unet=DEFAULT_TARGET_REPLACE,
):

# save ti
ti_path = _ti_lora_path(save_path)
learned_embeds_dict = {}
for tok, tok_id in zip(placeholder_tokens, placeholder_token_ids):
learned_embeds = text_encoder.get_input_embeddings().weight[tok_id]
print(
f"Current Learned Embeddings for {tok}:, id {tok_id} ", learned_embeds[:4]
)
learned_embeds_dict[tok] = learned_embeds.detach().cpu()
if save_ti:
ti_path = _ti_lora_path(save_path)
learned_embeds_dict = {}
for tok, tok_id in zip(placeholder_tokens, placeholder_token_ids):
learned_embeds = text_encoder.get_input_embeddings().weight[tok_id]
print(
f"Current Learned Embeddings for {tok}:, id {tok_id} ",
learned_embeds[:4],
)
learned_embeds_dict[tok] = learned_embeds.detach().cpu()

torch.save(learned_embeds_dict, ti_path)
print("Ti saved to ", ti_path)
torch.save(learned_embeds_dict, ti_path)
print("Ti saved to ", ti_path)

# save text encoder
if save_lora:
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"console_scripts": [
"lora_add = lora_diffusion.cli_lora_add:main",
"lora_pti = lora_diffusion.cli_lora_pti:main",
"lora_distill = lora_diffusion.cli_svd:main",
],
},
install_requires=[
Expand Down

0 comments on commit 7dd0467

Please sign in to comment.