From 486a115b2264c8335649d5e49544f794926c44a2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=92=8B=E7=A1=95?= Date: Wed, 13 Nov 2024 11:42:10 +0800 Subject: [PATCH] Fix issues for saving checkpointing steps --- examples/dreambooth/train_dreambooth_flux.py | 3 ++- examples/dreambooth/train_dreambooth_lora_flux.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_flux.py b/examples/dreambooth/train_dreambooth_flux.py index f60b2e9813f0a..b6126b73b6d9c 100644 --- a/examples/dreambooth/train_dreambooth_flux.py +++ b/examples/dreambooth/train_dreambooth_flux.py @@ -1187,7 +1187,8 @@ def save_model_hook(models, weights, output_dir): raise ValueError(f"Wrong model supplied: {type(model)=}.") # make sure to pop weight so that corresponding model is not saved again - weights.pop() + if weights: + weights.pop() def load_model_hook(models, input_dir): for _ in range(len(models)): diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index f244a6ce55113..e6037a025cbd9 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -1250,7 +1250,8 @@ def save_model_hook(models, weights, output_dir): raise ValueError(f"unexpected save model: {model.__class__}") # make sure to pop weight so that corresponding model is not saved again - weights.pop() + if weights: + weights.pop() FluxPipeline.save_lora_weights( output_dir,