Skip to content

Commit

Permalink
Fix issues for saving checkpointing steps
Browse files Browse the repository at this point in the history
  • Loading branch information
蒋硕 committed Nov 13, 2024
1 parent 6837a81 commit 486a115
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
3 changes: 2 additions & 1 deletion examples/dreambooth/train_dreambooth_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -1187,7 +1187,8 @@ def save_model_hook(models, weights, output_dir):
raise ValueError(f"Wrong model supplied: {type(model)=}.")

# make sure to pop weight so that corresponding model is not saved again
weights.pop()
if weights:
weights.pop()

def load_model_hook(models, input_dir):
for _ in range(len(models)):
Expand Down
3 changes: 2 additions & 1 deletion examples/dreambooth/train_dreambooth_lora_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -1250,7 +1250,8 @@ def save_model_hook(models, weights, output_dir):
raise ValueError(f"unexpected save model: {model.__class__}")

# make sure to pop weight so that corresponding model is not saved again
weights.pop()
if weights:
weights.pop()

FluxPipeline.save_lora_weights(
output_dir,
Expand Down

0 comments on commit 486a115

Please sign in to comment.