Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support SD3 #1374

Draft
wants to merge 435 commits into
base: dev
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 46 commits
Commits
Show all changes
435 commits
Select commit Hold shift + click to select a range
822fe57
add workaround for 'Some tensors share memory' error #1614
kohya-ss Sep 28, 2024
1a0f5b0
re-fix sample generation is not working in FLUX1 split mode #1647
kohya-ss Sep 28, 2024
d050638
Merge branch 'dev' into sd3
kohya-ss Sep 29, 2024
e0c3630
Support Sdxl Controlnet (#1648)
sdbds Sep 29, 2024
56a63f0
Merge branch 'sd3' into multi-gpu-caching
kohya-ss Sep 29, 2024
8919b31
use original ControlNet instead of Diffusers
kohya-ss Sep 29, 2024
0243c65
fix typo
kohya-ss Sep 29, 2024
8bea039
Merge branch 'dev' into sd3
kohya-ss Sep 29, 2024
d78f6a7
Merge branch 'sd3' into sdxl-ctrl-net
kohya-ss Sep 29, 2024
793999d
sample generation in SDXL ControlNet training
kohya-ss Sep 30, 2024
33e942e
Merge branch 'sd3' into fast_image_sizes
kohya-ss Sep 30, 2024
c2440f9
fix cond image normlization, add independent LR for control
kohya-ss Oct 3, 2024
3028027
Update train_network.py
gesen2egee Oct 4, 2024
dece2c3
Update train_db.py
gesen2egee Oct 4, 2024
ba08a89
call optimizer eval/train for sample_at_first, also set train after r…
kohya-ss Oct 4, 2024
83e3048
load Diffusers format, check schnell/dev
kohya-ss Oct 6, 2024
126159f
Merge branch 'sd3' into sdxl-ctrl-net
kohya-ss Oct 7, 2024
886f753
support weighted captions for sdxl LoRA and fine tuning
kohya-ss Oct 9, 2024
3de42b6
fix: distributed training in windows
Akegarasu Oct 10, 2024
9f4dac5
torch 2.4
Akegarasu Oct 10, 2024
f2bc820
support weighted captions for SD/SDXL
kohya-ss Oct 10, 2024
035c4a8
update docs and help text
kohya-ss Oct 11, 2024
43bfeea
Merge pull request #1655 from kohya-ss/sdxl-ctrl-net
kohya-ss Oct 11, 2024
d005652
Merge pull request #1686 from Akegarasu/sd3
kohya-ss Oct 12, 2024
0d3058b
update README
kohya-ss Oct 12, 2024
ff4083b
Merge branch 'sd3' into multi-gpu-caching
kohya-ss Oct 12, 2024
c80c304
Refactor caching in train scripts
kohya-ss Oct 12, 2024
ecaea90
update README
kohya-ss Oct 12, 2024
e277b57
Update FLUX.1 support for compact models
kohya-ss Oct 12, 2024
5bb9f7f
Merge branch 'sd3' into multi-gpu-caching
kohya-ss Oct 13, 2024
74228c9
update cache_latents/text_encoder_outputs
kohya-ss Oct 13, 2024
c65cf38
Merge branch 'sd3' into fast_image_sizes
kohya-ss Oct 13, 2024
2244cf5
load images in parallel when caching latents
kohya-ss Oct 13, 2024
bfc3a65
fix to work cache latents/text encoder outputs
kohya-ss Oct 13, 2024
d02a6ef
Merge pull request #1660 from kohya-ss/fast_image_sizes
kohya-ss Oct 13, 2024
886ffb4
Merge branch 'sd3' into multi-gpu-caching
kohya-ss Oct 13, 2024
2d5f7fa
update README
kohya-ss Oct 13, 2024
1275e14
Merge pull request #1690 from kohya-ss/multi-gpu-caching
kohya-ss Oct 13, 2024
2500f5a
fix latents caching not working closes #1696
kohya-ss Oct 14, 2024
3cc5b8d
Diff Output Preserv loss for SDXL
kohya-ss Oct 18, 2024
d8d7142
fix to work caching latents #1696
kohya-ss Oct 18, 2024
ef70aa7
add FLUX.1 support
kohya-ss Oct 18, 2024
2c45d97
update README, remove unnecessary autocast
kohya-ss Oct 19, 2024
09b4d1e
Merge branch 'sd3' into diff_output_prsv
kohya-ss Oct 19, 2024
aa93242
Merge pull request #1710 from kohya-ss/diff_output_prsv
kohya-ss Oct 19, 2024
7fe8e16
fix to work ControlNetSubset with custom_attributes
kohya-ss Oct 19, 2024
138dac4
update README
kohya-ss Oct 20, 2024
623017f
refactor SD3 CLIP to transformers etc.
kohya-ss Oct 24, 2024
e3c43bd
reduce memory usage in sample image generation
kohya-ss Oct 24, 2024
0286114
support SD3.5L, fix final saving
kohya-ss Oct 24, 2024
f8c5146
support block swap with fused_optimizer_pass
kohya-ss Oct 24, 2024
5fba6f5
Merge branch 'dev' into sd3
kohya-ss Oct 25, 2024
f52fb66
Merge branch 'sd3' into sd3_5_support
kohya-ss Oct 25, 2024
d2c549d
support SD3 LoRA
kohya-ss Oct 25, 2024
0031d91
add latent scaling/shifting
kohya-ss Oct 25, 2024
56bf761
fix errors in SD3 LoRA training with Text Encoders close #1724
kohya-ss Oct 26, 2024
014064f
fix sample image generation without seed failed close #1726
kohya-ss Oct 26, 2024
8549669
Merge branch 'dev' into sd3
kohya-ss Oct 26, 2024
150579d
Merge branch 'sd3' into sd3_5_support
kohya-ss Oct 26, 2024
731664b
Merge branch 'dev' into sd3
kohya-ss Oct 27, 2024
b649bbf
Merge branch 'sd3' into sd3_5_support
kohya-ss Oct 27, 2024
db2b4d4
Add dropout rate arguments for CLIP-L, CLIP-G, and T5, fix Text Encod…
kohya-ss Oct 27, 2024
a1255d6
Fix SD3 LoRA training to work (WIP)
kohya-ss Oct 27, 2024
d4f7849
prevent unintended cast for disk cached TE outputs
kohya-ss Oct 27, 2024
1065dd1
Fix to work dropout_rate for TEs
kohya-ss Oct 27, 2024
af8e216
Fix sample image gen to work with block swap
kohya-ss Oct 28, 2024
7555486
Fix error on saving T5XXL
kohya-ss Oct 28, 2024
0af4edd
Fix split_qkv
kohya-ss Oct 29, 2024
d4e19fb
Support Lora
kohya-ss Oct 29, 2024
80bb3f4
Merge branch 'sd3_5_support' of https://github.com/kohya-ss/sd-script…
kohya-ss Oct 29, 2024
1e2f7b0
Support for checkpoint files with a mysterious prefix "model.diffusio…
kohya-ss Oct 29, 2024
ce5b532
Fix additional LoRA to work
kohya-ss Oct 29, 2024
c9a1417
Merge branch 'sd3' into sd3_5_support
kohya-ss Oct 29, 2024
b502f58
Fix emb_dim to work.
kohya-ss Oct 29, 2024
bdddc20
support SD3.5M
kohya-ss Oct 30, 2024
8c3c825
Merge branch 'sd3_5_support' of https://github.com/kohya-ss/sd-script…
kohya-ss Oct 30, 2024
70a179e
Fix to use SDPA instead of xformers
kohya-ss Oct 30, 2024
1434d85
Support SD3.5M multi resolutional training
kohya-ss Oct 31, 2024
9e23368
Update SD3 training
kohya-ss Oct 31, 2024
830df4a
Fix crashing if image is too tall or wide.
kohya-ss Oct 31, 2024
9aa6f52
Fix memory leak in latent caching. bmp failed to cache
kohya-ss Nov 1, 2024
82daa98
remove duplicate resolution for scaled pos embed
kohya-ss Nov 1, 2024
264328d
Merge pull request #1719 from kohya-ss/sd3_5_support
kohya-ss Nov 1, 2024
e0db596
update multi-res training in SD3.5M
kohya-ss Nov 2, 2024
5e32ee2
fix crashing in DDP training closes #1751
kohya-ss Nov 2, 2024
81c0c96
faster block swap
kohya-ss Nov 5, 2024
aab943c
remove unused weight swapping functions from utils.py
kohya-ss Nov 5, 2024
4384903
Fix to work without latent cache #1758
kohya-ss Nov 6, 2024
40ed54b
Simplify Timestep weighting
Dango233 Nov 7, 2024
e54462a
Fix SD3 trained lora loading and merging
Dango233 Nov 7, 2024
bafd10d
Fix typo
Dango233 Nov 7, 2024
588ea9e
Merge pull request #1768 from Dango233/dango/timesteps_fix
kohya-ss Nov 7, 2024
5e86323
Update README and clean-up the code for SD3 timesteps
kohya-ss Nov 7, 2024
f264f40
Update README.md
kohya-ss Nov 7, 2024
5eb6d20
Update README.md
Dango233 Nov 7, 2024
387b40e
Merge pull request #1769 from Dango233/patch-1
kohya-ss Nov 7, 2024
e877b30
Merge branch 'dev' into sd3
kohya-ss Nov 7, 2024
123474d
Merge branch 'sd3' of https://github.com/kohya-ss/sd-scripts into sd3
kohya-ss Nov 7, 2024
b8d3fec
Merge branch 'sd3' into faster-block-swap
kohya-ss Nov 7, 2024
186aa5b
fix illeagal block is swapped #1764
kohya-ss Nov 7, 2024
b3248a8
fix: sort order when getting image size from cache file
feffy380 Nov 7, 2024
2a2042a
Merge pull request #1770 from feffy380/fix-size-from-cache
kohya-ss Nov 9, 2024
8fac3c3
update README
kohya-ss Nov 9, 2024
26bd454
init
sdbds Nov 11, 2024
02bd76e
Refactor block swapping to utilize custom offloading utilities
kohya-ss Nov 11, 2024
7feaae5
Merge branch 'sd3' into faster-block-swap
kohya-ss Nov 11, 2024
92482c7
Merge pull request #1774 from sdbds/avif_get_imagesize
kohya-ss Nov 11, 2024
3fe94b0
update comment
kohya-ss Nov 11, 2024
cde90b8
feat: implement block swapping for FLUX.1 LoRA (WIP)
kohya-ss Nov 11, 2024
17cf249
Merge branch 'sd3' into faster-block-swap
kohya-ss Nov 11, 2024
2cb7a6d
feat: add block swap for FLUX.1/SD3 LoRA training
kohya-ss Nov 12, 2024
2bb0f54
update grad hook creation to fix TE lr in sd3 fine tuning
kohya-ss Nov 14, 2024
5c5b544
refactor: remove unused prepare_split_model method from FluxNetworkTr…
kohya-ss Nov 14, 2024
fd2d879
docs: update README
kohya-ss Nov 14, 2024
0047bb1
Merge pull request #1779 from kohya-ss/faster-block-swap
kohya-ss Nov 14, 2024
ccfaa00
add flux controlnet base module
minux302 Nov 15, 2024
42f6edf
fix for adding controlnet
minux302 Nov 15, 2024
e358b11
fix dataloader
minux302 Nov 16, 2024
2a188f0
Fix to work DOP with bock swap
kohya-ss Nov 17, 2024
b2660bb
train run
minux302 Nov 17, 2024
35778f0
fix sample_images type
minux302 Nov 17, 2024
4dd4cd6
work cn load and validation
minux302 Nov 18, 2024
31ca899
fix depth value
minux302 Nov 18, 2024
2a61fc0
docs: fix typo from block_to_swap to blocks_to_swap in README
kohya-ss Nov 20, 2024
0b5229a
save cn
minux302 Nov 21, 2024
420a180
Implement pseudo Huber loss for Flux and SD3
recris Nov 27, 2024
740ec1d
Fix issues found in review
recris Nov 28, 2024
9dff44d
fix device
minux302 Nov 29, 2024
575f583
add README
minux302 Nov 29, 2024
be5860f
add schnell option to load_cn
minux302 Nov 29, 2024
f40632b
rm abundant arg
minux302 Nov 29, 2024
928b939
Allow unknown schedule-free optimizers to continue to module loader
rockerBOO Nov 20, 2024
87f5224
Support d*lr for ProdigyPlus optimizer
rockerBOO Nov 20, 2024
6593cfb
Fix d * lr step log
rockerBOO Nov 21, 2024
c7cadbc
Add pytest testing
rockerBOO Nov 29, 2024
2dd063a
add torch torchvision accelerate versions
rockerBOO Nov 29, 2024
e59e276
Add dadaptation
rockerBOO Nov 29, 2024
dd3b846
Install pytorch first to pin version
rockerBOO Nov 29, 2024
89825d6
Run typos workflows once where appropriate
rockerBOO Nov 29, 2024
4f7f248
Bump typos action
rockerBOO Nov 29, 2024
9c885e5
fix: improve pos_embed handling for oversized images and update resol…
kohya-ss Nov 30, 2024
7b61e9e
Fix issues found in review (pt 2)
recris Nov 30, 2024
a5a27fe
Merge pull request #1808 from recris/huber-loss-flux
kohya-ss Dec 1, 2024
14f642f
fix: huber_schedule exponential not working on sd3_train.py
kohya-ss Dec 1, 2024
0fe6320
fix flux_train.py is not working
kohya-ss Dec 1, 2024
cc11989
fix: refactor huber-loss calculation in multiple training scripts
kohya-ss Dec 1, 2024
1476040
fix: update help text for huber loss parameters in train_util.py
kohya-ss Dec 1, 2024
bdf9a8c
Merge pull request #1815 from kohya-ss/flux-huber-loss
kohya-ss Dec 1, 2024
34e7f50
docs: update README for huber loss
kohya-ss Dec 1, 2024
14c9ba9
Merge pull request #1811 from rockerBOO/schedule-free-prodigy
kohya-ss Dec 1, 2024
1dc873d
update README and clean up code for schedulefree optimizer
kohya-ss Dec 1, 2024
e3fd6c5
Merge pull request #1812 from rockerBOO/tests
kohya-ss Dec 2, 2024
09a3740
Merge pull request #1813 from minux302/flux-controlnet
kohya-ss Dec 2, 2024
e369b9a
docs: update README with FLUX.1 ControlNet training details and impro…
kohya-ss Dec 2, 2024
5ab00f9
Update workflow tests with cleanup and documentation
rockerBOO Dec 2, 2024
63738ec
Add tests documentation
rockerBOO Dec 2, 2024
2610e96
Pytest
rockerBOO Dec 2, 2024
3e5d89c
Add more resources
rockerBOO Dec 2, 2024
8b36d90
feat: support block_to_swap for FLUX.1 ControlNet training
kohya-ss Dec 2, 2024
6bee18d
fix: resolve model corruption issue with pos_embed when using --enabl…
kohya-ss Dec 7, 2024
2be3366
Merge pull request #1817 from rockerBOO/workflow-tests-fixes
kohya-ss Dec 7, 2024
abff4b0
Unify controlnet parameters name and change scripts name. (#1821)
sdbds Dec 7, 2024
e425996
feat: unify ControlNet model name option and deprecate old training s…
kohya-ss Dec 7, 2024
3cb8cb2
Prevent git credentials from leaking into other actions
rockerBOO Dec 9, 2024
8e378cf
add RAdamScheduleFree support
nhamanasu Dec 11, 2024
d3305f9
Merge pull request #1828 from rockerBOO/workflow-security-audit
kohya-ss Dec 15, 2024
f2d38e6
Merge pull request #1830 from nhamanasu/sd3
kohya-ss Dec 15, 2024
e896539
update requirements.txt and README to include RAdamScheduleFree optim…
kohya-ss Dec 15, 2024
05bb918
Add Validation loss for LoRA training
hinablue Dec 27, 2024
62164e5
Change val loss calculate method
hinablue Dec 27, 2024
64bd531
Split val latents/batch and pick up val latents shape size which equa…
hinablue Dec 28, 2024
cb89e02
Change val latent loss compare
hinablue Dec 28, 2024
8743532
val
gesen2egee Mar 9, 2024
449c1c5
Adding modified train_util and config_util
rockerBOO Jan 2, 2025
7f6e124
Merge branch 'gesen2egee/val' into validation-loss-upstream
rockerBOO Jan 3, 2025
d23c732
Merge remote-tracking branch 'hina/feature/val-loss' into validation-…
rockerBOO Jan 3, 2025
7470173
Remove defunct code for train_controlnet.py
rockerBOO Jan 3, 2025
534059d
Typos and lingering is_train
rockerBOO Jan 3, 2025
c8c3569
Cleanup order, types, print to logger
rockerBOO Jan 3, 2025
fbfc275
Update text for train/reg with repeats
rockerBOO Jan 3, 2025
58bfa36
Add seed help clarifying info
rockerBOO Jan 3, 2025
6604b36
Remove duplicate assignment
rockerBOO Jan 3, 2025
0522070
Fix training, validation split, revert to using upstream implemenation
rockerBOO Jan 3, 2025
695f389
Move get_huber_threshold_if_needed
rockerBOO Jan 3, 2025
1f9ba40
Add step break for validation epoch. Remove unused variable
rockerBOO Jan 3, 2025
1c0ae30
Add missing functions for training batch
rockerBOO Jan 3, 2025
bbf6bbd
Use self.get_noise_pred_and_target and drop fixed timesteps
rockerBOO Jan 6, 2025
f4840ef
Revert train_db.py
rockerBOO Jan 6, 2025
1c63e7c
Cleanup unused code and formatting
rockerBOO Jan 6, 2025
c64d1a2
Add validate_every_n_epochs, change name validate_every_n_steps
rockerBOO Jan 6, 2025
f885029
Fix validate epoch, cleanup imports
rockerBOO Jan 6, 2025
fcb2ff0
Clean up some validation help documentation
rockerBOO Jan 6, 2025
742bee9
Set validation steps in multiple lines for readability
rockerBOO Jan 6, 2025
1231f51
Remove unused train_util code, fix accelerate.log for wandb, add init…
rockerBOO Jan 8, 2025
556f3f1
Fix documentation, remove unused function, fix bucket reso for sd1.5,…
rockerBOO Jan 8, 2025
9fde0d7
Handle tuple return from generate_dataset_group_by_blueprint
rockerBOO Jan 8, 2025
1e61392
Revert bucket_reso_steps to correct 64
rockerBOO Jan 8, 2025
d6f158d
Fix incorrect destructoring for load_abritrary_dataset
rockerBOO Jan 8, 2025
264167f
Apply is_training_dataset only to DreamBoothDataset. Add validation_s…
rockerBOO Jan 9, 2025
4c61adc
Add divergence to logs
rockerBOO Jan 12, 2025
2bbb40c
Fix regularization images with validation
rockerBOO Jan 12, 2025
0456858
Fix validate_every_n_steps always running first step
rockerBOO Jan 12, 2025
ee9265c
Fix validate_every_n_steps for gradient accumulation
rockerBOO Jan 12, 2025
25929dd
Remove Validating... print to fix output layout
rockerBOO Jan 12, 2025
b489082
Disable repeats for validation datasets
rockerBOO Jan 12, 2025
c04e5df
Fix loss recorder on 0. Fix validation for cached runs. Assert on val…
rockerBOO Jan 23, 2025
6acdbed
Merge branch 'dev' into sd3
kohya-ss Jan 26, 2025
23ce75c
Merge branch 'dev' into sd3
kohya-ss Jan 26, 2025
b833d47
Merge pull request #1864 from rockerBOO/validation-loss-upstream
kohya-ss Jan 26, 2025
58b82a5
Fix to work with validation dataset
kohya-ss Jan 26, 2025
e852961
README.md: Update recent updates section to include validation loss s…
kohya-ss Jan 26, 2025
f1ac81e
Merge pull request #1899 from kohya-ss/val-loss
kohya-ss Jan 26, 2025
59b3b94
README.md: Update limitation for validation loss support to include s…
kohya-ss Jan 26, 2025
532f5c5
formatting
kohya-ss Jan 27, 2025
86a2f3f
Fix gradient handling when Text Encoders are trained
kohya-ss Jan 27, 2025
b6a3093
call optimizer eval/train fn before/after validation
kohya-ss Jan 27, 2025
29f31d0
add network.train()/eval() for validation
kohya-ss Jan 27, 2025
0750859
validation: Implement timestep-based validation processing
kohya-ss Jan 27, 2025
0778dd9
fix Text Encoder only LoRA training
kohya-ss Jan 27, 2025
42c0a9e
Merge branch 'sd3' into val-loss-improvement
kohya-ss Jan 27, 2025
45ec02b
use same noise for every validation
kohya-ss Jan 27, 2025
de830b8
Move progress bar to account for sampling image first
rockerBOO Jan 29, 2025
c5b803c
rng state management: Implement functions to get and set RNG states f…
kohya-ss Feb 4, 2025
a24db1d
fix: validation timestep generation fails on SD/SDXL training
kohya-ss Feb 4, 2025
0911683
set python random state
kohya-ss Feb 9, 2025
344845b
fix: validation with block swap
kohya-ss Feb 9, 2025
1772038
fix: unpause training progress bar after vaidation
kohya-ss Feb 11, 2025
cd80752
fix: remove unused parameter 'accelerator' from encode_images_to_late…
kohya-ss Feb 11, 2025
76b7619
fix: simplify validation step condition in NetworkTrainer
kohya-ss Feb 11, 2025
ab88b43
Fix validation epoch divergence
rockerBOO Feb 14, 2025
ee295c7
Merge pull request #1935 from rockerBOO/validation-epoch-fix
kohya-ss Feb 15, 2025
63337d9
Merge branch 'sd3' into val-loss-improvement
kohya-ss Feb 15, 2025
4671e23
Fix validation epoch loss to check epoch average
rockerBOO Feb 16, 2025
3c7496a
Fix sizes for validation split
rockerBOO Feb 17, 2025
f3a0109
Clear sizes for validation reg images to be consistent
rockerBOO Feb 17, 2025
6051fa8
Merge pull request #1940 from rockerBOO/split-size-fix
kohya-ss Feb 17, 2025
7c22e12
Merge pull request #1938 from rockerBOO/validation-epoch-loss-recorder
kohya-ss Feb 17, 2025
9436b41
Fix validation split and add test
rockerBOO Feb 17, 2025
894037f
Merge pull request #1943 from rockerBOO/validation-split-test
kohya-ss Feb 18, 2025
dc7d5fb
Merge branch 'sd3' into val-loss-improvement
kohya-ss Feb 18, 2025
4a36996
modify log step calculation
kohya-ss Feb 18, 2025
efb2a12
fix wandb val logging
kohya-ss Feb 21, 2025
905f081
Merge branch 'dev' into sd3
kohya-ss Feb 24, 2025
67fde01
Merge branch 'dev' into sd3
kohya-ss Feb 24, 2025
6e90c0f
Merge pull request #1909 from rockerBOO/progress_bar
kohya-ss Feb 24, 2025
ae409e8
fix: FLUX/SD3 network training not working without caching latents cl…
kohya-ss Feb 26, 2025
1fcac98
Merge branch 'sd3' into val-loss-improvement
kohya-ss Feb 26, 2025
4965189
Merge pull request #1903 from kohya-ss/val-loss-improvement
kohya-ss Feb 26, 2025
ec350c8
Merge branch 'dev' into sd3
kohya-ss Feb 26, 2025
3d79239
docs: update README to include recent improvements in validation loss…
kohya-ss Feb 26, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 129 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,134 @@
This repository contains training, generation and utility scripts for Stable Diffusion.

## FLUX.1 LoRA training (WIP)

This feature is experimental. The options and the training script may change in the future. Please let us know if you have any idea to improve the training.

__Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchvision==0.19.0` with CUDA 12.4. We also updated `accelerate` to 0.33.0 just to be safe. `requirements.txt` is also updated, so please update the requirements.__

The command to install PyTorch is as follows:
`pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124`

Aug 16, 2024:

FLUX.1 schnell model based training is now supported (but not tested). If the name of the model file contains `schnell`, the model is treated as a schnell model.

Added `--t5xxl_max_token_length` option to specify the maximum token length of T5XXL. The default is 512 in dev and 256 in schnell.

Previously, when `--max_token_length` was specified, that value was used, and 512 was used when omitted (default). Therefore, there is no impact if `--max_token_length` was not specified. If `--max_token_length` was specified, please specify `--t5xxl_max_token_length` instead. `--max_token_length` is ignored during FLUX.1 training.

Aug 14, 2024: Sample image generation during training is now supported. Specify options such as `--sample_prompts` and `--sample_every_n_epochs`. It will be very slow when `--split_mode` is specified.

Aug 13, 2024:

__Experimental__ A network argument `train_blocks` is added to `lora_flux`. This is to select the target blocks of LoRA from FLUX double blocks and single blocks. Specify like `--network_args "train_blocks=single"`. `all` trains both double blocks and single blocks, `double` trains only double blocks, and `single` trains only single blocks. The default (omission) is `all`.

This argument is available even if `--split_mode` is not specified.

__Experimental__ `--split_mode` option is added to `flux_train_network.py`. This splits FLUX into double blocks and single blocks for training. By enabling gradients only for the single blocks part, memory usage is reduced. When this option is specified, you need to specify `"train_blocks=single"` in the network arguments.

This option enables training with 12GB VRAM GPUs, but the training speed is 2-3 times slower than the default.

Aug 11, 2024: Fix `--apply_t5_attn_mask` option to work. Please remove and re-generate the latents cache file if you have used the option before.

Aug 10, 2024: LoRA key prefix is changed to `lora_unet` from `lora_flex` to make it compatible with ComfyUI.

We have added a new training script for LoRA training. The script is `flux_train_network.py`. See `--help` for options. Sample command is below, settings are based on [AI Toolkit by Ostris](https://github.com/ostris/ai-toolkit). It will work with 24GB VRAM GPUs.

```
accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_train_network.py --pretrained_model_name_or_path flux1-dev.sft --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors --ae ae.sft --cache_latents_to_disk --save_model_as safetensors --sdpa --persistent_data_loader_workers --max_data_loader_n_workers 2 --seed 42 --gradient_checkpointing --mixed_precision bf16 --save_precision bf16 --network_module networks.lora_flux --network_dim 4 --optimizer_type adamw8bit --learning_rate 1e-4 --network_train_unet_only --cache_text_encoder_outputs --cache_text_encoder_outputs_to_disk --fp8_base --highvram --max_train_epochs 4 --save_every_n_epochs 1 --dataset_config dataset_1024_bs2.toml --output_dir path/to/output/dir --output_name flux-lora-name --timestep_sampling sigmoid --model_prediction_type raw --guidance_scale 1.0 --loss_type l2
```

The training can be done with 16GB VRAM GPUs with Adafactor optimizer. Please use settings like below:

```
--optimizer_type adafactor --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False"
```

The training can be done with 12GB VRAM GPUs with Adafactor optimizer, `--split_mode` and `train_blocks=single` options. Please use settings like below:

```
--optimizer_type adafactor --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False" --split_mode --network_args "train_blocks=single"
```

LoRAs for Text Encoders are not tested yet.

We have added some new options (Aug 10, 2024): `--time_sampling`, `--sigmoid_scale`, `--model_prediction_type` and `--discrete_flow_shift`. The options are as follows:

- `--timestep_sampling` is the method to sample timesteps (0-1): `sigma` (sigma-based, same as SD3), `uniform` (uniform random), or `sigmoid` (sigmoid of random normal, same as x-flux).
- `--sigmoid_scale` is the scale factor for sigmoid timestep sampling (only used when timestep-sampling is "sigmoid"). The default is 1.0. Larger values will make the sampling more uniform.
- `--model_prediction_type` is how to interpret and process the model prediction: `raw` (use as is, same as x-flux), `additive` (add to noisy input), `sigma_scaled` (apply sigma scaling, same as SD3).
- `--discrete_flow_shift` is the discrete flow shift for the Euler Discrete Scheduler, default is 3.0 (same as SD3).

`--loss_type` may be useful for FLUX.1 training. The default is `l2`.

In our experiments, `--timestep_sampling sigma --model_prediction_type raw --discrete_flow_shift 1.0` with `--loss_type l2` seems to work better than the default (SD3) settings. The multiplier of LoRA should be adjusted.

additional note (Aug 11): A quick check shows that the settings in [AI Toolkit by Ostris](https://github.com/ostris/ai-toolkit) seems to be equivalent to `--timestep_sampling sigmoid --model_prediction_type raw --guidance_scale 1.0` (with the default `l2` loss_type). This seems to be a good starting point. Thanks to Ostris for the great work!

Other settings may work better, so please try different settings.

We also not sure how many epochs are needed for convergence, and how the learning rate should be adjusted.

The trained LoRA model can be used with ComfyUI.

The inference script is also available. The script is `flux_minimal_inference.py`. See `--help` for options.

Aug 12: `--interactive` option is now working.

```
python flux_minimal_inference.py --ckpt flux1-dev.sft --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors --ae ae.sft --dtype bf16 --prompt "a cat holding a sign that says hello world" --out path/to/output/dir --seed 1 --flux_dtype fp8 --offload --lora lora-flux-name.safetensors;1.0
```

## SD3 training

SD3 training is done with `sd3_train.py`.

__Jul 27, 2024__:
- Latents and text encoder outputs caching mechanism is refactored significantly.
- Existing cache files for SD3 need to be recreated. Please delete the previous cache files.
- With this change, dataset initialization is significantly faster, especially for large datasets.

- Architecture-dependent parts are extracted from the dataset (`train_util.py`). This is expected to make it easier to add future architectures.

- Architecture-dependent parts including the cache mechanism for SD1/2/SDXL are also extracted. The basic operation of SD1/2/SDXL training on the sd3 branch has been confirmed, but there may be bugs. Please use the main or dev branch for SD1/2/SDXL training.

---

`fp16` and `bf16` are available for mixed precision training. We are not sure which is better.

`optimizer_type = "adafactor"` is recommended for 24GB VRAM GPUs. `cache_text_encoder_outputs_to_disk` and `cache_latents_to_disk` are necessary currently.

`clip_l`, `clip_g` and `t5xxl` can be specified if the checkpoint does not include them.

t5xxl works with `fp16` now.

There are `t5xxl_device` and `t5xxl_dtype` options for `t5xxl` device and dtype.

`text_encoder_batch_size` is added experimentally for caching faster.

```toml
learning_rate = 1e-6 # seems to depend on the batch size
optimizer_type = "adafactor"
optimizer_args = [ "scale_parameter=False", "relative_step=False", "warmup_init=False" ]
cache_text_encoder_outputs = true
cache_text_encoder_outputs_to_disk = true
vae_batch_size = 1
text_encoder_batch_size = 4
cache_latents = true
cache_latents_to_disk = true
```

__2024/7/27:__

Latents およびテキストエンコーダ出力のキャッシュの仕組みを大きくリファクタリングしました。SD3 用の既存のキャッシュファイルの再作成が必要になりますが、ご了承ください(以前のキャッシュファイルは削除してください)。これにより、特にデータセットの規模が大きい場合のデータセット初期化が大幅に高速化されます。

データセット (`train_util.py`) からアーキテクチャ依存の部分を切り出しました。これにより将来的なアーキテクチャ追加が容易になると期待しています。

SD1/2/SDXL のキャッシュ機構を含むアーキテクチャ依存の部分も切り出しました。sd3 ブランチの SD1/2/SDXL 学習について、基本的な動作は確認していますが、不具合があるかもしれません。SD1/2/SDXL の学習には main または dev ブランチをお使いください。

---

[__Change History__](#change-history) is moved to the bottom of the page.
更新履歴は[ページ末尾](#change-history)に移しました。

Expand Down
54 changes: 39 additions & 15 deletions fine_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from tqdm import tqdm

import torch
from library import deepspeed_utils
from library import deepspeed_utils, strategy_base
from library.device_utils import init_ipex, clean_memory_on_device

init_ipex()
Expand Down Expand Up @@ -39,6 +39,7 @@
scale_v_prediction_loss_like_noise_prediction,
apply_debiased_estimation,
)
import library.strategy_sd as strategy_sd


def train(args):
Expand All @@ -52,7 +53,15 @@ def train(args):
if args.seed is not None:
set_seed(args.seed) # 乱数系列を初期化する

tokenizer = train_util.load_tokenizer(args)
tokenize_strategy = strategy_sd.SdTokenizeStrategy(args.v2, args.max_token_length, args.tokenizer_cache_dir)
strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy)

# prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization.
if cache_latents:
latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(
False, args.cache_latents_to_disk, args.vae_batch_size, False
)
strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy)

# データセットを準備する
if args.dataset_class is None:
Expand Down Expand Up @@ -81,10 +90,10 @@ def train(args):
]
}

blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
blueprint = blueprint_generator.generate(user_config, args)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
else:
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer)
train_dataset_group = train_util.load_arbitrary_dataset(args)

current_epoch = Value("i", 0)
current_step = Value("i", 0)
Expand Down Expand Up @@ -165,8 +174,9 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
vae.to(accelerator.device, dtype=vae_dtype)
vae.requires_grad_(False)
vae.eval()
with torch.no_grad():
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)

train_dataset_group.new_cache_latents(vae, accelerator.is_main_process)

vae.to("cpu")
clean_memory_on_device(accelerator.device)

Expand All @@ -192,6 +202,9 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
else:
text_encoder.eval()

text_encoding_strategy = strategy_sd.SdTextEncodingStrategy(args.clip_skip)
strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy)

if not cache_latents:
vae.requires_grad_(False)
vae.eval()
Expand All @@ -214,7 +227,11 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
accelerator.print("prepare optimizer, data loader etc.")
_, _, optimizer = train_util.get_optimizer(args, trainable_params=trainable_params)

# dataloaderを準備する
# prepare dataloader
# strategies are set here because they cannot be referenced in another process. Copy them with the dataset
# some strategies can be None
train_dataset_group.set_current_strategies()

# DataLoaderのプロセス数:0 は persistent_workers が使えないので注意
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
train_dataloader = torch.utils.data.DataLoader(
Expand Down Expand Up @@ -317,7 +334,9 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
)

# For --sample_at_first
train_util.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
train_util.sample_images(
accelerator, args, 0, global_step, accelerator.device, vae, tokenize_strategy.tokenizer, text_encoder, unet
)

loss_recorder = train_util.LossRecorder()
for epoch in range(num_train_epochs):
Expand All @@ -342,19 +361,22 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
with torch.set_grad_enabled(args.train_text_encoder):
# Get the text embedding for conditioning
if args.weighted_captions:
# TODO move to strategy_sd.py
encoder_hidden_states = get_weighted_text_embeddings(
tokenizer,
tokenize_strategy.tokenizer,
text_encoder,
batch["captions"],
accelerator.device,
args.max_token_length // 75 if args.max_token_length else 1,
clip_skip=args.clip_skip,
)
else:
input_ids = batch["input_ids"].to(accelerator.device)
encoder_hidden_states = train_util.get_hidden_states(
args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype
)
input_ids = batch["input_ids_list"][0].to(accelerator.device)
encoder_hidden_states = text_encoding_strategy.encode_tokens(
tokenize_strategy, [text_encoder], [input_ids]
)[0]
if args.full_fp16:
encoder_hidden_states = encoder_hidden_states.to(weight_dtype)

# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
Expand Down Expand Up @@ -409,7 +431,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
global_step += 1

train_util.sample_images(
accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet
accelerator, args, None, global_step, accelerator.device, vae, tokenize_strategy.tokenizer, text_encoder, unet
)

# 指定ステップごとにモデルを保存
Expand Down Expand Up @@ -472,7 +494,9 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
vae,
)

train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
train_util.sample_images(
accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenize_strategy.tokenizer, text_encoder, unet
)

is_main_process = accelerator.is_main_process
if is_main_process:
Expand Down
Loading