Skip to content

Commit

Permalink
Enable distributed sample image generation on multi-GPU enviroment (#…
Browse files Browse the repository at this point in the history
…1061)

* Update train_util.py

Modifying to attempt enable multi GPU inference

* Update train_util.py

additional VRAM checking, refactor check_vram_usage to return string for use with accelerator.print

* Update train_network.py

* Update train_util.py

* Update train_util.py

* Update train_util.py

* Update train_util.py

* Update train_util.py

* Update train_util.py

* Update train_util.py

* Update train_util.py

* Update train_util.py

* Update train_util.py

* Update train_util.py

* Update train_util.py

* Update train_util.py

* Update train_util.py

* Update train_util.py

* Update train_util.py

remove sample image debug outputs

* Update train_util.py

* Update train_util.py

* Update train_network.py

* Update train_util.py

* Update train_util.py

* Update train_util.py

* Update train_util.py

* Update train_util.py

* Update train_util.py

* Update train_util.py

* Update train_util.py

* Update train_util.py

* Update train_network.py

* Update train_util.py

* Update train_network.py

* Update train_network.py

* Update train_network.py

* Cleanup of debugging outputs

* adopt more elegant coding

Co-authored-by: Aarni Koskela <akx@iki.fi>

* Update train_util.py

Fix leftover debugging code
attempt to refactor inference into separate function

* refactor in function generate_per_device_prompt_list() generation of distributed prompt list

* Clean up missing variables

* fix syntax error

* Update train_util.py

* Update train_util.py

* Update train_util.py

* Update train_util.py

* Update train_util.py

* Update train_util.py

* Update train_util.py

* Update train_util.py

* true random sample image generation

update code to reinitialize random seed to true random if seed was set

* true random sample image generation

* simplify per process prompt

* Update train_util.py

* Update train_util.py

* Update train_util.py

* Update train_util.py

* Update train_util.py

* Update train_util.py

* Update train_util.py

* Update train_util.py

* Update train_network.py

* Update train_network.py

* Update train_network.py

---------

Co-authored-by: Aarni Koskela <akx@iki.fi>
  • Loading branch information
DKnight54 and akx authored Feb 3, 2024
1 parent 7f948db commit 1567ce1
Showing 1 changed file with 115 additions and 93 deletions.
208 changes: 115 additions & 93 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
Tuple,
Union,
)
from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs
from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs, PartialState
import gc
import glob
import math
Expand Down Expand Up @@ -4636,7 +4636,6 @@ def line_to_prompt_dict(line: str) -> dict:

return prompt_dict


def sample_images_common(
pipe_class,
accelerator: Accelerator,
Expand All @@ -4654,6 +4653,7 @@ def sample_images_common(
"""
StableDiffusionLongPromptWeightingPipelineの改造版を使うようにしたので、clip skipおよびプロンプトの重みづけに対応した
"""

if steps == 0:
if not args.sample_at_first:
return
Expand All @@ -4668,13 +4668,15 @@ def sample_images_common(
if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch
return

distributed_state = PartialState() #testing implementation of multi gpu distributed inference

print(f"\ngenerating sample images at step / サンプル画像生成 ステップ: {steps}")
if not os.path.isfile(args.sample_prompts):
print(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}")
return

org_vae_device = vae.device # CPUにいるはず
vae.to(device)
vae.to(distributed_state.device)

# unwrap unet and text_encoder(s)
unet = accelerator.unwrap_model(unet)
Expand All @@ -4700,12 +4702,11 @@ def sample_images_common(
with open(args.sample_prompts, "r", encoding="utf-8") as f:
prompts = json.load(f)

schedulers: dict = {}
# schedulers: dict = {} cannot find where this is used
default_scheduler = get_my_scheduler(
sample_sampler=args.sample_sampler,
v_parameterization=args.v_parameterization,
)
schedulers[args.sample_sampler] = default_scheduler

pipeline = pipe_class(
text_encoder=text_encoder,
Expand All @@ -4718,114 +4719,135 @@ def sample_images_common(
requires_safety_checker=False,
clip_skip=args.clip_skip,
)
pipeline.to(device)

pipeline.to(distributed_state.device)
save_dir = args.output_dir + "/sample"
os.makedirs(save_dir, exist_ok=True)

# Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processess available (number of devices available)

Check warning on line 4726 in library/train_util.py

View workflow job for this annotation

GitHub Actions / build

"processess" should be "processes".
# prompt_dicts are assigned to lists based on order of processes, to attempt to time the image creation time to match enum order. Probably only works when steps and sampler are identical.
per_process_prompts = generate_per_device_prompt_list(prompts, num_of_processes = distributed_state.num_processes, prompt_replacement = prompt_replacement)

rng_state = torch.get_rng_state()
cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None
# True random sample image generation
torch.seed()
torch.cuda.seed()

with torch.no_grad():
# with accelerator.autocast():
for i, prompt_dict in enumerate(prompts):
if not accelerator.is_main_process:
continue

if isinstance(prompt_dict, str):
prompt_dict = line_to_prompt_dict(prompt_dict)

assert isinstance(prompt_dict, dict)
negative_prompt = prompt_dict.get("negative_prompt")
sample_steps = prompt_dict.get("sample_steps", 30)
width = prompt_dict.get("width", 512)
height = prompt_dict.get("height", 512)
scale = prompt_dict.get("scale", 7.5)
seed = prompt_dict.get("seed")
controlnet_image = prompt_dict.get("controlnet_image")
prompt: str = prompt_dict.get("prompt", "")
sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler)

if seed is not None:
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

scheduler = schedulers.get(sampler_name)
if scheduler is None:
scheduler = get_my_scheduler(
sample_sampler=sampler_name,
v_parameterization=args.v_parameterization,
)
schedulers[sampler_name] = scheduler
pipeline.scheduler = scheduler

if prompt_replacement is not None:
prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
if negative_prompt is not None:
negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1])

if controlnet_image is not None:
controlnet_image = Image.open(controlnet_image).convert("RGB")
controlnet_image = controlnet_image.resize((width, height), Image.LANCZOS)

height = max(64, height - height % 8) # round to divisible by 8
width = max(64, width - width % 8) # round to divisible by 8
print(f"prompt: {prompt}")
print(f"negative_prompt: {negative_prompt}")
print(f"height: {height}")
print(f"width: {width}")
print(f"sample_steps: {sample_steps}")
print(f"scale: {scale}")
print(f"sample_sampler: {sampler_name}")
if seed is not None:
print(f"seed: {seed}")
with accelerator.autocast():
latents = pipeline(
prompt=prompt,
height=height,
width=width,
num_inference_steps=sample_steps,
guidance_scale=scale,
negative_prompt=negative_prompt,
controlnet=controlnet,
controlnet_image=controlnet_image,
)
with distributed_state.split_between_processes(per_process_prompts) as prompt_dict_lists:
for prompt_dict in prompt_dict_lists[0]:
sample_image_inference(accelerator, args, pipeline, save_dir, prompt_dict, epoch, steps, controlnet=controlnet)

image = pipeline.latents_to_image(latents)[0]

ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime())
num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}"
seed_suffix = "" if seed is None else f"_{seed}"
img_filename = (
f"{'' if args.output_name is None else args.output_name + '_'}{ts_str}_{num_suffix}_{i:02d}{seed_suffix}.png"
)

image.save(os.path.join(save_dir, img_filename))

# wandb有効時のみログを送信
try:
wandb_tracker = accelerator.get_tracker("wandb")
try:
import wandb
except ImportError: # 事前に一度確認するのでここはエラー出ないはず
raise ImportError("No wandb / wandb がインストールされていないようです")

wandb_tracker.log({f"sample_{i}": wandb.Image(image)})
except: # wandb 無効時
pass

# clear pipeline and cache to reduce vram usage
del pipeline
torch.cuda.empty_cache()

with torch.cuda.device(torch.cuda.current_device()):
torch.cuda.empty_cache()

torch.set_rng_state(rng_state)
if cuda_rng_state is not None:
torch.cuda.set_rng_state(cuda_rng_state)
vae.to(org_vae_device)

def generate_per_device_prompt_list(prompts, num_of_processes, prompt_replacement=None):

# Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processess available (number of devices available)

Check warning on line 4755 in library/train_util.py

View workflow job for this annotation

GitHub Actions / build

"processess" should be "processes".
# prompt_dicts are assigned to lists based on order of processes, to attempt to time the image creation time to match enum order. Probably only works when steps and sampler are identical.
per_process_prompts = [[] for i in range(num_of_processes)]
for i, prompt in enumerate(prompts):
if isinstance(prompt, str):
prompt = line_to_prompt_dict(prompt)
assert isinstance(prompt, dict)
prompt.pop("subset", None) # Clean up subset key
prompt["enum"] = i
# Adds an enumerator to the dict based on prompt position. Used later to name image files. Also cleanup of extra data in original prompt dict.
if prompt_replacement is not None:
prompt["prompt"] = prompt["prompt"].replace(prompt_replacement[0], prompt_replacement[1])
if prompt["negative_prompt"] is not None:
prompt["negative_prompt"] = prompt["negative_prompt"].replace(prompt_replacement[0], prompt_replacement[1])
# Refactor prompt replacement to here in order to simplify sample_image_inference function.
per_process_prompts[i % num_of_processes].append(prompt)
return per_process_prompts

def sample_image_inference(accelerator: Accelerator, args: argparse.Namespace, pipeline, save_dir, prompt_dict, epoch, steps, controlnet=None):
assert isinstance(prompt_dict, dict)
negative_prompt = prompt_dict.get("negative_prompt")
sample_steps = prompt_dict.get("sample_steps", 30)
width = prompt_dict.get("width", 512)
height = prompt_dict.get("height", 512)
scale = prompt_dict.get("scale", 7.5)
seed = prompt_dict.get("seed")
controlnet_image = prompt_dict.get("controlnet_image")
prompt: str = prompt_dict.get("prompt", "")
sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler)

if seed is not None:
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

scheduler = get_my_scheduler(
sample_sampler=sampler_name,
v_parameterization=args.v_parameterization,
)
pipeline.scheduler = scheduler

if controlnet_image is not None:
controlnet_image = Image.open(controlnet_image).convert("RGB")
controlnet_image = controlnet_image.resize((width, height), Image.LANCZOS)

height = max(64, height - height % 8) # round to divisible by 8
width = max(64, width - width % 8) # round to divisible by 8
print(f"\nprompt: {prompt}")
print(f"negative_prompt: {negative_prompt}")
print(f"height: {height}")
print(f"width: {width}")
print(f"sample_steps: {sample_steps}")
print(f"scale: {scale}")
print(f"sample_sampler: {sampler_name}")
if seed is not None:
print(f"seed: {seed}")
with accelerator.autocast():
latents = pipeline(
prompt=prompt,
height=height,
width=width,
num_inference_steps=sample_steps,
guidance_scale=scale,
negative_prompt=negative_prompt,
controlnet=controlnet,
controlnet_image=controlnet_image,
)
image = pipeline.latents_to_image(latents)[0]
# adding accelerator.wait_for_everyone() here should sync up and ensure that sample images are saved in the same order as the original prompt list
ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime())
num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}"
seed_suffix = "" if seed is None else f"_{seed}"
i: int = prompt_dict["enum"]
img_filename = (
f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png"
)

image.save(os.path.join(save_dir, img_filename))
if seed is not None:
torch.seed()
torch.cuda.seed()
# wandb有効時のみログを送信
try:
wandb_tracker = accelerator.get_tracker("wandb")
try:
import wandb
except ImportError: # 事前に一度確認するのでここはエラー出ないはず
raise ImportError("No wandb / wandb がインストールされていないようです")

wandb_tracker.log({f"sample_{i}": wandb.Image(image)})
except: # wandb 無効時
pass
# endregion




# region 前処理用


Expand Down

0 comments on commit 1567ce1

Please sign in to comment.