Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

prompt editing timeline has separate range for first pass and hires-fix pass #12457

Merged
merged 5 commits into from
Aug 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions modules/processing.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -394,12 +394,14 @@ def setup_prompts(self):
self.main_prompt = self.all_prompts[0]
self.main_negative_prompt = self.all_negative_prompts[0]

def cached_params(self, required_prompts, steps, extra_network_data):
def cached_params(self, required_prompts, steps, extra_network_data, hires_steps=None, use_old_scheduling=False):
"""Returns parameters that invalidate the cond cache if changed"""

return (
required_prompts,
steps,
hires_steps,
use_old_scheduling,
opts.CLIP_stop_at_last_layers,
shared.sd_model.sd_checkpoint_info,
extra_network_data,
Expand All @@ -409,7 +411,7 @@ def cached_params(self, required_prompts, steps, extra_network_data):
self.height,
)

def get_conds_with_caching(self, function, required_prompts, steps, caches, extra_network_data):
def get_conds_with_caching(self, function, required_prompts, steps, caches, extra_network_data, hires_steps=None):
"""
Returns the result of calling function(shared.sd_model, required_prompts, steps)
using a cache to store the result if the same arguments have been used before.
Expand All @@ -422,7 +424,7 @@ def get_conds_with_caching(self, function, required_prompts, steps, caches, extr
caches is a list with items described above.
"""

cached_params = self.cached_params(required_prompts, steps, extra_network_data)
cached_params = self.cached_params(required_prompts, steps, extra_network_data, hires_steps, shared.opts.use_old_scheduling)

for cache in caches:
if cache[0] is not None and cached_params == cache[0]:
Expand All @@ -431,7 +433,7 @@ def get_conds_with_caching(self, function, required_prompts, steps, caches, extr
cache = caches[0]

with devices.autocast():
cache[1] = function(shared.sd_model, required_prompts, steps)
cache[1] = function(shared.sd_model, required_prompts, steps, hires_steps, shared.opts.use_old_scheduling)

cache[0] = cached_params
return cache[1]
Expand All @@ -443,6 +445,8 @@ def setup_conds(self):
sampler_config = sd_samplers.find_sampler_config(self.sampler_name)
total_steps = sampler_config.total_steps(self.steps) if sampler_config else self.steps
self.step_multiplier = total_steps // self.steps
self.firstpass_steps = total_steps

self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, total_steps, [self.cached_uc], self.extra_network_data)
self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, total_steps, [self.cached_c], self.extra_network_data)

Expand Down Expand Up @@ -1279,8 +1283,8 @@ def calculate_hr_conds(self):
steps = self.hr_second_pass_steps or self.steps
total_steps = sampler_config.total_steps(steps) if sampler_config else steps

self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, hr_negative_prompts, total_steps, [self.cached_hr_uc, self.cached_uc], self.hr_extra_network_data)
self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, hr_prompts, total_steps, [self.cached_hr_c, self.cached_c], self.hr_extra_network_data)
self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, hr_negative_prompts, self.firstpass_steps, [self.cached_hr_uc, self.cached_uc], self.hr_extra_network_data, total_steps)
self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, hr_prompts, self.firstpass_steps, [self.cached_hr_c, self.cached_c], self.hr_extra_network_data, total_steps)

def setup_conds(self):
if self.is_hr_pass:
Expand Down
41 changes: 31 additions & 10 deletions modules/prompt_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
%import common.SIGNED_NUMBER -> NUMBER
""")

def get_learned_conditioning_prompt_schedules(prompts, steps):
def get_learned_conditioning_prompt_schedules(prompts, base_steps, hires_steps=None, use_old_scheduling=False):
"""
>>> g = lambda p: get_learned_conditioning_prompt_schedules([p], 10)[0]
>>> g("test")
Expand Down Expand Up @@ -57,18 +57,39 @@ def get_learned_conditioning_prompt_schedules(prompts, steps):
[[1, 'female'], [2, 'male'], [3, 'female'], [4, 'male'], [5, 'female'], [6, 'male'], [7, 'female'], [8, 'male'], [9, 'female'], [10, 'male']]
>>> g("[fe|||]male")
[[1, 'female'], [2, 'male'], [3, 'male'], [4, 'male'], [5, 'female'], [6, 'male'], [7, 'male'], [8, 'male'], [9, 'female'], [10, 'male']]
>>> g = lambda p: get_learned_conditioning_prompt_schedules([p], 10, 10)[0]
>>> g("a [b:.5] c")
[[10, 'a b c']]
>>> g("a [b:1.5] c")
[[5, 'a c'], [10, 'a b c']]
"""

if hires_steps is None or use_old_scheduling:
int_offset = 0
flt_offset = 0
steps = base_steps
else:
int_offset = base_steps
flt_offset = 1.0
steps = hires_steps

def collect_steps(steps, tree):
res = [steps]

class CollectSteps(lark.Visitor):
def scheduled(self, tree):
tree.children[-2] = float(tree.children[-2])
if tree.children[-2] < 1:
tree.children[-2] *= steps
tree.children[-2] = min(steps, int(tree.children[-2]))
res.append(tree.children[-2])
s = tree.children[-2]
v = float(s)
if use_old_scheduling:
v = v*steps if v<1 else v
else:
if "." in s:
v = (v - flt_offset) * steps
else:
v = (v - int_offset)
tree.children[-2] = min(steps, int(v))
if tree.children[-2] >= 1:
res.append(tree.children[-2])

def alternate(self, tree):
res.extend(range(1, steps+1))
Expand Down Expand Up @@ -134,7 +155,7 @@ def __init__(self, prompts, is_negative_prompt=False, width=None, height=None, c



def get_learned_conditioning(model, prompts: SdConditioning | list[str], steps):
def get_learned_conditioning(model, prompts: SdConditioning | list[str], steps, hires_steps=None, use_old_scheduling=False):
"""converts a list of prompts into a list of prompt schedules - each schedule is a list of ScheduledPromptConditioning, specifying the comdition (cond),
and the sampling step at which this condition is to be replaced by the next one.

Expand All @@ -154,7 +175,7 @@ def get_learned_conditioning(model, prompts: SdConditioning | list[str], steps):
"""
res = []

prompt_schedules = get_learned_conditioning_prompt_schedules(prompts, steps)
prompt_schedules = get_learned_conditioning_prompt_schedules(prompts, steps, hires_steps, use_old_scheduling)
cache = {}

for prompt, prompt_schedule in zip(prompts, prompt_schedules):
Expand Down Expand Up @@ -229,7 +250,7 @@ def __init__(self, shape, batch):
self.batch: List[List[ComposableScheduledPromptConditioning]] = batch


def get_multicond_learned_conditioning(model, prompts, steps) -> MulticondLearnedConditioning:
def get_multicond_learned_conditioning(model, prompts, steps, hires_steps=None, use_old_scheduling=False) -> MulticondLearnedConditioning:
"""same as get_learned_conditioning, but returns a list of ScheduledPromptConditioning along with the weight objects for each prompt.
For each prompt, the list is obtained by splitting the prompt using the AND separator.

Expand All @@ -238,7 +259,7 @@ def get_multicond_learned_conditioning(model, prompts, steps) -> MulticondLearne

res_indexes, prompt_flat_list, prompt_indexes = get_multicond_prompt_list(prompts)

learned_conditioning = get_learned_conditioning(model, prompt_flat_list, steps)
learned_conditioning = get_learned_conditioning(model, prompt_flat_list, steps, hires_steps, use_old_scheduling)

res = []
for indexes in res_indexes:
Expand Down
1 change: 1 addition & 0 deletions modules/shared_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@
"use_old_hires_fix_width_height": OptionInfo(False, "For hires fix, use width/height sliders to set final resolution rather than first pass (disables Upscale by, Resize width/height to)."),
"dont_fix_second_order_samplers_schedule": OptionInfo(False, "Do not fix prompt schedule for second order samplers."),
"hires_fix_use_firstpass_conds": OptionInfo(False, "For hires fix, calculate conds of second pass using extra networks of first pass."),
"use_old_scheduling": OptionInfo(False, "Use old prompt where first pass and hires both used the same timeline, and < 1 meant relative and >= 1 meant absolute"),
}))

options_templates.update(options_section(('interrogate', "Interrogate"), {
Expand Down