Skip to content

Commit

Permalink
SD2 v autodetection fix
Browse files Browse the repository at this point in the history
  • Loading branch information
AUTOMATIC1111 committed Jul 6, 2024
1 parent 477869c commit 74069ad
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions modules/sd_models_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,13 @@ def is_using_v_parameterization_for_sd2(state_dict):
with torch.no_grad():
unet_sd = {k.replace("model.diffusion_model.", ""): v for k, v in state_dict.items() if "model.diffusion_model." in k}
unet.load_state_dict(unet_sd, strict=True)
unet.to(device=device, dtype=torch.float)
unet.to(device=device, dtype=devices.dtype_unet)

test_cond = torch.ones((1, 2, 1024), device=device) * 0.5
x_test = torch.ones((1, 4, 8, 8), device=device) * 0.5

out = (unet(x_test, torch.asarray([999], device=device), context=test_cond) - x_test).mean().item()
with devices.autocast():
out = (unet(x_test, torch.asarray([999], device=device), context=test_cond) - x_test).mean().cpu().item()

return out < -1

Expand Down

0 comments on commit 74069ad

Please sign in to comment.