Skip to content

Commit

Permalink
Dev (#10)
Browse files Browse the repository at this point in the history
* Final implementation

* Skip the final 1 step

* fix alpha mask without disk cache closes kohya-ss#1351, ref kohya-ss#1339

* update for corner cases

* Bump crate-ci/typos from 1.19.0 to 1.21.0, fix typos, and updated _typos.toml (Close kohya-ss#1307)

* set static graph flag when DDP ref kohya-ss#1363

* make forward/backward pathes same ref kohya-ss#1363

* update README

* add grad_hook after restore state closes kohya-ss#1344

* fix to work cache_latents/text_encoder_outputs

* show file name if error in load_image ref kohya-ss#1385

---------

Co-authored-by: Kohya S <ykumeykume@gmail.com>
Co-authored-by: Kohya S <52813779+kohya-ss@users.noreply.github.com>
Co-authored-by: Yuta Hayashibe <yuta@hayashibe.jp>
  • Loading branch information
4 people authored Jul 31, 2024
1 parent 7ffc83a commit 645c974
Show file tree
Hide file tree
Showing 11 changed files with 196 additions and 53 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/typos.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,4 @@ jobs:
- uses: actions/checkout@v4

- name: typos-action
uses: crate-ci/typos@v1.19.0
uses: crate-ci/typos@v1.21.0
18 changes: 18 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,12 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser

- The ControlNet training script `train_controlnet.py` for SD1.5/2.x was not working, but it has been fixed. PR [#1284](https://github.com/kohya-ss/sd-scripts/pull/1284) Thanks to sdbds!

- `train_network.py` and `sdxl_train_network.py` now restore the order/position of data loading from DataSet when resuming training. PR [#1353](https://github.com/kohya-ss/sd-scripts/pull/1353) [#1359](https://github.com/kohya-ss/sd-scripts/pull/1359) Thanks to KohakuBlueleaf!
- This resolves the issue where the order of data loading from DataSet changes when resuming training.
- Specify the `--skip_until_initial_step` option to skip data loading until the specified step. If not specified, data loading starts from the beginning of the DataSet (same as before).
- If `--resume` is specified, the step saved in the state is used.
- Specify the `--initial_step` or `--initial_epoch` option to skip data loading until the specified step or epoch. Use these options in conjunction with `--skip_until_initial_step`. These options can be used without `--resume` (use them when resuming training with `--network_weights`).

- An option `--disable_mmap_load_safetensors` is added to disable memory mapping when loading the model's .safetensors in SDXL. PR [#1266](https://github.com/kohya-ss/sd-scripts/pull/1266) Thanks to Zovjsra!
- It seems that the model file loading is faster in the WSL environment etc.
- Available in `sdxl_train.py`, `sdxl_train_network.py`, `sdxl_train_textual_inversion.py`, and `sdxl_train_control_net_lllite.py`.
Expand Down Expand Up @@ -235,6 +241,12 @@ https://github.com/kohya-ss/sd-scripts/pull/1290) Thanks to frodo821!

- SD1.5/2.x 用の ControlNet 学習スクリプト `train_controlnet.py` が動作しなくなっていたのが修正されました。PR [#1284](https://github.com/kohya-ss/sd-scripts/pull/1284) sdbds 氏に感謝します。

- `train_network.py` および `sdxl_train_network.py` で、学習再開時に DataSet の読み込み順についても復元できるようになりました。PR [#1353](https://github.com/kohya-ss/sd-scripts/pull/1353) [#1359](https://github.com/kohya-ss/sd-scripts/pull/1359) KohakuBlueleaf 氏に感謝します。
- これにより、学習再開時に DataSet の読み込み順が変わってしまう問題が解消されます。
- `--skip_until_initial_step` オプションを指定すると、指定したステップまで DataSet 読み込みをスキップします。指定しない場合の動作は変わりません(DataSet の最初から読み込みます)
- `--resume` オプションを指定すると、state に保存されたステップ数が使用されます。
- `--initial_step` または `--initial_epoch` オプションを指定すると、指定したステップまたはエポックまで DataSet 読み込みをスキップします。これらのオプションは `--skip_until_initial_step` と併用してください。またこれらのオプションは `--resume` と併用しなくても使えます(`--network_weights` を用いた学習再開時などにお使いください )。

- SDXL でモデルの .safetensors を読み込む際にメモリマッピングを無効化するオプション `--disable_mmap_load_safetensors` が追加されました。PR [#1266](https://github.com/kohya-ss/sd-scripts/pull/1266) Zovjsra 氏に感謝します。
- WSL 環境等でモデルファイルの読み込みが高速化されるようです。
- `sdxl_train.py``sdxl_train_network.py``sdxl_train_textual_inversion.py``sdxl_train_control_net_lllite.py` で使用可能です。
Expand All @@ -253,6 +265,12 @@ https://github.com/kohya-ss/sd-scripts/pull/1290) frodo821 氏に感謝します

- `gen_imgs.py` のプロンプトオプションに、保存時のファイル名を指定する `--f` オプションを追加しました。また同スクリプトで Diffusers ベースのキーを持つ LoRA の重みに対応しました。

### Jun 23, 2024 / 2024-06-23:

- Fixed `cache_latents.py` and `cache_text_encoder_outputs.py` not working. (Will be included in the next release.)

- `cache_latents.py` および `cache_text_encoder_outputs.py` が動作しなくなっていたのを修正しました。(次回リリースに含まれます。)

### Apr 7, 2024 / 2024-04-07: v0.8.7

- The default value of `huber_schedule` in Scheduled Huber Loss is changed from `exponential` to `snr`, which is expected to give better results.
Expand Down
2 changes: 2 additions & 0 deletions _typos.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Instruction: https://github.com/marketplace/actions/typos-action#getting-started

[default.extend-identifiers]
ddPn08="ddPn08"

[default.extend-words]
NIN="NIN"
Expand All @@ -27,6 +28,7 @@ rik="rik"
koo="koo"
yos="yos"
wn="wn"
hime="hime"


[files]
Expand Down
2 changes: 1 addition & 1 deletion library/ipex/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

# pylint: disable=protected-access, missing-function-docstring, line-too-long

# ARC GPUs can't allocate more than 4GB to a single block so we slice the attetion layers
# ARC GPUs can't allocate more than 4GB to a single block so we slice the attention layers

sdpa_slice_trigger_rate = float(os.environ.get('IPEX_SDPA_SLICE_TRIGGER_RATE', 4))
attention_slice_rate = float(os.environ.get('IPEX_ATTENTION_SLICE_RATE', 4))
Expand Down
48 changes: 32 additions & 16 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,8 +657,16 @@ def set_caching_mode(self, mode):

def set_current_epoch(self, epoch):
if not self.current_epoch == epoch: # epochが切り替わったらバケツをシャッフルする
self.shuffle_buckets()
self.current_epoch = epoch
if epoch > self.current_epoch:
logger.info("epoch is incremented. current_epoch: {}, epoch: {}".format(self.current_epoch, epoch))
num_epochs = epoch - self.current_epoch
for _ in range(num_epochs):
self.current_epoch += 1
self.shuffle_buckets()
# self.current_epoch seem to be set to 0 again in the next epoch. it may be caused by skipped_dataloader?
else:
logger.warning("epoch is not incremented. current_epoch: {}, epoch: {}".format(self.current_epoch, epoch))
self.current_epoch = epoch

def set_current_step(self, step):
self.current_step = step
Expand Down Expand Up @@ -1265,7 +1273,8 @@ def __getitem__(self, index):
if subset.alpha_mask:
if img.shape[2] == 4:
alpha_mask = img[:, :, 3] # [H,W]
alpha_mask = transforms.ToTensor()(alpha_mask) # 0-255 -> 0-1
alpha_mask = alpha_mask.astype(np.float32) / 255.0 # 0.0~1.0
alpha_mask = torch.FloatTensor(alpha_mask)
else:
alpha_mask = torch.ones((img.shape[0], img.shape[1]), dtype=torch.float32)
else:
Expand Down Expand Up @@ -2211,7 +2220,7 @@ def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool, alph
# 戻り値は、latents_tensor, (original_size width, original_size height), (crop left, crop top)
def load_latents_from_disk(
npz_path,
) -> Tuple[Optional[torch.Tensor], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
npz = np.load(npz_path)
if "latents" not in npz:
raise ValueError(f"error: npz is old format. please re-generate {npz_path}")
Expand All @@ -2229,7 +2238,7 @@ def save_latents_to_disk(npz_path, latents_tensor, original_size, crop_ltrb, fli
if flipped_latents_tensor is not None:
kwargs["latents_flipped"] = flipped_latents_tensor.float().cpu().numpy()
if alpha_mask is not None:
kwargs["alpha_mask"] = alpha_mask # ndarray
kwargs["alpha_mask"] = alpha_mask.float().cpu().numpy()
np.savez(
npz_path,
latents=latents_tensor.float().cpu().numpy(),
Expand Down Expand Up @@ -2425,16 +2434,20 @@ def load_arbitrary_dataset(args, tokenizer) -> MinimalDataset:
return train_dataset_group


def load_image(image_path, alpha=False):
image = Image.open(image_path)
if alpha:
if not image.mode == "RGBA":
image = image.convert("RGBA")
else:
if not image.mode == "RGB":
image = image.convert("RGB")
img = np.array(image, np.uint8)
return img
def load_image(image_path, alpha=False):
try:
with Image.open(image_path) as image:
if alpha:
if not image.mode == "RGBA":
image = image.convert("RGBA")
else:
if not image.mode == "RGB":
image = image.convert("RGB")
img = np.array(image, np.uint8)
return img
except (IOError, OSError) as e:
logger.error(f"Error loading file: {image_path}")
raise e


# 画像を読み込む。戻り値はnumpy.ndarray,(original width, original height),(crop left, crop top, crop right, crop bottom)
Expand Down Expand Up @@ -2496,8 +2509,9 @@ def cache_batch_latents(
if image.shape[2] == 4:
alpha_mask = image[:, :, 3] # [H,W]
alpha_mask = alpha_mask.astype(np.float32) / 255.0
alpha_mask = torch.FloatTensor(alpha_mask) # [H,W]
else:
alpha_mask = np.ones_like(image[:, :, 0], dtype=np.float32)
alpha_mask = torch.ones_like(image[:, :, 0], dtype=torch.float32) # [H,W]
else:
alpha_mask = None
alpha_masks.append(alpha_mask)
Expand Down Expand Up @@ -5554,6 +5568,8 @@ def add(self, *, epoch: int, step: int, loss: float) -> None:
if epoch == 0:
self.loss_list.append(loss)
else:
while len(self.loss_list) <= step:
self.loss_list.append(0.0)
self.loss_total -= self.loss_list[step]
self.loss_list[step] = loss
self.loss_total += loss
Expand Down
12 changes: 4 additions & 8 deletions networks/control_net_lllite_for_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@
import torch
from library import sdxl_original_unet
from library.utils import setup_logging

setup_logging()
import logging

logger = logging.getLogger(__name__)

# input_blocksに適用するかどうか / if True, input_blocks are not applied
Expand Down Expand Up @@ -103,19 +105,15 @@ def set_lllite(self, depth, cond_emb_dim, name, mlp_dim, dropout=None, multiplie
add_lllite_modules(self, in_dim, depth, cond_emb_dim, mlp_dim)

self.cond_image = None
self.cond_emb = None

def set_cond_image(self, cond_image):
self.cond_image = cond_image
self.cond_emb = None

def forward(self, x):
if not self.enabled:
return super().forward(x)

if self.cond_emb is None:
self.cond_emb = self.lllite_conditioning1(self.cond_image)
cx = self.cond_emb
cx = self.lllite_conditioning1(self.cond_image) # make forward and backward compatible

# reshape / b,c,h,w -> b,h*w,c
n, c, h, w = cx.shape
Expand Down Expand Up @@ -159,9 +157,7 @@ def forward(self, x): # , cond_image=None):
if not self.enabled:
return super().forward(x)

if self.cond_emb is None:
self.cond_emb = self.lllite_conditioning1(self.cond_image)
cx = self.cond_emb
cx = self.lllite_conditioning1(self.cond_image)

cx = torch.cat([cx, self.down(x)], dim=1)
cx = self.mid(cx)
Expand Down
46 changes: 25 additions & 21 deletions sdxl_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,26 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
text_encoder2 = accelerator.prepare(text_encoder2)
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)

# TextEncoderの出力をキャッシュするときにはCPUへ移動する
if args.cache_text_encoder_outputs:
# move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16
text_encoder1.to("cpu", dtype=torch.float32)
text_encoder2.to("cpu", dtype=torch.float32)
clean_memory_on_device(accelerator.device)
else:
# make sure Text Encoders are on GPU
text_encoder1.to(accelerator.device)
text_encoder2.to(accelerator.device)

# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
if args.full_fp16:
# During deepseed training, accelerate not handles fp16/bf16|mixed precision directly via scaler. Let deepspeed engine do.
# -> But we think it's ok to patch accelerator even if deepspeed is enabled.
train_util.patch_accelerator_for_fp16_training(accelerator)

# resumeする
train_util.resume_from_local_or_hf_if_specified(accelerator, args)

if args.fused_backward_pass:
# use fused optimizer for backward pass: other optimizers will be supported in the future
import library.adafactor_fused
Expand Down Expand Up @@ -532,26 +552,6 @@ def optimizer_hook(parameter: torch.Tensor):
parameter_optimizer_map[parameter] = opt_idx
num_parameters_per_group[opt_idx] += 1

# TextEncoderの出力をキャッシュするときにはCPUへ移動する
if args.cache_text_encoder_outputs:
# move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16
text_encoder1.to("cpu", dtype=torch.float32)
text_encoder2.to("cpu", dtype=torch.float32)
clean_memory_on_device(accelerator.device)
else:
# make sure Text Encoders are on GPU
text_encoder1.to(accelerator.device)
text_encoder2.to(accelerator.device)

# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
if args.full_fp16:
# During deepseed training, accelerate not handles fp16/bf16|mixed precision directly via scaler. Let deepspeed engine do.
# -> But we think it's ok to patch accelerator even if deepspeed is enabled.
train_util.patch_accelerator_for_fp16_training(accelerator)

# resumeする
train_util.resume_from_local_or_hf_if_specified(accelerator, args)

# epoch数を計算する
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
Expand Down Expand Up @@ -589,7 +589,11 @@ def optimizer_hook(parameter: torch.Tensor):
init_kwargs["wandb"] = {"name": args.wandb_run_name}
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs)
accelerator.init_trackers(
"finetuning" if args.log_tracker_name is None else args.log_tracker_name,
config=train_util.get_sanitized_config_or_none(args),
init_kwargs=init_kwargs,
)

# For --sample_at_first
sdxl_train_util.sample_images(
Expand Down
3 changes: 3 additions & 0 deletions sdxl_train_control_net_lllite.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,9 @@ def train(args):
# acceleratorがなんかよろしくやってくれるらしい
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)

if isinstance(unet, DDP):
unet._set_static_graph() # avoid error for multiple use of the parameter

if args.gradient_checkpointing:
unet.train() # according to TI example in Diffusers, train is required -> これオリジナルのU-Netしたので本当は外せる
else:
Expand Down
6 changes: 4 additions & 2 deletions tools/cache_latents.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,15 @@
ConfigSanitizer,
BlueprintGenerator,
)
from library.utils import setup_logging

from library.utils import setup_logging, add_logging_arguments
setup_logging()
import logging

logger = logging.getLogger(__name__)


def cache_to_disk(args: argparse.Namespace) -> None:
setup_logging(args, reset=True)
train_util.prepare_dataset_args(args, True)

# check cache latents arg
Expand Down Expand Up @@ -97,6 +97,7 @@ def cache_to_disk(args: argparse.Namespace) -> None:

# acceleratorを準備する
logger.info("prepare accelerator")
args.deepspeed = False
accelerator = train_util.prepare_accelerator(args)

# mixed precisionに対応した型を用意しておき適宜castする
Expand Down Expand Up @@ -176,6 +177,7 @@ def cache_to_disk(args: argparse.Namespace) -> None:
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()

add_logging_arguments(parser)
train_util.add_sd_models_arguments(parser)
train_util.add_training_arguments(parser, True)
train_util.add_dataset_arguments(parser, True, True, True)
Expand Down
5 changes: 4 additions & 1 deletion tools/cache_text_encoder_outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,13 @@
ConfigSanitizer,
BlueprintGenerator,
)
from library.utils import setup_logging
from library.utils import setup_logging, add_logging_arguments
setup_logging()
import logging
logger = logging.getLogger(__name__)

def cache_to_disk(args: argparse.Namespace) -> None:
setup_logging(args, reset=True)
train_util.prepare_dataset_args(args, True)

# check cache arg
Expand Down Expand Up @@ -99,6 +100,7 @@ def cache_to_disk(args: argparse.Namespace) -> None:

# acceleratorを準備する
logger.info("prepare accelerator")
args.deepspeed = False
accelerator = train_util.prepare_accelerator(args)

# mixed precisionに対応した型を用意しておき適宜castする
Expand Down Expand Up @@ -171,6 +173,7 @@ def cache_to_disk(args: argparse.Namespace) -> None:
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()

add_logging_arguments(parser)
train_util.add_sd_models_arguments(parser)
train_util.add_training_arguments(parser, True)
train_util.add_dataset_arguments(parser, True, True, True)
Expand Down
Loading

0 comments on commit 645c974

Please sign in to comment.