diff --git a/lora_diffusion/cli_lora_pti.py b/lora_diffusion/cli_lora_pti.py index 311d087..136004e 100644 --- a/lora_diffusion/cli_lora_pti.py +++ b/lora_diffusion/cli_lora_pti.py @@ -361,7 +361,9 @@ def train_inversion( text_encoder=text_encoder, placeholder_token_ids=placeholder_token_ids, placeholder_tokens=placeholder_tokens, - save_path=os.path.join(save_path, f"step_inv_{global_step}.pt"), + save_path=os.path.join( + save_path, f"step_inv_{global_step}.safetensors" + ), save_lora=False, ) if log_wandb: @@ -453,7 +455,9 @@ def perform_tuning( text_encoder, placeholder_token_ids=placeholder_token_ids, placeholder_tokens=placeholder_tokens, - save_path=os.path.join(save_path, f"step_{global_step}.pt"), + save_path=os.path.join( + save_path, f"step_{global_step}.safetensors" + ), ) moved = ( torch.tensor(list(itertools.chain(*inspect_lora(unet).values()))) diff --git a/lora_diffusion/lora.py b/lora_diffusion/lora.py index b9df63e..41b0ff2 100644 --- a/lora_diffusion/lora.py +++ b/lora_diffusion/lora.py @@ -671,7 +671,7 @@ def load_learned_embed_in_clip( def patch_pipe( pipe, - unet_path, + maybe_unet_path, token: Optional[str] = None, r: int = 4, patch_unet=True, @@ -681,35 +681,50 @@ def patch_pipe( unet_target_replace_module=DEFAULT_TARGET_REPLACE, text_target_replace_module=TEXT_ENCODER_DEFAULT_TARGET_REPLACE, ): - if unet_path.endswith(".ti.pt"): - unet_path = unet_path[:-6] + ".pt" - elif unet_path.endswith(".text_encoder.pt"): - unet_path = unet_path[:-16] + ".pt" - - ti_path = _ti_lora_path(unet_path) - text_path = _text_lora_path(unet_path) - - if patch_unet: - print("LoRA : Patching Unet") - monkeypatch_or_replace_lora( - pipe.unet, - torch.load(unet_path), - r=r, - target_replace_module=unet_target_replace_module, - ) + if maybe_unet_path.endswith(".pt"): + # torch format + + if maybe_unet_path.endswith(".ti.pt"): + unet_path = unet_path[:-6] + ".pt" + elif maybe_unet_path.endswith(".text_encoder.pt"): + unet_path = unet_path[:-16] + ".pt" + + ti_path = _ti_lora_path(unet_path) + text_path = _text_lora_path(unet_path) + + if patch_unet: + print("LoRA : Patching Unet") + monkeypatch_or_replace_lora( + pipe.unet, + torch.load(unet_path), + r=r, + target_replace_module=unet_target_replace_module, + ) - if patch_text: - print("LoRA : Patching text encoder") - monkeypatch_or_replace_lora( - pipe.text_encoder, - torch.load(text_path), - target_replace_module=text_target_replace_module, - r=r, - ) - if patch_ti: - print("LoRA : Patching token input") - token = load_learned_embed_in_clip( - ti_path, + if patch_text: + print("LoRA : Patching text encoder") + monkeypatch_or_replace_lora( + pipe.text_encoder, + torch.load(text_path), + target_replace_module=text_target_replace_module, + r=r, + ) + if patch_ti: + print("LoRA : Patching token input") + token = load_learned_embed_in_clip( + ti_path, + pipe.text_encoder, + pipe.tokenizer, + token=token, + idempotent=idempotent_token, + ) + + elif maybe_unet_path.endswith(".safetensors"): + safeloras = safe_open(maybe_unet_path, framework="pt", device="cpu") + monkeypatch_or_replace_safeloras(pipe, safeloras) + tok_dict = parse_safeloras_embeds(safeloras) + apply_learned_embed_in_clip( + tok_dict, pipe.text_encoder, pipe.tokenizer, token=token, @@ -747,34 +762,60 @@ def save_all( save_ti=True, target_replace_module_text=TEXT_ENCODER_DEFAULT_TARGET_REPLACE, target_replace_module_unet=DEFAULT_TARGET_REPLACE, + safe_form=True, ): + if not safe_form: + # save ti + if save_ti: + ti_path = _ti_lora_path(save_path) + learned_embeds_dict = {} + for tok, tok_id in zip(placeholder_tokens, placeholder_token_ids): + learned_embeds = text_encoder.get_input_embeddings().weight[tok_id] + print( + f"Current Learned Embeddings for {tok}:, id {tok_id} ", + learned_embeds[:4], + ) + learned_embeds_dict[tok] = learned_embeds.detach().cpu() + + torch.save(learned_embeds_dict, ti_path) + print("Ti saved to ", ti_path) + + # save text encoder + if save_lora: + + save_lora_weight( + unet, save_path, target_replace_module=target_replace_module_unet + ) + print("Unet saved to ", save_path) - # save ti - if save_ti: - ti_path = _ti_lora_path(save_path) - learned_embeds_dict = {} - for tok, tok_id in zip(placeholder_tokens, placeholder_token_ids): - learned_embeds = text_encoder.get_input_embeddings().weight[tok_id] - print( - f"Current Learned Embeddings for {tok}:, id {tok_id} ", - learned_embeds[:4], + save_lora_weight( + text_encoder, + _text_lora_path(save_path), + target_replace_module=target_replace_module_text, ) - learned_embeds_dict[tok] = learned_embeds.detach().cpu() + print("Text Encoder saved to ", _text_lora_path(save_path)) - torch.save(learned_embeds_dict, ti_path) - print("Ti saved to ", ti_path) + else: + assert save_path.endswith( + ".safetensors" + ), f"Save path : {save_path} should end with .safetensors" - # save text encoder - if save_lora: + loras = {} + embeds = None - save_lora_weight( - unet, save_path, target_replace_module=target_replace_module_unet - ) - print("Unet saved to ", save_path) + if save_lora: - save_lora_weight( - text_encoder, - _text_lora_path(save_path), - target_replace_module=target_replace_module_text, - ) - print("Text Encoder saved to ", _text_lora_path(save_path)) + loras["unet"] = (unet, target_replace_module_unet) + loras["text_encoder"] = (text_encoder, target_replace_module_text) + + if save_ti: + embeds = {} + for tok, tok_id in zip(placeholder_tokens, placeholder_token_ids): + learned_embeds = text_encoder.get_input_embeddings().weight[tok_id] + print( + f"Current Learned Embeddings for {tok}:, id {tok_id} ", + learned_embeds[:4], + ) + embeds[tok] = learned_embeds.detach().cpu() + + save_safeloras_with_embeds(loras, embeds, save_path)