Skip to content

Commit

Permalink
Hack for multi GPU inference
Browse files Browse the repository at this point in the history
  • Loading branch information
TikiTDO committed Feb 28, 2023
1 parent 0cc0ee1 commit 2619a99
Show file tree
Hide file tree
Showing 7 changed files with 184 additions and 41 deletions.
177 changes: 151 additions & 26 deletions modules/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
import os
import sys
import warnings
import traceback

import torch
import torch.multiprocessing as multiprocessing
import numpy as np
from PIL import Image, ImageFilter, ImageOps
import random
Expand All @@ -22,7 +24,12 @@
import modules.images as images
import modules.styles
import modules.sd_models as sd_models
import modules.sd_hijack as sd_hijack
import modules.sd_vae as sd_vae
from modules import modelloader
import modules.codeformer_model as codeformer
import modules.gfpgan_model as gfpgan

import logging
from ldm.data.util import AddMiDaS
from ldm.models.diffusion.ddpm import LatentDepth2ImageDiffusion
Expand All @@ -34,6 +41,106 @@
opt_C = 4
opt_f = 8

gpu_tasks = []
gpu_count = torch.cuda.device_count()


def init_processors():
# Each task will just dispatch the same parameters to it's assignedGPU
class GpuTask:
def __init__(self, gpu_id):
self.gpu_id = gpu_id
self.queue_to_processor = multiprocessing.Queue()
self.queue_from_processor = multiprocessing.Queue()
self.process = multiprocessing.Process(target=start_processor, args=(self.gpu_id, self.queue_to_processor, self.queue_from_processor))
self.process.start()
self.is_busy = False



def task(self, p):
self.is_busy = True
self.queue_to_processor.put(p)

def get_result(self):
result = self.queue_from_processor.get(True)
self.is_busy = False
return result

def __del__(self):
if self.process and self.process.is_alive():
self.queue_to_processor.put(None)
self.process.join()

gpu_tasks.remove(self)

multiprocessing.set_start_method('spawn')

for i in range(gpu_count):
gpu_tasks.append(GpuTask(i))

def start_processor(gpu_id, input_queue: multiprocessing.Queue, output_queue: multiprocessing.Queue, override_settings: dict = {}, override_settings_restore_afterwards: dict = {}):
"""Start the processing loop which will listen for new images requests from the input queue, and put the result images into the output queue."""
stored_opts = {k: opts.data[k] for k in override_settings.keys()}

with torch.cuda.device(gpu_id):
# sd_hijack.model_hijack.undo_hijack(shared.sd_model)
# sd_models.update_model_for_current_device()

modelloader.cleanup_models()
modules.sd_models.setup_model()
codeformer.setup_model(cmd_opts.codeformer_models_path)
gfpgan.setup_model(cmd_opts.gfpgan_models_path)

modelloader.list_builtin_upscalers()
modules.scripts.load_scripts()
modelloader.load_upscalers()

modules.sd_models.load_model()
sd_hijack.model_hijack.hijack(shared.sd_model)



while True:
p = input_queue.get(True)

print(f"Processor {gpu_id} got new task {p}\n")

if p is None:
break
else:
try:
if p is None:
break
else:
print("Starting processing\n")
result = process_images_inner(p)
print(f"Processor {gpu_id} finished processing {len(result.images)} images\n")
output_queue.put(result)

except Exception as e:
print(f"Exception in processor {gpu_id}: {e}\n")
# Print backtrace
traceback.print_tb(e.__traceback__)
print(f"Sigh\n")

output_queue.put(None)

finally:
print(f"Finally\n")

# restore opts to original state
if override_settings_restore_afterwards:
for k, v in stored_opts.items():
setattr(opts, k, v)

if k == 'sd_model_checkpoint':
sd_models.reload_model_weights()

if k == 'sd_vae':
sd_vae.reload_vae_weights()



def setup_color_correction(image):
logging.info("Calibrating color correction.")
Expand Down Expand Up @@ -469,35 +576,48 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter

return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip()


def process_images(p: StableDiffusionProcessing) -> Processed:
stored_opts = {k: opts.data[k] for k in p.override_settings.keys()}

try:
for k, v in p.override_settings.items():
setattr(opts, k, v)

if k == 'sd_model_checkpoint':
sd_models.reload_model_weights()
"""Dispatches the processing of the images to the GPU's and waits for the results"""

if k == 'sd_vae':
sd_vae.reload_vae_weights()
first_result = None

res = process_images_inner(p)
if len(gpu_tasks ) == 0:
init_processors()

finally:
# restore opts to original state
if p.override_settings_restore_afterwards:
for k, v in stored_opts.items():
setattr(opts, k, v)
if k == 'sd_model_checkpoint':
sd_models.reload_model_weights()
# Each iteration will pick a GPU, and wait for the last task on that GPU to finish before starting the next one
# This works best when the GPUs are identical
for i in range(p.n_iter):
task = gpu_tasks[i % gpu_count]

if k == 'sd_vae':
sd_vae.reload_vae_weights()

return res
# If the task is still busy, block until it's done
current_result = None
if task.is_busy:
current_result = task.get_result()
if current_result is None:
# No result for some reason, so skip to next task
continue
else:
# Load the images into the first result
if first_result is None:
first_result = current_result
else:
first_result.images.extend(current_result.images)

# Scripts don't get pickled, so we need to set it to None before sending it to the GPU
p.scripts = None
task.task(p)

# Now wait for all the remaining tasks to finish
for i in range(gpu_count):
task = gpu_tasks[i]
if task.is_busy:
current_result = task.get_result()
if first_result is None:
first_result = current_result
else:
first_result.images.extend(current_result.images)

return first_result

def process_images_inner(p: StableDiffusionProcessing) -> Processed:
"""this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
Expand Down Expand Up @@ -572,6 +692,8 @@ def get_conds_with_caching(function, required_prompts, steps, cache):
cache[0] = (required_prompts, steps)
return cache[1]

extra_network_data = None

with torch.no_grad(), p.sd_model.ema_scope():
with devices.autocast():
p.init(p.all_prompts, p.all_seeds, p.all_subseeds)
Expand All @@ -580,10 +702,13 @@ def get_conds_with_caching(function, required_prompts, steps, cache):
if shared.opts.live_previews_enable and opts.show_progress_type == "Approx NN":
sd_vae_approx.model()

state.skipped = False
state.interrupted = False

if state.job_count == -1:
state.job_count = p.n_iter
state.job_count = 1

for n in range(p.n_iter):
for n in [1]:
p.iteration = n

if state.skipped:
Expand Down Expand Up @@ -709,7 +834,7 @@ def get_conds_with_caching(function, required_prompts, steps, cache):
if opts.grid_save:
images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True)

if not p.disable_extra_networks:
if not p.disable_extra_networks and extra_network_data is not None:
extra_networks.deactivate(p, extra_network_data)

devices.torch_gc()
Expand Down
2 changes: 2 additions & 0 deletions modules/scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,8 @@ def run(self, p, *args):
if script_index == 0:
return None



script = self.selectable_scripts[script_index-1]

if script is None:
Expand Down
12 changes: 10 additions & 2 deletions modules/sd_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def model_hash(filename):

def select_checkpoint():
model_checkpoint = shared.opts.sd_model_checkpoint

checkpoint_info = checkpoint_alisases.get(model_checkpoint, None)
if checkpoint_info is not None:
return checkpoint_info
Expand Down Expand Up @@ -339,7 +339,7 @@ def load_model_wrapper(model_type):
if not os.path.exists(path):
if not os.path.exists(midas_path):
mkdir(midas_path)

print(f"Downloading midas model weights for {model_type} to {path}")
request.urlretrieve(midas_urls[model_type], path)
print(f"{model_type} downloaded")
Expand All @@ -363,6 +363,13 @@ def repair_config(sd_config):
sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight'
sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight'

def update_model_for_current_device():
device = torch.cuda.current_device()
if device not in shared.sd_models_by_device:
load_model()

shared.sd_model = shared.sd_models_by_device[device]

def load_model(checkpoint_info=None, already_loaded_state_dict=None, time_taken_to_load_state_dict=None):
from modules import lowvram, sd_hijack
checkpoint_info = checkpoint_info or select_checkpoint()
Expand Down Expand Up @@ -424,6 +431,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None, time_taken_

sd_model.eval()
shared.sd_model = sd_model
shared.sd_models_by_device[torch.cuda.current_device()] = sd_model

sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) # Reload embeddings after model load as they may or may not fit the model

Expand Down
16 changes: 7 additions & 9 deletions modules/sd_vae_approx.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
from torch import nn
from modules import devices, paths

sd_vae_approx_model = None

sd_vae_approx_models = {}

class VAEApprox(nn.Module):
def __init__(self):
Expand All @@ -32,15 +31,14 @@ def forward(self, x):


def model():
global sd_vae_approx_model

if sd_vae_approx_model is None:
sd_vae_approx_model = VAEApprox()
sd_vae_approx_model.load_state_dict(torch.load(os.path.join(paths.models_path, "VAE-approx", "model.pt"), map_location='cpu' if devices.device.type != 'cuda' else None))
sd_vae_approx_model.eval()
sd_vae_approx_model.to(devices.device, devices.dtype)
if torch.cuda.current_device() not in sd_vae_approx_models:
sd_vae_approx_models[torch.cuda.current_device()] = VAEApprox()
sd_vae_approx_models[torch.cuda.current_device()].load_state_dict(torch.load(os.path.join(paths.models_path, "VAE-approx", "model.pt"), map_location='cpu' if devices.device.type != 'cuda' else None))
sd_vae_approx_models[torch.cuda.current_device()].eval()
sd_vae_approx_models[torch.cuda.current_device()].to(devices.device, devices.dtype)

return sd_vae_approx_model
return sd_vae_approx_models[torch.cuda.current_device()]


def cheap_approximation(sample):
Expand Down
5 changes: 4 additions & 1 deletion modules/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,8 @@
"scripts",
]

cmd_opts.disable_extension_access = (cmd_opts.share or cmd_opts.listen or cmd_opts.server_name) and not cmd_opts.enable_insecure_extension_access
# cmd_opts.disable_extension_access = (cmd_opts.share or cmd_opts.listen or cmd_opts.server_name) and not cmd_opts.enable_insecure_extension_access
cmd_opts.disable_extension_access = False

devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_esrgan, devices.device_codeformer = \
(devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'esrgan', 'codeformer'])
Expand Down Expand Up @@ -658,6 +659,8 @@ def cast_value(self, key, value):

sd_model = None

sd_models_by_device = {}

clip_model = None

progress_print_out = sys.stdout
Expand Down
5 changes: 5 additions & 0 deletions webui.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,11 @@ def initialize():

try:
modules.sd_models.load_model()

# with torch.cuda.device(0):
# modules.sd_models.load_model()
# with torch.cuda.device(1):
# modules.sd_models.load_model()
except Exception as e:
errors.display(e, "loading stable diffusion model")
print("", file=sys.stderr)
Expand Down
8 changes: 5 additions & 3 deletions webui.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# change the variables in webui-user.sh instead #
#################################################

export COMMANDLINE_ARGS="--xformers --no-half --precision=full --listen"

This comment has been minimized.

Copy link
@peteygao

peteygao Dec 15, 2023

@TikiTDO Is there a reason you need to use full precision rather than half-width FP16 when inferring in parallel? I don't think there's any torch limitation.

Or are you simply taking advantage of the increased VRAM capacity to load in full precision for slightly higher quality model?

This comment has been minimized.

Copy link
@TikiTDO

TikiTDO Dec 15, 2023

Author Owner

Honestly, this is so long ago that I can't say. I think in that case I was having some trouble with training giving me black images.

Keep in mind this isn't any sort of prod-ready code, it's just something I hacked up for my own use long ago in the ancient dinosaur days of February 2023. Now that it's Dec 2023 and a few centuries have passed it's kinda hard to recall things that far back.

This comment has been minimized.

Copy link
@peteygao

peteygao Dec 15, 2023

Yeah, that makes sense. Ironic that I came across this fork because the official stable diffusion webui still doesn't support multi-GPU to parallelise inference. Btw, what GPU are you running? Older cards don't support half-width fp16 (pre-Pascal), so that could explain it. While some cards support it but at 1/64th the speed (god knows why, nvidia gimped on purpose so people have to buy their Quadro cards?) so maybe the preview just loads reeeaaaally sloooooowly 😂.

This comment has been minimized.

Copy link
@TikiTDO

TikiTDO Dec 15, 2023

Author Owner

I have some 3090s I scored used from a company that was closing down. When I still ran this code it really struggled with less than 24GB of vram. A lot of older supports legit don't have fp16 at the hardware level. If you're seeing it working on older cards maybe there's a software layer involved?

That said, I actually don't use this code anymore. Originally when I was just starting out with AI art I needed a lot more tries to get the results I wanted. These days I have improved my skills and my workflows so that I honestly wouldn't benefit too much the speed boost I would get from parallel generation. That's probably why this story has seen so little activity.

It's just that once you get good enough at using all the tools at your disposal in the system, the amount of time you spend generating images as opposed to thinking what to do next starts to weigh much heavier towards the latter.


# If run from macOS, load defaults from webui-macos-env.sh
if [[ "$OSTYPE" == "darwin"* ]]; then
if [[ -f webui-macos-env.sh ]]
Expand Down Expand Up @@ -113,13 +115,13 @@ case "$gpu_info" in
printf "Experimental support for Renoir: make sure to have at least 4GB of VRAM and 10GB of RAM or enable cpu mode: --use-cpu all --no-half"
printf "\n%s\n" "${delimiter}"
;;
*)
*)
;;
esac
if echo "$gpu_info" | grep -q "AMD" && [[ -z "${TORCH_COMMAND}" ]]
then
export TORCH_COMMAND="pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/rocm5.2"
fi
fi

for preq in "${GIT}" "${python_cmd}"
do
Expand Down Expand Up @@ -181,6 +183,6 @@ then
else
printf "\n%s\n" "${delimiter}"
printf "Launching launch.py..."
printf "\n%s\n" "${delimiter}"
printf "\n%s\n" "${delimiter}"
exec "${python_cmd}" "${LAUNCH_SCRIPT}" "$@"
fi

0 comments on commit 2619a99

Please sign in to comment.