Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Webdataset improvements; CogView4 example with The Simpsons webdataset #305

Merged
merged 19 commits into from
Mar 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 22 additions & 17 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,31 +2,23 @@

Finetrainers is a work-in-progress library to support (accessible) training of diffusion models. Our first priority is to support LoRA training for all popular video models in [Diffusers](https://github.com/huggingface/diffusers), and eventually other methods like controlnets, control-loras, distillation, etc.

`cogvideox-factory` was renamed to `finetrainers`. If you're looking to train CogVideoX or Mochi with the legacy training scripts, please refer to [this](./training/README.md) README instead. Everything in the `training/` directory will be eventually moved and supported under `finetrainers`.

<table align="center">
<tr>
<td align="center"><video src="https://github.com/user-attachments/assets/aad07161-87cb-4784-9e6b-16d06581e3e5">Your browser does not support the video tag.</video></td>
<td align="center"><video src="https://github.com/user-attachments/assets/c23d53e2-b422-4084-9156-3fce9fd01dad">Your browser does not support the video tag.</video></td>
</tr>
<tr>
<th align="center">CogVideoX LoRA training as the first iteration of this project</th>
<th align="center">Replication of PikaEffects</th>
</tr>
</table>

## News

- 🔥 **2025-03-03**: Wan T2V support added!
- 🔥 **2025-03-03**: We have shipped a complete refactor to support multi-backend distributed training, better precomputation handling for big datasets, model specification format (externally usable for training custom models), FSDP & more.
- 🔥 **2025-02-12**: We have shipped a set of tooling to curate small and high-quality video datasets for fine-tuning. See [video-dataset-scripts](https://github.com/huggingface/video-dataset-scripts) documentation page for details!
- 🔥 **2025-02-12**: Check out [eisneim/ltx_lora_training_i2v_t2v](https://github.com/eisneim/ltx_lora_training_i2v_t2v/)! It builds off of `finetrainers` to support image to video training for LTX-Video and STG guidance for inference.
- 🔥 **2025-01-15**: Support for naive FP8 weight-casting training added! This allows training HunyuanVideo in under 24 GB upto specific resolutions.
- 🔥 **2025-01-13**: Support for T2V full-finetuning added! Thanks to [@ArEnSc](https://github.com/ArEnSc) for taking up the initiative!
- 🔥 **2025-01-03**: Support for T2V LoRA finetuning of [CogVideoX](https://huggingface.co/docs/diffusers/main/api/pipelines/cogvideox) added!
- 🔥 **2024-12-20**: Support for T2V LoRA finetuning of [Hunyuan Video](https://huggingface.co/docs/diffusers/main/api/pipelines/hunyuan_video) added! We would like to thank @SHYuanBest for his work on a training script [here](https://github.com/huggingface/diffusers/pull/10254).
- 🔥 **2024-12-18**: Support for T2V LoRA finetuning of [LTX Video](https://huggingface.co/docs/diffusers/main/api/pipelines/ltx_video) added!

## Table of Contents

- [Quickstart](#quickstart)
- [News](#news)
- [Support Matrix](#support-matrix)
- [Featured Projects](#featured-projects)
- [Featured Projects](#featured-projects-)
- [Acknowledgements](#acknowledgements)

## Quickstart
Expand All @@ -40,7 +32,7 @@ git fetch --all --tags
git checkout tags/v0.0.1
```

Follow the instructions mentioned in the [README](https://github.com/a-r-r-o-w/finetrainers/tree/v0.0.1) for the release tag.
Follow the instructions mentioned in the [README](https://github.com/a-r-r-o-w/finetrainers/tree/v0.0.1) for the latest stable release.

#### Using the main branch

Expand All @@ -59,6 +51,19 @@ Please checkout [`docs/models`](./docs/models/) and [`examples/training`](./exam
> [!IMPORTANT]
> It is recommended to use Pytorch 2.5.1 or above for training. Previous versions can lead to completely black videos, OOM errors, or other issues and are not tested. For fully reproducible training, please use the same environment as mentioned in [environment.md](./docs/environment.md).

## News

- 🔥 **2025-03-07**: CogView4 support added!
- 🔥 **2025-03-03**: Wan T2V support added!
- 🔥 **2025-03-03**: We have shipped a complete refactor to support multi-backend distributed training, better precomputation handling for big datasets, model specification format (externally usable for training custom models), FSDP & more.
- 🔥 **2025-02-12**: We have shipped a set of tooling to curate small and high-quality video datasets for fine-tuning. See [video-dataset-scripts](https://github.com/huggingface/video-dataset-scripts) documentation page for details!
- 🔥 **2025-02-12**: Check out [eisneim/ltx_lora_training_i2v_t2v](https://github.com/eisneim/ltx_lora_training_i2v_t2v/)! It builds off of `finetrainers` to support image to video training for LTX-Video and STG guidance for inference.
- 🔥 **2025-01-15**: Support for naive FP8 weight-casting training added! This allows training HunyuanVideo in under 24 GB upto specific resolutions.
- 🔥 **2025-01-13**: Support for T2V full-finetuning added! Thanks to [@ArEnSc](https://github.com/ArEnSc) for taking up the initiative!
- 🔥 **2025-01-03**: Support for T2V LoRA finetuning of [CogVideoX](https://huggingface.co/docs/diffusers/main/api/pipelines/cogvideox) added!
- 🔥 **2024-12-20**: Support for T2V LoRA finetuning of [Hunyuan Video](https://huggingface.co/docs/diffusers/main/api/pipelines/hunyuan_video) added! We would like to thank @SHYuanBest for his work on a training script [here](https://github.com/huggingface/diffusers/pull/10254).
- 🔥 **2024-12-18**: Support for T2V LoRA finetuning of [LTX Video](https://huggingface.co/docs/diffusers/main/api/pipelines/ltx_video) added!

## Support Matrix

> [!NOTE]
Expand All @@ -72,7 +77,7 @@ Please checkout [`docs/models`](./docs/models/) and [`examples/training`](./exam
| [HunyuanVideo](./docs/models/hunyuan_video.md) | Text-to-Video | 32 GB | OOM |
| [CogVideoX-5b](./docs/models/cogvideox.md) | Text-to-Video | 18 GB | 53 GB |
| [Wan](./docs/models/wan.md) | Text-to-Video | TODO | TODO |
| [CogView4](./docs/models/cogview4.md) | Text-to-Video | TODO | TODO |
| [CogView4](./docs/models/cogview4.md) | Text-to-Image | TODO | TODO |

</div>

Expand Down
5 changes: 5 additions & 0 deletions examples/training/sft/cogview4/the_simpsons/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# CogView4-6B The Simpsons dataset

This example is only an experiment to verify if webdataset loading and streaming from the HF Hub works as expected. Do not expect meaningful results.

The dataset used for testing is available at [`bigdata-pw/TheSimpsons`](https://huggingface.co/datasets/bigdata-pw/TheSimpsons).
161 changes: 161 additions & 0 deletions examples/training/sft/cogview4/the_simpsons/train.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
#!/bin/bash

set -e -x

# export TORCH_LOGS="+dynamo,recompiles,graph_breaks"
# export TORCHDYNAMO_VERBOSE=1
export WANDB_MODE="offline"
export NCCL_P2P_DISABLE=1
export TORCH_NCCL_ENABLE_MONITORING=0
export FINETRAINERS_LOG_LEVEL="INFO"

# Finetrainers supports multiple backends for distributed training. Select your favourite and benchmark the differences!
# BACKEND="accelerate"
BACKEND="ptd"

# In this setting, I'm using all 8 GPUs on a 8-GPU node for training
NUM_GPUS=8
CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"

# Check the JSON files for the expected JSON format
TRAINING_DATASET_CONFIG="examples/training/sft/cogview4/the_simpsons/training.json"
VALIDATION_DATASET_FILE="examples/training/sft/cogview4/the_simpsons/validation.json"

# Depending on how many GPUs you have available, choose your degree of parallelism and technique!
DDP_1="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 1 --dp_shards 1 --cp_degree 1 --tp_degree 1"
DDP_2="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 2 --dp_shards 1 --cp_degree 1 --tp_degree 1"
DDP_4="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 4 --dp_shards 1 --cp_degree 1 --tp_degree 1"
FSDP_2="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 1 --dp_shards 2 --cp_degree 1 --tp_degree 1"
FSDP_4="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 1 --dp_shards 4 --cp_degree 1 --tp_degree 1"
HSDP_2_2="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 2 --dp_shards 2 --cp_degree 1 --tp_degree 1"
HSDP_4_2="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 4 --dp_shards 2 --cp_degree 1 --tp_degree 1"

# Parallel arguments
parallel_cmd=(
$HSDP_4_2
)

# Model arguments
model_cmd=(
--model_name "cogview4"
--pretrained_model_name_or_path "THUDM/CogView4-6B"
)

# Dataset arguments
# Here, we know that the dataset size if about ~80 images. In `training.json`, we duplicate the same
# dataset 3 times for multi-resolution training. This gives us a total of about 240 images. Since
# we're using 2 GPUs for training, we can split the data into 120 images per GPU and precompute
# all embeddings at once, instead of doing it on-the-fly which would be slower (the ideal usecase
# of not using `--precomputation_once` is when you're training on large datasets)
dataset_cmd=(
--dataset_config $TRAINING_DATASET_CONFIG
--dataset_shuffle_buffer_size 32
)

# Dataloader arguments
dataloader_cmd=(
--dataloader_num_workers 0
)

# Diffusion arguments
diffusion_cmd=(
--flow_weighting_scheme "logit_normal"
)

# Training arguments
# We target just the attention projections layers for LoRA training here.
# You can modify as you please and target any layer (regex is supported)
training_cmd=(
--training_type "lora"
--seed 42
--batch_size 1
--train_steps 5000
--rank 128
--lora_alpha 128
--target_modules "transformer_blocks.*(to_q|to_k|to_v|to_out.0)"
--gradient_accumulation_steps 1
--gradient_checkpointing
--checkpointing_steps 1000
--checkpointing_limit 2
# --resume_from_checkpoint 3000
--enable_slicing
--enable_tiling
)

# Optimizer arguments
optimizer_cmd=(
--optimizer "adamw"
--lr 1e-5
--lr_scheduler "constant_with_warmup"
--lr_warmup_steps 2000
--lr_num_cycles 1
--beta1 0.9
--beta2 0.99
--weight_decay 1e-4
--epsilon 1e-8
--max_grad_norm 1.0
)

# Validation arguments
validation_cmd=(
--validation_dataset_file "$VALIDATION_DATASET_FILE"
--validation_steps 500
)

# Miscellaneous arguments
miscellaneous_cmd=(
--tracker_name "finetrainers-cogview4"
--output_dir "/fsx/aryan/cogview4"
--init_timeout 600
--nccl_timeout 600
--report_to "wandb"
)

# Execute the training script
if [ "$BACKEND" == "accelerate" ]; then

ACCELERATE_CONFIG_FILE=""
if [ "$NUM_GPUS" == 1 ]; then
ACCELERATE_CONFIG_FILE="accelerate_configs/uncompiled_1.yaml"
elif [ "$NUM_GPUS" == 2 ]; then
ACCELERATE_CONFIG_FILE="accelerate_configs/uncompiled_2.yaml"
elif [ "$NUM_GPUS" == 4 ]; then
ACCELERATE_CONFIG_FILE="accelerate_configs/uncompiled_4.yaml"
elif [ "$NUM_GPUS" == 8 ]; then
ACCELERATE_CONFIG_FILE="accelerate_configs/uncompiled_8.yaml"
fi

accelerate launch --config_file "$ACCELERATE_CONFIG_FILE" --gpu_ids $CUDA_VISIBLE_DEVICES train.py \
"${parallel_cmd[@]}" \
"${model_cmd[@]}" \
"${dataset_cmd[@]}" \
"${dataloader_cmd[@]}" \
"${diffusion_cmd[@]}" \
"${training_cmd[@]}" \
"${optimizer_cmd[@]}" \
"${validation_cmd[@]}" \
"${miscellaneous_cmd[@]}"

elif [ "$BACKEND" == "ptd" ]; then

export CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES

torchrun \
--standalone \
--nnodes=1 \
--nproc_per_node=$NUM_GPUS \
--rdzv_backend c10d \
--rdzv_endpoint="localhost:0" \
train.py \
"${parallel_cmd[@]}" \
"${model_cmd[@]}" \
"${dataset_cmd[@]}" \
"${dataloader_cmd[@]}" \
"${diffusion_cmd[@]}" \
"${training_cmd[@]}" \
"${optimizer_cmd[@]}" \
"${validation_cmd[@]}" \
"${miscellaneous_cmd[@]}"
fi

echo -ne "-------------------- Finished executing script --------------------\n\n"
24 changes: 24 additions & 0 deletions examples/training/sft/cogview4/the_simpsons/training.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
{
"datasets": [
{
"data_root": "bigdata-pw/TheSimpsons",
"dataset_type": "image",
"id_token": "SMPSN",
"image_resolution_buckets": [
[960, 528],
[720, 528],
[720, 480]
],
"reshape_mode": "bicubic",
"remove_common_llm_caption_prefixes": true,
"caption_options": {
"column_names": ["caption.txt", "detailed_caption.txt", "more_detailed_caption.txt"],
"weights": {
"caption.txt": 0.2,
"detailed_caption.txt": 0.6,
"more_detailed_caption.txt": 0.2
}
}
}
]
}
Loading