Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parse LoRA from prompt #198

Merged
merged 12 commits into from
Dec 10, 2023
35 changes: 21 additions & 14 deletions ai_diffusion/attention_edit.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,38 +3,40 @@


def select_current_parenthesis_block(
text: str, cursor_pos: int, open_bracket: str, close_bracket: str
text: str, cursor_pos: int, open_brackets: list[str], close_brackets: list[str]
) -> Tuple[int, int] | None:
"""Select the current parenthesis block that the cursor points to."""
# Ensure cursor position is within valid range
cursor_pos = max(0, min(cursor_pos, len(text)))

# Find the nearest '(' before the cursor
start = text.rfind(open_bracket, 0, cursor_pos)
start = -1
for open_bracket in open_brackets:
start = max(start, text.rfind(open_bracket, 0, cursor_pos))

# If '(' is found, find the corresponding ')' after the cursor
end = -1
if start != -1:
open_parens = 1
for i in range(start + 1, len(text)):
if text[i] == open_bracket:
if text[i] in open_brackets:
open_parens += 1
elif text[i] == close_bracket:
elif text[i] in close_brackets:
open_parens -= 1
if open_parens == 0:
end = i
break

# Return the indices only if both '(' and ')' are found
if start != -1 and end > cursor_pos:
return (start, end + 1)
if start != -1 and end >= cursor_pos:
return start, end + 1
else:
return None


def select_current_word(text: str, cursor_pos: int) -> Tuple[int, int]:
"""Select the word the cursor points to."""
delimiters = r".,\/!?%^*;:{}=`~() " + "\t\r\n"
delimiters = r".,\/!?%^*;:{}=`~()<> " + "\t\r\n"
start = end = cursor_pos

# seek backward to find beginning
Expand All @@ -50,9 +52,9 @@ def select_current_word(text: str, cursor_pos: int) -> Tuple[int, int]:

def select_on_cursor_pos(text: str, cursor_pos: int) -> Tuple[int, int]:
"""Return a range in the text based on the cursor_position."""
return select_current_parenthesis_block(text, cursor_pos, "(", ")") or select_current_word(
text, cursor_pos
)
return select_current_parenthesis_block(
text, cursor_pos, ["(", "<"], [")", ">"]
) or select_current_word(text, cursor_pos)


class ExprNode:
Expand All @@ -77,7 +79,7 @@ def parse_expr(expression: str) -> List[ExprNode]:
"""

def parse_segment(segment):
match = re.match(r"^[([{<](.*?):([\d.]+)[\]})>]$", segment)
match = re.match(r"^[([{<](.*?):(-?[\d.]+)[\]})>]$", segment)
if match:
inner_expr = match.group(1)
number = float(match.group(2))
Expand All @@ -88,7 +90,7 @@ def parse_segment(segment):
segments = []
stack = []
start = 0
bracket_pairs = {"(": ")"}
bracket_pairs = {"(": ")", "<": ">"}

for i, char in enumerate(expression):
if char in bracket_pairs:
Expand Down Expand Up @@ -127,18 +129,23 @@ def edit_attention(text: str, positive: bool) -> str:
weight = segments[0].weight
open_bracket = text[0]
close_bracket = text[-1]
elif text[0] == "<":
attention_string = text[1:-1]
weight = 1.0
open_bracket = "<"
close_bracket = ">"
else:
attention_string = text
weight = 1.0
open_bracket = "("
close_bracket = ")"

weight = weight + 0.1 * (1 if positive else -1)
weight = max(weight, 0.0)
weight = max(weight, -2.0)
weight = min(weight, 2.0)

return (
attention_string
if weight == 1.0
if weight == 1.0 and open_bracket == "("
else f"{open_bracket}{attention_string}:{weight:.1f}{close_bracket}"
)
57 changes: 48 additions & 9 deletions ai_diffusion/workflow.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from __future__ import annotations
import math
import re
from itertools import chain
from pathlib import Path
from typing import Any, List, NamedTuple, Optional

from .image import Bounds, Extent, Image, Mask
Expand All @@ -11,6 +14,9 @@
from .util import client_logger as log


_pattern_lora = re.compile(r"\s*<lora:([^:<>]+)(?::(-?[^:<>]*))?>\s*", re.IGNORECASE)


class ScaledExtent(NamedTuple):
initial: Extent # resolution for initial generation
expanded: Extent # resolution for high res pass
Expand Down Expand Up @@ -193,6 +199,33 @@ def _sampler_params(
return params


def _parse_loras(client_loras: list[str], prompt: str) -> list[dict[str, str | float]]:
loras = []
for match in _pattern_lora.findall(prompt):
lora_name = ""

for client_lora in client_loras:
lora_filename = Path(client_lora).stem
if match[0].lower() == lora_filename.lower():
lora_name = client_lora

if not lora_name:
error = f"LoRA not found : {match[0]}"
log.warning(error)
raise Exception(error)

lora_strength = match[1] if match[1] != "" else 1.0
try:
lora_strength = float(lora_strength)
except ValueError:
error = f"Invalid LoRA strength for {match[0]} : {lora_strength}"
log.warning(error)
raise Exception(error)

loras.append(dict(name=lora_name, strength=lora_strength))
return loras


def _apply_strength(strength: float, steps: int, min_steps: int = 0) -> tuple[int, int]:
start_at_step = round(steps * (1 - strength))

Expand All @@ -203,7 +236,13 @@ def _apply_strength(strength: float, steps: int, min_steps: int = 0) -> tuple[in
return steps, start_at_step


def load_model_with_lora(w: ComfyWorkflow, comfy: Client, style: Style, is_live=False):
def load_model_with_lora(
w: ComfyWorkflow,
comfy: Client,
style: Style,
prompt: str,
is_live=False,
):
checkpoint = style.sd_checkpoint
if checkpoint not in comfy.checkpoints:
checkpoint = next(iter(comfy.checkpoints.keys()))
Expand All @@ -217,9 +256,9 @@ def load_model_with_lora(w: ComfyWorkflow, comfy: Client, style: Style, is_live=
else:
log.warning(f"Style VAE {style.vae} not found, using default VAE from checkpoint")

for lora in style.loras:
for lora in chain(style.loras, _parse_loras(comfy.lora_models, prompt)):
if lora["name"] not in comfy.lora_models:
log.warning(f"Style LoRA {lora['name']} not found, skipping")
log.warning(f"LoRA {lora['name']} not found, skipping")
continue
model, clip = w.load_lora(model, clip, lora["name"], lora["strength"], lora["strength"])

Expand Down Expand Up @@ -312,7 +351,7 @@ def merge_prompt(prompt: str, style_prompt: str):
def apply_conditioning(
cond: Conditioning, w: ComfyWorkflow, comfy: Client, model: Output, clip: Output, style: Style
):
prompt = merge_prompt(cond.prompt, style.style_prompt)
prompt = merge_prompt(_pattern_lora.sub("", cond.prompt), style.style_prompt)
if cond.area:
prompt = merge_prompt("", style.style_prompt)
positive = w.clip_text_encode(clip, prompt)
Expand Down Expand Up @@ -423,7 +462,7 @@ def generate(
batch = 1 if live.is_active else batch

w = ComfyWorkflow(comfy.nodes_inputs)
model, clip, vae = load_model_with_lora(w, comfy, style, is_live=live.is_active)
model, clip, vae = load_model_with_lora(w, comfy, style, cond.prompt, is_live=live.is_active)
latent = w.empty_latent_image(extent.initial.width, extent.initial.height, batch)
model, positive, negative = apply_conditioning(cond, w, comfy, model, clip, style)
out_latent = w.ksampler_advanced(model, positive, negative, latent, **sampler_params)
Expand All @@ -444,7 +483,7 @@ def inpaint(comfy: Client, style: Style, image: Image, mask: Mask, cond: Conditi
expanded_bounds = Bounds(*mask.bounds.offset, *region_expanded)

w = ComfyWorkflow(comfy.nodes_inputs)
model, clip, vae = load_model_with_lora(w, comfy, style)
model, clip, vae = load_model_with_lora(w, comfy, style, cond.prompt)
in_image = w.load_image(scaled_image)
in_mask = w.load_mask(scaled_mask)
cropped_mask = w.load_mask(mask.to_image())
Expand Down Expand Up @@ -522,7 +561,7 @@ def refine(
sampler_params = _sampler_params(style, live=live, strength=strength)

w = ComfyWorkflow(comfy.nodes_inputs)
model, clip, vae = load_model_with_lora(w, comfy, style, is_live=live.is_active)
model, clip, vae = load_model_with_lora(w, comfy, style, cond.prompt, is_live=live.is_active)
in_image = w.load_image(image)
if extent.is_incompatible:
in_image = w.scale_image(in_image, extent.expanded)
Expand Down Expand Up @@ -555,7 +594,7 @@ def refine_region(
sampler_params = _sampler_params(style, strength=strength, live=live)

w = ComfyWorkflow(comfy.nodes_inputs)
model, clip, vae = load_model_with_lora(w, comfy, style, is_live=live.is_active)
model, clip, vae = load_model_with_lora(w, comfy, style, cond.prompt, is_live=live.is_active)
in_image = w.load_image(image)
in_mask = w.load_mask(mask_image)
if extent.requires_downscale:
Expand Down Expand Up @@ -649,7 +688,7 @@ def upscale_tiled(

w = ComfyWorkflow(comfy.nodes_inputs)
img = w.load_image(image)
checkpoint, clip, vae = load_model_with_lora(w, comfy, style)
checkpoint, clip, vae = load_model_with_lora(w, comfy, style, cond.prompt)
upscale_model = w.load_upscale_model(model)
if sd_ver.has_controlnet_blur:
cond.control.append(Control(ControlMode.blur, img))
Expand Down
11 changes: 9 additions & 2 deletions tests/test_attention_edit.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ def test_upper_bound(self):
assert edit_attention("(bar:1.95)", positive=True) == "(bar:2.0)"

def test_lower_bound(self):
assert edit_attention("(bar:0.0)", positive=False) == "(bar:0.0)"
assert edit_attention("(bar:0.01)", positive=False) == "(bar:0.0)"
assert edit_attention("(bar:-1.95)", positive=False) == "(bar:-2.0)"
assert edit_attention("(bar:-2.0)", positive=False) == "(bar:-2.0)"

def test_single_digit(self):
assert edit_attention("(bar:0)", positive=True) == "(bar:0.1)"
Expand All @@ -41,6 +41,12 @@ def test_invalid_weight(self):
def test_no_weight(self):
assert edit_attention("(foo)", positive=True) == "((foo):1.1)"

def test_angle_bracket(self):
assert edit_attention("<bar:1.0>", positive=True) == "<bar:1.1>"
assert edit_attention("<foo:bar:1.0>", positive=True) == "<foo:bar:1.1>"
assert edit_attention("<foo:bar:1.1>", positive=False) == "<foo:bar:1.0>"
assert edit_attention("<foo:bar:0.0>", positive=False) == "<foo:bar:-0.1>"


class TestSelectOnCursorPos:
def test_word_selection(self):
Expand All @@ -52,3 +58,4 @@ def test_word_selection(self):
def test_range_selection(self):
assert select_on_cursor_pos("(foo:1.3), bar, baz", 1) == (0, 9)
assert select_on_cursor_pos("foo, (bar:1.1), baz", 6) == (5, 14)
assert select_on_cursor_pos("foo, (bar:1.1) <bar:baz:1.0>", 16) == (15, 28)
34 changes: 34 additions & 0 deletions tests/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,40 @@ def test_merge_prompt():
assert workflow.merge_prompt("", "b {prompt} c") == "b c"


def test_parse_lora(comfy):
client_loras = [
"/path/to/Lora-One.safetensors",
"Lora-two.safetensors",
]

assert workflow._parse_loras(client_loras, "a ship") == []
assert workflow._parse_loras(client_loras, "a ship <lora:lora-one>") == [
{"name": client_loras[0], "strength": 1.0}
]
assert workflow._parse_loras(client_loras, "a ship <lora:LoRA-one>") == [
{"name": client_loras[0], "strength": 1.0}
]
assert workflow._parse_loras(client_loras, "a ship <lora:lora-one:0.0>") == [
{"name": client_loras[0], "strength": 0.0}
]
assert workflow._parse_loras(client_loras, "a ship <lora:lora-two:0.5>") == [
{"name": client_loras[1], "strength": 0.5}
]
assert workflow._parse_loras(client_loras, "a ship <lora:lora-two:-1.0>") == [
{"name": client_loras[1], "strength": -1.0}
]

try:
workflow._parse_loras(client_loras, "a ship <lora:lora-three>")
except Exception as e:
assert str(e).startswith("LoRA not found")

try:
workflow._parse_loras(client_loras, "a ship <lora:lora-one:test-invalid-str>")
except Exception as e:
assert str(e).startswith("Invalid LoRA strength")


@pytest.mark.parametrize("extent", [Extent(256, 256), Extent(800, 800), Extent(512, 1024)])
def test_generate(qtapp, comfy, temp_settings, extent):
temp_settings.batch_size = 1
Expand Down