From 7eae07b757227f6a4c41872ae79e7ad6c572aaa6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=92=8B=E7=A1=95?= Date: Fri, 15 Nov 2024 17:24:07 +0800 Subject: [PATCH] Fix issues for saving checkpointing steps --- examples/dreambooth/train_dreambooth_flux.py | 41 +++++++++---------- .../dreambooth/train_dreambooth_lora_flux.py | 17 ++++---- 2 files changed, 28 insertions(+), 30 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_flux.py b/examples/dreambooth/train_dreambooth_flux.py index 51d4768c76d64..fa4cb09c58311 100644 --- a/examples/dreambooth/train_dreambooth_flux.py +++ b/examples/dreambooth/train_dreambooth_flux.py @@ -1191,33 +1191,32 @@ def save_model_hook(models, weights, output_dir): weights.pop() def load_model_hook(models, input_dir): - if not accelerator.distributed_type == DistributedType.DEEPSPEED: - for _ in range(len(models)): - # pop models so that they are not loaded again - model = models.pop() - - # load diffusers style into model - if isinstance(unwrap_model(model), FluxTransformer2DModel): - load_model = FluxTransformer2DModel.from_pretrained(input_dir, subfolder="transformer") - model.register_to_config(**load_model.config) - + for _ in range(len(models)): + # pop models so that they are not loaded again + model = models.pop() + + # load diffusers style into model + if isinstance(unwrap_model(model), FluxTransformer2DModel): + load_model = FluxTransformer2DModel.from_pretrained(input_dir, subfolder="transformer") + model.register_to_config(**load_model.config) + + model.load_state_dict(load_model.state_dict()) + elif isinstance(unwrap_model(model), (CLIPTextModelWithProjection, T5EncoderModel)): + try: + load_model = CLIPTextModelWithProjection.from_pretrained(input_dir, subfolder="text_encoder") + model(**load_model.config) model.load_state_dict(load_model.state_dict()) - elif isinstance(unwrap_model(model), (CLIPTextModelWithProjection, T5EncoderModel)): + except Exception: try: - load_model = CLIPTextModelWithProjection.from_pretrained(input_dir, subfolder="text_encoder") + load_model = T5EncoderModel.from_pretrained(input_dir, subfolder="text_encoder_2") model(**load_model.config) model.load_state_dict(load_model.state_dict()) except Exception: - try: - load_model = T5EncoderModel.from_pretrained(input_dir, subfolder="text_encoder_2") - model(**load_model.config) - model.load_state_dict(load_model.state_dict()) - except Exception: - raise ValueError(f"Couldn't load the model of type: ({type(model)}).") - else: - raise ValueError(f"Unsupported model found: {type(model)=}") + raise ValueError(f"Couldn't load the model of type: ({type(model)}).") + else: + raise ValueError(f"Unsupported model found: {type(model)=}") - del load_model + del load_model accelerator.register_save_state_pre_hook(save_model_hook) accelerator.register_load_state_pre_hook(load_model_hook) diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index 3474001fec8ee..67b3bce8602de 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -1263,16 +1263,15 @@ def load_model_hook(models, input_dir): transformer_ = None text_encoder_one_ = None - if not accelerator.distributed_type == DistributedType.DEEPSPEED: - while len(models) > 0: - model = models.pop() + while len(models) > 0: + model = models.pop() - if isinstance(model, type(unwrap_model(transformer))): - transformer_ = model - elif isinstance(model, type(unwrap_model(text_encoder_one))): - text_encoder_one_ = model - else: - raise ValueError(f"unexpected save model: {model.__class__}") + if isinstance(model, type(unwrap_model(transformer))): + transformer_ = model + elif isinstance(model, type(unwrap_model(text_encoder_one))): + text_encoder_one_ = model + else: + raise ValueError(f"unexpected save model: {model.__class__}") lora_state_dict = FluxPipeline.lora_state_dict(input_dir)