Skip to content

Commit

Permalink
Make safetensors properly optional, and support storing textural inve…
Browse files Browse the repository at this point in the history
…rsion embeddings (#101)
  • Loading branch information
hafriedlander authored Dec 30, 2022
1 parent 4936d0f commit eaf725a
Show file tree
Hide file tree
Showing 5 changed files with 352 additions and 53 deletions.
85 changes: 85 additions & 0 deletions lora_diffusion/cli_pt_to_safetensors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import os

import fire
import torch
from lora_diffusion import (
DEFAULT_TARGET_REPLACE,
TEXT_ENCODER_DEFAULT_TARGET_REPLACE,
UNET_DEFAULT_TARGET_REPLACE,
convert_loras_to_safeloras_with_embeds,
safetensors_available,
)

_target_by_name = {
"unet": UNET_DEFAULT_TARGET_REPLACE,
"text_encoder": TEXT_ENCODER_DEFAULT_TARGET_REPLACE,
}


def convert(*paths, outpath, overwrite=False, **settings):
"""
Converts one or more pytorch Lora and/or Textual Embedding pytorch files
into a safetensor file.
Pass all the input paths as arguments. Whether they are Textual Embedding
or Lora models will be auto-detected.
For Lora models, their name will be taken from the path, i.e.
"lora_weight.pt" => unet
"lora_weight.text_encoder.pt" => text_encoder
You can also set target_modules and/or rank by providing an argument prefixed
by the name.
So a complete example might be something like:
```
python -m lora_diffusion.cli_pt_to_safetensors lora_weight.* --outpath lora_weight.safetensor --unet.rank 8
```
"""
modelmap = {}
embeds = {}

if os.path.exists(outpath) and not overwrite:
raise ValueError(
f"Output path {outpath} already exists, and overwrite is not True"
)

for path in paths:
data = torch.load(path)

if isinstance(data, dict):
print(f"Loading textual inversion embeds {data.keys()} from {path}")
embeds.update(data)

else:
name_parts = os.path.split(path)[1].split(".")
name = name_parts[-2] if len(name_parts) > 2 else "unet"

model_settings = {
"target_modules": _target_by_name.get(name, DEFAULT_TARGET_REPLACE),
"rank": 4,
}

prefix = f"{name}."
model_settings |= {
k[len(prefix) :]: v for k, v in settings.items() if k.startswith(prefix)
}

print(f"Loading Lora for {name} from {path} with settings {model_settings}")

modelmap[name] = (
path,
model_settings["target_modules"],
model_settings["rank"],
)

convert_loras_to_safeloras_with_embeds(modelmap, embeds, outpath)


def main():
fire.Fire(convert)


if __name__ == "__main__":
main()
146 changes: 116 additions & 30 deletions lora_diffusion/lora.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import math
from itertools import groupby
from typing import Callable, Dict, List, Optional, Set, Tuple, Type, Union
Expand All @@ -8,6 +9,25 @@
import torch.nn as nn
import torch.nn.functional as F

try:
from safetensors.torch import safe_open
from safetensors.torch import save_file as safe_save

safetensors_available = True
except ImportError:
from .safe_open import safe_open

def safe_save(
tensors: Dict[str, torch.Tensor],
filename: str,
metadata: Optional[Dict[str, str]] = None,
) -> None:
raise EnvironmentError(
"Saving safetensors requires the safetensors library. Please install with pip or similar."
)

safetensors_available = False


class LoraInjectedLinear(nn.Module):
def __init__(self, in_features, out_features, bias=False, r=4):
Expand Down Expand Up @@ -35,6 +55,8 @@ def forward(self, input):

DEFAULT_TARGET_REPLACE = UNET_DEFAULT_TARGET_REPLACE

EMBED_FLAG = "<embed>"


def _find_children(
model,
Expand Down Expand Up @@ -203,8 +225,9 @@ def save_lora_as_json(model, path="./lora.json"):
json.dump(weights, f)


def save_safeloras(
def save_safeloras_with_embeds(
modelmap: Dict[str, Tuple[nn.Module, Set[str]]] = {},
embeds: Dict[str, torch.Tensor] = {},
outpath="./lora.safetensors",
):
"""
Expand All @@ -217,10 +240,6 @@ def save_safeloras(
weights = {}
metadata = {}

import json

from safetensors.torch import save_file

for name, (model, target_replace_module) in modelmap.items():
metadata[name] = json.dumps(list(target_replace_module))

Expand All @@ -231,12 +250,24 @@ def save_safeloras(
weights[f"{name}:{i}:up"] = _up.weight
weights[f"{name}:{i}:down"] = _down.weight

print(f"Saving weights to {outpath} with metadata", metadata)
save_file(weights, outpath, metadata)
for token, tensor in embeds.items():
metadata[token] = EMBED_FLAG
weights[token] = tensor

print(f"Saving weights to {outpath}")
safe_save(weights, outpath, metadata)

def convert_loras_to_safeloras(

def save_safeloras(
modelmap: Dict[str, Tuple[nn.Module, Set[str]]] = {},
outpath="./lora.safetensors",
):
return save_safeloras_with_embeds(modelmap=modelmap, outpath=outpath)


def convert_loras_to_safeloras_with_embeds(
modelmap: Dict[str, Tuple[str, Set[str], int]] = {},
embeds: Dict[str, torch.Tensor] = {},
outpath="./lora.safetensors",
):
"""
Expand All @@ -250,10 +281,6 @@ def convert_loras_to_safeloras(
weights = {}
metadata = {}

import json

from safetensors.torch import save_file

for name, (path, target_replace_module, r) in modelmap.items():
metadata[name] = json.dumps(list(target_replace_module))

Expand All @@ -268,8 +295,19 @@ def convert_loras_to_safeloras(
else:
weights[f"{name}:{i}:down"] = weight

print(f"Saving weights to {outpath} with metadata", metadata)
save_file(weights, outpath, metadata)
for token, tensor in embeds.items():
metadata[token] = EMBED_FLAG
weights[token] = tensor

print(f"Saving weights to {outpath}")
safe_save(weights, outpath, metadata)


def convert_loras_to_safeloras(
modelmap: Dict[str, Tuple[str, Set[str], int]] = {},
outpath="./lora.safetensors",
):
convert_loras_to_safeloras_with_embeds(modelmap=modelmap, outpath=outpath)


def parse_safeloras(
Expand All @@ -288,9 +326,6 @@ def parse_safeloras(
}
"""
loras = {}

import json

metadata = safeloras.metadata()

get_name = lambda k: k.split(":")[0]
Expand All @@ -299,12 +334,24 @@ def parse_safeloras(
keys.sort(key=get_name)

for name, module_keys in groupby(keys, get_name):
info = metadata.get(name)

if not info:
raise ValueError(
f"Tensor {name} has no metadata - is this a Lora safetensor?"
)

# Skip Textual Inversion embeds
if info == EMBED_FLAG:
continue

# Handle Loras
# Extract the targets
target = json.loads(metadata[name])
target = json.loads(info)

# Build the result lists - Python needs us to preallocate lists to insert into them
module_keys = list(module_keys)
ranks = [None] * (len(module_keys) // 2)
ranks = [4] * (len(module_keys) // 2)
weights = [None] * len(module_keys)

for key in module_keys:
Expand All @@ -313,7 +360,7 @@ def parse_safeloras(
idx = int(idx)

# Add the rank
ranks[idx] = json.loads(metadata[f"{name}:{idx}:rank"])
ranks[idx] = int(metadata[f"{name}:{idx}:rank"])

# Insert the weight into the list
idx = idx * 2 + (1 if direction == "down" else 0)
Expand All @@ -324,14 +371,42 @@ def parse_safeloras(
return loras


def load_safeloras(path, device="cpu"):
def parse_safeloras_embeds(
safeloras,
) -> Dict[str, torch.Tensor]:
"""
Converts a loaded safetensor file that contains Textual Inversion embeds into
a dictionary of embed_token: Tensor
"""
embeds = {}
metadata = safeloras.metadata()

from safetensors.torch import safe_open
for key in safeloras.keys():
# Only handle Textual Inversion embeds
meta = metadata.get(key)
if not meta or meta != EMBED_FLAG:
continue

embeds[key] = safeloras.get_tensor(key)

return embeds


def load_safeloras(path, device="cpu"):
safeloras = safe_open(path, framework="pt", device=device)
return parse_safeloras(safeloras)


def load_safeloras_embeds(path, device="cpu"):
safeloras = safe_open(path, framework="pt", device=device)
return parse_safeloras_embeds(safeloras)


def load_safeloras_both(path, device="cpu"):
safeloras = safe_open(path, framework="pt", device=device)
return parse_safeloras(safeloras), parse_safeloras_embeds(safeloras)


def weight_apply_lora(
model, loras, target_replace_module=DEFAULT_TARGET_REPLACE, alpha=1.0
):
Expand Down Expand Up @@ -535,28 +610,26 @@ def _ti_lora_path(path: str) -> str:
return ".".join(path.split(".")[:-1] + ["ti", "pt"])


def load_learned_embed_in_clip(
learned_embeds_path,
def apply_learned_embed_in_clip(
learned_embeds,
text_encoder,
tokenizer,
token: Union[str, List[str]] = None,
token: Optional[Union[str, List[str]]] = None,
idempotent=False,
):
loaded_learned_embeds = torch.load(learned_embeds_path, map_location="cpu")

if isinstance(token, str):
trained_tokens = [token]
elif isinstance(token, list):
assert len(loaded_learned_embeds.keys()) == len(
assert len(learned_embeds.keys()) == len(
token
), "The number of tokens and the number of embeds should be the same"
trained_tokens = token
else:
trained_tokens = list(loaded_learned_embeds.keys())
trained_tokens = list(learned_embeds.keys())

for token in trained_tokens:
print(token)
embeds = loaded_learned_embeds[token]
embeds = learned_embeds[token]

# cast to dtype of text_encoder
dtype = text_encoder.get_input_embeddings().weight.dtype
Expand All @@ -583,6 +656,19 @@ def load_learned_embed_in_clip(
return token


def load_learned_embed_in_clip(
learned_embeds_path,
text_encoder,
tokenizer,
token: Optional[Union[str, List[str]]] = None,
idempotent=False,
):
learned_embeds = torch.load(learned_embeds_path)
apply_learned_embed_in_clip(
learned_embeds, text_encoder, tokenizer, token, idempotent
)


def patch_pipe(
pipe,
unet_path,
Expand Down
Loading

0 comments on commit eaf725a

Please sign in to comment.