Skip to content

Commit

Permalink
[bugfix]Fix bug in Lora checkpoint saving step
Browse files Browse the repository at this point in the history
  • Loading branch information
J石页 committed Dec 23, 2024
1 parent 3f3bbcb commit 4807390
Showing 1 changed file with 1 addition and 24 deletions.
25 changes: 1 addition & 24 deletions examples/dreambooth/train_dreambooth_lora_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -1328,7 +1328,7 @@ def load_model_hook(models, input_dir):
text_encoder_two_ = None

while len(models) > 0:
model = models.pop()
model = models.pop( )

if isinstance(model, type(unwrap_model(transformer))):
transformer_ = model
Expand All @@ -1339,29 +1339,6 @@ def load_model_hook(models, input_dir):
else:
raise ValueError(f"unexpected save model: {model.__class__}")

lora_state_dict = StableDiffusion3Pipeline.lora_state_dict(input_dir)

transformer_state_dict = {
f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.")
}
transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
if incompatible_keys is not None:
# check only for unexpected keys
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
if unexpected_keys:
logger.warning(
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
f" {unexpected_keys}. "
)
if args.train_text_encoder:
# Do we need to call `scale_lora_layers()` here?
_set_state_dict_into_text_encoder(lora_state_dict, prefix="text_encoder.", text_encoder=text_encoder_one_)

_set_state_dict_into_text_encoder(
lora_state_dict, prefix="text_encoder_2.", text_encoder=text_encoder_two_
)

# Make sure the trainable params are in float32. This is again needed since the base models
# are in `weight_dtype`. More details:
# https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
Expand Down

0 comments on commit 4807390

Please sign in to comment.