Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix SD2 loading #16078

Merged
merged 1 commit into from
Jul 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions modules/sd_hijack_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ def hijack_ddpm_edit():
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond)
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.get_first_stage_encoding', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).float(), first_stage_cond)

# Always make sure inputs to unet are in correct dtype
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model)
CondFunc('sgm.modules.diffusionmodules.wrappers.OpenAIWrapper.forward', apply_model)

Expand All @@ -150,5 +151,6 @@ def timestep_embedding_cast_result(orig_func, timesteps, *args, **kwargs):
return orig_func(timesteps, *args, **kwargs).to(dtype=dtype)


# Always make sure timestep calculation is in correct dtype
CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', timestep_embedding_cast_result)
CondFunc('sgm.modules.diffusionmodules.openaimodel.timestep_embedding', timestep_embedding_cast_result)
7 changes: 6 additions & 1 deletion modules/sd_models_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,19 @@ def is_using_v_parameterization_for_sd2(state_dict):
unet.eval()

with torch.no_grad():
unet_dtype = torch.float
original_unet_dtype = devices.dtype_unet

unet_sd = {k.replace("model.diffusion_model.", ""): v for k, v in state_dict.items() if "model.diffusion_model." in k}
unet.load_state_dict(unet_sd, strict=True)
unet.to(device=device, dtype=torch.float)
unet.to(device=device, dtype=unet_dtype)
devices.dtype_unet = unet_dtype

test_cond = torch.ones((1, 2, 1024), device=device) * 0.5
x_test = torch.ones((1, 4, 8, 8), device=device) * 0.5

out = (unet(x_test, torch.asarray([999], device=device), context=test_cond) - x_test).mean().item()
devices.dtype_unet = original_unet_dtype

return out < -1

Expand Down