Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add custom k_diffusion scheduler #10634

Merged
merged 14 commits into from
May 27, 2023
Merged
7 changes: 6 additions & 1 deletion modules/img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args):
processed_image.save(os.path.join(output_dir, filename))


def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, *args):
def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, enable_k_sched, k_sched_type, sigma_min, sigma_max, rho, *args):
override_settings = create_override_settings_dict(override_settings_texts)

is_batch = mode == 5
Expand Down Expand Up @@ -155,6 +155,11 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
inpaint_full_res_padding=inpaint_full_res_padding,
inpainting_mask_invert=inpainting_mask_invert,
override_settings=override_settings,
enable_custom_k_sched=enable_k_sched,
k_sched_type=k_sched_type,
sigma_min=sigma_min,
sigma_max=sigma_max,
rho=rho
)

p.scripts = modules.scripts.scripts_img2img
Expand Down
16 changes: 14 additions & 2 deletions modules/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from typing import Any, Dict, List

import modules.sd_hijack
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_samplers_kdiffusion
from modules.sd_hijack import model_hijack
from modules.shared import opts, cmd_opts, state
import modules.shared as shared
Expand Down Expand Up @@ -106,7 +106,7 @@ class StableDiffusionProcessing:
"""
The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing
"""
def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_min_uncond: float = 0.0, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None, override_settings_restore_afterwards: bool = True, sampler_index: int = None, script_args: list = None):
def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_min_uncond: float = 0.0, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None, override_settings_restore_afterwards: bool = True, sampler_index: int = None, script_args: list = None, enable_custom_k_sched: bool = False, k_sched_type: str = "karras", sigma_min: float=0.1, sigma_max: float=10.0, rho: float=7.0):
if sampler_index is not None:
print("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name", file=sys.stderr)

Expand Down Expand Up @@ -146,6 +146,11 @@ def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prom
self.s_tmin = s_tmin or opts.s_tmin
self.s_tmax = s_tmax or float('inf') # not representable as a standard ui option
self.s_noise = s_noise or opts.s_noise
self.enable_custom_k_sched = enable_custom_k_sched
self.k_sched_type = k_sched_type
self.sigma_max = sigma_max
self.sigma_min = sigma_min
self.rho = rho
self.override_settings = {k: v for k, v in (override_settings or {}).items() if k not in shared.restricted_opts}
self.override_settings_restore_afterwards = override_settings_restore_afterwards
self.is_using_inpainting_conditioning = False
Expand Down Expand Up @@ -555,9 +560,16 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
if uses_ensd:
uses_ensd = sd_samplers_common.is_sampler_using_eta_noise_seed_delta(p)

use_custom_k_sched = p.enable_custom_k_sched and p.sampler_name in sd_samplers_kdiffusion.k_diffusion_samplers_map

generation_params = {
"Steps": p.steps,
"Sampler": p.sampler_name,
"Enable Custom Karras Schedule": use_custom_k_sched or None,
"kdiffusion Scheduler Type": p.k_sched_type if use_custom_k_sched else None,
"kdiffusion Scheduler sigma_max": p.sigma_max if use_custom_k_sched else None,
"kdiffusion Scheduler sigma_min": p.sigma_min if use_custom_k_sched else None,
"kdiffusion Scheduler rho": p.rho if use_custom_k_sched else None,
KohakuBlueleaf marked this conversation as resolved.
Show resolved Hide resolved
"CFG scale": p.cfg_scale,
"Image CFG scale": getattr(p, 'image_cfg_scale', None),
"Seed": all_seeds[index],
Expand Down
16 changes: 16 additions & 0 deletions modules/sd_samplers_kdiffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,13 @@
'sample_dpm_2': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
}

k_diffusion_samplers_map = {x.name: x for x in samplers_data_k_diffusion}
k_diffusion_scheduler = {
'karras': k_diffusion.sampling.get_sigmas_karras,
'exponential': k_diffusion.sampling.get_sigmas_exponential,
'polyexponential': k_diffusion.sampling.get_sigmas_polyexponential
}


class CFGDenoiser(torch.nn.Module):
"""
Expand Down Expand Up @@ -304,6 +311,15 @@ def get_sigmas(self, p, steps):

if p.sampler_noise_scheduler_override:
sigmas = p.sampler_noise_scheduler_override(steps)
elif p.enable_custom_k_sched:
sigmas_func = k_diffusion_scheduler[p.k_sched_type]
sigmas_kwargs = {
'sigma_min': p.sigma_min,
'sigma_max': p.sigma_max
}
if p.k_sched_type != 'exponential':
sigmas_kwargs['rho'] = p.rho
sigmas = sigmas_func(n=steps, **sigmas_kwargs, device=shared.device)
elif self.config is not None and self.config.options.get('scheduler', None) == 'karras':
sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item())

Expand Down
1 change: 1 addition & 0 deletions modules/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
"inpaint",
"sampler",
"checkboxes",
"kdiffusion_scheduler",
"hires_fix",
"dimensions",
"cfg",
Expand Down
7 changes: 6 additions & 1 deletion modules/txt2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@



def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_sampler_index: int, hr_prompt: str, hr_negative_prompt, override_settings_texts, *args):
def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_sampler_index: int, hr_prompt: str, hr_negative_prompt, override_settings_texts, enable_k_sched, k_sched_type, sigma_min, sigma_max, rho, *args):
override_settings = create_override_settings_dict(override_settings_texts)

p = processing.StableDiffusionProcessingTxt2Img(
Expand Down Expand Up @@ -43,6 +43,11 @@ def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, step
hr_prompt=hr_prompt,
hr_negative_prompt=hr_negative_prompt,
override_settings=override_settings,
enable_custom_k_sched=enable_k_sched,
k_sched_type=k_sched_type,
sigma_min=sigma_min,
sigma_max=sigma_max,
rho=rho
)

p.scripts = modules.scripts.scripts_txt2img
Expand Down
52 changes: 52 additions & 0 deletions modules/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,7 @@ def create_ui():
with FormRow(elem_classes="checkboxes-row", variant="compact"):
restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="txt2img_restore_faces")
tiling = gr.Checkbox(label='Tiling', value=False, elem_id="txt2img_tiling")
t2i_enable_k_sched = gr.Checkbox(label='Custom KDiffusion Scheduler', value=False, elem_id="txt2img_enable_k_sched")
enable_hr = gr.Checkbox(label='Hires. fix', value=False, elem_id="txt2img_enable_hr")
hr_final_resolution = FormHTML(value="", elem_id="txtimg_hr_finalres", label="Upscaled resolution", interactive=False)

Expand All @@ -510,6 +511,14 @@ def create_ui():
with gr.Row():
hr_negative_prompt = gr.Textbox(label="Negative prompt", elem_id="hires_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt for hires fix pass.\nLeave empty to use the same negative prompt as in first pass.", elem_classes=["prompt"])

elif category == "kdiffusion_scheduler":
with FormGroup(visible=False, elem_id="txt2img_kdiffusion_scheduler") as t2i_k_sched_options:
with FormRow(elem_id="txt2img_kdiffusion_scheduler_row1", variant="compact"):
t2i_k_sched_type = gr.Dropdown(label="Type", elem_id="t2i_k_sched_type", choices=['karras', 'exponential', 'polyexponential'], value='karras')
t2i_k_sched_sigma_min = gr.Slider(minimum=0.0, maximum=0.5, step=0.05, label='sigma min', value=0.1, elem_id="txt2img_sigma_min")
t2i_k_sched_sigma_max = gr.Slider(minimum=5.0, maximum=50.0, step=0.1, label='sigma max', value=10.0, elem_id="txt2img_sigma_max")
t2i_k_sched_rho = gr.Slider(minimum=0.5, maximum=10.0, step=0.1, label='rho', value=7.0, elem_id="txt2img_rho")

elif category == "batch":
if not opts.dimensions_and_batch_together:
with FormRow(elem_id="txt2img_column_batch"):
Expand Down Expand Up @@ -578,6 +587,11 @@ def create_ui():
hr_prompt,
hr_negative_prompt,
override_settings,
t2i_enable_k_sched,
t2i_k_sched_type,
t2i_k_sched_sigma_min,
t2i_k_sched_sigma_max,
t2i_k_sched_rho

] + custom_inputs,

Expand Down Expand Up @@ -627,6 +641,13 @@ def create_ui():
show_progress = False,
)

t2i_enable_k_sched.change(
fn=lambda x: gr_show(x),
inputs=[t2i_enable_k_sched],
outputs=[t2i_k_sched_options],
show_progress=False
)

txt2img_paste_fields = [
(txt2img_prompt, "Prompt"),
(txt2img_negative_prompt, "Negative prompt"),
Expand Down Expand Up @@ -655,6 +676,11 @@ def create_ui():
(hr_prompt, "Hires prompt"),
(hr_negative_prompt, "Hires negative prompt"),
(hr_prompts_container, lambda d: gr.update(visible=True) if d.get("Hires prompt", "") != "" or d.get("Hires negative prompt", "") != "" else gr.update()),
(t2i_enable_k_sched, "Enable Custom Karras Schedule"),
KohakuBlueleaf marked this conversation as resolved.
Show resolved Hide resolved
(t2i_k_sched_type, "KDiffusion Scheduler Type"),
(t2i_k_sched_sigma_max, "KDiffusion Scheduler sigma_max"),
(t2i_k_sched_sigma_min, "KDiffusion Scheduler sigma_min"),
(t2i_k_sched_rho, "KDiffusion Scheduler rho"),
*modules.scripts.scripts_txt2img.infotext_fields
]
parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields, override_settings)
Expand Down Expand Up @@ -846,6 +872,15 @@ def copy_image(img):
with FormRow(elem_classes="checkboxes-row", variant="compact"):
restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="img2img_restore_faces")
tiling = gr.Checkbox(label='Tiling', value=False, elem_id="img2img_tiling")
i2i_enable_k_sched = gr.Checkbox(label='Custom KDiffusion Scheduler', value=False, elem_id="txt2img_enable_k_sched")

elif category == "kdiffusion_scheduler":
with FormGroup(visible=False, elem_id="img2img_kdiffusion_scheduler") as i2i_k_sched_options:
with FormRow(elem_id="img2img_kdiffusion_scheduler_row1", variant="compact"):
i2i_k_sched_type = gr.Dropdown(label="Type", elem_id="t2i_k_sched_type", choices=['karras', 'exponential', 'polyexponential'], value='karras')
i2i_k_sched_sigma_min = gr.Slider(minimum=0.0, maximum=0.5, step=0.05, label='sigma min', value=0.1, elem_id="txt2img_sigma_min")
i2i_k_sched_sigma_max = gr.Slider(minimum=5.0, maximum=50.0, step=0.1, label='sigma max', value=10.0, elem_id="txt2img_sigma_max")
i2i_k_sched_rho = gr.Slider(minimum=0.5, maximum=10.0, step=0.1, label='rho', value=7.0, elem_id="txt2img_rho")

elif category == "batch":
if not opts.dimensions_and_batch_together:
Expand Down Expand Up @@ -949,6 +984,11 @@ def select_img2img_tab(tab):
img2img_batch_output_dir,
img2img_batch_inpaint_mask_dir,
override_settings,
i2i_enable_k_sched,
i2i_k_sched_type,
i2i_k_sched_sigma_min,
i2i_k_sched_sigma_max,
i2i_k_sched_rho
] + custom_inputs,
outputs=[
img2img_gallery,
Expand Down Expand Up @@ -1032,6 +1072,13 @@ def select_img2img_tab(tab):
outputs=[prompt, negative_prompt, styles],
)

i2i_enable_k_sched.change(
fn=lambda x: gr_show(x),
inputs=[i2i_enable_k_sched],
outputs=[i2i_k_sched_options],
show_progress=False
)

token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter])
negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[img2img_negative_prompt, steps], outputs=[negative_token_counter])

Expand All @@ -1043,6 +1090,11 @@ def select_img2img_tab(tab):
(steps, "Steps"),
(sampler_index, "Sampler"),
(restore_faces, "Face restoration"),
(i2i_enable_k_sched, "Enable Custom Karras Schedule"),
KohakuBlueleaf marked this conversation as resolved.
Show resolved Hide resolved
(i2i_k_sched_type, "KDiffusion Scheduler Type"),
(i2i_k_sched_sigma_max, "KDiffusion Scheduler sigma_max"),
(i2i_k_sched_sigma_min, "KDiffusion Scheduler sigma_min"),
(i2i_k_sched_rho, "KDiffusion Scheduler rho"),
(cfg_scale, "CFG scale"),
(image_cfg_scale, "Image CFG scale"),
(seed, "Seed"),
Expand Down
6 changes: 5 additions & 1 deletion scripts/xyz_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import modules.scripts as scripts
import gradio as gr

from modules import images, sd_samplers, processing, sd_models, sd_vae
from modules import images, sd_samplers, processing, sd_models, sd_vae, sd_samplers_kdiffusion
from modules.processing import process_images, Processed, StableDiffusionProcessingTxt2Img
from modules.shared import opts, state
import modules.shared as shared
Expand Down Expand Up @@ -220,6 +220,10 @@ def __init__(self, *args, **kwargs):
AxisOption("Sigma min", float, apply_field("s_tmin")),
AxisOption("Sigma max", float, apply_field("s_tmax")),
AxisOption("Sigma noise", float, apply_field("s_noise")),
AxisOption("KDiffusion Scheduler Type", str, apply_field("k_sched_type"), choices=lambda: [x for x in sd_samplers_kdiffusion.k_diffusion_scheduler]),
AxisOption("KDiffusion Scheduler Sigma Min", float, apply_field("sigma_min")),
AxisOption("KDiffusion Scheduler Sigma Max", float, apply_field("sigma_max")),
AxisOption("KDiffusion Scheduler rho", float, apply_field("rho")),
AxisOption("Eta", float, apply_field("eta")),
AxisOption("Clip skip", int, apply_clip_skip),
AxisOption("Denoising", float, apply_field("denoising_strength")),
Expand Down