Skip to content

Commit

Permalink
Vae/te preferences via cards (#1912)
Browse files Browse the repository at this point in the history
Allows setting of preferred VAE and Text encoder(s) for checkpoints when selected via Checkpoint cards. No selection saved means no change to current toprow setting. 'Built in' option, if the only choice, means clear the toprow selection (therefore use vae/te built-in to checkpoint).
Also allows setting model type for checkpoints (SD1/SD2/SDXL/Flux/Unknown) (user set only, no attempt at autodetection), enabling filtering of the cards based on UI preset.
  • Loading branch information
DenOfEquity authored Sep 25, 2024
1 parent c2d290e commit 7876862
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 18 deletions.
4 changes: 4 additions & 0 deletions javascript/ui.js
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,10 @@ function selectCheckpoint(name) {
desiredCheckpointName = name;
gradioApp().getElementById('change_checkpoint').click();
}
var desiredVAEName = 0;
function selectVAE(vae) {
desiredVAEName = vae;
}

function currentImg2imgSourceResolution(w, h, r) {
var img = gradioApp().querySelector('#mode_img2img > div[style="display: block;"] :is(img, canvas)');
Expand Down
9 changes: 9 additions & 0 deletions modules/sd_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,15 @@ def list_models():

re_strip_checksum = re.compile(r"\s*\[[^]]+]\s*$")

def match_checkpoint_to_name(name):
name = name.split(' [')[0]

for ckptname in checkpoints_list.values():
title = ckptname.title.split(' [')[0]
if (name in title) or (title in name):
return ckptname.short_title if shared.opts.sd_checkpoint_dropdown_use_short else ckptname.name.split(' [')[0]

return name

def get_closet_checkpoint_match(search_string):
if not search_string:
Expand Down
11 changes: 10 additions & 1 deletion modules/ui_extra_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,12 @@ def read_user_metadata(self, item, use_cache=True):
desc = metadata.get("description", None)
if desc is not None:
item["description"] = desc
vae = metadata.get("vae", None)
if vae is not None:
item["vae"] = vae
version = metadata.get("sd_version_str", None)
if version is not None:
item["sd_version_str"] = version

item["user_metadata"] = metadata

Expand Down Expand Up @@ -257,7 +263,7 @@ def create_item_html(
background_image = f'<img src="{html.escape(preview)}" class="preview" loading="lazy">' if preview else ''

onclick = item.get("onclick", None)
if onclick is None:
if onclick is None: # this path is 'Textual Inversion' and 'Lora'
# Don't quote prompt/neg_prompt since they are stored as js strings already.
onclick_js_tpl = "cardClicked('{tabname}', {prompt}, {neg_prompt}, {allow_neg});"
onclick = onclick_js_tpl.format(
Expand All @@ -269,6 +275,9 @@ def create_item_html(
}
)
onclick = html.escape(onclick)
else: # this path is 'Checkpoints'
vae = item.get("vae", [])
onclick = html.escape(f"selectVAE({vae});") + onclick

btn_copy_path = self.btn_copy_path_tpl.format(**{"filename": item["filename"]})
btn_metadata = ""
Expand Down
47 changes: 35 additions & 12 deletions modules/ui_extra_networks_checkpoints_user_metadata.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,65 @@
import gradio as gr

from modules import ui_extra_networks_user_metadata, sd_vae, shared
from modules.ui_common import create_refresh_button
from modules.ui_components import ToolButton
from modules_forge import main_entry

refresh_symbol = '\U0001f504' # 🔄

class CheckpointUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor):
def __init__(self, ui, tabname, page):
super().__init__(ui, tabname, page)

self.select_vae = None
self.sd_version = 'Unknown'

def save_user_metadata(self, name, desc, notes, vae):
def save_user_metadata(self, name, desc, notes, vae, sd_version):
user_metadata = self.get_user_metadata(name)
user_metadata["description"] = desc
user_metadata["notes"] = notes
user_metadata["vae"] = vae
user_metadata["sd_version_str"] = 'SdVersion.' + sd_version

self.write_user_metadata(name, user_metadata)

def update_vae(self, name):
if name == shared.sd_model.sd_checkpoint_info.name_for_extra:
sd_vae.reload_vae_weights()

def put_values_into_components(self, name):
user_metadata = self.get_user_metadata(name)
values = super().put_values_into_components(name)

vae = user_metadata.get('vae', None)

version = user_metadata.get('sd_version_str', '')
if version == '':
version = 'Unknown'
else:
version = version.replace('SdVersion.', '')

return [
*values[0:5],
user_metadata.get('vae', ''),
vae,
version,
]

def create_editor(self):
def create_editor(self): #happens before main_entry.modules_list is filled
modules_list = ['Built in']
if main_entry.module_list == {}:
_, modules = main_entry.refresh_models()
modules_list += list(modules)
else:
modules_list += list(main_entry.module_list.keys())

def refreshModules ():
return gr.update(choices=['Built in'] + list(main_entry.module_list.keys()))

self.create_default_editor_elems()

self.sd_version = gr.Radio(['SD1', 'SD2', 'SDXL', 'Flux', 'Unknown'], value='Unknown', label='Base model', interactive=True)

with gr.Row():
self.select_vae = gr.Dropdown(choices=["Automatic", "None"] + list(sd_vae.vae_dict), value="None", label="Preferred VAE", elem_id="checpoint_edit_user_metadata_preferred_vae")
create_refresh_button(self.select_vae, sd_vae.refresh_vae_list, lambda: {"choices": ["Automatic", "None"] + list(sd_vae.vae_dict)}, "checpoint_edit_user_metadata_refresh_preferred_vae")
self.select_vae = gr.Dropdown(choices=modules_list, value=None, label="Preferred VAE / Text encoder(s)", elem_id="checpoint_edit_user_metadata_preferred_vae", multiselect=True)
self.refresh = ToolButton(refresh_symbol)

self.refresh.click(fn=refreshModules, outputs=self.select_vae, show_progress='hidden')

self.edit_notes = gr.TextArea(label='Notes', lines=4)

Expand All @@ -49,6 +72,7 @@ def create_editor(self):
self.html_preview,
self.edit_notes,
self.select_vae,
self.sd_version,
]

self.button_edit\
Expand All @@ -59,8 +83,7 @@ def create_editor(self):
self.edit_description,
self.edit_notes,
self.select_vae,
self.sd_version,
]

self.setup_save_handler(self.button_save, self.save_user_metadata, edited_components)
self.button_save.click(fn=self.update_vae, inputs=[self.edit_name_input])

13 changes: 8 additions & 5 deletions modules/ui_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,15 +324,18 @@ def add_functionality(self, demo):
show_progress=False,
)

def button_set_checkpoint_change(value, dummy):
return value.split(' [')[0], opts.dumpjson()
def button_set_checkpoint_change(model, vae, dummy):
if 'Built in' in vae:
vae.remove('Built in')
model = sd_models.match_checkpoint_to_name(model)
return model, vae, opts.dumpjson()

button_set_checkpoint = gr.Button('Change checkpoint', elem_id='change_checkpoint', visible=False)
button_set_checkpoint.click(
fn=button_set_checkpoint_change,
js="function(v){ var res = desiredCheckpointName; desiredCheckpointName = ''; return [res || v, null]; }",
inputs=[main_entry.ui_checkpoint, self.dummy_component],
outputs=[main_entry.ui_checkpoint, self.text_settings],
js="function(c, v, n){ var ckpt = desiredCheckpointName; var vae = desiredVAEName; if (vae == 0) vae = v; desiredCheckpointName = null; desiredVAEName = 0; return [ckpt, vae, null]; }",
inputs=[main_entry.ui_checkpoint, main_entry.ui_vae, self.dummy_component],
outputs=[main_entry.ui_checkpoint, main_entry.ui_vae, self.text_settings],
)

component_keys = [k for k in opts.data_labels.keys() if k in self.component_dict]
Expand Down

0 comments on commit 7876862

Please sign in to comment.