Skip to content

Commit

Permalink
temp fix for preprocessor_revision.py (#1860)
Browse files Browse the repository at this point in the history
fix for #1420. Skip noise augmentation if model doesn't provide the method; allows Revision controlnet preprocessors to function.
  • Loading branch information
Melyns authored Sep 19, 2024
1 parent 93bcfd3 commit b20cb4b
Showing 1 changed file with 19 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@

def revision_conditioning_modifier(model, x, timestep, uncond, cond, cond_scale, model_options, seed):
revision_conditions = model_options['revision_conditions']
noise_augmentor = model.noise_augmentor

noise_augmentor = getattr(model, 'noise_augmentor', None)

noise_augment_merge = 0.0
ignore_prompt = False

Expand All @@ -18,10 +20,15 @@ def revision_conditioning_modifier(model, x, timestep, uncond, cond, cond_scale,
adm_cond = revision_condition['cond'].image_embeds
weight = revision_condition["weight"]
noise_augment = revision_condition["noise_aug"]
noise_level = round((noise_augmentor.max_noise_level - 1) * noise_augment)
c_adm, noise_level_emb = noise_augmentor(adm_cond.to(x.device),
noise_level=torch.tensor([noise_level], device=x.device), seed=seed)
adm_out = torch.cat((c_adm, noise_level_emb), 1) * weight

if noise_augmentor is not None:
noise_level = round((noise_augmentor.max_noise_level - 1) * noise_augment)
c_adm, noise_level_emb = noise_augmentor(adm_cond.to(x.device),
noise_level=torch.tensor([noise_level], device=x.device), seed=seed)
adm_out = torch.cat((c_adm, noise_level_emb), 1) * weight
else:
adm_out = adm_cond * weight # skip noise augmentation

weights.append(weight)
noise_aug.append(noise_augment)
adm_inputs.append(adm_out)
Expand All @@ -30,11 +37,12 @@ def revision_conditioning_modifier(model, x, timestep, uncond, cond, cond_scale,

if len(noise_aug) > 1:
adm_out = torch.stack(adm_inputs).sum(0)
noise_augment = noise_augment_merge
noise_level = round((noise_augmentor.max_noise_level - 1) * noise_augment)
c_adm, noise_level_emb = noise_augmentor(adm_out[:, :noise_augmentor.time_embed.dim],
noise_level=torch.tensor([noise_level], device=x.device))
adm_out = torch.cat((c_adm, noise_level_emb), 1)
if noise_augmentor is not None:
noise_augment = noise_augment_merge
noise_level = round((noise_augmentor.max_noise_level - 1) * noise_augment)
c_adm, noise_level_emb = noise_augmentor(adm_out[:, :noise_augmentor.time_embed.dim],
noise_level=torch.tensor([noise_level], device=x.device))
adm_out = torch.cat((c_adm, noise_level_emb), 1)

new_y = adm_out[:, :1280]
cond = copy.deepcopy(cond)
Expand All @@ -61,7 +69,7 @@ def __init__(self, name, url, filename, ignore_prompt=False):
self.do_not_need_model = True
self.ignore_prompt = ignore_prompt
self.slider_1 = PreprocessorParameter(
label="Noise Augmentation", minimum=0.0, maximum=1.0, value=0.0, visible=True)
label="Noise Augmentation", minimum=0.0, maximum=1.0, value=0.0, visible=False) # hiding the noise slider since it has no effect

def process_before_every_sampling(self, process, cond, mask, *args, **kwargs):
unit = kwargs['unit']
Expand Down

0 comments on commit b20cb4b

Please sign in to comment.