Skip to content

Commit

Permalink
Merge pull request AUTOMATIC1111#17 from uservar/dev3
Browse files Browse the repository at this point in the history
Fix v-prediction model detection
  • Loading branch information
uservar authored Nov 25, 2022
2 parents d3fc11c + d6552c0 commit 8b5297f
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 4 deletions.
2 changes: 1 addition & 1 deletion modules/sd_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ def load_model(checkpoint_info=None):

# check if the model uses v-prediction (the 768x768 model and x4 upscaler do)
# see https://arxiv.org/abs/2202.00512
shared.opts.v_sampling = sd_config.get("model.params.parameterization") == "v"
shared.v_prediction = sd_config.model.params.get("parameterization") == "v"

if should_hijack_inpainting(checkpoint_info):
# Hardcoded config for now...
Expand Down
2 changes: 1 addition & 1 deletion modules/sd_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ def __getattr__(self, item):

class KDiffusionSampler:
def __init__(self, funcname, sd_model):
wrapper = k_diffusion.external.CompVisVDenoiser if shared.opts.v_sampling else k_diffusion.external.CompVisDenoiser
wrapper = k_diffusion.external.CompVisVDenoiser if shared.v_prediction else k_diffusion.external.CompVisDenoiser
self.model_wrap = wrapper(sd_model, quantize=shared.opts.enable_quantization)
self.funcname = funcname
self.func = getattr(k_diffusion.sampling, self.funcname)
Expand Down
5 changes: 3 additions & 2 deletions modules/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@
hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir)
loaded_hypernetwork = None

v_prediction = False

def reload_hypernetworks():
global hypernetworks

Expand Down Expand Up @@ -380,7 +382,7 @@ def options_section(section_identifier, options_dict):
"js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"),
"js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"),
"show_progress_in_title": OptionInfo(True, "Show generation progress in window title."),
'quicksettings': OptionInfo("sd_model_checkpoint, v_sampling", "Quicksettings list"),
'quicksettings': OptionInfo("sd_model_checkpoint", "Quicksettings list"),
'localization': OptionInfo("None", "Localization (requires restart)", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)),
}))

Expand All @@ -393,7 +395,6 @@ def options_section(section_identifier, options_dict):
's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}),
"v_sampling": OptionInfo(False, "Use v-prediction"),
}))

options_templates.update(options_section((None, "Hidden options"), {
Expand Down

0 comments on commit 8b5297f

Please sign in to comment.