Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make safetensors properly optional, and support storing TI #101

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 85 additions & 0 deletions lora_diffusion/cli_pt_to_safetensors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import os

import fire
import torch
from lora_diffusion import (
DEFAULT_TARGET_REPLACE,
TEXT_ENCODER_DEFAULT_TARGET_REPLACE,
UNET_DEFAULT_TARGET_REPLACE,
convert_loras_to_safeloras_with_embeds,
safetensors_available,
)

_target_by_name = {
"unet": UNET_DEFAULT_TARGET_REPLACE,
"text_encoder": TEXT_ENCODER_DEFAULT_TARGET_REPLACE,
}


def convert(*paths, outpath, overwrite=False, **settings):
"""
Converts one or more pytorch Lora and/or Textual Embedding pytorch files
into a safetensor file.
Pass all the input paths as arguments. Whether they are Textual Embedding
or Lora models will be auto-detected.
For Lora models, their name will be taken from the path, i.e.
"lora_weight.pt" => unet
"lora_weight.text_encoder.pt" => text_encoder
You can also set target_modules and/or rank by providing an argument prefixed
by the name.
So a complete example might be something like:
```
python -m lora_diffusion.cli_pt_to_safetensors lora_weight.* --outpath lora_weight.safetensor --unet.rank 8
```
"""
modelmap = {}
embeds = {}

if os.path.exists(outpath) and not overwrite:
raise ValueError(
f"Output path {outpath} already exists, and overwrite is not True"
)

for path in paths:
data = torch.load(path)

if isinstance(data, dict):
print(f"Loading textual inversion embeds {data.keys()} from {path}")
embeds.update(data)

else:
name_parts = os.path.split(path)[1].split(".")
name = name_parts[-2] if len(name_parts) > 2 else "unet"

model_settings = {
"target_modules": _target_by_name.get(name, DEFAULT_TARGET_REPLACE),
"rank": 4,
}

prefix = f"{name}."
model_settings |= {
k[len(prefix) :]: v for k, v in settings.items() if k.startswith(prefix)
}

print(f"Loading Lora for {name} from {path} with settings {model_settings}")

modelmap[name] = (
path,
model_settings["target_modules"],
model_settings["rank"],
)

convert_loras_to_safeloras_with_embeds(modelmap, embeds, outpath)


def main():
fire.Fire(convert)


if __name__ == "__main__":
main()
146 changes: 116 additions & 30 deletions lora_diffusion/lora.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import math
from itertools import groupby
from typing import Callable, Dict, List, Optional, Set, Tuple, Type, Union
Expand All @@ -8,6 +9,25 @@
import torch.nn as nn
import torch.nn.functional as F

try:
from safetensors.torch import safe_open
from safetensors.torch import save_file as safe_save

safetensors_available = True
except ImportError:
from .safe_open import safe_open

def safe_save(
tensors: Dict[str, torch.Tensor],
filename: str,
metadata: Optional[Dict[str, str]] = None,
) -> None:
raise EnvironmentError(
"Saving safetensors requires the safetensors library. Please install with pip or similar."
)

safetensors_available = False


class LoraInjectedLinear(nn.Module):
def __init__(self, in_features, out_features, bias=False, r=4):
Expand Down Expand Up @@ -35,6 +55,8 @@ def forward(self, input):

DEFAULT_TARGET_REPLACE = UNET_DEFAULT_TARGET_REPLACE

EMBED_FLAG = "<embed>"


def _find_children(
model,
Expand Down Expand Up @@ -203,8 +225,9 @@ def save_lora_as_json(model, path="./lora.json"):
json.dump(weights, f)


def save_safeloras(
def save_safeloras_with_embeds(
modelmap: Dict[str, Tuple[nn.Module, Set[str]]] = {},
embeds: Dict[str, torch.Tensor] = {},
outpath="./lora.safetensors",
):
"""
Expand All @@ -217,10 +240,6 @@ def save_safeloras(
weights = {}
metadata = {}

import json

from safetensors.torch import save_file

for name, (model, target_replace_module) in modelmap.items():
metadata[name] = json.dumps(list(target_replace_module))

Expand All @@ -231,12 +250,24 @@ def save_safeloras(
weights[f"{name}:{i}:up"] = _up.weight
weights[f"{name}:{i}:down"] = _down.weight

print(f"Saving weights to {outpath} with metadata", metadata)
save_file(weights, outpath, metadata)
for token, tensor in embeds.items():
metadata[token] = EMBED_FLAG
weights[token] = tensor

print(f"Saving weights to {outpath}")
safe_save(weights, outpath, metadata)

def convert_loras_to_safeloras(

def save_safeloras(
modelmap: Dict[str, Tuple[nn.Module, Set[str]]] = {},
outpath="./lora.safetensors",
):
return save_safeloras_with_embeds(modelmap=modelmap, outpath=outpath)


def convert_loras_to_safeloras_with_embeds(
modelmap: Dict[str, Tuple[str, Set[str], int]] = {},
embeds: Dict[str, torch.Tensor] = {},
outpath="./lora.safetensors",
):
"""
Expand All @@ -250,10 +281,6 @@ def convert_loras_to_safeloras(
weights = {}
metadata = {}

import json

from safetensors.torch import save_file

for name, (path, target_replace_module, r) in modelmap.items():
metadata[name] = json.dumps(list(target_replace_module))

Expand All @@ -268,8 +295,19 @@ def convert_loras_to_safeloras(
else:
weights[f"{name}:{i}:down"] = weight

print(f"Saving weights to {outpath} with metadata", metadata)
save_file(weights, outpath, metadata)
for token, tensor in embeds.items():
metadata[token] = EMBED_FLAG
weights[token] = tensor

print(f"Saving weights to {outpath}")
safe_save(weights, outpath, metadata)


def convert_loras_to_safeloras(
modelmap: Dict[str, Tuple[str, Set[str], int]] = {},
outpath="./lora.safetensors",
):
convert_loras_to_safeloras_with_embeds(modelmap=modelmap, outpath=outpath)


def parse_safeloras(
Expand All @@ -288,9 +326,6 @@ def parse_safeloras(
}
"""
loras = {}

import json

metadata = safeloras.metadata()

get_name = lambda k: k.split(":")[0]
Expand All @@ -299,12 +334,24 @@ def parse_safeloras(
keys.sort(key=get_name)

for name, module_keys in groupby(keys, get_name):
info = metadata.get(name)

if not info:
raise ValueError(
f"Tensor {name} has no metadata - is this a Lora safetensor?"
)

# Skip Textual Inversion embeds
if info == EMBED_FLAG:
continue

# Handle Loras
# Extract the targets
target = json.loads(metadata[name])
target = json.loads(info)

# Build the result lists - Python needs us to preallocate lists to insert into them
module_keys = list(module_keys)
ranks = [None] * (len(module_keys) // 2)
ranks = [4] * (len(module_keys) // 2)
weights = [None] * len(module_keys)

for key in module_keys:
Expand All @@ -313,7 +360,7 @@ def parse_safeloras(
idx = int(idx)

# Add the rank
ranks[idx] = json.loads(metadata[f"{name}:{idx}:rank"])
ranks[idx] = int(metadata[f"{name}:{idx}:rank"])

# Insert the weight into the list
idx = idx * 2 + (1 if direction == "down" else 0)
Expand All @@ -324,14 +371,42 @@ def parse_safeloras(
return loras


def load_safeloras(path, device="cpu"):
def parse_safeloras_embeds(
safeloras,
) -> Dict[str, torch.Tensor]:
"""
Converts a loaded safetensor file that contains Textual Inversion embeds into
a dictionary of embed_token: Tensor
"""
embeds = {}
metadata = safeloras.metadata()

from safetensors.torch import safe_open
for key in safeloras.keys():
# Only handle Textual Inversion embeds
meta = metadata.get(key)
if not meta or meta != EMBED_FLAG:
continue

embeds[key] = safeloras.get_tensor(key)

return embeds


def load_safeloras(path, device="cpu"):
safeloras = safe_open(path, framework="pt", device=device)
return parse_safeloras(safeloras)


def load_safeloras_embeds(path, device="cpu"):
safeloras = safe_open(path, framework="pt", device=device)
return parse_safeloras_embeds(safeloras)


def load_safeloras_both(path, device="cpu"):
safeloras = safe_open(path, framework="pt", device=device)
return parse_safeloras(safeloras), parse_safeloras_embeds(safeloras)


def weight_apply_lora(
model, loras, target_replace_module=DEFAULT_TARGET_REPLACE, alpha=1.0
):
Expand Down Expand Up @@ -535,28 +610,26 @@ def _ti_lora_path(path: str) -> str:
return ".".join(path.split(".")[:-1] + ["ti", "pt"])


def load_learned_embed_in_clip(
learned_embeds_path,
def apply_learned_embed_in_clip(
learned_embeds,
text_encoder,
tokenizer,
token: Union[str, List[str]] = None,
token: Optional[Union[str, List[str]]] = None,
idempotent=False,
):
loaded_learned_embeds = torch.load(learned_embeds_path, map_location="cpu")

if isinstance(token, str):
trained_tokens = [token]
elif isinstance(token, list):
assert len(loaded_learned_embeds.keys()) == len(
assert len(learned_embeds.keys()) == len(
token
), "The number of tokens and the number of embeds should be the same"
trained_tokens = token
else:
trained_tokens = list(loaded_learned_embeds.keys())
trained_tokens = list(learned_embeds.keys())

for token in trained_tokens:
print(token)
embeds = loaded_learned_embeds[token]
embeds = learned_embeds[token]

# cast to dtype of text_encoder
dtype = text_encoder.get_input_embeddings().weight.dtype
Expand All @@ -583,6 +656,19 @@ def load_learned_embed_in_clip(
return token


def load_learned_embed_in_clip(
learned_embeds_path,
text_encoder,
tokenizer,
token: Optional[Union[str, List[str]]] = None,
idempotent=False,
):
learned_embeds = torch.load(learned_embeds_path)
apply_learned_embed_in_clip(
learned_embeds, text_encoder, tokenizer, token, idempotent
)


def patch_pipe(
pipe,
unet_path,
Expand Down
Loading