diff --git a/modules/processing.py b/modules/processing.py index 2009d3bf816..a5b534e235b 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -3,8 +3,10 @@ import os import sys import warnings +import traceback import torch +import torch.multiprocessing as multiprocessing import numpy as np from PIL import Image, ImageFilter, ImageOps import random @@ -22,7 +24,12 @@ import modules.images as images import modules.styles import modules.sd_models as sd_models +import modules.sd_hijack as sd_hijack import modules.sd_vae as sd_vae +from modules import modelloader +import modules.codeformer_model as codeformer +import modules.gfpgan_model as gfpgan + import logging from ldm.data.util import AddMiDaS from ldm.models.diffusion.ddpm import LatentDepth2ImageDiffusion @@ -34,6 +41,106 @@ opt_C = 4 opt_f = 8 +gpu_tasks = [] +gpu_count = torch.cuda.device_count() + + +def init_processors(): + # Each task will just dispatch the same parameters to it's assignedGPU + class GpuTask: + def __init__(self, gpu_id): + self.gpu_id = gpu_id + self.queue_to_processor = multiprocessing.Queue() + self.queue_from_processor = multiprocessing.Queue() + self.process = multiprocessing.Process(target=start_processor, args=(self.gpu_id, self.queue_to_processor, self.queue_from_processor)) + self.process.start() + self.is_busy = False + + + + def task(self, p): + self.is_busy = True + self.queue_to_processor.put(p) + + def get_result(self): + result = self.queue_from_processor.get(True) + self.is_busy = False + return result + + def __del__(self): + if self.process and self.process.is_alive(): + self.queue_to_processor.put(None) + self.process.join() + + gpu_tasks.remove(self) + + multiprocessing.set_start_method('spawn') + + for i in range(gpu_count): + gpu_tasks.append(GpuTask(i)) + +def start_processor(gpu_id, input_queue: multiprocessing.Queue, output_queue: multiprocessing.Queue, override_settings: dict = {}, override_settings_restore_afterwards: dict = {}): + """Start the processing loop which will listen for new images requests from the input queue, and put the result images into the output queue.""" + stored_opts = {k: opts.data[k] for k in override_settings.keys()} + + with torch.cuda.device(gpu_id): + # sd_hijack.model_hijack.undo_hijack(shared.sd_model) + # sd_models.update_model_for_current_device() + + modelloader.cleanup_models() + modules.sd_models.setup_model() + codeformer.setup_model(cmd_opts.codeformer_models_path) + gfpgan.setup_model(cmd_opts.gfpgan_models_path) + + modelloader.list_builtin_upscalers() + modules.scripts.load_scripts() + modelloader.load_upscalers() + + modules.sd_models.load_model() + sd_hijack.model_hijack.hijack(shared.sd_model) + + + + while True: + p = input_queue.get(True) + + print(f"Processor {gpu_id} got new task {p}\n") + + if p is None: + break + else: + try: + if p is None: + break + else: + print("Starting processing\n") + result = process_images_inner(p) + print(f"Processor {gpu_id} finished processing {len(result.images)} images\n") + output_queue.put(result) + + except Exception as e: + print(f"Exception in processor {gpu_id}: {e}\n") + # Print backtrace + traceback.print_tb(e.__traceback__) + print(f"Sigh\n") + + output_queue.put(None) + + finally: + print(f"Finally\n") + + # restore opts to original state + if override_settings_restore_afterwards: + for k, v in stored_opts.items(): + setattr(opts, k, v) + + if k == 'sd_model_checkpoint': + sd_models.reload_model_weights() + + if k == 'sd_vae': + sd_vae.reload_vae_weights() + + def setup_color_correction(image): logging.info("Calibrating color correction.") @@ -469,35 +576,48 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip() - def process_images(p: StableDiffusionProcessing) -> Processed: - stored_opts = {k: opts.data[k] for k in p.override_settings.keys()} - - try: - for k, v in p.override_settings.items(): - setattr(opts, k, v) - - if k == 'sd_model_checkpoint': - sd_models.reload_model_weights() + """Dispatches the processing of the images to the GPU's and waits for the results""" - if k == 'sd_vae': - sd_vae.reload_vae_weights() + first_result = None - res = process_images_inner(p) + if len(gpu_tasks ) == 0: + init_processors() - finally: - # restore opts to original state - if p.override_settings_restore_afterwards: - for k, v in stored_opts.items(): - setattr(opts, k, v) - if k == 'sd_model_checkpoint': - sd_models.reload_model_weights() + # Each iteration will pick a GPU, and wait for the last task on that GPU to finish before starting the next one + # This works best when the GPUs are identical + for i in range(p.n_iter): + task = gpu_tasks[i % gpu_count] - if k == 'sd_vae': - sd_vae.reload_vae_weights() - - return res + # If the task is still busy, block until it's done + current_result = None + if task.is_busy: + current_result = task.get_result() + if current_result is None: + # No result for some reason, so skip to next task + continue + else: + # Load the images into the first result + if first_result is None: + first_result = current_result + else: + first_result.images.extend(current_result.images) + + # Scripts don't get pickled, so we need to set it to None before sending it to the GPU + p.scripts = None + task.task(p) + + # Now wait for all the remaining tasks to finish + for i in range(gpu_count): + task = gpu_tasks[i] + if task.is_busy: + current_result = task.get_result() + if first_result is None: + first_result = current_result + else: + first_result.images.extend(current_result.images) + return first_result def process_images_inner(p: StableDiffusionProcessing) -> Processed: """this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch""" @@ -572,6 +692,8 @@ def get_conds_with_caching(function, required_prompts, steps, cache): cache[0] = (required_prompts, steps) return cache[1] + extra_network_data = None + with torch.no_grad(), p.sd_model.ema_scope(): with devices.autocast(): p.init(p.all_prompts, p.all_seeds, p.all_subseeds) @@ -580,10 +702,13 @@ def get_conds_with_caching(function, required_prompts, steps, cache): if shared.opts.live_previews_enable and opts.show_progress_type == "Approx NN": sd_vae_approx.model() + state.skipped = False + state.interrupted = False + if state.job_count == -1: - state.job_count = p.n_iter + state.job_count = 1 - for n in range(p.n_iter): + for n in [1]: p.iteration = n if state.skipped: @@ -709,7 +834,7 @@ def get_conds_with_caching(function, required_prompts, steps, cache): if opts.grid_save: images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True) - if not p.disable_extra_networks: + if not p.disable_extra_networks and extra_network_data is not None: extra_networks.deactivate(p, extra_network_data) devices.torch_gc() diff --git a/modules/scripts.py b/modules/scripts.py index 24056a12f90..d511759490f 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -367,6 +367,8 @@ def run(self, p, *args): if script_index == 0: return None + + script = self.selectable_scripts[script_index-1] if script is None: diff --git a/modules/sd_models.py b/modules/sd_models.py index 93959f55f32..657fb7369f1 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -156,7 +156,7 @@ def model_hash(filename): def select_checkpoint(): model_checkpoint = shared.opts.sd_model_checkpoint - + checkpoint_info = checkpoint_alisases.get(model_checkpoint, None) if checkpoint_info is not None: return checkpoint_info @@ -339,7 +339,7 @@ def load_model_wrapper(model_type): if not os.path.exists(path): if not os.path.exists(midas_path): mkdir(midas_path) - + print(f"Downloading midas model weights for {model_type} to {path}") request.urlretrieve(midas_urls[model_type], path) print(f"{model_type} downloaded") @@ -363,6 +363,13 @@ def repair_config(sd_config): sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight' sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight' +def update_model_for_current_device(): + device = torch.cuda.current_device() + if device not in shared.sd_models_by_device: + load_model() + + shared.sd_model = shared.sd_models_by_device[device] + def load_model(checkpoint_info=None, already_loaded_state_dict=None, time_taken_to_load_state_dict=None): from modules import lowvram, sd_hijack checkpoint_info = checkpoint_info or select_checkpoint() @@ -424,6 +431,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None, time_taken_ sd_model.eval() shared.sd_model = sd_model + shared.sd_models_by_device[torch.cuda.current_device()] = sd_model sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) # Reload embeddings after model load as they may or may not fit the model diff --git a/modules/sd_vae_approx.py b/modules/sd_vae_approx.py index 0027343a745..a903c3f0497 100644 --- a/modules/sd_vae_approx.py +++ b/modules/sd_vae_approx.py @@ -4,8 +4,7 @@ from torch import nn from modules import devices, paths -sd_vae_approx_model = None - +sd_vae_approx_models = {} class VAEApprox(nn.Module): def __init__(self): @@ -32,15 +31,14 @@ def forward(self, x): def model(): - global sd_vae_approx_model - if sd_vae_approx_model is None: - sd_vae_approx_model = VAEApprox() - sd_vae_approx_model.load_state_dict(torch.load(os.path.join(paths.models_path, "VAE-approx", "model.pt"), map_location='cpu' if devices.device.type != 'cuda' else None)) - sd_vae_approx_model.eval() - sd_vae_approx_model.to(devices.device, devices.dtype) + if torch.cuda.current_device() not in sd_vae_approx_models: + sd_vae_approx_models[torch.cuda.current_device()] = VAEApprox() + sd_vae_approx_models[torch.cuda.current_device()].load_state_dict(torch.load(os.path.join(paths.models_path, "VAE-approx", "model.pt"), map_location='cpu' if devices.device.type != 'cuda' else None)) + sd_vae_approx_models[torch.cuda.current_device()].eval() + sd_vae_approx_models[torch.cuda.current_device()].to(devices.device, devices.dtype) - return sd_vae_approx_model + return sd_vae_approx_models[torch.cuda.current_device()] def cheap_approximation(sample): diff --git a/modules/shared.py b/modules/shared.py index 805f9cc19cf..fa491e1284c 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -141,7 +141,8 @@ "scripts", ] -cmd_opts.disable_extension_access = (cmd_opts.share or cmd_opts.listen or cmd_opts.server_name) and not cmd_opts.enable_insecure_extension_access +# cmd_opts.disable_extension_access = (cmd_opts.share or cmd_opts.listen or cmd_opts.server_name) and not cmd_opts.enable_insecure_extension_access +cmd_opts.disable_extension_access = False devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_esrgan, devices.device_codeformer = \ (devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'esrgan', 'codeformer']) @@ -658,6 +659,8 @@ def cast_value(self, key, value): sd_model = None +sd_models_by_device = {} + clip_model = None progress_print_out = sys.stdout diff --git a/webui.py b/webui.py index 9e8b486aa73..4e32ff9ff1c 100644 --- a/webui.py +++ b/webui.py @@ -109,6 +109,11 @@ def initialize(): try: modules.sd_models.load_model() + + # with torch.cuda.device(0): + # modules.sd_models.load_model() + # with torch.cuda.device(1): + # modules.sd_models.load_model() except Exception as e: errors.display(e, "loading stable diffusion model") print("", file=sys.stderr) diff --git a/webui.sh b/webui.sh index 8cdad22d310..76a85be9071 100755 --- a/webui.sh +++ b/webui.sh @@ -4,6 +4,8 @@ # change the variables in webui-user.sh instead # ################################################# +export COMMANDLINE_ARGS="--xformers --no-half --precision=full --listen" + # If run from macOS, load defaults from webui-macos-env.sh if [[ "$OSTYPE" == "darwin"* ]]; then if [[ -f webui-macos-env.sh ]] @@ -113,13 +115,13 @@ case "$gpu_info" in printf "Experimental support for Renoir: make sure to have at least 4GB of VRAM and 10GB of RAM or enable cpu mode: --use-cpu all --no-half" printf "\n%s\n" "${delimiter}" ;; - *) + *) ;; esac if echo "$gpu_info" | grep -q "AMD" && [[ -z "${TORCH_COMMAND}" ]] then export TORCH_COMMAND="pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/rocm5.2" -fi +fi for preq in "${GIT}" "${python_cmd}" do @@ -181,6 +183,6 @@ then else printf "\n%s\n" "${delimiter}" printf "Launching launch.py..." - printf "\n%s\n" "${delimiter}" + printf "\n%s\n" "${delimiter}" exec "${python_cmd}" "${LAUNCH_SCRIPT}" "$@" fi