Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pivotal Tuning with hackable training code for CLI #83

Merged
merged 10 commits into from
Dec 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
482 changes: 482 additions & 0 deletions lora_diffusion/cli_lora_pti.py

Large diffs are not rendered by default.

59 changes: 41 additions & 18 deletions lora_diffusion/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,22 +185,29 @@ def __getitem__(self, index):


class PivotalTuningDatasetCapation(Dataset):
"""
A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
It pre-processes the images and the tokenizes prompts.
"""

def __init__(
self,
instance_data_root,
learnable_property,
placeholder_token,
stochastic_attribute,
tokenizer,
class_data_root=None,
class_prompt=None,
size=512,
h_flip=True,
center_crop=False,
color_jitter=False,
resize=True,
):
self.size = size
self.center_crop = center_crop
self.tokenizer = tokenizer
self.resize = resize

self.instance_data_root = Path(instance_data_root)
if not self.instance_data_root.exists():
Expand All @@ -210,7 +217,6 @@ def __init__(
self.num_instance_images = len(self.instance_images_path)

self.placeholder_token = placeholder_token
self.stochastic_attribute = stochastic_attribute.split(",")

self._length = self.num_instance_images

Expand All @@ -224,22 +230,38 @@ def __init__(
else:
self.class_data_root = None

self.image_transforms = transforms.Compose(
[
transforms.Resize(
size, interpolation=transforms.InterpolationMode.BILINEAR
),
transforms.CenterCrop(size)
if center_crop
else transforms.RandomCrop(size),
transforms.ColorJitter(0.2, 0.1)
if color_jitter
else transforms.Lambda(lambda x: x),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
if resize:
self.image_transforms = transforms.Compose(
[
transforms.Resize(
size, interpolation=transforms.InterpolationMode.BILINEAR
),
transforms.ColorJitter(0.2, 0.1)
if color_jitter
else transforms.Lambda(lambda x: x),
transforms.RandomHorizontalFlip()
if h_flip
else transforms.Lambda(lambda x: x),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
else:
self.image_transforms = transforms.Compose(
[
transforms.CenterCrop(size)
if center_crop
else transforms.Lambda(lambda x: x),
transforms.ColorJitter(0.2, 0.1)
if color_jitter
else transforms.Lambda(lambda x: x),
transforms.RandomHorizontalFlip()
if h_flip
else transforms.Lambda(lambda x: x),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)

def __len__(self):
return self._length
Expand All @@ -255,6 +277,7 @@ def __getitem__(self, index):

text = self.instance_images_path[index % self.num_instance_images].stem

# print(text)
example["instance_prompt_ids"] = self.tokenizer(
text,
padding="do_not_pad",
Expand Down
95 changes: 80 additions & 15 deletions lora_diffusion/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def _find_children(
yield parent, name, module


def _find_modules(
def _find_modules_v2(
model,
ancestor_class: Set[str] = DEFAULT_TARGET_REPLACE,
search_class: List[Type[nn.Module]] = [nn.Linear],
Expand Down Expand Up @@ -92,6 +92,26 @@ def _find_modules(
yield parent, name, module


def _find_modules_old(
model,
ancestor_class: Set[str] = DEFAULT_TARGET_REPLACE,
search_class: List[Type[nn.Module]] = [nn.Linear],
exclude_children_of: Optional[List[Type[nn.Module]]] = [LoraInjectedLinear],
):
ret = []
for _module in model.modules():
if _module.__class__.__name__ in ancestor_class:

for name, _child_module in _module.named_modules():
if _child_module.__class__ in search_class:
ret.append((_module, name, _child_module))
print(ret)
return ret


_find_modules = _find_modules_v2


def inject_trainable_lora(
model: nn.Module,
target_replace_module: Set[str] = DEFAULT_TARGET_REPLACE,
Expand Down Expand Up @@ -124,6 +144,7 @@ def inject_trainable_lora(
_tmp.linear.bias = bias

# switch the module
_tmp.to(_child_module.weight.device).to(_child_module.weight.dtype)
_module._modules[name] = _tmp

require_grad_params.append(_module._modules[name].lora_up.parameters())
Expand Down Expand Up @@ -559,6 +580,8 @@ def patch_pipe(
patch_text=False,
patch_ti=False,
idempotent_token=True,
unet_target_replace_module=DEFAULT_TARGET_REPLACE,
text_target_replace_module=TEXT_ENCODER_DEFAULT_TARGET_REPLACE,
):
assert (
len(token) > 0
Expand All @@ -569,14 +592,19 @@ def patch_pipe(

if patch_unet:
print("LoRA : Patching Unet")
monkeypatch_or_replace_lora(pipe.unet, torch.load(unet_path), r=r)
monkeypatch_or_replace_lora(
pipe.unet,
torch.load(unet_path),
r=r,
target_replace_module=unet_target_replace_module,
)

if patch_text:
print("LoRA : Patching text encoder")
monkeypatch_or_replace_lora(
pipe.text_encoder,
torch.load(text_path),
target_replace_module=TEXT_ENCODER_DEFAULT_TARGET_REPLACE,
target_replace_module=text_target_replace_module,
r=r,
)
if patch_ti:
Expand All @@ -591,19 +619,56 @@ def patch_pipe(


@torch.no_grad()
def inspect_lora(model, target_replace_module=DEFAULT_TARGET_REPLACE):
def inspect_lora(model):
moved = {}

fnorm = {k: [] for k in target_replace_module}
for name, _module in model.named_modules():
if _module.__class__.__name__ == "LoraInjectedLinear":
ups = _module.lora_up.weight.data.clone()
downs = _module.lora_down.weight.data.clone()

wght: torch.Tensor = ups @ downs

dist = wght.flatten().abs().mean().item()
if name in moved:
moved[name].append(dist)
else:
moved[name] = [dist]

return moved

for _module in model.modules():
if _module.__class__.__name__ in target_replace_module:
for name, _child_module in _module.named_modules():
if _child_module.__class__.__name__ == "LoraInjectedLinear":
ups = _module._modules[name].lora_up.weight
downs = _module._modules[name].lora_down.weight

wght: torch.Tensor = downs @ ups
fnorm[name].append(wght.flatten().pow(2).mean().item())
def save_all(
unet,
text_encoder,
placeholder_token_id,
placeholder_token,
save_path,
save_lora=True,
target_replace_module_text=TEXT_ENCODER_DEFAULT_TARGET_REPLACE,
target_replace_module_unet=DEFAULT_TARGET_REPLACE,
):

# save ti
ti_path = _ti_lora_path(save_path)
learned_embeds = text_encoder.get_input_embeddings().weight[placeholder_token_id]
print("Current Learned Embeddings: ", learned_embeds[:4])

learned_embeds_dict = {placeholder_token: learned_embeds.detach().cpu()}
torch.save(learned_embeds_dict, ti_path)
print("Ti saved to ", ti_path)

# save text encoder
if save_lora:

for k, v in fnorm.items():
print(f"F norm on Current LoRA of {k} : {v}")
save_lora_weight(
unet, save_path, target_replace_module=target_replace_module_unet
)
print("Unet saved to ", save_path)

save_lora_weight(
text_encoder,
_text_lora_path(save_path),
target_replace_module=target_replace_module_text,
)
print("Text Encoder saved to ", _text_lora_path(save_path))
Binary file removed lora_illust.pt
Binary file not shown.
Binary file removed lora_kiriko.pt
Binary file not shown.
Binary file removed lora_kiriko.text_encoder.pt
Binary file not shown.
Binary file added lora_kiriko2.pt
Binary file not shown.
Binary file added lora_kiriko2.text_encoder.pt
Binary file not shown.
Binary file added lora_kiriko2.ti.pt
Binary file not shown.
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@ diffusers>=0.9.0
transformers
scipy
ftfy
fire
fire
wandb
113 changes: 63 additions & 50 deletions scripts/run_lorpt.ipynb

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,14 @@
setup(
name="lora_diffusion",
py_modules=["lora_diffusion"],
version="0.0.6",
version="0.0.7",
description="Low Rank Adaptation for Diffusion Models. Works with Stable Diffusion out-of-the-box.",
author="Simo Ryu",
packages=find_packages(),
entry_points={
"console_scripts": [
"lora_add = lora_diffusion.cli_lora_add:main",
"lora_pti = lora_diffusion.cli_lora_pti:main",
],
},
install_requires=[
Expand Down
3 changes: 2 additions & 1 deletion train_lora_pt_caption.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def __init__(
class_data_root=None,
class_prompt=None,
size=512,
h_flip=True,
center_crop=False,
color_jitter=False,
resize=False,
Expand Down Expand Up @@ -508,7 +509,7 @@ def parse_args(input_args=None):
type=bool,
default=True,
required=False,
help="Should images be resized to --resolution before training?"
help="Should images be resized to --resolution before training?",
)

if input_args is not None:
Expand Down