Skip to content

Commit

Permalink
rollback pr kohya-ss#1000; unwrapp model for sample_images
Browse files Browse the repository at this point in the history
  • Loading branch information
ngitnenlim committed Dec 13, 2023
1 parent 20cb51c commit d0fff2c
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 7 deletions.
12 changes: 10 additions & 2 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
Tuple,
Union,
)
from accelerate import Accelerator, InitProcessGroupKwargs
from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs
import gc
import glob
import math
Expand Down Expand Up @@ -2899,6 +2899,12 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
default=None,
help="DDP timeout (min, None for default of accelerate) / DDPのタイムアウト(分、Noneでaccelerateのデフォルト)",
)
parser.add_argument(
"--gradient_as_bucket_view", action="store_true", help="enable gradient_as_bucket_view for DDP",
)
parser.add_argument(
"--static_graph", action="store_true", help="enable static_graph for DDP",
)
parser.add_argument(
"--clip_skip",
type=int,
Expand Down Expand Up @@ -3861,8 +3867,10 @@ def prepare_accelerator(args: argparse.Namespace):
wandb.login(key=args.wandb_api_key)

kwargs_handlers = (
None if args.ddp_timeout is None else [InitProcessGroupKwargs(timeout=datetime.timedelta(minutes=args.ddp_timeout))]
InitProcessGroupKwargs(timeout=datetime.timedelta(minutes=args.ddp_timeout)) if args.ddp_timeout else None,
DistributedDataParallelKwargs(gradient_as_bucket_view=args.gradient_as_bucket_view, static_graph=args.static_graph)
)
kwargs_handlers = list(filter(lambda x: x is not None, kwargs_handlers))
accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision,
Expand Down
15 changes: 10 additions & 5 deletions sdxl_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,9 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
if train_unet:
unet = accelerator.prepare(unet)
if train_text_encoder1:
# freeze last layer and final_layer_norm in te1 since we use the output of the penultimate layer
text_encoder1.text_model.encoder.layers[-1].requires_grad_(False)
text_encoder1.text_model.final_layer_norm.requires_grad_(False)
text_encoder1 = accelerator.prepare(text_encoder1)
if train_text_encoder2:
text_encoder2 = accelerator.prepare(text_encoder2)
Expand Down Expand Up @@ -460,7 +463,9 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):

# For --sample_at_first
sdxl_train_util.sample_images(
accelerator, args, 0, global_step, accelerator.device, vae, [tokenizer1, tokenizer2], [text_encoder1, text_encoder2], unet
accelerator, args, 0, global_step, accelerator.device, vae, [tokenizer1, tokenizer2],
[accelerator.unwrap_model(text_encoder1), accelerator.unwrap_model(text_encoder2)],
accelerator.unwrap_model(unet)
)

loss_recorder = train_util.LossRecorder()
Expand Down Expand Up @@ -608,8 +613,8 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
accelerator.device,
vae,
[tokenizer1, tokenizer2],
[text_encoder1, text_encoder2],
unet,
[accelerator.unwrap_model(text_encoder1), accelerator.unwrap_model(text_encoder2)],
accelerator.unwrap_model(unet),
)

# 指定ステップごとにモデルを保存
Expand Down Expand Up @@ -690,8 +695,8 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
accelerator.device,
vae,
[tokenizer1, tokenizer2],
[text_encoder1, text_encoder2],
unet,
[accelerator.unwrap_model(text_encoder1), accelerator.unwrap_model(text_encoder2)],
accelerator.unwrap_model(unet),
)

is_main_process = accelerator.is_main_process
Expand Down

0 comments on commit d0fff2c

Please sign in to comment.