diff --git a/analog_svd_distill.pt b/analog_svd_distill.pt new file mode 100644 index 0000000..8ffc684 Binary files /dev/null and b/analog_svd_distill.pt differ diff --git a/analog_svd_distill.text_encoder.pt b/analog_svd_distill.text_encoder.pt new file mode 100644 index 0000000..4abdd2b Binary files /dev/null and b/analog_svd_distill.text_encoder.pt differ diff --git a/lora_diffusion/cli_svd.py b/lora_diffusion/cli_svd.py new file mode 100644 index 0000000..55e9026 --- /dev/null +++ b/lora_diffusion/cli_svd.py @@ -0,0 +1,115 @@ +import fire +from diffusers import StableDiffusionPipeline +import torch +import torch.nn as nn + +from .lora import save_all, _find_modules + + +def _text_lora_path(path: str) -> str: + assert path.endswith(".pt"), "Only .pt files are supported" + return ".".join(path.split(".")[:-1] + ["text_encoder", "pt"]) + + +def _ti_lora_path(path: str) -> str: + assert path.endswith(".pt"), "Only .pt files are supported" + return ".".join(path.split(".")[:-1] + ["ti", "pt"]) + + +def extract_linear_weights(model, target_replace_module): + lins = [] + for _module, name, _child_module in _find_modules( + model, target_replace_module, search_class=[nn.Linear] + ): + lins.append(_child_module.weight) + + return lins + + +def svd_distill( + target_model: str, + base_model: str, + rank: int = 4, + clamp_quantile: float = 0.99, + device: str = "cuda:0", + save_path: str = "svd_distill.pt", +): + pipe_base = StableDiffusionPipeline.from_pretrained( + base_model, torch_dtype=torch.float16 + ).to(device) + + model_id = "wavymulder/Analog-Diffusion" + pipe_tuned = StableDiffusionPipeline.from_pretrained( + target_model, torch_dtype=torch.float16 + ).to(device) + + ori_unet = extract_linear_weights( + pipe_base.unet, ["CrossAttention", "Attention", "GEGLU"] + ) + ori_clip = extract_linear_weights(pipe_base.text_encoder, ["CLIPAttention"]) + + tuned_unet = extract_linear_weights( + pipe_tuned.unet, ["CrossAttention", "Attention", "GEGLU"] + ) + tuned_clip = extract_linear_weights(pipe_tuned.text_encoder, ["CLIPAttention"]) + + diffs_unet = [] + diffs_clip = [] + + for ori, tuned in zip(ori_unet, tuned_unet): + diffs_unet.append(tuned - ori) + + for ori, tuned in zip(ori_clip, tuned_clip): + diffs_clip.append(tuned - ori) + + uds_unet = [] + uds_clip = [] + with torch.no_grad(): + for mat in diffs_unet: + mat = mat.float() + + U, S, Vh = torch.linalg.svd(mat) + + U = U[:, :rank] + S = S[:rank] + U = U @ torch.diag(S) + + Vh = Vh[:rank, :] + + dist = torch.cat([U.flatten(), Vh.flatten()]) + hi_val = torch.quantile(dist, clamp_quantile) + low_val = -hi_val + + U = U.clamp(low_val, hi_val) + Vh = Vh.clamp(low_val, hi_val) + + uds_unet.append(U) + uds_unet.append(Vh) + + for mat in diffs_clip: + mat = mat.float() + + U, S, Vh = torch.linalg.svd(mat) + + U = U[:, :rank] + S = S[:rank] + U = U @ torch.diag(S) + + Vh = Vh[:rank, :] + + dist = torch.cat([U.flatten(), Vh.flatten()]) + hi_val = torch.quantile(dist, clamp_quantile) + low_val = -hi_val + + U = U.clamp(low_val, hi_val) + Vh = Vh.clamp(low_val, hi_val) + + uds_clip.append(U) + uds_clip.append(Vh) + + torch.save(uds_unet, save_path) + torch.save(uds_clip, _text_lora_path(save_path)) + + +def main(): + fire.Fire(svd_distill) diff --git a/lora_diffusion/lora.py b/lora_diffusion/lora.py index 11ac001..dc5a944 100644 --- a/lora_diffusion/lora.py +++ b/lora_diffusion/lora.py @@ -654,22 +654,25 @@ def save_all( placeholder_tokens, save_path, save_lora=True, + save_ti=True, target_replace_module_text=TEXT_ENCODER_DEFAULT_TARGET_REPLACE, target_replace_module_unet=DEFAULT_TARGET_REPLACE, ): # save ti - ti_path = _ti_lora_path(save_path) - learned_embeds_dict = {} - for tok, tok_id in zip(placeholder_tokens, placeholder_token_ids): - learned_embeds = text_encoder.get_input_embeddings().weight[tok_id] - print( - f"Current Learned Embeddings for {tok}:, id {tok_id} ", learned_embeds[:4] - ) - learned_embeds_dict[tok] = learned_embeds.detach().cpu() + if save_ti: + ti_path = _ti_lora_path(save_path) + learned_embeds_dict = {} + for tok, tok_id in zip(placeholder_tokens, placeholder_token_ids): + learned_embeds = text_encoder.get_input_embeddings().weight[tok_id] + print( + f"Current Learned Embeddings for {tok}:, id {tok_id} ", + learned_embeds[:4], + ) + learned_embeds_dict[tok] = learned_embeds.detach().cpu() - torch.save(learned_embeds_dict, ti_path) - print("Ti saved to ", ti_path) + torch.save(learned_embeds_dict, ti_path) + print("Ti saved to ", ti_path) # save text encoder if save_lora: diff --git a/setup.py b/setup.py index be6859b..2f92a67 100644 --- a/setup.py +++ b/setup.py @@ -14,6 +14,7 @@ "console_scripts": [ "lora_add = lora_diffusion.cli_lora_add:main", "lora_pti = lora_diffusion.cli_lora_pti:main", + "lora_distill = lora_diffusion.cli_svd:main", ], }, install_requires=[