diff --git a/examples/research_projects/mulit_token_textual_inversion/README.md b/examples/research_projects/mulit_token_textual_inversion/README.md
new file mode 100644
index 0000000000..540e4a705f
--- /dev/null
+++ b/examples/research_projects/mulit_token_textual_inversion/README.md
@@ -0,0 +1,140 @@
+## Multi Token Textual Inversion
+The author of this project is [Isamu Isozaki](https://github.com/isamu-isozaki) - please make sure to tag the author for issue and PRs as well as @patrickvonplaten.
+
+We add multi token support to textual inversion. I added
+1. num_vec_per_token for the number of used to reference that token
+2. progressive_tokens for progressively training the token from 1 token to 2 token etc
+3. progressive_tokens_max_steps for the max number of steps until we start full training
+4. vector_shuffle to shuffle vectors
+
+Feel free to add these options to your training! In practice num_vec_per_token around 10+vector shuffle works great!
+
+## Textual Inversion fine-tuning example
+
+[Textual inversion](https://arxiv.org/abs/2208.01618) is a method to personalize text2image models like stable diffusion on your own images using just 3-5 examples.
+The `textual_inversion.py` script shows how to implement the training procedure and adapt it for stable diffusion.
+
+## Running on Colab
+
+Colab for training
+[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/sd_textual_inversion_training.ipynb)
+
+Colab for inference
+[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/stable_conceptualizer_inference.ipynb)
+
+## Running locally with PyTorch
+### Installing the dependencies
+
+Before running the scripts, make sure to install the library's training dependencies:
+
+**Important**
+
+To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
+```bash
+git clone https://github.com/huggingface/diffusers
+cd diffusers
+pip install .
+```
+
+Then cd in the example folder and run
+```bash
+pip install -r requirements.txt
+```
+
+And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
+
+```bash
+accelerate config
+```
+
+
+### Cat toy example
+
+You need to accept the model license before downloading or using the weights. In this example we'll use model version `v1-5`, so you'll need to visit [its card](https://huggingface.co/runwayml/stable-diffusion-v1-5), read the license and tick the checkbox if you agree.
+
+You have to be a registered user in 🤗 Hugging Face Hub, and you'll also need to use an access token for the code to work. For more information on access tokens, please refer to [this section of the documentation](https://huggingface.co/docs/hub/security-tokens).
+
+Run the following command to authenticate your token
+
+```bash
+huggingface-cli login
+```
+
+If you have already cloned the repo, then you won't need to go through these steps.
+
+
+
+Now let's get our dataset.Download 3-4 images from [here](https://drive.google.com/drive/folders/1fmJMs25nxS_rSNqS5hTcRdLem_YQXbq5) and save them in a directory. This will be our training data.
+
+And launch the training using
+
+**___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.___**
+
+```bash
+export MODEL_NAME="runwayml/stable-diffusion-v1-5"
+export DATA_DIR="path-to-dir-containing-images"
+
+accelerate launch textual_inversion.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --train_data_dir=$DATA_DIR \
+ --learnable_property="object" \
+ --placeholder_token="" --initializer_token="toy" \
+ --resolution=512 \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=4 \
+ --max_train_steps=3000 \
+ --learning_rate=5.0e-04 --scale_lr \
+ --lr_scheduler="constant" \
+ --lr_warmup_steps=0 \
+ --output_dir="textual_inversion_cat"
+```
+
+A full training run takes ~1 hour on one V100 GPU.
+
+### Inference
+
+Once you have trained a model using above command, the inference can be done simply using the `StableDiffusionPipeline`. Make sure to include the `placeholder_token` in your prompt.
+
+```python
+from diffusers import StableDiffusionPipeline
+
+model_id = "path-to-your-trained-model"
+pipe = StableDiffusionPipeline.from_pretrained(model_id,torch_dtype=torch.float16).to("cuda")
+
+prompt = "A backpack"
+
+image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5).images[0]
+
+image.save("cat-backpack.png")
+```
+
+
+## Training with Flax/JAX
+
+For faster training on TPUs and GPUs you can leverage the flax training example. Follow the instructions above to get the model and dataset before running the script.
+
+Before running the scripts, make sure to install the library's training dependencies:
+
+```bash
+pip install -U -r requirements_flax.txt
+```
+
+```bash
+export MODEL_NAME="duongna/stable-diffusion-v1-4-flax"
+export DATA_DIR="path-to-dir-containing-images"
+
+python textual_inversion_flax.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --train_data_dir=$DATA_DIR \
+ --learnable_property="object" \
+ --placeholder_token="" --initializer_token="toy" \
+ --resolution=512 \
+ --train_batch_size=1 \
+ --max_train_steps=3000 \
+ --learning_rate=5.0e-04 --scale_lr \
+ --output_dir="textual_inversion_cat"
+```
+It should be at least 70% faster than the PyTorch script with the same configuration.
+
+### Training with xformers:
+You can enable memory efficient attention by [installing xFormers](https://github.com/facebookresearch/xformers#installing-xformers) and padding the `--enable_xformers_memory_efficient_attention` argument to the script. This is not available with the Flax/JAX implementation.
diff --git a/examples/research_projects/mulit_token_textual_inversion/multi_token_clip.py b/examples/research_projects/mulit_token_textual_inversion/multi_token_clip.py
new file mode 100644
index 0000000000..4388771b84
--- /dev/null
+++ b/examples/research_projects/mulit_token_textual_inversion/multi_token_clip.py
@@ -0,0 +1,103 @@
+"""
+The main idea for this code is to provide a way for users to not need to bother with the hassle of multiple tokens for a concept by typing
+a photo of _0 _1 ... and so on
+and instead just do
+a photo of
+which gets translated to the above. This needs to work for both inference and training.
+For inference,
+the tokenizer encodes the text. So, we would want logic for our tokenizer to replace the placeholder token with
+it's underlying vectors
+For training,
+we would want to abstract away some logic like
+1. Adding tokens
+2. Updating gradient mask
+3. Saving embeddings
+to our Util class here.
+so
+TODO:
+1. have tokenizer keep track of concept, multiconcept pairs and replace during encode call x
+2. have mechanism for adding tokens x
+3. have mech for saving emebeddings x
+4. get mask to update x
+5. Loading tokens from embedding x
+6. Integrate to training x
+7. Test
+"""
+import copy
+import random
+
+from transformers import CLIPTokenizer
+
+
+class MultiTokenCLIPTokenizer(CLIPTokenizer):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.token_map = {}
+
+ def try_adding_tokens(self, placeholder_token, *args, **kwargs):
+ num_added_tokens = super().add_tokens(placeholder_token, *args, **kwargs)
+ if num_added_tokens == 0:
+ raise ValueError(
+ f"The tokenizer already contains the token {placeholder_token}. Please pass a different"
+ " `placeholder_token` that is not already in the tokenizer."
+ )
+
+ def add_placeholder_tokens(self, placeholder_token, *args, num_vec_per_token=1, **kwargs):
+ output = []
+ if num_vec_per_token == 1:
+ self.try_adding_tokens(placeholder_token, *args, **kwargs)
+ output.append(placeholder_token)
+ else:
+ output = []
+ for i in range(num_vec_per_token):
+ ith_token = placeholder_token + f"_{i}"
+ self.try_adding_tokens(ith_token, *args, **kwargs)
+ output.append(ith_token)
+ # handle cases where there is a new placeholder token that contains the current placeholder token but is larger
+ for token in self.token_map:
+ if token in placeholder_token:
+ raise ValueError(
+ f"The tokenizer already has placeholder token {token} that can get confused with"
+ f" {placeholder_token}keep placeholder tokens independent"
+ )
+ self.token_map[placeholder_token] = output
+
+ def replace_placeholder_tokens_in_text(self, text, vector_shuffle=False, prop_tokens_to_load=1.0):
+ """
+ Here, we replace the placeholder tokens in text recorded in token_map so that the text_encoder
+ can encode them
+ vector_shuffle was inspired by https://github.com/rinongal/textual_inversion/pull/119
+ where shuffling tokens were found to force the model to learn the concepts more descriptively.
+ """
+ if isinstance(text, list):
+ output = []
+ for i in range(len(text)):
+ output.append(self.replace_placeholder_tokens_in_text(text[i], vector_shuffle=vector_shuffle))
+ return output
+ for placeholder_token in self.token_map:
+ if placeholder_token in text:
+ tokens = self.token_map[placeholder_token]
+ tokens = tokens[: 1 + int(len(tokens) * prop_tokens_to_load)]
+ if vector_shuffle:
+ tokens = copy.copy(tokens)
+ random.shuffle(tokens)
+ text = text.replace(placeholder_token, " ".join(tokens))
+ return text
+
+ def __call__(self, text, *args, vector_shuffle=False, prop_tokens_to_load=1.0, **kwargs):
+ return super().__call__(
+ self.replace_placeholder_tokens_in_text(
+ text, vector_shuffle=vector_shuffle, prop_tokens_to_load=prop_tokens_to_load
+ ),
+ *args,
+ **kwargs,
+ )
+
+ def encode(self, text, *args, vector_shuffle=False, prop_tokens_to_load=1.0, **kwargs):
+ return super().encode(
+ self.replace_placeholder_tokens_in_text(
+ text, vector_shuffle=vector_shuffle, prop_tokens_to_load=prop_tokens_to_load
+ ),
+ *args,
+ **kwargs,
+ )
diff --git a/examples/research_projects/mulit_token_textual_inversion/requirements.txt b/examples/research_projects/mulit_token_textual_inversion/requirements.txt
new file mode 100644
index 0000000000..7d93f3d03b
--- /dev/null
+++ b/examples/research_projects/mulit_token_textual_inversion/requirements.txt
@@ -0,0 +1,6 @@
+accelerate
+torchvision
+transformers>=4.25.1
+ftfy
+tensorboard
+Jinja2
diff --git a/examples/research_projects/mulit_token_textual_inversion/requirements_flax.txt b/examples/research_projects/mulit_token_textual_inversion/requirements_flax.txt
new file mode 100644
index 0000000000..8f85ad523a
--- /dev/null
+++ b/examples/research_projects/mulit_token_textual_inversion/requirements_flax.txt
@@ -0,0 +1,8 @@
+transformers>=4.25.1
+flax
+optax
+torch
+torchvision
+ftfy
+tensorboard
+Jinja2
diff --git a/examples/research_projects/mulit_token_textual_inversion/textual_inversion.py b/examples/research_projects/mulit_token_textual_inversion/textual_inversion.py
new file mode 100644
index 0000000000..459d9e65c5
--- /dev/null
+++ b/examples/research_projects/mulit_token_textual_inversion/textual_inversion.py
@@ -0,0 +1,941 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+import argparse
+import logging
+import math
+import os
+import random
+from pathlib import Path
+from typing import Optional
+
+import numpy as np
+import PIL
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+import transformers
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import ProjectConfiguration, set_seed
+from huggingface_hub import HfFolder, Repository, create_repo, whoami
+from multi_token_clip import MultiTokenCLIPTokenizer
+
+# TODO: remove and import from diffusers.utils when the new version of diffusers is released
+from packaging import version
+from PIL import Image
+from torch.utils.data import Dataset
+from torchvision import transforms
+from tqdm.auto import tqdm
+from transformers import CLIPTextModel
+
+import diffusers
+from diffusers import (
+ AutoencoderKL,
+ DDPMScheduler,
+ DiffusionPipeline,
+ DPMSolverMultistepScheduler,
+ StableDiffusionPipeline,
+ UNet2DConditionModel,
+)
+from diffusers.optimization import get_scheduler
+from diffusers.utils import check_min_version, is_wandb_available
+from diffusers.utils.import_utils import is_xformers_available
+
+
+if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
+ PIL_INTERPOLATION = {
+ "linear": PIL.Image.Resampling.BILINEAR,
+ "bilinear": PIL.Image.Resampling.BILINEAR,
+ "bicubic": PIL.Image.Resampling.BICUBIC,
+ "lanczos": PIL.Image.Resampling.LANCZOS,
+ "nearest": PIL.Image.Resampling.NEAREST,
+ }
+else:
+ PIL_INTERPOLATION = {
+ "linear": PIL.Image.LINEAR,
+ "bilinear": PIL.Image.BILINEAR,
+ "bicubic": PIL.Image.BICUBIC,
+ "lanczos": PIL.Image.LANCZOS,
+ "nearest": PIL.Image.NEAREST,
+ }
+# ------------------------------------------------------------------------------
+
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.14.0.dev0")
+
+logger = get_logger(__name__)
+
+
+def add_tokens(tokenizer, text_encoder, placeholder_token, num_vec_per_token=1, initializer_token=None):
+ """
+ Add tokens to the tokenizer and set the initial value of token embeddings
+ """
+ tokenizer.add_placeholder_tokens(placeholder_token, num_vec_per_token=num_vec_per_token)
+ text_encoder.resize_token_embeddings(len(tokenizer))
+ token_embeds = text_encoder.get_input_embeddings().weight.data
+ placeholder_token_ids = tokenizer.encode(placeholder_token, add_special_tokens=False)
+ if initializer_token:
+ token_ids = tokenizer.encode(initializer_token, add_special_tokens=False)
+ for i, placeholder_token_id in enumerate(placeholder_token_ids):
+ token_embeds[placeholder_token_id] = token_embeds[token_ids[i * len(token_ids) // num_vec_per_token]]
+ else:
+ for i, placeholder_token_id in enumerate(placeholder_token_ids):
+ token_embeds[placeholder_token_id] = torch.randn_like(token_embeds[placeholder_token_id])
+ return placeholder_token
+
+
+def save_progress(tokenizer, text_encoder, accelerator, save_path):
+ for placeholder_token in tokenizer.token_map:
+ placeholder_token_ids = tokenizer.encode(placeholder_token, add_special_tokens=False)
+ learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[placeholder_token_ids]
+ if len(placeholder_token_ids) == 1:
+ learned_embeds = learned_embeds[None]
+ learned_embeds_dict = {placeholder_token: learned_embeds.detach().cpu()}
+ torch.save(learned_embeds_dict, save_path)
+
+
+def load_multitoken_tokenizer(tokenizer, text_encoder, learned_embeds_dict):
+ for placeholder_token in learned_embeds_dict:
+ placeholder_embeds = learned_embeds_dict[placeholder_token]
+ num_vec_per_token = placeholder_embeds.shape[0]
+ placeholder_embeds = placeholder_embeds.to(dtype=text_encoder.dtype)
+ add_tokens(tokenizer, text_encoder, placeholder_token, num_vec_per_token=num_vec_per_token)
+ placeholder_token_ids = tokenizer.encode(placeholder_token, add_special_tokens=False)
+ token_embeds = text_encoder.get_input_embeddings().weight.data
+ for i, placeholder_token_id in enumerate(placeholder_token_ids):
+ token_embeds[placeholder_token_id] = placeholder_embeds[i]
+
+
+def load_multitoken_tokenizer_from_automatic(tokenizer, text_encoder, automatic_dict, placeholder_token):
+ """
+ Automatic1111's tokens have format
+ {'string_to_token': {'*': 265}, 'string_to_param': {'*': tensor([[ 0.0833, 0.0030, 0.0057, ..., -0.0264, -0.0616, -0.0529],
+ [ 0.0058, -0.0190, -0.0584, ..., -0.0025, -0.0945, -0.0490],
+ [ 0.0916, 0.0025, 0.0365, ..., -0.0685, -0.0124, 0.0728],
+ [ 0.0812, -0.0199, -0.0100, ..., -0.0581, -0.0780, 0.0254]],
+ requires_grad=True)}, 'name': 'FloralMarble-400', 'step': 399, 'sd_checkpoint': '4bdfc29c', 'sd_checkpoint_name': 'SD2.1-768'}
+ """
+ learned_embeds_dict = {}
+ learned_embeds_dict[placeholder_token] = automatic_dict["string_to_param"]["*"]
+ load_multitoken_tokenizer(tokenizer, text_encoder, learned_embeds_dict)
+
+
+def get_mask(tokenizer, accelerator):
+ # Get the mask of the weights that won't change
+ mask = torch.ones(len(tokenizer)).to(accelerator.device, dtype=torch.bool)
+ for placeholder_token in tokenizer.token_map:
+ placeholder_token_ids = tokenizer.encode(placeholder_token, add_special_tokens=False)
+ for i in range(len(placeholder_token_ids)):
+ mask = mask & (torch.arange(len(tokenizer)) != placeholder_token_ids[i]).to(accelerator.device)
+ return mask
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ parser.add_argument(
+ "--progressive_tokens_max_steps",
+ type=int,
+ default=2000,
+ help="The number of steps until all tokens will be used.",
+ )
+ parser.add_argument(
+ "--progressive_tokens",
+ action="store_true",
+ help="Progressively train the tokens. For example, first train for 1 token, then 2 tokens and so on.",
+ )
+ parser.add_argument("--vector_shuffle", action="store_true", help="Shuffling tokens durint training")
+ parser.add_argument(
+ "--num_vec_per_token",
+ type=int,
+ default=1,
+ help=(
+ "The number of vectors used to represent the placeholder token. The higher the number, the better the"
+ " result at the cost of editability. This can be fixed by prompt editing."
+ ),
+ )
+ parser.add_argument(
+ "--save_steps",
+ type=int,
+ default=500,
+ help="Save learned_embeds.bin every X updates steps.",
+ )
+ parser.add_argument(
+ "--only_save_embeds",
+ action="store_true",
+ default=False,
+ help="Save only the embeddings for the new concept.",
+ )
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--tokenizer_name",
+ type=str,
+ default=None,
+ help="Pretrained tokenizer name or path if not the same as model_name",
+ )
+ parser.add_argument(
+ "--train_data_dir", type=str, default=None, required=True, help="A folder containing the training data."
+ )
+ parser.add_argument(
+ "--placeholder_token",
+ type=str,
+ default=None,
+ required=True,
+ help="A token to use as a placeholder for the concept.",
+ )
+ parser.add_argument(
+ "--initializer_token", type=str, default=None, required=True, help="A token to use as initializer word."
+ )
+ parser.add_argument("--learnable_property", type=str, default="object", help="Choose between 'object' and 'style'")
+ parser.add_argument("--repeats", type=int, default=100, help="How many times to repeat the training data.")
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="text-inversion-model",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution."
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=100)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=5000,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default="no",
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose"
+ "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
+ "and an Nvidia Ampere GPU."
+ ),
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument(
+ "--validation_prompt",
+ type=str,
+ default=None,
+ help="A prompt that is used during validation to verify that the model is learning.",
+ )
+ parser.add_argument(
+ "--num_validation_images",
+ type=int,
+ default=4,
+ help="Number of images that should be generated during validation with `validation_prompt`.",
+ )
+ parser.add_argument(
+ "--validation_epochs",
+ type=int,
+ default=50,
+ help=(
+ "Run validation every X epochs. Validation consists of running the prompt"
+ " `args.validation_prompt` multiple times: `args.num_validation_images`"
+ " and logging the images."
+ ),
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
+ " training using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=(
+ "Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
+ " See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
+ " for more docs"
+ ),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
+ )
+
+ args = parser.parse_args()
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ if args.train_data_dir is None:
+ raise ValueError("You must specify a train data directory.")
+
+ return args
+
+
+imagenet_templates_small = [
+ "a photo of a {}",
+ "a rendering of a {}",
+ "a cropped photo of the {}",
+ "the photo of a {}",
+ "a photo of a clean {}",
+ "a photo of a dirty {}",
+ "a dark photo of the {}",
+ "a photo of my {}",
+ "a photo of the cool {}",
+ "a close-up photo of a {}",
+ "a bright photo of the {}",
+ "a cropped photo of a {}",
+ "a photo of the {}",
+ "a good photo of the {}",
+ "a photo of one {}",
+ "a close-up photo of the {}",
+ "a rendition of the {}",
+ "a photo of the clean {}",
+ "a rendition of a {}",
+ "a photo of a nice {}",
+ "a good photo of a {}",
+ "a photo of the nice {}",
+ "a photo of the small {}",
+ "a photo of the weird {}",
+ "a photo of the large {}",
+ "a photo of a cool {}",
+ "a photo of a small {}",
+]
+
+imagenet_style_templates_small = [
+ "a painting in the style of {}",
+ "a rendering in the style of {}",
+ "a cropped painting in the style of {}",
+ "the painting in the style of {}",
+ "a clean painting in the style of {}",
+ "a dirty painting in the style of {}",
+ "a dark painting in the style of {}",
+ "a picture in the style of {}",
+ "a cool painting in the style of {}",
+ "a close-up painting in the style of {}",
+ "a bright painting in the style of {}",
+ "a cropped painting in the style of {}",
+ "a good painting in the style of {}",
+ "a close-up painting in the style of {}",
+ "a rendition in the style of {}",
+ "a nice painting in the style of {}",
+ "a small painting in the style of {}",
+ "a weird painting in the style of {}",
+ "a large painting in the style of {}",
+]
+
+
+class TextualInversionDataset(Dataset):
+ def __init__(
+ self,
+ data_root,
+ tokenizer,
+ learnable_property="object", # [object, style]
+ size=512,
+ repeats=100,
+ interpolation="bicubic",
+ flip_p=0.5,
+ set="train",
+ placeholder_token="*",
+ center_crop=False,
+ vector_shuffle=False,
+ progressive_tokens=False,
+ ):
+ self.data_root = data_root
+ self.tokenizer = tokenizer
+ self.learnable_property = learnable_property
+ self.size = size
+ self.placeholder_token = placeholder_token
+ self.center_crop = center_crop
+ self.flip_p = flip_p
+ self.vector_shuffle = vector_shuffle
+ self.progressive_tokens = progressive_tokens
+ self.prop_tokens_to_load = 0
+
+ self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)]
+
+ self.num_images = len(self.image_paths)
+ self._length = self.num_images
+
+ if set == "train":
+ self._length = self.num_images * repeats
+
+ self.interpolation = {
+ "linear": PIL_INTERPOLATION["linear"],
+ "bilinear": PIL_INTERPOLATION["bilinear"],
+ "bicubic": PIL_INTERPOLATION["bicubic"],
+ "lanczos": PIL_INTERPOLATION["lanczos"],
+ }[interpolation]
+
+ self.templates = imagenet_style_templates_small if learnable_property == "style" else imagenet_templates_small
+ self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p)
+
+ def __len__(self):
+ return self._length
+
+ def __getitem__(self, i):
+ example = {}
+ image = Image.open(self.image_paths[i % self.num_images])
+
+ if not image.mode == "RGB":
+ image = image.convert("RGB")
+
+ placeholder_string = self.placeholder_token
+ text = random.choice(self.templates).format(placeholder_string)
+
+ example["input_ids"] = self.tokenizer.encode(
+ text,
+ padding="max_length",
+ truncation=True,
+ max_length=self.tokenizer.model_max_length,
+ return_tensors="pt",
+ vector_shuffle=self.vector_shuffle,
+ prop_tokens_to_load=self.prop_tokens_to_load if self.progressive_tokens else 1.0,
+ )[0]
+
+ # default to score-sde preprocessing
+ img = np.array(image).astype(np.uint8)
+
+ if self.center_crop:
+ crop = min(img.shape[0], img.shape[1])
+ (
+ h,
+ w,
+ ) = (
+ img.shape[0],
+ img.shape[1],
+ )
+ img = img[(h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2]
+
+ image = Image.fromarray(img)
+ image = image.resize((self.size, self.size), resample=self.interpolation)
+
+ image = self.flip_transform(image)
+ image = np.array(image).astype(np.uint8)
+ image = (image / 127.5 - 1.0).astype(np.float32)
+
+ example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1)
+ return example
+
+
+def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
+ if token is None:
+ token = HfFolder.get_token()
+ if organization is None:
+ username = whoami(token)["name"]
+ return f"{username}/{model_id}"
+ else:
+ return f"{organization}/{model_id}"
+
+
+def main():
+ args = parse_args()
+ logging_dir = os.path.join(args.output_dir, args.logging_dir)
+
+ accelerator_project_config = ProjectConfiguration(total_limit=args.checkpoints_total_limit)
+
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ logging_dir=logging_dir,
+ project_config=accelerator_project_config,
+ )
+
+ if args.report_to == "wandb":
+ if not is_wandb_available():
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
+ import wandb
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.push_to_hub:
+ if args.hub_model_id is None:
+ repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
+ else:
+ repo_name = args.hub_model_id
+ create_repo(repo_name, exist_ok=True, token=args.hub_token)
+ repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token)
+
+ with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
+ if "step_*" not in gitignore:
+ gitignore.write("step_*\n")
+ if "epoch_*" not in gitignore:
+ gitignore.write("epoch_*\n")
+ elif args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ # Load tokenizer
+ if args.tokenizer_name:
+ tokenizer = MultiTokenCLIPTokenizer.from_pretrained(args.tokenizer_name)
+ elif args.pretrained_model_name_or_path:
+ tokenizer = MultiTokenCLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
+
+ # Load scheduler and models
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
+ text_encoder = CLIPTextModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
+ )
+ vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
+ unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
+ )
+ if is_xformers_available():
+ try:
+ unet.enable_xformers_memory_efficient_attention()
+ except Exception as e:
+ logger.warning(
+ "Could not enable memory efficient attention. Make sure xformers is installed"
+ f" correctly and a GPU is available: {e}"
+ )
+ add_tokens(tokenizer, text_encoder, args.placeholder_token, args.num_vec_per_token, args.initializer_token)
+
+ # Freeze vae and unet
+ vae.requires_grad_(False)
+ unet.requires_grad_(False)
+ # Freeze all parameters except for the token embeddings in text encoder
+ text_encoder.text_model.encoder.requires_grad_(False)
+ text_encoder.text_model.final_layer_norm.requires_grad_(False)
+ text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)
+
+ if args.gradient_checkpointing:
+ # Keep unet in train mode if we are using gradient checkpointing to save memory.
+ # The dropout cannot be != 0 so it doesn't matter if we are in eval or train mode.
+ unet.train()
+ text_encoder.gradient_checkpointing_enable()
+ unet.enable_gradient_checkpointing()
+
+ if args.enable_xformers_memory_efficient_attention:
+ if is_xformers_available():
+ import xformers
+
+ xformers_version = version.parse(xformers.__version__)
+ if xformers_version == version.parse("0.0.16"):
+ logger.warn(
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
+ )
+ unet.enable_xformers_memory_efficient_attention()
+ else:
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+
+ # Initialize the optimizer
+ optimizer = torch.optim.AdamW(
+ text_encoder.get_input_embeddings().parameters(), # only optimize the embeddings
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ # Dataset and DataLoaders creation:
+ train_dataset = TextualInversionDataset(
+ data_root=args.train_data_dir,
+ tokenizer=tokenizer,
+ size=args.resolution,
+ placeholder_token=args.placeholder_token,
+ repeats=args.repeats,
+ learnable_property=args.learnable_property,
+ center_crop=args.center_crop,
+ set="train",
+ )
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers
+ )
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
+ num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
+ )
+
+ # Prepare everything with our `accelerator`.
+ text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ text_encoder, optimizer, train_dataloader, lr_scheduler
+ )
+
+ # For mixed precision training we cast the unet and vae weights to half-precision
+ # as these models are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ # Move vae and unet to device and cast to weight_dtype
+ unet.to(accelerator.device, dtype=weight_dtype)
+ vae.to(accelerator.device, dtype=weight_dtype)
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ accelerator.init_trackers("textual_inversion", config=vars(args))
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the most recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ resume_global_step = global_step * args.gradient_accumulation_steps
+ first_epoch = global_step // num_update_steps_per_epoch
+ resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
+
+ # Only show the progress bar once on each machine.
+ progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
+ progress_bar.set_description("Steps")
+
+ # keep original embeddings as reference
+ orig_embeds_params = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight.data.clone()
+
+ for epoch in range(first_epoch, args.num_train_epochs):
+ text_encoder.train()
+ for step, batch in enumerate(train_dataloader):
+ # Skip steps until we reach the resumed step
+ if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
+ if step % args.gradient_accumulation_steps == 0:
+ progress_bar.update(1)
+ continue
+ if args.progressive_tokens:
+ train_dataset.prop_tokens_to_load = float(global_step) / args.progressive_tokens_max_steps
+
+ with accelerator.accumulate(text_encoder):
+ # Convert images to latent space
+ latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample().detach()
+ latents = latents * vae.config.scaling_factor
+
+ # Sample noise that we'll add to the latents
+ noise = torch.randn_like(latents)
+ bsz = latents.shape[0]
+ # Sample a random timestep for each image
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
+ timesteps = timesteps.long()
+
+ # Add noise to the latents according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
+
+ # Get the text embedding for conditioning
+ encoder_hidden_states = text_encoder(batch["input_ids"])[0].to(dtype=weight_dtype)
+
+ # Predict the noise residual
+ model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
+
+ # Get the target for loss depending on the prediction type
+ if noise_scheduler.config.prediction_type == "epsilon":
+ target = noise
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
+ else:
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
+
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
+
+ accelerator.backward(loss)
+
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ # Let's make sure we don't update any embedding weights besides the newly added token
+ index_no_updates = get_mask(tokenizer, accelerator)
+ with torch.no_grad():
+ accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[
+ index_no_updates
+ ] = orig_embeds_params[index_no_updates]
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ global_step += 1
+ if global_step % args.save_steps == 0:
+ save_path = os.path.join(args.output_dir, f"learned_embeds-steps-{global_step}.bin")
+ save_progress(tokenizer, text_encoder, accelerator, save_path)
+
+ if global_step % args.checkpointing_steps == 0:
+ if accelerator.is_main_process:
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+ accelerator.log(logs, step=global_step)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
+ logger.info(
+ f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
+ f" {args.validation_prompt}."
+ )
+ # create pipeline (note: unet and vae are loaded again in float32)
+ pipeline = DiffusionPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ text_encoder=accelerator.unwrap_model(text_encoder),
+ tokenizer=tokenizer,
+ unet=unet,
+ vae=vae,
+ revision=args.revision,
+ torch_dtype=weight_dtype,
+ )
+ pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
+ pipeline = pipeline.to(accelerator.device)
+ pipeline.set_progress_bar_config(disable=True)
+
+ # run inference
+ generator = (
+ None if args.seed is None else torch.Generator(device=accelerator.device).manual_seed(args.seed)
+ )
+ images = []
+ for _ in range(args.num_validation_images):
+ with torch.autocast("cuda"):
+ image = pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]
+ images.append(image)
+
+ for tracker in accelerator.trackers:
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in images])
+ tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
+ if tracker.name == "wandb":
+ tracker.log(
+ {
+ "validation": [
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
+ for i, image in enumerate(images)
+ ]
+ }
+ )
+
+ del pipeline
+ torch.cuda.empty_cache()
+
+ # Create the pipeline using using the trained modules and save it.
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ if args.push_to_hub and args.only_save_embeds:
+ logger.warn("Enabling full model saving because --push_to_hub=True was specified.")
+ save_full_model = True
+ else:
+ save_full_model = not args.only_save_embeds
+ if save_full_model:
+ pipeline = StableDiffusionPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ text_encoder=accelerator.unwrap_model(text_encoder),
+ vae=vae,
+ unet=unet,
+ tokenizer=tokenizer,
+ )
+ pipeline.save_pretrained(args.output_dir)
+ # Save the newly trained embeddings
+ save_path = os.path.join(args.output_dir, "learned_embeds.bin")
+ save_progress(tokenizer, text_encoder, accelerator, save_path)
+
+ if args.push_to_hub:
+ repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/research_projects/mulit_token_textual_inversion/textual_inversion_flax.py b/examples/research_projects/mulit_token_textual_inversion/textual_inversion_flax.py
new file mode 100644
index 0000000000..c23fa4f5d3
--- /dev/null
+++ b/examples/research_projects/mulit_token_textual_inversion/textual_inversion_flax.py
@@ -0,0 +1,668 @@
+import argparse
+import logging
+import math
+import os
+import random
+from pathlib import Path
+from typing import Optional
+
+import jax
+import jax.numpy as jnp
+import numpy as np
+import optax
+import PIL
+import torch
+import torch.utils.checkpoint
+import transformers
+from flax import jax_utils
+from flax.training import train_state
+from flax.training.common_utils import shard
+from huggingface_hub import HfFolder, Repository, create_repo, whoami
+
+# TODO: remove and import from diffusers.utils when the new version of diffusers is released
+from packaging import version
+from PIL import Image
+from torch.utils.data import Dataset
+from torchvision import transforms
+from tqdm.auto import tqdm
+from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel, set_seed
+
+from diffusers import (
+ FlaxAutoencoderKL,
+ FlaxDDPMScheduler,
+ FlaxPNDMScheduler,
+ FlaxStableDiffusionPipeline,
+ FlaxUNet2DConditionModel,
+)
+from diffusers.pipelines.stable_diffusion import FlaxStableDiffusionSafetyChecker
+from diffusers.utils import check_min_version
+
+
+if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
+ PIL_INTERPOLATION = {
+ "linear": PIL.Image.Resampling.BILINEAR,
+ "bilinear": PIL.Image.Resampling.BILINEAR,
+ "bicubic": PIL.Image.Resampling.BICUBIC,
+ "lanczos": PIL.Image.Resampling.LANCZOS,
+ "nearest": PIL.Image.Resampling.NEAREST,
+ }
+else:
+ PIL_INTERPOLATION = {
+ "linear": PIL.Image.LINEAR,
+ "bilinear": PIL.Image.BILINEAR,
+ "bicubic": PIL.Image.BICUBIC,
+ "lanczos": PIL.Image.LANCZOS,
+ "nearest": PIL.Image.NEAREST,
+ }
+# ------------------------------------------------------------------------------
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.14.0.dev0")
+
+logger = logging.getLogger(__name__)
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--tokenizer_name",
+ type=str,
+ default=None,
+ help="Pretrained tokenizer name or path if not the same as model_name",
+ )
+ parser.add_argument(
+ "--train_data_dir", type=str, default=None, required=True, help="A folder containing the training data."
+ )
+ parser.add_argument(
+ "--placeholder_token",
+ type=str,
+ default=None,
+ required=True,
+ help="A token to use as a placeholder for the concept.",
+ )
+ parser.add_argument(
+ "--initializer_token", type=str, default=None, required=True, help="A token to use as initializer word."
+ )
+ parser.add_argument("--learnable_property", type=str, default="object", help="Choose between 'object' and 'style'")
+ parser.add_argument("--repeats", type=int, default=100, help="How many times to repeat the training data.")
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="text-inversion-model",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution."
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=100)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=5000,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=True,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument(
+ "--use_auth_token",
+ action="store_true",
+ help=(
+ "Will use the token generated when running `huggingface-cli login` (necessary to use this script with"
+ " private models)."
+ ),
+ )
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+
+ args = parser.parse_args()
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ if args.train_data_dir is None:
+ raise ValueError("You must specify a train data directory.")
+
+ return args
+
+
+imagenet_templates_small = [
+ "a photo of a {}",
+ "a rendering of a {}",
+ "a cropped photo of the {}",
+ "the photo of a {}",
+ "a photo of a clean {}",
+ "a photo of a dirty {}",
+ "a dark photo of the {}",
+ "a photo of my {}",
+ "a photo of the cool {}",
+ "a close-up photo of a {}",
+ "a bright photo of the {}",
+ "a cropped photo of a {}",
+ "a photo of the {}",
+ "a good photo of the {}",
+ "a photo of one {}",
+ "a close-up photo of the {}",
+ "a rendition of the {}",
+ "a photo of the clean {}",
+ "a rendition of a {}",
+ "a photo of a nice {}",
+ "a good photo of a {}",
+ "a photo of the nice {}",
+ "a photo of the small {}",
+ "a photo of the weird {}",
+ "a photo of the large {}",
+ "a photo of a cool {}",
+ "a photo of a small {}",
+]
+
+imagenet_style_templates_small = [
+ "a painting in the style of {}",
+ "a rendering in the style of {}",
+ "a cropped painting in the style of {}",
+ "the painting in the style of {}",
+ "a clean painting in the style of {}",
+ "a dirty painting in the style of {}",
+ "a dark painting in the style of {}",
+ "a picture in the style of {}",
+ "a cool painting in the style of {}",
+ "a close-up painting in the style of {}",
+ "a bright painting in the style of {}",
+ "a cropped painting in the style of {}",
+ "a good painting in the style of {}",
+ "a close-up painting in the style of {}",
+ "a rendition in the style of {}",
+ "a nice painting in the style of {}",
+ "a small painting in the style of {}",
+ "a weird painting in the style of {}",
+ "a large painting in the style of {}",
+]
+
+
+class TextualInversionDataset(Dataset):
+ def __init__(
+ self,
+ data_root,
+ tokenizer,
+ learnable_property="object", # [object, style]
+ size=512,
+ repeats=100,
+ interpolation="bicubic",
+ flip_p=0.5,
+ set="train",
+ placeholder_token="*",
+ center_crop=False,
+ ):
+ self.data_root = data_root
+ self.tokenizer = tokenizer
+ self.learnable_property = learnable_property
+ self.size = size
+ self.placeholder_token = placeholder_token
+ self.center_crop = center_crop
+ self.flip_p = flip_p
+
+ self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)]
+
+ self.num_images = len(self.image_paths)
+ self._length = self.num_images
+
+ if set == "train":
+ self._length = self.num_images * repeats
+
+ self.interpolation = {
+ "linear": PIL_INTERPOLATION["linear"],
+ "bilinear": PIL_INTERPOLATION["bilinear"],
+ "bicubic": PIL_INTERPOLATION["bicubic"],
+ "lanczos": PIL_INTERPOLATION["lanczos"],
+ }[interpolation]
+
+ self.templates = imagenet_style_templates_small if learnable_property == "style" else imagenet_templates_small
+ self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p)
+
+ def __len__(self):
+ return self._length
+
+ def __getitem__(self, i):
+ example = {}
+ image = Image.open(self.image_paths[i % self.num_images])
+
+ if not image.mode == "RGB":
+ image = image.convert("RGB")
+
+ placeholder_string = self.placeholder_token
+ text = random.choice(self.templates).format(placeholder_string)
+
+ example["input_ids"] = self.tokenizer(
+ text,
+ padding="max_length",
+ truncation=True,
+ max_length=self.tokenizer.model_max_length,
+ return_tensors="pt",
+ ).input_ids[0]
+
+ # default to score-sde preprocessing
+ img = np.array(image).astype(np.uint8)
+
+ if self.center_crop:
+ crop = min(img.shape[0], img.shape[1])
+ (
+ h,
+ w,
+ ) = (
+ img.shape[0],
+ img.shape[1],
+ )
+ img = img[(h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2]
+
+ image = Image.fromarray(img)
+ image = image.resize((self.size, self.size), resample=self.interpolation)
+
+ image = self.flip_transform(image)
+ image = np.array(image).astype(np.uint8)
+ image = (image / 127.5 - 1.0).astype(np.float32)
+
+ example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1)
+ return example
+
+
+def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
+ if token is None:
+ token = HfFolder.get_token()
+ if organization is None:
+ username = whoami(token)["name"]
+ return f"{username}/{model_id}"
+ else:
+ return f"{organization}/{model_id}"
+
+
+def resize_token_embeddings(model, new_num_tokens, initializer_token_id, placeholder_token_id, rng):
+ if model.config.vocab_size == new_num_tokens or new_num_tokens is None:
+ return
+ model.config.vocab_size = new_num_tokens
+
+ params = model.params
+ old_embeddings = params["text_model"]["embeddings"]["token_embedding"]["embedding"]
+ old_num_tokens, emb_dim = old_embeddings.shape
+
+ initializer = jax.nn.initializers.normal()
+
+ new_embeddings = initializer(rng, (new_num_tokens, emb_dim))
+ new_embeddings = new_embeddings.at[:old_num_tokens].set(old_embeddings)
+ new_embeddings = new_embeddings.at[placeholder_token_id].set(new_embeddings[initializer_token_id])
+ params["text_model"]["embeddings"]["token_embedding"]["embedding"] = new_embeddings
+
+ model.params = params
+ return model
+
+
+def get_params_to_save(params):
+ return jax.device_get(jax.tree_util.tree_map(lambda x: x[0], params))
+
+
+def main():
+ args = parse_args()
+
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ if jax.process_index() == 0:
+ if args.push_to_hub:
+ if args.hub_model_id is None:
+ repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
+ else:
+ repo_name = args.hub_model_id
+ create_repo(repo_name, exist_ok=True, token=args.hub_token)
+ repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token)
+
+ with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
+ if "step_*" not in gitignore:
+ gitignore.write("step_*\n")
+ if "epoch_*" not in gitignore:
+ gitignore.write("epoch_*\n")
+ elif args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ # Setup logging, we only want one process per machine to log things on the screen.
+ logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
+ if jax.process_index() == 0:
+ transformers.utils.logging.set_verbosity_info()
+ else:
+ transformers.utils.logging.set_verbosity_error()
+
+ # Load the tokenizer and add the placeholder token as a additional special token
+ if args.tokenizer_name:
+ tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
+ elif args.pretrained_model_name_or_path:
+ tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
+
+ # Add the placeholder token in tokenizer
+ num_added_tokens = tokenizer.add_tokens(args.placeholder_token)
+ if num_added_tokens == 0:
+ raise ValueError(
+ f"The tokenizer already contains the token {args.placeholder_token}. Please pass a different"
+ " `placeholder_token` that is not already in the tokenizer."
+ )
+
+ # Convert the initializer_token, placeholder_token to ids
+ token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False)
+ # Check if initializer_token is a single token or a sequence of tokens
+ if len(token_ids) > 1:
+ raise ValueError("The initializer token must be a single token.")
+
+ initializer_token_id = token_ids[0]
+ placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token)
+
+ # Load models and create wrapper for stable diffusion
+ text_encoder = FlaxCLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder")
+ vae, vae_params = FlaxAutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
+ unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
+
+ # Create sampling rng
+ rng = jax.random.PRNGKey(args.seed)
+ rng, _ = jax.random.split(rng)
+ # Resize the token embeddings as we are adding new special tokens to the tokenizer
+ text_encoder = resize_token_embeddings(
+ text_encoder, len(tokenizer), initializer_token_id, placeholder_token_id, rng
+ )
+ original_token_embeds = text_encoder.params["text_model"]["embeddings"]["token_embedding"]["embedding"]
+
+ train_dataset = TextualInversionDataset(
+ data_root=args.train_data_dir,
+ tokenizer=tokenizer,
+ size=args.resolution,
+ placeholder_token=args.placeholder_token,
+ repeats=args.repeats,
+ learnable_property=args.learnable_property,
+ center_crop=args.center_crop,
+ set="train",
+ )
+
+ def collate_fn(examples):
+ pixel_values = torch.stack([example["pixel_values"] for example in examples])
+ input_ids = torch.stack([example["input_ids"] for example in examples])
+
+ batch = {"pixel_values": pixel_values, "input_ids": input_ids}
+ batch = {k: v.numpy() for k, v in batch.items()}
+
+ return batch
+
+ total_train_batch_size = args.train_batch_size * jax.local_device_count()
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset, batch_size=total_train_batch_size, shuffle=True, drop_last=True, collate_fn=collate_fn
+ )
+
+ # Optimization
+ if args.scale_lr:
+ args.learning_rate = args.learning_rate * total_train_batch_size
+
+ constant_scheduler = optax.constant_schedule(args.learning_rate)
+
+ optimizer = optax.adamw(
+ learning_rate=constant_scheduler,
+ b1=args.adam_beta1,
+ b2=args.adam_beta2,
+ eps=args.adam_epsilon,
+ weight_decay=args.adam_weight_decay,
+ )
+
+ def create_mask(params, label_fn):
+ def _map(params, mask, label_fn):
+ for k in params:
+ if label_fn(k):
+ mask[k] = "token_embedding"
+ else:
+ if isinstance(params[k], dict):
+ mask[k] = {}
+ _map(params[k], mask[k], label_fn)
+ else:
+ mask[k] = "zero"
+
+ mask = {}
+ _map(params, mask, label_fn)
+ return mask
+
+ def zero_grads():
+ # from https://github.com/deepmind/optax/issues/159#issuecomment-896459491
+ def init_fn(_):
+ return ()
+
+ def update_fn(updates, state, params=None):
+ return jax.tree_util.tree_map(jnp.zeros_like, updates), ()
+
+ return optax.GradientTransformation(init_fn, update_fn)
+
+ # Zero out gradients of layers other than the token embedding layer
+ tx = optax.multi_transform(
+ {"token_embedding": optimizer, "zero": zero_grads()},
+ create_mask(text_encoder.params, lambda s: s == "token_embedding"),
+ )
+
+ state = train_state.TrainState.create(apply_fn=text_encoder.__call__, params=text_encoder.params, tx=tx)
+
+ noise_scheduler = FlaxDDPMScheduler(
+ beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000
+ )
+ noise_scheduler_state = noise_scheduler.create_state()
+
+ # Initialize our training
+ train_rngs = jax.random.split(rng, jax.local_device_count())
+
+ # Define gradient train step fn
+ def train_step(state, vae_params, unet_params, batch, train_rng):
+ dropout_rng, sample_rng, new_train_rng = jax.random.split(train_rng, 3)
+
+ def compute_loss(params):
+ vae_outputs = vae.apply(
+ {"params": vae_params}, batch["pixel_values"], deterministic=True, method=vae.encode
+ )
+ latents = vae_outputs.latent_dist.sample(sample_rng)
+ # (NHWC) -> (NCHW)
+ latents = jnp.transpose(latents, (0, 3, 1, 2))
+ latents = latents * vae.config.scaling_factor
+
+ noise_rng, timestep_rng = jax.random.split(sample_rng)
+ noise = jax.random.normal(noise_rng, latents.shape)
+ bsz = latents.shape[0]
+ timesteps = jax.random.randint(
+ timestep_rng,
+ (bsz,),
+ 0,
+ noise_scheduler.config.num_train_timesteps,
+ )
+ noisy_latents = noise_scheduler.add_noise(noise_scheduler_state, latents, noise, timesteps)
+ encoder_hidden_states = state.apply_fn(
+ batch["input_ids"], params=params, dropout_rng=dropout_rng, train=True
+ )[0]
+ # Predict the noise residual and compute loss
+ model_pred = unet.apply(
+ {"params": unet_params}, noisy_latents, timesteps, encoder_hidden_states, train=False
+ ).sample
+
+ # Get the target for loss depending on the prediction type
+ if noise_scheduler.config.prediction_type == "epsilon":
+ target = noise
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ target = noise_scheduler.get_velocity(noise_scheduler_state, latents, noise, timesteps)
+ else:
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
+
+ loss = (target - model_pred) ** 2
+ loss = loss.mean()
+
+ return loss
+
+ grad_fn = jax.value_and_grad(compute_loss)
+ loss, grad = grad_fn(state.params)
+ grad = jax.lax.pmean(grad, "batch")
+ new_state = state.apply_gradients(grads=grad)
+
+ # Keep the token embeddings fixed except the newly added embeddings for the concept,
+ # as we only want to optimize the concept embeddings
+ token_embeds = original_token_embeds.at[placeholder_token_id].set(
+ new_state.params["text_model"]["embeddings"]["token_embedding"]["embedding"][placeholder_token_id]
+ )
+ new_state.params["text_model"]["embeddings"]["token_embedding"]["embedding"] = token_embeds
+
+ metrics = {"loss": loss}
+ metrics = jax.lax.pmean(metrics, axis_name="batch")
+ return new_state, metrics, new_train_rng
+
+ # Create parallel version of the train and eval step
+ p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
+
+ # Replicate the train state on each device
+ state = jax_utils.replicate(state)
+ vae_params = jax_utils.replicate(vae_params)
+ unet_params = jax_utils.replicate(unet_params)
+
+ # Train!
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader))
+
+ # Scheduler and math around the number of training steps.
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel & distributed) = {total_train_batch_size}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+
+ global_step = 0
+
+ epochs = tqdm(range(args.num_train_epochs), desc=f"Epoch ... (1/{args.num_train_epochs})", position=0)
+ for epoch in epochs:
+ # ======================== Training ================================
+
+ train_metrics = []
+
+ steps_per_epoch = len(train_dataset) // total_train_batch_size
+ train_step_progress_bar = tqdm(total=steps_per_epoch, desc="Training...", position=1, leave=False)
+ # train
+ for batch in train_dataloader:
+ batch = shard(batch)
+ state, train_metric, train_rngs = p_train_step(state, vae_params, unet_params, batch, train_rngs)
+ train_metrics.append(train_metric)
+
+ train_step_progress_bar.update(1)
+ global_step += 1
+
+ if global_step >= args.max_train_steps:
+ break
+
+ train_metric = jax_utils.unreplicate(train_metric)
+
+ train_step_progress_bar.close()
+ epochs.write(f"Epoch... ({epoch + 1}/{args.num_train_epochs} | Loss: {train_metric['loss']})")
+
+ # Create the pipeline using using the trained modules and save it.
+ if jax.process_index() == 0:
+ scheduler = FlaxPNDMScheduler(
+ beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True
+ )
+ safety_checker = FlaxStableDiffusionSafetyChecker.from_pretrained(
+ "CompVis/stable-diffusion-safety-checker", from_pt=True
+ )
+ pipeline = FlaxStableDiffusionPipeline(
+ text_encoder=text_encoder,
+ vae=vae,
+ unet=unet,
+ tokenizer=tokenizer,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
+ )
+
+ pipeline.save_pretrained(
+ args.output_dir,
+ params={
+ "text_encoder": get_params_to_save(state.params),
+ "vae": get_params_to_save(vae_params),
+ "unet": get_params_to_save(unet_params),
+ "safety_checker": safety_checker.params,
+ },
+ )
+
+ # Also save the newly trained embeddings
+ learned_embeds = get_params_to_save(state.params)["text_model"]["embeddings"]["token_embedding"]["embedding"][
+ placeholder_token_id
+ ]
+ learned_embeds_dict = {args.placeholder_token: learned_embeds}
+ jnp.save(os.path.join(args.output_dir, "learned_embeds.npy"), learned_embeds_dict)
+
+ if args.push_to_hub:
+ repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)
+
+
+if __name__ == "__main__":
+ main()