Skip to content

Commit

Permalink
Variable-Strength StyleAlign (#1387)
Browse files Browse the repository at this point in the history
* adding efficiency

* adding variable strength

* Revert "adding efficiency"

This reverts commit 6d0ad98.

* updating with 0 and 1 cases

---------

Co-authored-by: T. Warren de Wit <tww0007@uah.edu>
  • Loading branch information
twarrendewit and T. Warren de Wit authored Aug 22, 2024
1 parent d169cd5 commit f7ab23b
Showing 1 changed file with 24 additions and 5 deletions.
29 changes: 24 additions & 5 deletions extensions-builtin/sd_forge_stylealign/scripts/forge_stylealign.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,15 @@ def show(self, is_img2img):
def ui(self, *args, **kwargs):
with gr.Accordion(open=False, label=self.title()):
shared_attention = gr.Checkbox(label='Share attention in batch', value=False)
strength = gr.Slider(label='Strength', minimum=0.0, maximum=1.0, value=1.0)

return [shared_attention]
return [shared_attention, strength]

def process_before_every_sampling(self, p, *script_args, **kwargs):
# This will be called before every sampling.
# If you use highres fix, this will be called twice.

shared_attention = script_args[0]
shared_attention, strength = script_args

if not shared_attention:
return
Expand Down Expand Up @@ -60,9 +61,26 @@ def attn1_proc(q, k, v, transformer_options):
indices = uncond_indices

if len(indices) > 0:

bq, bk, bv = q[indices], k[indices], v[indices]
bo = aligned_attention(bq, bk, bv, transformer_options)
results.append(bo)

if strength < 0.01:
# At strength = 0, use original.
original_attention = sdp(bq, bk, bv, transformer_options)
results.append(original_attention)

elif strength > 0.99:
# At strength 1, use aligned.
aligned_attention_result = aligned_attention(bq, bk, bv, transformer_options)
results.append(aligned_attention_result)

else:
# In between, blend original and aligned attention based on strength.
original_attention = sdp(bq, bk, bv, transformer_options)
aligned_attention_result = aligned_attention(bq, bk, bv, transformer_options)
blended_attention = (1.0 - strength) * original_attention + strength * aligned_attention_result
results.append(blended_attention)


results = torch.cat(results, dim=0)
return results
Expand All @@ -75,6 +93,7 @@ def attn1_proc(q, k, v, transformer_options):
# The extra_generation_params does not influence results.
p.extra_generation_params.update(dict(
stylealign_enabled=shared_attention,
stylealign_strength=strength,
))

return
return

0 comments on commit f7ab23b

Please sign in to comment.