From 5d07df8e1d85399131e19019b40806c8f9a0e5b9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=92=8B=E7=A1=95?= Date: Wed, 13 Nov 2024 11:46:02 +0800 Subject: [PATCH] Fix issues for saving checkpointing steps --- examples/dreambooth/train_dreambooth_lora.py | 3 ++- examples/dreambooth/train_dreambooth_lora_sd3.py | 3 ++- examples/dreambooth/train_dreambooth_lora_sdxl.py | 3 ++- examples/dreambooth/train_dreambooth_sd3.py | 3 ++- 4 files changed, 8 insertions(+), 4 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora.py b/examples/dreambooth/train_dreambooth_lora.py index 584f906d6b15e..b3877f7cdcdf0 100644 --- a/examples/dreambooth/train_dreambooth_lora.py +++ b/examples/dreambooth/train_dreambooth_lora.py @@ -955,7 +955,8 @@ def save_model_hook(models, weights, output_dir): raise ValueError(f"unexpected save model: {model.__class__}") # make sure to pop weight so that corresponding model is not saved again - weights.pop() + if weights: + weights.pop() StableDiffusionLoraLoaderMixin.save_lora_weights( output_dir, diff --git a/examples/dreambooth/train_dreambooth_lora_sd3.py b/examples/dreambooth/train_dreambooth_lora_sd3.py index cad677861cd32..f64ff331a2100 100644 --- a/examples/dreambooth/train_dreambooth_lora_sd3.py +++ b/examples/dreambooth/train_dreambooth_lora_sd3.py @@ -1302,7 +1302,8 @@ def save_model_hook(models, weights, output_dir): raise ValueError(f"unexpected save model: {model.__class__}") # make sure to pop weight so that corresponding model is not saved again - weights.pop() + if weights: + weights.pop() StableDiffusion3Pipeline.save_lora_weights( output_dir, diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index 6bec9a8c9b53b..4113f7d092a0a 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -1235,7 +1235,8 @@ def save_model_hook(models, weights, output_dir): raise ValueError(f"unexpected save model: {model.__class__}") # make sure to pop weight so that corresponding model is not saved again - weights.pop() + if weights: + weights.pop() StableDiffusionXLPipeline.save_lora_weights( output_dir, diff --git a/examples/dreambooth/train_dreambooth_sd3.py b/examples/dreambooth/train_dreambooth_sd3.py index 4cabd413d0a7e..aea74cae5a1eb 100644 --- a/examples/dreambooth/train_dreambooth_sd3.py +++ b/examples/dreambooth/train_dreambooth_sd3.py @@ -1192,7 +1192,8 @@ def save_model_hook(models, weights, output_dir): raise ValueError(f"Wrong model supplied: {type(model)=}.") # make sure to pop weight so that corresponding model is not saved again - weights.pop() + if weights: + weights.pop() def load_model_hook(models, input_dir): for _ in range(len(models)):