Skip to content

Commit

Permalink
Merge pull request AUTOMATIC1111#7 from DarioFT/master
Browse files Browse the repository at this point in the history
Reduce memory usage when merging and UX improvements.
  • Loading branch information
uservar authored Nov 21, 2022
2 parents 056ae19 + 854c686 commit 2f70d94
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 35 deletions.
86 changes: 53 additions & 33 deletions modules/extras.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ def run_pnginfo(image):
return '', geninfo, info


def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_name, interp_method, multiplier, save_as_half, custom_name):
def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name):
def weighted_sum(theta0, theta1, alpha):
return ((1 - alpha) * theta0) + (alpha * theta1)

Expand All @@ -259,49 +259,69 @@ def get_difference(theta1, theta2):
def add_difference(theta0, theta1_2_diff, alpha):
return theta0 + (alpha * theta1_2_diff)

theta_funcs = {
"Weighted sum": (None, weighted_sum),
"Add difference": (get_difference, add_difference),
}

theta_func1, theta_func2 = theta_funcs[interp_method]

# Load info for A and B as they're always required.
primary_model_info = sd_models.checkpoints_list[primary_model_name]
secondary_model_info = sd_models.checkpoints_list[secondary_model_name]
teritary_model_info = sd_models.checkpoints_list.get(teritary_model_name, None)
b_loaded = False

print(f"Loading {primary_model_info.filename}...")
primary_model = torch.load(primary_model_info.filename, map_location='cpu')
theta_0 = sd_models.get_state_dict_from_checkpoint(primary_model)
print(f"Interpolation method: {interp_method}")
print(f"Merging (Step 1/2)...")

print(f"Loading {secondary_model_info.filename}...")
secondary_model = torch.load(secondary_model_info.filename, map_location='cpu')
theta_1 = sd_models.get_state_dict_from_checkpoint(secondary_model)
if interp_method == "Add difference":

if teritary_model_info is not None:
print(f"Loading {teritary_model_info.filename}...")
teritary_model = torch.load(teritary_model_info.filename, map_location='cpu')
theta_2 = sd_models.get_state_dict_from_checkpoint(teritary_model)
else:
teritary_model = None
theta_2 = None
if tertiary_model_name != "":

theta_funcs = {
"Weighted sum": (None, weighted_sum),
"Add difference": (get_difference, add_difference),
}
theta_func1, theta_func2 = theta_funcs[interp_method]
# Load models B and C.
print(f"Loading secondary model (B): {secondary_model_info.filename}...")
secondary_model = torch.load(secondary_model_info.filename, map_location='cpu')
theta_1 = sd_models.get_state_dict_from_checkpoint(secondary_model)
b_loaded = True

print(f"Merging...")
tertiary_model_info = sd_models.checkpoints_list.get(tertiary_model_name, None)
if tertiary_model_info is not None:
print(f"Loading tertiary model (C): {tertiary_model_info.filename}...")
tertiary_model = torch.load(tertiary_model_info.filename, map_location='cpu')
theta_2 = sd_models.get_state_dict_from_checkpoint(tertiary_model)
else:
tertiary_model = None
theta_2 = None

if theta_func1:
for key in tqdm.tqdm(theta_1.keys()):
if 'model' in key:
if key in theta_2:
t2 = theta_2.get(key, torch.zeros_like(theta_1[key]))
theta_1[key] = theta_func1(theta_1[key], t2)
else:
theta_1[key] = torch.zeros_like(theta_1[key])
del theta_2, tertiary_model
else:
print(f"No model selected for C.")
return ["Select a tertiary model (C) or consider using 'Weighted sum'"] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)]

if theta_func1:
for key in tqdm.tqdm(theta_1.keys()):
if 'model' in key:
if key in theta_2:
t2 = theta_2.get(key, torch.zeros_like(theta_1[key]))
theta_1[key] = theta_func1(theta_1[key], t2)
else:
theta_1[key] = torch.zeros_like(theta_1[key])
del theta_2, teritary_model
# Load model A.
print(f"Loading primary model (A): {primary_model_info.filename}...")
primary_model = torch.load(primary_model_info.filename, map_location='cpu')
theta_0 = sd_models.get_state_dict_from_checkpoint(primary_model)

# Load model B if we haven't loaded it yet to operate with C.
if b_loaded == False:
print(f"Loading secondary model (B): {secondary_model_info.filename}...")
secondary_model = torch.load(secondary_model_info.filename, map_location='cpu')
theta_1 = sd_models.get_state_dict_from_checkpoint(secondary_model)

print(f"Merging (Step 2/2)...")

for key in tqdm.tqdm(theta_0.keys()):
if 'model' in key and key in theta_1:

theta_0[key] = theta_func2(theta_0[key], theta_1[key], multiplier)

if save_as_half:
theta_0[key] = theta_0[key].half()

Expand All @@ -324,4 +344,4 @@ def add_difference(theta0, theta1_2_diff, alpha):
sd_models.list_models()

print(f"Checkpoint saved.")
return ["Checkpoint saved to " + output_modelname] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)]
return ["Checkpoint saved to " + output_modelname] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)]
4 changes: 2 additions & 2 deletions modules/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -1197,8 +1197,8 @@ def create_ui(wrap_gradio_gpu_call):
gr.HTML(value="<p>A merger of the two checkpoints will be generated in your <b>checkpoint</b> directory.</p>")

with gr.Row():
primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary model (A)")
secondary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_secondary_model_name", label="Secondary model (B)")
primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), value=random.choice(modules.sd_models.checkpoint_tiles()), elem_id="modelmerger_primary_model_name", label="Primary model (A)")
secondary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), value=random.choice(modules.sd_models.checkpoint_tiles()), elem_id="modelmerger_secondary_model_name", label="Secondary model (B)")
tertiary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_tertiary_model_name", label="Tertiary model (C)")
custom_name = gr.Textbox(label="Custom Name (Optional)")
interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3)
Expand Down

0 comments on commit 2f70d94

Please sign in to comment.