-
Notifications
You must be signed in to change notification settings - Fork 486
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Make safetensors properly optional, and support storing textural inve…
…rsion embeddings (#101)
- Loading branch information
1 parent
4936d0f
commit eaf725a
Showing
5 changed files
with
352 additions
and
53 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
import os | ||
|
||
import fire | ||
import torch | ||
from lora_diffusion import ( | ||
DEFAULT_TARGET_REPLACE, | ||
TEXT_ENCODER_DEFAULT_TARGET_REPLACE, | ||
UNET_DEFAULT_TARGET_REPLACE, | ||
convert_loras_to_safeloras_with_embeds, | ||
safetensors_available, | ||
) | ||
|
||
_target_by_name = { | ||
"unet": UNET_DEFAULT_TARGET_REPLACE, | ||
"text_encoder": TEXT_ENCODER_DEFAULT_TARGET_REPLACE, | ||
} | ||
|
||
|
||
def convert(*paths, outpath, overwrite=False, **settings): | ||
""" | ||
Converts one or more pytorch Lora and/or Textual Embedding pytorch files | ||
into a safetensor file. | ||
Pass all the input paths as arguments. Whether they are Textual Embedding | ||
or Lora models will be auto-detected. | ||
For Lora models, their name will be taken from the path, i.e. | ||
"lora_weight.pt" => unet | ||
"lora_weight.text_encoder.pt" => text_encoder | ||
You can also set target_modules and/or rank by providing an argument prefixed | ||
by the name. | ||
So a complete example might be something like: | ||
``` | ||
python -m lora_diffusion.cli_pt_to_safetensors lora_weight.* --outpath lora_weight.safetensor --unet.rank 8 | ||
``` | ||
""" | ||
modelmap = {} | ||
embeds = {} | ||
|
||
if os.path.exists(outpath) and not overwrite: | ||
raise ValueError( | ||
f"Output path {outpath} already exists, and overwrite is not True" | ||
) | ||
|
||
for path in paths: | ||
data = torch.load(path) | ||
|
||
if isinstance(data, dict): | ||
print(f"Loading textual inversion embeds {data.keys()} from {path}") | ||
embeds.update(data) | ||
|
||
else: | ||
name_parts = os.path.split(path)[1].split(".") | ||
name = name_parts[-2] if len(name_parts) > 2 else "unet" | ||
|
||
model_settings = { | ||
"target_modules": _target_by_name.get(name, DEFAULT_TARGET_REPLACE), | ||
"rank": 4, | ||
} | ||
|
||
prefix = f"{name}." | ||
model_settings |= { | ||
k[len(prefix) :]: v for k, v in settings.items() if k.startswith(prefix) | ||
} | ||
|
||
print(f"Loading Lora for {name} from {path} with settings {model_settings}") | ||
|
||
modelmap[name] = ( | ||
path, | ||
model_settings["target_modules"], | ||
model_settings["rank"], | ||
) | ||
|
||
convert_loras_to_safeloras_with_embeds(modelmap, embeds, outpath) | ||
|
||
|
||
def main(): | ||
fire.Fire(convert) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.