Skip to content

Commit

Permalink
Fix issues for saving checkpointing steps
Browse files Browse the repository at this point in the history
  • Loading branch information
蒋硕 committed Nov 13, 2024
1 parent 486a115 commit 5d07df8
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 4 deletions.
3 changes: 2 additions & 1 deletion examples/dreambooth/train_dreambooth_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -955,7 +955,8 @@ def save_model_hook(models, weights, output_dir):
raise ValueError(f"unexpected save model: {model.__class__}")

# make sure to pop weight so that corresponding model is not saved again
weights.pop()
if weights:
weights.pop()

StableDiffusionLoraLoaderMixin.save_lora_weights(
output_dir,
Expand Down
3 changes: 2 additions & 1 deletion examples/dreambooth/train_dreambooth_lora_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -1302,7 +1302,8 @@ def save_model_hook(models, weights, output_dir):
raise ValueError(f"unexpected save model: {model.__class__}")

# make sure to pop weight so that corresponding model is not saved again
weights.pop()
if weights:
weights.pop()

StableDiffusion3Pipeline.save_lora_weights(
output_dir,
Expand Down
3 changes: 2 additions & 1 deletion examples/dreambooth/train_dreambooth_lora_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1235,7 +1235,8 @@ def save_model_hook(models, weights, output_dir):
raise ValueError(f"unexpected save model: {model.__class__}")

# make sure to pop weight so that corresponding model is not saved again
weights.pop()
if weights:
weights.pop()

StableDiffusionXLPipeline.save_lora_weights(
output_dir,
Expand Down
3 changes: 2 additions & 1 deletion examples/dreambooth/train_dreambooth_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -1192,7 +1192,8 @@ def save_model_hook(models, weights, output_dir):
raise ValueError(f"Wrong model supplied: {type(model)=}.")

# make sure to pop weight so that corresponding model is not saved again
weights.pop()
if weights:
weights.pop()

def load_model_hook(models, input_dir):
for _ in range(len(models)):
Expand Down

0 comments on commit 5d07df8

Please sign in to comment.