Skip to content

Commit

Permalink
Merge pull request #1139 from kohya-ss/deep-speed
Browse files Browse the repository at this point in the history
Deep speed
  • Loading branch information
kohya-ss authored Mar 26, 2024
2 parents 9c4492b + a2b8531 commit ea05e3f
Show file tree
Hide file tree
Showing 12 changed files with 288 additions and 59 deletions.
40 changes: 30 additions & 10 deletions fine_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
from tqdm import tqdm

import torch
from library import deepspeed_utils
from library.device_utils import init_ipex, clean_memory_on_device

init_ipex()

from accelerate.utils import set_seed
Expand Down Expand Up @@ -42,6 +44,7 @@
def train(args):
train_util.verify_training_args(args)
train_util.prepare_dataset_args(args, True)
deepspeed_utils.prepare_deepspeed_args(args)
setup_logging(args, reset=True)

cache_latents = args.cache_latents
Expand Down Expand Up @@ -108,6 +111,7 @@ def train(args):

# mixed precisionに対応した型を用意しておき適宜castする
weight_dtype, save_dtype = train_util.prepare_dtype(args)
vae_dtype = torch.float32 if args.no_half_vae else weight_dtype

# モデルを読み込む
text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype, accelerator)
Expand Down Expand Up @@ -158,7 +162,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):

# 学習を準備する
if cache_latents:
vae.to(accelerator.device, dtype=weight_dtype)
vae.to(accelerator.device, dtype=vae_dtype)
vae.requires_grad_(False)
vae.eval()
with torch.no_grad():
Expand Down Expand Up @@ -191,7 +195,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
if not cache_latents:
vae.requires_grad_(False)
vae.eval()
vae.to(accelerator.device, dtype=weight_dtype)
vae.to(accelerator.device, dtype=vae_dtype)

for m in training_models:
m.requires_grad_(True)
Expand Down Expand Up @@ -246,13 +250,23 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
unet.to(weight_dtype)
text_encoder.to(weight_dtype)

# acceleratorがなんかよろしくやってくれるらしい
if args.train_text_encoder:
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
if args.deepspeed:
if args.train_text_encoder:
ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet, text_encoder=text_encoder)
else:
ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet)
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
ds_model, optimizer, train_dataloader, lr_scheduler
)
training_models = [ds_model]
else:
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
# acceleratorがなんかよろしくやってくれるらしい
if args.train_text_encoder:
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
)
else:
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)

# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
if args.full_fp16:
Expand Down Expand Up @@ -311,13 +325,13 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):

for step, batch in enumerate(train_dataloader):
current_step.value = global_step
with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく
with accelerator.accumulate(*training_models):
with torch.no_grad():
if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device) # .to(dtype=weight_dtype)
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
else:
# latentに変換
latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample().to(weight_dtype)
latents = latents * 0.18215
b_size = latents.shape[0]

Expand Down Expand Up @@ -477,6 +491,7 @@ def setup_parser() -> argparse.ArgumentParser:
train_util.add_sd_models_arguments(parser)
train_util.add_dataset_arguments(parser, False, True, True)
train_util.add_training_arguments(parser, False)
deepspeed_utils.add_deepspeed_arguments(parser)
train_util.add_sd_saving_arguments(parser)
train_util.add_optimizer_arguments(parser)
config_util.add_config_arguments(parser)
Expand All @@ -492,6 +507,11 @@ def setup_parser() -> argparse.ArgumentParser:
default=None,
help="learning rate for text encoder, default is same as unet / Text Encoderの学習率、デフォルトはunetと同じ",
)
parser.add_argument(
"--no_half_vae",
action="store_true",
help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う",
)

return parser

Expand Down
139 changes: 139 additions & 0 deletions library/deepspeed_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
import os
import argparse
import torch
from accelerate import DeepSpeedPlugin, Accelerator

from .utils import setup_logging

setup_logging()
import logging

logger = logging.getLogger(__name__)


def add_deepspeed_arguments(parser: argparse.ArgumentParser):
# DeepSpeed Arguments. https://huggingface.co/docs/accelerate/usage_guides/deepspeed
parser.add_argument("--deepspeed", action="store_true", help="enable deepspeed training")
parser.add_argument("--zero_stage", type=int, default=2, choices=[0, 1, 2, 3], help="Possible options are 0,1,2,3.")
parser.add_argument(
"--offload_optimizer_device",
type=str,
default=None,
choices=[None, "cpu", "nvme"],
help="Possible options are none|cpu|nvme. Only applicable with ZeRO Stages 2 and 3.",
)
parser.add_argument(
"--offload_optimizer_nvme_path",
type=str,
default=None,
help="Possible options are /nvme|/local_nvme. Only applicable with ZeRO Stage 3.",
)
parser.add_argument(
"--offload_param_device",
type=str,
default=None,
choices=[None, "cpu", "nvme"],
help="Possible options are none|cpu|nvme. Only applicable with ZeRO Stage 3.",
)
parser.add_argument(
"--offload_param_nvme_path",
type=str,
default=None,
help="Possible options are /nvme|/local_nvme. Only applicable with ZeRO Stage 3.",
)
parser.add_argument(
"--zero3_init_flag",
action="store_true",
help="Flag to indicate whether to enable `deepspeed.zero.Init` for constructing massive models."
"Only applicable with ZeRO Stage-3.",
)
parser.add_argument(
"--zero3_save_16bit_model",
action="store_true",
help="Flag to indicate whether to save 16-bit model. Only applicable with ZeRO Stage-3.",
)
parser.add_argument(
"--fp16_master_weights_and_gradients",
action="store_true",
help="fp16_master_and_gradients requires optimizer to support keeping fp16 master and gradients while keeping the optimizer states in fp32.",
)


def prepare_deepspeed_args(args: argparse.Namespace):
if not args.deepspeed:
return

# To avoid RuntimeError: DataLoader worker exited unexpectedly with exit code 1.
args.max_data_loader_n_workers = 1


def prepare_deepspeed_plugin(args: argparse.Namespace):
if not args.deepspeed:
return None

try:
import deepspeed
except ImportError as e:
logger.error(
"deepspeed is not installed. please install deepspeed in your environment with following command. DS_BUILD_OPS=0 pip install deepspeed"
)
exit(1)

deepspeed_plugin = DeepSpeedPlugin(
zero_stage=args.zero_stage,
gradient_accumulation_steps=args.gradient_accumulation_steps,
gradient_clipping=args.max_grad_norm,
offload_optimizer_device=args.offload_optimizer_device,
offload_optimizer_nvme_path=args.offload_optimizer_nvme_path,
offload_param_device=args.offload_param_device,
offload_param_nvme_path=args.offload_param_nvme_path,
zero3_init_flag=args.zero3_init_flag,
zero3_save_16bit_model=args.zero3_save_16bit_model,
)
deepspeed_plugin.deepspeed_config["train_micro_batch_size_per_gpu"] = args.train_batch_size
deepspeed_plugin.deepspeed_config["train_batch_size"] = (
args.train_batch_size * args.gradient_accumulation_steps * int(os.environ["WORLD_SIZE"])
)
deepspeed_plugin.set_mixed_precision(args.mixed_precision)
if args.mixed_precision.lower() == "fp16":
deepspeed_plugin.deepspeed_config["fp16"]["initial_scale_power"] = 0 # preventing overflow.
if args.full_fp16 or args.fp16_master_weights_and_gradients:
if args.offload_optimizer_device == "cpu" and args.zero_stage == 2:
deepspeed_plugin.deepspeed_config["fp16"]["fp16_master_weights_and_grads"] = True
logger.info("[DeepSpeed] full fp16 enable.")
else:
logger.info(
"[DeepSpeed]full fp16, fp16_master_weights_and_grads currently only supported using ZeRO-Offload with DeepSpeedCPUAdam on ZeRO-2 stage."
)

if args.offload_optimizer_device is not None:
logger.info("[DeepSpeed] start to manually build cpu_adam.")
deepspeed.ops.op_builder.CPUAdamBuilder().load()
logger.info("[DeepSpeed] building cpu_adam done.")

return deepspeed_plugin


# Accelerate library does not support multiple models for deepspeed. So, we need to wrap multiple models into a single model.
def prepare_deepspeed_model(args: argparse.Namespace, **models):
# remove None from models
models = {k: v for k, v in models.items() if v is not None}

class DeepSpeedWrapper(torch.nn.Module):
def __init__(self, **kw_models) -> None:
super().__init__()
self.models = torch.nn.ModuleDict()

for key, model in kw_models.items():
if isinstance(model, list):
model = torch.nn.ModuleList(model)
assert isinstance(
model, torch.nn.Module
), f"model must be an instance of torch.nn.Module, but got {key} is {type(model)}"
self.models.update(torch.nn.ModuleDict({key: model}))

def get_models(self):
return self.models

ds_model = DeepSpeedWrapper(**models)
return ds_model
1 change: 0 additions & 1 deletion library/sdxl_train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@


def load_target_model(args, accelerator, model_version: str, weight_dtype):
# load models for each process
model_dtype = match_mixed_precision(args, weight_dtype) # prepare fp16/bf16
for pi in range(accelerator.state.num_processes):
if pi == accelerator.state.local_process_index:
Expand Down
11 changes: 8 additions & 3 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
import library.model_util as model_util
import library.huggingface_util as huggingface_util
import library.sai_model_spec as sai_model_spec
import library.deepspeed_utils as deepspeed_utils
from library.utils import setup_logging

setup_logging()
Expand Down Expand Up @@ -4095,6 +4096,10 @@ def load_tokenizer(args: argparse.Namespace):


def prepare_accelerator(args: argparse.Namespace):
"""
this function also prepares deepspeed plugin
"""

if args.logging_dir is None:
logging_dir = None
else:
Expand Down Expand Up @@ -4140,13 +4145,16 @@ def prepare_accelerator(args: argparse.Namespace):
),
)
kwargs_handlers = list(filter(lambda x: x is not None, kwargs_handlers))
deepspeed_plugin = deepspeed_utils.prepare_deepspeed_plugin(args)

accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision,
log_with=log_with,
project_dir=logging_dir,
kwargs_handlers=kwargs_handlers,
dynamo_backend=dynamo_backend,
deepspeed_plugin=deepspeed_plugin,
)
print("accelerator device:", accelerator.device)
return accelerator
Expand Down Expand Up @@ -4217,7 +4225,6 @@ def _load_target_model(args: argparse.Namespace, weight_dtype, device="cpu", une


def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projection_in_v2=False):
# load models for each process
for pi in range(accelerator.state.num_processes):
if pi == accelerator.state.local_process_index:
logger.info(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}")
Expand All @@ -4228,7 +4235,6 @@ def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projectio
accelerator.device if args.lowram else "cpu",
unet_use_linear_projection_in_v2=unet_use_linear_projection_in_v2,
)

# work on low-ram device
if args.lowram:
text_encoder.to(accelerator.device)
Expand All @@ -4237,7 +4243,6 @@ def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projectio

clean_memory_on_device(accelerator.device)
accelerator.wait_for_everyone()

return text_encoder, vae, unet, load_stable_diffusion_format


Expand Down
38 changes: 29 additions & 9 deletions sdxl_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@

import torch
from library.device_utils import init_ipex, clean_memory_on_device

init_ipex()

from accelerate.utils import set_seed
from diffusers import DDPMScheduler
from library import sdxl_model_util
from library import deepspeed_utils, sdxl_model_util

import library.train_util as train_util

Expand Down Expand Up @@ -97,6 +98,7 @@ def train(args):
train_util.verify_training_args(args)
train_util.prepare_dataset_args(args, True)
sdxl_train_util.verify_sdxl_training_args(args)
deepspeed_utils.prepare_deepspeed_args(args)
setup_logging(args, reset=True)

assert (
Expand Down Expand Up @@ -398,18 +400,33 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
text_encoder1.to(weight_dtype)
text_encoder2.to(weight_dtype)

# acceleratorがなんかよろしくやってくれるらしい
if train_unet:
unet = accelerator.prepare(unet)
# freeze last layer and final_layer_norm in te1 since we use the output of the penultimate layer
if train_text_encoder1:
# freeze last layer and final_layer_norm in te1 since we use the output of the penultimate layer
text_encoder1.text_model.encoder.layers[-1].requires_grad_(False)
text_encoder1.text_model.final_layer_norm.requires_grad_(False)
text_encoder1 = accelerator.prepare(text_encoder1)
if train_text_encoder2:
text_encoder2 = accelerator.prepare(text_encoder2)

optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)
if args.deepspeed:
ds_model = deepspeed_utils.prepare_deepspeed_model(
args,
unet=unet if train_unet else None,
text_encoder1=text_encoder1 if train_text_encoder1 else None,
text_encoder2=text_encoder2 if train_text_encoder2 else None,
)
# most of ZeRO stage uses optimizer partitioning, so we have to prepare optimizer and ds_model at the same time. # pull/1139#issuecomment-1986790007
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
ds_model, optimizer, train_dataloader, lr_scheduler
)
training_models = [ds_model]

else:
# acceleratorがなんかよろしくやってくれるらしい
if train_unet:
unet = accelerator.prepare(unet)
if train_text_encoder1:
text_encoder1 = accelerator.prepare(text_encoder1)
if train_text_encoder2:
text_encoder2 = accelerator.prepare(text_encoder2)
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)

# TextEncoderの出力をキャッシュするときにはCPUへ移動する
if args.cache_text_encoder_outputs:
Expand All @@ -424,6 +441,8 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):

# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
if args.full_fp16:
# During deepseed training, accelerate not handles fp16/bf16|mixed precision directly via scaler. Let deepspeed engine do.
# -> But we think it's ok to patch accelerator even if deepspeed is enabled.
train_util.patch_accelerator_for_fp16_training(accelerator)

# resumeする
Expand Down Expand Up @@ -744,6 +763,7 @@ def setup_parser() -> argparse.ArgumentParser:
train_util.add_sd_models_arguments(parser)
train_util.add_dataset_arguments(parser, True, True, True)
train_util.add_training_arguments(parser, False)
deepspeed_utils.add_deepspeed_arguments(parser)
train_util.add_sd_saving_arguments(parser)
train_util.add_optimizer_arguments(parser)
config_util.add_config_arguments(parser)
Expand Down
Loading

0 comments on commit ea05e3f

Please sign in to comment.