From 854c68638095dbbe072b82ae714f41528ed629ef Mon Sep 17 00:00:00 2001 From: darioft Date: Wed, 2 Nov 2022 03:19:09 -0300 Subject: [PATCH] Reduce memory usage when merging and UX improvements. I did a few tweaks to reduce memory usage by not loading all models at the same time when merging using "Add diference", by loading and unloading the models when appropiate. --- modules/extras.py | 86 +++++++++++++++++++++++++++++------------------ modules/ui.py | 4 +-- 2 files changed, 55 insertions(+), 35 deletions(-) diff --git a/modules/extras.py b/modules/extras.py index 8e2ab35c2cc..9d582bf8b55 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -248,7 +248,7 @@ def run_pnginfo(image): return '', geninfo, info -def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_name, interp_method, multiplier, save_as_half, custom_name): +def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name): def weighted_sum(theta0, theta1, alpha): return ((1 - alpha) * theta0) + (alpha * theta1) @@ -258,49 +258,69 @@ def get_difference(theta1, theta2): def add_difference(theta0, theta1_2_diff, alpha): return theta0 + (alpha * theta1_2_diff) + theta_funcs = { + "Weighted sum": (None, weighted_sum), + "Add difference": (get_difference, add_difference), + } + + theta_func1, theta_func2 = theta_funcs[interp_method] + + # Load info for A and B as they're always required. primary_model_info = sd_models.checkpoints_list[primary_model_name] secondary_model_info = sd_models.checkpoints_list[secondary_model_name] - teritary_model_info = sd_models.checkpoints_list.get(teritary_model_name, None) + b_loaded = False - print(f"Loading {primary_model_info.filename}...") - primary_model = torch.load(primary_model_info.filename, map_location='cpu') - theta_0 = sd_models.get_state_dict_from_checkpoint(primary_model) + print(f"Interpolation method: {interp_method}") + print(f"Merging (Step 1/2)...") - print(f"Loading {secondary_model_info.filename}...") - secondary_model = torch.load(secondary_model_info.filename, map_location='cpu') - theta_1 = sd_models.get_state_dict_from_checkpoint(secondary_model) + if interp_method == "Add difference": - if teritary_model_info is not None: - print(f"Loading {teritary_model_info.filename}...") - teritary_model = torch.load(teritary_model_info.filename, map_location='cpu') - theta_2 = sd_models.get_state_dict_from_checkpoint(teritary_model) - else: - teritary_model = None - theta_2 = None + if tertiary_model_name != "": - theta_funcs = { - "Weighted sum": (None, weighted_sum), - "Add difference": (get_difference, add_difference), - } - theta_func1, theta_func2 = theta_funcs[interp_method] + # Load models B and C. + print(f"Loading secondary model (B): {secondary_model_info.filename}...") + secondary_model = torch.load(secondary_model_info.filename, map_location='cpu') + theta_1 = sd_models.get_state_dict_from_checkpoint(secondary_model) + b_loaded = True - print(f"Merging...") + tertiary_model_info = sd_models.checkpoints_list.get(tertiary_model_name, None) + if tertiary_model_info is not None: + print(f"Loading tertiary model (C): {tertiary_model_info.filename}...") + tertiary_model = torch.load(tertiary_model_info.filename, map_location='cpu') + theta_2 = sd_models.get_state_dict_from_checkpoint(tertiary_model) + else: + tertiary_model = None + theta_2 = None + + if theta_func1: + for key in tqdm.tqdm(theta_1.keys()): + if 'model' in key: + if key in theta_2: + t2 = theta_2.get(key, torch.zeros_like(theta_1[key])) + theta_1[key] = theta_func1(theta_1[key], t2) + else: + theta_1[key] = torch.zeros_like(theta_1[key]) + del theta_2, tertiary_model + else: + print(f"No model selected for C.") + return ["Select a tertiary model (C) or consider using 'Weighted sum'"] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)] - if theta_func1: - for key in tqdm.tqdm(theta_1.keys()): - if 'model' in key: - if key in theta_2: - t2 = theta_2.get(key, torch.zeros_like(theta_1[key])) - theta_1[key] = theta_func1(theta_1[key], t2) - else: - theta_1[key] = torch.zeros_like(theta_1[key]) - del theta_2, teritary_model + # Load model A. + print(f"Loading primary model (A): {primary_model_info.filename}...") + primary_model = torch.load(primary_model_info.filename, map_location='cpu') + theta_0 = sd_models.get_state_dict_from_checkpoint(primary_model) + + # Load model B if we haven't loaded it yet to operate with C. + if b_loaded == False: + print(f"Loading secondary model (B): {secondary_model_info.filename}...") + secondary_model = torch.load(secondary_model_info.filename, map_location='cpu') + theta_1 = sd_models.get_state_dict_from_checkpoint(secondary_model) + + print(f"Merging (Step 2/2)...") for key in tqdm.tqdm(theta_0.keys()): if 'model' in key and key in theta_1: - theta_0[key] = theta_func2(theta_0[key], theta_1[key], multiplier) - if save_as_half: theta_0[key] = theta_0[key].half() @@ -323,4 +343,4 @@ def add_difference(theta0, theta1_2_diff, alpha): sd_models.list_models() print(f"Checkpoint saved.") - return ["Checkpoint saved to " + output_modelname] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)] + return ["Checkpoint saved to " + output_modelname] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)] \ No newline at end of file diff --git a/modules/ui.py b/modules/ui.py index a94f46ea7ec..045d77b8bc3 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1152,8 +1152,8 @@ def create_ui(wrap_gradio_gpu_call): gr.HTML(value="

A merger of the two checkpoints will be generated in your checkpoint directory.

") with gr.Row(): - primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary model (A)") - secondary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_secondary_model_name", label="Secondary model (B)") + primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), value=random.choice(modules.sd_models.checkpoint_tiles()), elem_id="modelmerger_primary_model_name", label="Primary model (A)") + secondary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), value=random.choice(modules.sd_models.checkpoint_tiles()), elem_id="modelmerger_secondary_model_name", label="Secondary model (B)") tertiary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_tertiary_model_name", label="Tertiary model (C)") custom_name = gr.Textbox(label="Custom Name (Optional)") interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3)