Skip to content

Commit

Permalink
Fix issues for saving checkpointing steps
Browse files Browse the repository at this point in the history
  • Loading branch information
蒋硕 committed Nov 15, 2024
1 parent e78afaf commit 7eae07b
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 30 deletions.
41 changes: 20 additions & 21 deletions examples/dreambooth/train_dreambooth_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -1191,33 +1191,32 @@ def save_model_hook(models, weights, output_dir):
weights.pop()

def load_model_hook(models, input_dir):
if not accelerator.distributed_type == DistributedType.DEEPSPEED:
for _ in range(len(models)):
# pop models so that they are not loaded again
model = models.pop()

# load diffusers style into model
if isinstance(unwrap_model(model), FluxTransformer2DModel):
load_model = FluxTransformer2DModel.from_pretrained(input_dir, subfolder="transformer")
model.register_to_config(**load_model.config)

for _ in range(len(models)):
# pop models so that they are not loaded again
model = models.pop()

# load diffusers style into model
if isinstance(unwrap_model(model), FluxTransformer2DModel):
load_model = FluxTransformer2DModel.from_pretrained(input_dir, subfolder="transformer")
model.register_to_config(**load_model.config)

model.load_state_dict(load_model.state_dict())
elif isinstance(unwrap_model(model), (CLIPTextModelWithProjection, T5EncoderModel)):
try:
load_model = CLIPTextModelWithProjection.from_pretrained(input_dir, subfolder="text_encoder")
model(**load_model.config)
model.load_state_dict(load_model.state_dict())
elif isinstance(unwrap_model(model), (CLIPTextModelWithProjection, T5EncoderModel)):
except Exception:
try:
load_model = CLIPTextModelWithProjection.from_pretrained(input_dir, subfolder="text_encoder")
load_model = T5EncoderModel.from_pretrained(input_dir, subfolder="text_encoder_2")
model(**load_model.config)
model.load_state_dict(load_model.state_dict())
except Exception:
try:
load_model = T5EncoderModel.from_pretrained(input_dir, subfolder="text_encoder_2")
model(**load_model.config)
model.load_state_dict(load_model.state_dict())
except Exception:
raise ValueError(f"Couldn't load the model of type: ({type(model)}).")
else:
raise ValueError(f"Unsupported model found: {type(model)=}")
raise ValueError(f"Couldn't load the model of type: ({type(model)}).")
else:
raise ValueError(f"Unsupported model found: {type(model)=}")

del load_model
del load_model

accelerator.register_save_state_pre_hook(save_model_hook)
accelerator.register_load_state_pre_hook(load_model_hook)
Expand Down
17 changes: 8 additions & 9 deletions examples/dreambooth/train_dreambooth_lora_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -1263,16 +1263,15 @@ def load_model_hook(models, input_dir):
transformer_ = None
text_encoder_one_ = None

if not accelerator.distributed_type == DistributedType.DEEPSPEED:
while len(models) > 0:
model = models.pop()
while len(models) > 0:
model = models.pop()

if isinstance(model, type(unwrap_model(transformer))):
transformer_ = model
elif isinstance(model, type(unwrap_model(text_encoder_one))):
text_encoder_one_ = model
else:
raise ValueError(f"unexpected save model: {model.__class__}")
if isinstance(model, type(unwrap_model(transformer))):
transformer_ = model
elif isinstance(model, type(unwrap_model(text_encoder_one))):
text_encoder_one_ = model
else:
raise ValueError(f"unexpected save model: {model.__class__}")

lora_state_dict = FluxPipeline.lora_state_dict(input_dir)

Expand Down

0 comments on commit 7eae07b

Please sign in to comment.