Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Small fixes, bit readmes #70

Merged
merged 26 commits into from
Dec 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
7e78c8d
Add parameter to control rank of decomposition (#28)
brian6091 Dec 13, 2022
6aee5f3
Merge branch 'master' of https://github.com/cloneofsimo/lora into dev…
cloneofsimo Dec 14, 2022
9f31bd0
feat : statefully monkeypatch different loras + example ipynb + readme
cloneofsimo Dec 14, 2022
fececf3
Fix lora inject, added weight self apply lora (#39)
DavidePaglieri Dec 15, 2022
65438b5
Revert "Fix lora inject, added weight self apply lora (#39)" (#40)
cloneofsimo Dec 15, 2022
4975cfa
Merge branch 'master' of https://github.com/cloneofsimo/lora into dev…
cloneofsimo Dec 15, 2022
9ca7bc8
fix : rank bug in monkeypatch
cloneofsimo Dec 15, 2022
6a3ad97
fix cli fix
cloneofsimo Dec 15, 2022
40ad282
visualizatio on effect of LR
cloneofsimo Dec 15, 2022
a386525
Fix save_steps, max_train_steps, and logging (#45)
hdon96 Dec 16, 2022
6767142
Enable resuming (#52)
hdon96 Dec 16, 2022
24af4c8
feat : low-rank pivotal tuning
cloneofsimo Dec 16, 2022
046422c
feat : pivotal tuning
cloneofsimo Dec 16, 2022
0a92e62
Merge branch 'develop' of https://github.com/cloneofsimo/lora into de…
cloneofsimo Dec 16, 2022
4abbf90
v 0.0.6
cloneofsimo Dec 16, 2022
d0c4cc5
Merge branch 'master' into develop
cloneofsimo Dec 16, 2022
986626f
Learning rate switching & fix indent (#57)
hdon96 Dec 19, 2022
bbda1e5
Re:Fix indent (#58)
hdon96 Dec 19, 2022
46d9cf6
Merge branch 'master' into develop
cloneofsimo Dec 19, 2022
e1ea114
Merge branch 'develop' of https://github.com/cloneofsimo/lora into de…
cloneofsimo Dec 19, 2022
24617ea
ff now training default
cloneofsimo Dec 21, 2022
283f4bd
feat : dataset
cloneofsimo Dec 21, 2022
27145c3
feat : utils to back training
cloneofsimo Dec 21, 2022
7faef9f
readme : more contents. citations, etc.
cloneofsimo Dec 21, 2022
0e799a9
fix : weight init
cloneofsimo Dec 21, 2022
1abfc58
Merge branch 'master' of https://github.com/cloneofsimo/lora into dev…
cloneofsimo Dec 21, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 68 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,11 @@
## Main Features

- Fine-tune Stable diffusion models twice as faster than dreambooth method, by Low-rank Adaptation
- Get insanely small end result (3MB for just unet, 6MB for both unet + clip), easy to share and download.
- Get insanely small end result (3MB for just unet, 4MB for both unet + clip + token embedding), easy to share and download.
- Easy to use, compatible with `diffusers`
- Sometimes _even better performance_ than full fine-tuning (but left as future work for extensive comparisons)
- Merge checkpoints + Build recipes by merging LoRAs together
- Fine-tune both CLIP & Unet to gain better results.
- Pipeline to fine-tune CLIP + Unet + token to gain better results.

# Web Demo

Expand All @@ -49,6 +49,12 @@

# UPDATES & Notes

### 2022/12/22

- Pivotal Tuning now available with

### 2022/12/10

- **You can now fine-tune text_encoder as well! Enabled with simple `--train_text_encoder`**
- **Converting to CKPT format for A1111's repo consumption!** (Thanks to [jachiam](https://github.com/jachiam)'s conversion script)
- Img2Img Examples added.
Expand All @@ -71,6 +77,21 @@ This is the key idea of LoRA. We can then fine-tune $A$ and $B$ instead of $W$.

Also, not all of the parameters need tuning: they found that often, $Q, K, V, O$ (i.e., attention layer) of the transformer model is enough to tune. (This is also the reason why the end result is so small). This repo will follow the same idea.

Now, how would we actually use this to update diffusion model? First, we will use Stable-diffusion from [stability-ai](https://stability.ai/). Their model is nicely ported through Huggingface API, so this repo has built various fine-tuning methods around them. In detail, there are three subtle but important distictions in methods to make this work out.

1. [Dreambooth](https://arxiv.org/abs/2208.12242)

First, there is LoRA applied to Dreambooth. The idea is to use prior-preservation class images to regularize the training process, and use low-occuring tokens. This will keep the model's generalization capability while keeping high fidelity. If you turn off prior preservation, and train text encoder embedding as well, it will become naive fine tuning.

2. [Textual Inversion](https://arxiv.org/abs/2208.01618)

Second, there is Textual inversion. There is no room to apply LoRA here, but it is worth mensioning. The idea is to instantiate new token, and learn the token embedding via gradient descent. This is a very powerful method, and it is worth trying out if your use case is not focused on fidelity but rather on inverting conceptual ideas.

3. [Pivotal Tuning](https://arxiv.org/abs/2106.05744)

Last method (although originally proposed for GANs) takes the best of both worlds to further benefit. Wken combined together, this can be implemented as a strict generalization of both methods.
Simply you apply textual inversion to get a matching token embedding. Then, you use the token embedding + prior-preserving class image to fine-tune the model. This two-fold nature make this strict generalization of both methods.

Enough of the lengthy introduction, let's get to the code.

# Installation
Expand Down Expand Up @@ -102,7 +123,7 @@ optimizer = optim.Adam(
)
```

An example of this can be found in `train_lora_dreambooth.py`. Run this example with
A working example of this, applied on [Dreambooth](https://arxiv.org/abs/2208.12242) can be found in `train_lora_dreambooth.py`. Run this example with

```bash
run_lora_db.sh
Expand Down Expand Up @@ -318,3 +339,47 @@ TODOS
- Adaptor-guidance
- Time-aware fine-tuning.
- Test alpha scheduling. I think it will be meaningful.

# References

This work was heavily influenced by, and originated by these awesome researches. I'm just applying them here.

```bibtex
@article{roich2022pivotal,
title={Pivotal tuning for latent-based editing of real images},
author={Roich, Daniel and Mokady, Ron and Bermano, Amit H and Cohen-Or, Daniel},
journal={ACM Transactions on Graphics (TOG)},
volume={42},
number={1},
pages={1--13},
year={2022},
publisher={ACM New York, NY}
}
```

```bibtex
@article{ruiz2022dreambooth,
title={Dreambooth: Fine tuning text-to-image diffusion models for subject-driven generation},
author={Ruiz, Nataniel and Li, Yuanzhen and Jampani, Varun and Pritch, Yael and Rubinstein, Michael and Aberman, Kfir},
journal={arXiv preprint arXiv:2208.12242},
year={2022}
}
```

```bibtex
@article{gal2022image,
title={An image is worth one word: Personalizing text-to-image generation using textual inversion},
author={Gal, Rinon and Alaluf, Yuval and Atzmon, Yuval and Patashnik, Or and Bermano, Amit H and Chechik, Gal and Cohen-Or, Daniel},
journal={arXiv preprint arXiv:2208.01618},
year={2022}
}
```

```
@article{hu2021lora,
title={Lora: Low-rank adaptation of large language models},
author={Hu, Edward J and Shen, Yelong and Wallis, Phillip and Allen-Zhu, Zeyuan and Li, Yuanzhi and Wang, Shean and Wang, Lu and Chen, Weizhu},
journal={arXiv preprint arXiv:2106.09685},
year={2021}
}
```
Binary file added contents/1e-5-krk-pt.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added contents/2e-5-krk-pt.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added contents/2e-6-krk-pt.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added contents/5e-6-krk-pt.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added contents/pt-krk-caption-rank1.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added contents/pt-krk-caption-rank2.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added contents/pt-krk-caption-rank4-we1e-3.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added contents/pt-krk-caption-rank4.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added contents/pt-krk-caption-rank8.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions lora_diffusion/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .lora import *
from .dataset import *
279 changes: 279 additions & 0 deletions lora_diffusion/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,279 @@
from torch.utils.data import Dataset


from PIL import Image
from torchvision import transforms
from pathlib import Path

import random

imagenet_templates_small = [
"a photo of a {}",
"a rendering of a {}",
"a cropped photo of the {}",
"the photo of a {}",
"a photo of a clean {}",
"a photo of a dirty {}",
"a dark photo of the {}",
"a photo of my {}",
"a photo of the cool {}",
"a close-up photo of a {}",
"a bright photo of the {}",
"a cropped photo of a {}",
"a photo of the {}",
"a good photo of the {}",
"a photo of one {}",
"a close-up photo of the {}",
"a rendition of the {}",
"a photo of the clean {}",
"a rendition of a {}",
"a photo of a nice {}",
"a good photo of a {}",
"a photo of the nice {}",
"a photo of the small {}",
"a photo of the weird {}",
"a photo of the large {}",
"a photo of a cool {}",
"a photo of a small {}",
]

imagenet_style_templates_small = [
"a painting in the style of {}",
"a rendering in the style of {}",
"a cropped painting in the style of {}",
"the painting in the style of {}",
"a clean painting in the style of {}",
"a dirty painting in the style of {}",
"a dark painting in the style of {}",
"a picture in the style of {}",
"a cool painting in the style of {}",
"a close-up painting in the style of {}",
"a bright painting in the style of {}",
"a cropped painting in the style of {}",
"a good painting in the style of {}",
"a close-up painting in the style of {}",
"a rendition in the style of {}",
"a nice painting in the style of {}",
"a small painting in the style of {}",
"a weird painting in the style of {}",
"a large painting in the style of {}",
]


def _randomset(lis):
ret = []
for i in range(len(lis)):
if random.random() < 0.5:
ret.append(lis[i])
return ret


def _shuffle(lis):

return random.sample(lis, len(lis))


class PivotalTuningDatasetTemplate(Dataset):
"""
A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
It pre-processes the images and the tokenizes prompts.
"""

def __init__(
self,
instance_data_root,
learnable_property,
placeholder_token,
stochastic_attribute,
tokenizer,
class_data_root=None,
class_prompt=None,
size=512,
center_crop=False,
color_jitter=False,
):
self.size = size
self.center_crop = center_crop
self.tokenizer = tokenizer

self.instance_data_root = Path(instance_data_root)
if not self.instance_data_root.exists():
raise ValueError("Instance images root doesn't exists.")

self.instance_images_path = list(Path(instance_data_root).iterdir())
self.num_instance_images = len(self.instance_images_path)

self.placeholder_token = placeholder_token
self.stochastic_attribute = stochastic_attribute.split(",")

self.templates = (
imagenet_style_templates_small
if learnable_property == "style"
else imagenet_templates_small
)

self._length = self.num_instance_images

if class_data_root is not None:
self.class_data_root = Path(class_data_root)
self.class_data_root.mkdir(parents=True, exist_ok=True)
self.class_images_path = list(self.class_data_root.iterdir())
self.num_class_images = len(self.class_images_path)
self._length = max(self.num_class_images, self.num_instance_images)
self.class_prompt = class_prompt
else:
self.class_data_root = None

self.image_transforms = transforms.Compose(
[
transforms.Resize(
size, interpolation=transforms.InterpolationMode.BILINEAR
),
transforms.CenterCrop(size)
if center_crop
else transforms.RandomCrop(size),
transforms.ColorJitter(0.2, 0.1)
if color_jitter
else transforms.Lambda(lambda x: x),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)

def __len__(self):
return self._length

def __getitem__(self, index):
example = {}
instance_image = Image.open(
self.instance_images_path[index % self.num_instance_images]
)
if not instance_image.mode == "RGB":
instance_image = instance_image.convert("RGB")
example["instance_images"] = self.image_transforms(instance_image)

text = random.choice(self.templates).format(
", ".join(
[self.placeholder_token]
+ _shuffle(_randomset(self.stochastic_attribute))
)
)

example["instance_prompt_ids"] = self.tokenizer(
text,
padding="do_not_pad",
truncation=True,
max_length=self.tokenizer.model_max_length,
).input_ids

if self.class_data_root:
class_image = Image.open(
self.class_images_path[index % self.num_class_images]
)
if not class_image.mode == "RGB":
class_image = class_image.convert("RGB")
example["class_images"] = self.image_transforms(class_image)
example["class_prompt_ids"] = self.tokenizer(
self.class_prompt,
padding="do_not_pad",
truncation=True,
max_length=self.tokenizer.model_max_length,
).input_ids

return example


class PivotalTuningDatasetCapation(Dataset):
def __init__(
self,
instance_data_root,
learnable_property,
placeholder_token,
stochastic_attribute,
tokenizer,
class_data_root=None,
class_prompt=None,
size=512,
center_crop=False,
color_jitter=False,
):
self.size = size
self.center_crop = center_crop
self.tokenizer = tokenizer

self.instance_data_root = Path(instance_data_root)
if not self.instance_data_root.exists():
raise ValueError("Instance images root doesn't exists.")

self.instance_images_path = list(Path(instance_data_root).iterdir())
self.num_instance_images = len(self.instance_images_path)

self.placeholder_token = placeholder_token
self.stochastic_attribute = stochastic_attribute.split(",")

self._length = self.num_instance_images

if class_data_root is not None:
self.class_data_root = Path(class_data_root)
self.class_data_root.mkdir(parents=True, exist_ok=True)
self.class_images_path = list(self.class_data_root.iterdir())
self.num_class_images = len(self.class_images_path)
self._length = max(self.num_class_images, self.num_instance_images)
self.class_prompt = class_prompt
else:
self.class_data_root = None

self.image_transforms = transforms.Compose(
[
transforms.Resize(
size, interpolation=transforms.InterpolationMode.BILINEAR
),
transforms.CenterCrop(size)
if center_crop
else transforms.RandomCrop(size),
transforms.ColorJitter(0.2, 0.1)
if color_jitter
else transforms.Lambda(lambda x: x),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)

def __len__(self):
return self._length

def __getitem__(self, index):
example = {}
instance_image = Image.open(
self.instance_images_path[index % self.num_instance_images]
)
if not instance_image.mode == "RGB":
instance_image = instance_image.convert("RGB")
example["instance_images"] = self.image_transforms(instance_image)

text = self.instance_images_path[index % self.num_instance_images].stem

example["instance_prompt_ids"] = self.tokenizer(
text,
padding="do_not_pad",
truncation=True,
max_length=self.tokenizer.model_max_length,
).input_ids

if self.class_data_root:
class_image = Image.open(
self.class_images_path[index % self.num_class_images]
)
if not class_image.mode == "RGB":
class_image = class_image.convert("RGB")
example["class_images"] = self.image_transforms(class_image)
example["class_prompt_ids"] = self.tokenizer(
self.class_prompt,
padding="do_not_pad",
truncation=True,
max_length=self.tokenizer.model_max_length,
).input_ids

return example
Loading