forked from open-mmlab/mmsegmentation
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added multitoken training for textual inversion. Issue 369 (open-mmla…
…b#661) * Added multitoken training for textual inversion * Updated assertion * Removed duplicate save code * Fixed undefined bug * Fixed save * Added multitoken clip model +util helper * Removed code splitting * Removed class * Fixed errors * Fixed errors * Added loading functionality * Loading via dict instead * Fixed bug of invalid index being loaded * Fixed adding placeholder token only adding 1 token * Fixed bug when initializing tokens * Fixed bug when initializing tokens * Removed flawed logic * Fixed vector shuffle * Fixed tokenizer's inconsistent __call__ method * Fixed tokenizer's inconsistent __call__ method * Handling list input * Added exception for adding invalid tokens to token map * Removed unnecessary files and started working on progressive tokens * Set at minimum load one token * Changed to global step * Added method to load automatic1111 tokens * Fixed bug in load * Quality+style fixes * Update quality/style fixes * Cast embeddings to fp16 when loading * Fixed quality * Started moving things over * Clearing diffs * Clearing diffs * Moved everything * Requested changes
- Loading branch information
1 parent
e09a7d0
commit 8552fd7
Showing
6 changed files
with
1,866 additions
and
0 deletions.
There are no files selected for viewing
140 changes: 140 additions & 0 deletions
140
examples/research_projects/mulit_token_textual_inversion/README.md
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
## Multi Token Textual Inversion | ||
The author of this project is [Isamu Isozaki](https://github.com/isamu-isozaki) - please make sure to tag the author for issue and PRs as well as @patrickvonplaten. | ||
|
||
We add multi token support to textual inversion. I added | ||
1. num_vec_per_token for the number of used to reference that token | ||
2. progressive_tokens for progressively training the token from 1 token to 2 token etc | ||
3. progressive_tokens_max_steps for the max number of steps until we start full training | ||
4. vector_shuffle to shuffle vectors | ||
|
||
Feel free to add these options to your training! In practice num_vec_per_token around 10+vector shuffle works great! | ||
|
||
## Textual Inversion fine-tuning example | ||
|
||
[Textual inversion](https://arxiv.org/abs/2208.01618) is a method to personalize text2image models like stable diffusion on your own images using just 3-5 examples. | ||
The `textual_inversion.py` script shows how to implement the training procedure and adapt it for stable diffusion. | ||
|
||
## Running on Colab | ||
|
||
Colab for training | ||
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/sd_textual_inversion_training.ipynb) | ||
|
||
Colab for inference | ||
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/stable_conceptualizer_inference.ipynb) | ||
|
||
## Running locally with PyTorch | ||
### Installing the dependencies | ||
|
||
Before running the scripts, make sure to install the library's training dependencies: | ||
|
||
**Important** | ||
|
||
To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment: | ||
```bash | ||
git clone https://github.com/huggingface/diffusers | ||
cd diffusers | ||
pip install . | ||
``` | ||
|
||
Then cd in the example folder and run | ||
```bash | ||
pip install -r requirements.txt | ||
``` | ||
|
||
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with: | ||
|
||
```bash | ||
accelerate config | ||
``` | ||
|
||
|
||
### Cat toy example | ||
|
||
You need to accept the model license before downloading or using the weights. In this example we'll use model version `v1-5`, so you'll need to visit [its card](https://huggingface.co/runwayml/stable-diffusion-v1-5), read the license and tick the checkbox if you agree. | ||
|
||
You have to be a registered user in 🤗 Hugging Face Hub, and you'll also need to use an access token for the code to work. For more information on access tokens, please refer to [this section of the documentation](https://huggingface.co/docs/hub/security-tokens). | ||
|
||
Run the following command to authenticate your token | ||
|
||
```bash | ||
huggingface-cli login | ||
``` | ||
|
||
If you have already cloned the repo, then you won't need to go through these steps. | ||
|
||
<br> | ||
|
||
Now let's get our dataset.Download 3-4 images from [here](https://drive.google.com/drive/folders/1fmJMs25nxS_rSNqS5hTcRdLem_YQXbq5) and save them in a directory. This will be our training data. | ||
|
||
And launch the training using | ||
|
||
**___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.___** | ||
|
||
```bash | ||
export MODEL_NAME="runwayml/stable-diffusion-v1-5" | ||
export DATA_DIR="path-to-dir-containing-images" | ||
|
||
accelerate launch textual_inversion.py \ | ||
--pretrained_model_name_or_path=$MODEL_NAME \ | ||
--train_data_dir=$DATA_DIR \ | ||
--learnable_property="object" \ | ||
--placeholder_token="<cat-toy>" --initializer_token="toy" \ | ||
--resolution=512 \ | ||
--train_batch_size=1 \ | ||
--gradient_accumulation_steps=4 \ | ||
--max_train_steps=3000 \ | ||
--learning_rate=5.0e-04 --scale_lr \ | ||
--lr_scheduler="constant" \ | ||
--lr_warmup_steps=0 \ | ||
--output_dir="textual_inversion_cat" | ||
``` | ||
|
||
A full training run takes ~1 hour on one V100 GPU. | ||
|
||
### Inference | ||
|
||
Once you have trained a model using above command, the inference can be done simply using the `StableDiffusionPipeline`. Make sure to include the `placeholder_token` in your prompt. | ||
|
||
```python | ||
from diffusers import StableDiffusionPipeline | ||
|
||
model_id = "path-to-your-trained-model" | ||
pipe = StableDiffusionPipeline.from_pretrained(model_id,torch_dtype=torch.float16).to("cuda") | ||
|
||
prompt = "A <cat-toy> backpack" | ||
|
||
image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5).images[0] | ||
|
||
image.save("cat-backpack.png") | ||
``` | ||
|
||
|
||
## Training with Flax/JAX | ||
|
||
For faster training on TPUs and GPUs you can leverage the flax training example. Follow the instructions above to get the model and dataset before running the script. | ||
|
||
Before running the scripts, make sure to install the library's training dependencies: | ||
|
||
```bash | ||
pip install -U -r requirements_flax.txt | ||
``` | ||
|
||
```bash | ||
export MODEL_NAME="duongna/stable-diffusion-v1-4-flax" | ||
export DATA_DIR="path-to-dir-containing-images" | ||
|
||
python textual_inversion_flax.py \ | ||
--pretrained_model_name_or_path=$MODEL_NAME \ | ||
--train_data_dir=$DATA_DIR \ | ||
--learnable_property="object" \ | ||
--placeholder_token="<cat-toy>" --initializer_token="toy" \ | ||
--resolution=512 \ | ||
--train_batch_size=1 \ | ||
--max_train_steps=3000 \ | ||
--learning_rate=5.0e-04 --scale_lr \ | ||
--output_dir="textual_inversion_cat" | ||
``` | ||
It should be at least 70% faster than the PyTorch script with the same configuration. | ||
|
||
### Training with xformers: | ||
You can enable memory efficient attention by [installing xFormers](https://github.com/facebookresearch/xformers#installing-xformers) and padding the `--enable_xformers_memory_efficient_attention` argument to the script. This is not available with the Flax/JAX implementation. |
103 changes: 103 additions & 0 deletions
103
examples/research_projects/mulit_token_textual_inversion/multi_token_clip.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
""" | ||
The main idea for this code is to provide a way for users to not need to bother with the hassle of multiple tokens for a concept by typing | ||
a photo of <concept>_0 <concept>_1 ... and so on | ||
and instead just do | ||
a photo of <concept> | ||
which gets translated to the above. This needs to work for both inference and training. | ||
For inference, | ||
the tokenizer encodes the text. So, we would want logic for our tokenizer to replace the placeholder token with | ||
it's underlying vectors | ||
For training, | ||
we would want to abstract away some logic like | ||
1. Adding tokens | ||
2. Updating gradient mask | ||
3. Saving embeddings | ||
to our Util class here. | ||
so | ||
TODO: | ||
1. have tokenizer keep track of concept, multiconcept pairs and replace during encode call x | ||
2. have mechanism for adding tokens x | ||
3. have mech for saving emebeddings x | ||
4. get mask to update x | ||
5. Loading tokens from embedding x | ||
6. Integrate to training x | ||
7. Test | ||
""" | ||
import copy | ||
import random | ||
|
||
from transformers import CLIPTokenizer | ||
|
||
|
||
class MultiTokenCLIPTokenizer(CLIPTokenizer): | ||
def __init__(self, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
self.token_map = {} | ||
|
||
def try_adding_tokens(self, placeholder_token, *args, **kwargs): | ||
num_added_tokens = super().add_tokens(placeholder_token, *args, **kwargs) | ||
if num_added_tokens == 0: | ||
raise ValueError( | ||
f"The tokenizer already contains the token {placeholder_token}. Please pass a different" | ||
" `placeholder_token` that is not already in the tokenizer." | ||
) | ||
|
||
def add_placeholder_tokens(self, placeholder_token, *args, num_vec_per_token=1, **kwargs): | ||
output = [] | ||
if num_vec_per_token == 1: | ||
self.try_adding_tokens(placeholder_token, *args, **kwargs) | ||
output.append(placeholder_token) | ||
else: | ||
output = [] | ||
for i in range(num_vec_per_token): | ||
ith_token = placeholder_token + f"_{i}" | ||
self.try_adding_tokens(ith_token, *args, **kwargs) | ||
output.append(ith_token) | ||
# handle cases where there is a new placeholder token that contains the current placeholder token but is larger | ||
for token in self.token_map: | ||
if token in placeholder_token: | ||
raise ValueError( | ||
f"The tokenizer already has placeholder token {token} that can get confused with" | ||
f" {placeholder_token}keep placeholder tokens independent" | ||
) | ||
self.token_map[placeholder_token] = output | ||
|
||
def replace_placeholder_tokens_in_text(self, text, vector_shuffle=False, prop_tokens_to_load=1.0): | ||
""" | ||
Here, we replace the placeholder tokens in text recorded in token_map so that the text_encoder | ||
can encode them | ||
vector_shuffle was inspired by https://github.com/rinongal/textual_inversion/pull/119 | ||
where shuffling tokens were found to force the model to learn the concepts more descriptively. | ||
""" | ||
if isinstance(text, list): | ||
output = [] | ||
for i in range(len(text)): | ||
output.append(self.replace_placeholder_tokens_in_text(text[i], vector_shuffle=vector_shuffle)) | ||
return output | ||
for placeholder_token in self.token_map: | ||
if placeholder_token in text: | ||
tokens = self.token_map[placeholder_token] | ||
tokens = tokens[: 1 + int(len(tokens) * prop_tokens_to_load)] | ||
if vector_shuffle: | ||
tokens = copy.copy(tokens) | ||
random.shuffle(tokens) | ||
text = text.replace(placeholder_token, " ".join(tokens)) | ||
return text | ||
|
||
def __call__(self, text, *args, vector_shuffle=False, prop_tokens_to_load=1.0, **kwargs): | ||
return super().__call__( | ||
self.replace_placeholder_tokens_in_text( | ||
text, vector_shuffle=vector_shuffle, prop_tokens_to_load=prop_tokens_to_load | ||
), | ||
*args, | ||
**kwargs, | ||
) | ||
|
||
def encode(self, text, *args, vector_shuffle=False, prop_tokens_to_load=1.0, **kwargs): | ||
return super().encode( | ||
self.replace_placeholder_tokens_in_text( | ||
text, vector_shuffle=vector_shuffle, prop_tokens_to_load=prop_tokens_to_load | ||
), | ||
*args, | ||
**kwargs, | ||
) |
6 changes: 6 additions & 0 deletions
6
examples/research_projects/mulit_token_textual_inversion/requirements.txt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
accelerate | ||
torchvision | ||
transformers>=4.25.1 | ||
ftfy | ||
tensorboard | ||
Jinja2 |
8 changes: 8 additions & 0 deletions
8
examples/research_projects/mulit_token_textual_inversion/requirements_flax.txt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
transformers>=4.25.1 | ||
flax | ||
optax | ||
torch | ||
torchvision | ||
ftfy | ||
tensorboard | ||
Jinja2 |
Oops, something went wrong.