From 64916a35b2378c4a8cdf3e9efeef8b8ab7ccb41c Mon Sep 17 00:00:00 2001 From: Zovjsra <4703michael@gmail.com> Date: Tue, 16 Apr 2024 16:40:08 +0800 Subject: [PATCH] add disable_mmap to args --- library/sdxl_model_util.py | 14 +++++++++----- library/sdxl_train_util.py | 9 +++++++-- 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/library/sdxl_model_util.py b/library/sdxl_model_util.py index f03f1bae5..e6fcb1f9c 100644 --- a/library/sdxl_model_util.py +++ b/library/sdxl_model_util.py @@ -1,4 +1,5 @@ import torch +import safetensors from accelerate import init_empty_weights from accelerate.utils.modeling import set_module_tensor_to_device from safetensors.torch import load_file, save_file @@ -163,17 +164,20 @@ def _load_state_dict_on_device(model, state_dict, device, dtype=None): raise RuntimeError("Error(s) in loading state_dict for {}:\n\t{}".format(model.__class__.__name__, "\n\t".join(error_msgs))) -def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dtype=None): +def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dtype=None, disable_mmap=False): # model_version is reserved for future use # dtype is used for full_fp16/bf16 integration. Text Encoder will remain fp32, because it runs on CPU when caching # Load the state dict if model_util.is_safetensors(ckpt_path): checkpoint = None - try: - state_dict = load_file(ckpt_path, device=map_location) - except: - state_dict = load_file(ckpt_path) # prevent device invalid Error + if(disable_mmap): + state_dict = safetensors.torch.load(open(ckpt_path, 'rb').read()) + else: + try: + state_dict = load_file(ckpt_path, device=map_location) + except: + state_dict = load_file(ckpt_path) # prevent device invalid Error epoch = None global_step = None else: diff --git a/library/sdxl_train_util.py b/library/sdxl_train_util.py index a29013e34..106c5b455 100644 --- a/library/sdxl_train_util.py +++ b/library/sdxl_train_util.py @@ -44,6 +44,7 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype): weight_dtype, accelerator.device if args.lowram else "cpu", model_dtype, + args.disable_mmap_load_safetensors ) # work on low-ram device @@ -60,7 +61,7 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype): def _load_target_model( - name_or_path: str, vae_path: Optional[str], model_version: str, weight_dtype, device="cpu", model_dtype=None + name_or_path: str, vae_path: Optional[str], model_version: str, weight_dtype, device="cpu", model_dtype=None, disable_mmap=False ): # model_dtype only work with full fp16/bf16 name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path @@ -75,7 +76,7 @@ def _load_target_model( unet, logit_scale, ckpt_info, - ) = sdxl_model_util.load_models_from_sdxl_checkpoint(model_version, name_or_path, device, model_dtype) + ) = sdxl_model_util.load_models_from_sdxl_checkpoint(model_version, name_or_path, device, model_dtype, disable_mmap) else: # Diffusers model is loaded to CPU from diffusers import StableDiffusionXLPipeline @@ -332,6 +333,10 @@ def add_sdxl_training_arguments(parser: argparse.ArgumentParser): action="store_true", help="cache text encoder outputs to disk / text encoderの出力をディスクにキャッシュする", ) + parser.add_argument( + "--disable_mmap_load_safetensors", + action="store_true", + ) def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCaching: bool = True):