Skip to content

Commit

Permalink
Feat/allsave (#119)
Browse files Browse the repository at this point in the history
* feat : save_all now saves safetensor

* feat : allsave with safetensors
  • Loading branch information
cloneofsimo authored Jan 8, 2023
1 parent 603761a commit 72e7cb8
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 55 deletions.
8 changes: 6 additions & 2 deletions lora_diffusion/cli_lora_pti.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,9 @@ def train_inversion(
text_encoder=text_encoder,
placeholder_token_ids=placeholder_token_ids,
placeholder_tokens=placeholder_tokens,
save_path=os.path.join(save_path, f"step_inv_{global_step}.pt"),
save_path=os.path.join(
save_path, f"step_inv_{global_step}.safetensors"
),
save_lora=False,
)
if log_wandb:
Expand Down Expand Up @@ -453,7 +455,9 @@ def perform_tuning(
text_encoder,
placeholder_token_ids=placeholder_token_ids,
placeholder_tokens=placeholder_tokens,
save_path=os.path.join(save_path, f"step_{global_step}.pt"),
save_path=os.path.join(
save_path, f"step_{global_step}.safetensors"
),
)
moved = (
torch.tensor(list(itertools.chain(*inspect_lora(unet).values())))
Expand Down
147 changes: 94 additions & 53 deletions lora_diffusion/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,7 +671,7 @@ def load_learned_embed_in_clip(

def patch_pipe(
pipe,
unet_path,
maybe_unet_path,
token: Optional[str] = None,
r: int = 4,
patch_unet=True,
Expand All @@ -681,35 +681,50 @@ def patch_pipe(
unet_target_replace_module=DEFAULT_TARGET_REPLACE,
text_target_replace_module=TEXT_ENCODER_DEFAULT_TARGET_REPLACE,
):
if unet_path.endswith(".ti.pt"):
unet_path = unet_path[:-6] + ".pt"
elif unet_path.endswith(".text_encoder.pt"):
unet_path = unet_path[:-16] + ".pt"

ti_path = _ti_lora_path(unet_path)
text_path = _text_lora_path(unet_path)

if patch_unet:
print("LoRA : Patching Unet")
monkeypatch_or_replace_lora(
pipe.unet,
torch.load(unet_path),
r=r,
target_replace_module=unet_target_replace_module,
)
if maybe_unet_path.endswith(".pt"):
# torch format

if maybe_unet_path.endswith(".ti.pt"):
unet_path = unet_path[:-6] + ".pt"
elif maybe_unet_path.endswith(".text_encoder.pt"):
unet_path = unet_path[:-16] + ".pt"

ti_path = _ti_lora_path(unet_path)
text_path = _text_lora_path(unet_path)

if patch_unet:
print("LoRA : Patching Unet")
monkeypatch_or_replace_lora(
pipe.unet,
torch.load(unet_path),
r=r,
target_replace_module=unet_target_replace_module,
)

if patch_text:
print("LoRA : Patching text encoder")
monkeypatch_or_replace_lora(
pipe.text_encoder,
torch.load(text_path),
target_replace_module=text_target_replace_module,
r=r,
)
if patch_ti:
print("LoRA : Patching token input")
token = load_learned_embed_in_clip(
ti_path,
if patch_text:
print("LoRA : Patching text encoder")
monkeypatch_or_replace_lora(
pipe.text_encoder,
torch.load(text_path),
target_replace_module=text_target_replace_module,
r=r,
)
if patch_ti:
print("LoRA : Patching token input")
token = load_learned_embed_in_clip(
ti_path,
pipe.text_encoder,
pipe.tokenizer,
token=token,
idempotent=idempotent_token,
)

elif maybe_unet_path.endswith(".safetensors"):
safeloras = safe_open(maybe_unet_path, framework="pt", device="cpu")
monkeypatch_or_replace_safeloras(pipe, safeloras)
tok_dict = parse_safeloras_embeds(safeloras)
apply_learned_embed_in_clip(
tok_dict,
pipe.text_encoder,
pipe.tokenizer,
token=token,
Expand Down Expand Up @@ -747,34 +762,60 @@ def save_all(
save_ti=True,
target_replace_module_text=TEXT_ENCODER_DEFAULT_TARGET_REPLACE,
target_replace_module_unet=DEFAULT_TARGET_REPLACE,
safe_form=True,
):
if not safe_form:
# save ti
if save_ti:
ti_path = _ti_lora_path(save_path)
learned_embeds_dict = {}
for tok, tok_id in zip(placeholder_tokens, placeholder_token_ids):
learned_embeds = text_encoder.get_input_embeddings().weight[tok_id]
print(
f"Current Learned Embeddings for {tok}:, id {tok_id} ",
learned_embeds[:4],
)
learned_embeds_dict[tok] = learned_embeds.detach().cpu()

torch.save(learned_embeds_dict, ti_path)
print("Ti saved to ", ti_path)

# save text encoder
if save_lora:

save_lora_weight(
unet, save_path, target_replace_module=target_replace_module_unet
)
print("Unet saved to ", save_path)

# save ti
if save_ti:
ti_path = _ti_lora_path(save_path)
learned_embeds_dict = {}
for tok, tok_id in zip(placeholder_tokens, placeholder_token_ids):
learned_embeds = text_encoder.get_input_embeddings().weight[tok_id]
print(
f"Current Learned Embeddings for {tok}:, id {tok_id} ",
learned_embeds[:4],
save_lora_weight(
text_encoder,
_text_lora_path(save_path),
target_replace_module=target_replace_module_text,
)
learned_embeds_dict[tok] = learned_embeds.detach().cpu()
print("Text Encoder saved to ", _text_lora_path(save_path))

torch.save(learned_embeds_dict, ti_path)
print("Ti saved to ", ti_path)
else:
assert save_path.endswith(
".safetensors"
), f"Save path : {save_path} should end with .safetensors"

# save text encoder
if save_lora:
loras = {}
embeds = None

save_lora_weight(
unet, save_path, target_replace_module=target_replace_module_unet
)
print("Unet saved to ", save_path)
if save_lora:

save_lora_weight(
text_encoder,
_text_lora_path(save_path),
target_replace_module=target_replace_module_text,
)
print("Text Encoder saved to ", _text_lora_path(save_path))
loras["unet"] = (unet, target_replace_module_unet)
loras["text_encoder"] = (text_encoder, target_replace_module_text)

if save_ti:
embeds = {}
for tok, tok_id in zip(placeholder_tokens, placeholder_token_ids):
learned_embeds = text_encoder.get_input_embeddings().weight[tok_id]
print(
f"Current Learned Embeddings for {tok}:, id {tok_id} ",
learned_embeds[:4],
)
embeds[tok] = learned_embeds.detach().cpu()

save_safeloras_with_embeds(loras, embeds, save_path)

0 comments on commit 72e7cb8

Please sign in to comment.