Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable distributed sample image generation on multi-GPU enviroment #1061

Merged
Merged
Changes from all commits
Commits
Show all changes
71 commits
Select commit Hold shift + click to select a range
3064a27
Update train_util.py
DKnight54 Jan 11, 2024
91ed9ee
Update train_util.py
DKnight54 Jan 11, 2024
6f5cd82
Update train_network.py
DKnight54 Jan 11, 2024
a072fa5
Update train_util.py
DKnight54 Jan 11, 2024
b91c1bc
Update train_util.py
DKnight54 Jan 11, 2024
84205f8
Update train_util.py
DKnight54 Jan 11, 2024
2db9986
Update train_util.py
DKnight54 Jan 11, 2024
e421805
Update train_util.py
DKnight54 Jan 11, 2024
bd838c0
Update train_util.py
DKnight54 Jan 11, 2024
9171cb9
Update train_util.py
DKnight54 Jan 11, 2024
4c6c5e1
Update train_util.py
DKnight54 Jan 11, 2024
7fe56f3
Update train_util.py
DKnight54 Jan 11, 2024
428ed3a
Update train_util.py
DKnight54 Jan 11, 2024
31e8dd2
Update train_util.py
DKnight54 Jan 11, 2024
3c33d00
Update train_util.py
DKnight54 Jan 11, 2024
f624f10
Update train_util.py
DKnight54 Jan 11, 2024
ccdbc82
Update train_util.py
DKnight54 Jan 11, 2024
ee10f4d
Update train_util.py
DKnight54 Jan 12, 2024
d4417d3
Update train_util.py
DKnight54 Jan 12, 2024
f3b8d4e
Update train_util.py
DKnight54 Jan 13, 2024
dd1976f
Update train_util.py
DKnight54 Jan 13, 2024
4ad0343
Update train_network.py
DKnight54 Jan 13, 2024
76c1459
Update train_util.py
DKnight54 Jan 13, 2024
03098d8
Update train_util.py
DKnight54 Jan 13, 2024
99eb53c
Update train_util.py
DKnight54 Jan 14, 2024
2da9ccf
Update train_util.py
DKnight54 Jan 14, 2024
edf6da1
Update train_util.py
DKnight54 Jan 14, 2024
6ad8361
Update train_util.py
DKnight54 Jan 14, 2024
d91f152
Update train_util.py
DKnight54 Jan 14, 2024
5f02d7d
Update train_util.py
DKnight54 Jan 14, 2024
7a4850b
Update train_util.py
DKnight54 Jan 18, 2024
da367f6
Update train_network.py
DKnight54 Jan 18, 2024
21c3dc0
Update train_util.py
DKnight54 Jan 18, 2024
c4a150e
Update train_network.py
DKnight54 Jan 18, 2024
6978b00
Update train_network.py
DKnight54 Jan 18, 2024
3ad5d3f
Update train_network.py
DKnight54 Jan 18, 2024
ba0a15f
Cleanup of debugging outputs
DKnight54 Jan 18, 2024
e53a36e
Merge branch 'kohya-ss:main' into sample_image_dev
DKnight54 Jan 18, 2024
8f9640e
Merge branch 'kohya-ss:main' into vram_testing
DKnight54 Jan 18, 2024
f089bb8
adopt more elegant coding
DKnight54 Jan 19, 2024
bcc91c1
Update train_util.py
DKnight54 Jan 19, 2024
2554296
refactor in function generate_per_device_prompt_list() generation of …
DKnight54 Jan 19, 2024
3a41c27
Clean up missing variables
DKnight54 Jan 19, 2024
92a255c
fix syntax error
DKnight54 Jan 19, 2024
e0af6cb
Update train_util.py
DKnight54 Jan 19, 2024
c368237
Update train_util.py
DKnight54 Jan 19, 2024
86d77e6
Update train_util.py
DKnight54 Jan 19, 2024
e4ca96f
Update train_util.py
DKnight54 Jan 19, 2024
0a84a16
Update train_util.py
DKnight54 Jan 19, 2024
85f3c2a
Update train_util.py
DKnight54 Jan 19, 2024
b961e2e
Update train_util.py
DKnight54 Jan 19, 2024
9bcffdf
Update train_util.py
DKnight54 Jan 19, 2024
ef11df7
true random sample image generation
DKnight54 Jan 19, 2024
f92c1f0
true random sample image generation
DKnight54 Jan 19, 2024
51abf37
Merge branch 'kohya-ss:main' into sample_image_dev
DKnight54 Jan 23, 2024
434ca3f
simplify per process prompt
DKnight54 Jan 24, 2024
6544056
Update train_util.py
DKnight54 Jan 24, 2024
6d378ef
Update train_util.py
DKnight54 Jan 24, 2024
a156e6f
Update train_util.py
DKnight54 Jan 24, 2024
fad055b
Update train_util.py
DKnight54 Jan 24, 2024
09bb026
Merge pull request #1 from DKnight54/DKnight54-patch-1
DKnight54 Jan 24, 2024
23f49d7
Update train_util.py
DKnight54 Jan 24, 2024
e3b320b
Update train_util.py
DKnight54 Jan 24, 2024
bcc7c9e
Update train_util.py
DKnight54 Jan 25, 2024
8cef895
Merge branch 'kohya-ss:main' into vram_testing
DKnight54 Jan 25, 2024
55d9fc2
Merge branch 'vram_testing' into sample_image_dev
DKnight54 Jan 25, 2024
50dcb9b
Update train_util.py
DKnight54 Jan 25, 2024
cf11f1c
Update train_network.py
DKnight54 Jan 26, 2024
846b6bd
Update train_network.py
DKnight54 Jan 26, 2024
00c3027
Update train_network.py
DKnight54 Jan 26, 2024
2488d90
Merge branch 'kohya-ss:main' into sample_image_dev
DKnight54 Jan 27, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
208 changes: 115 additions & 93 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
Tuple,
Union,
)
from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs
from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs, PartialState
import gc
import glob
import math
Expand Down Expand Up @@ -4636,7 +4636,6 @@

return prompt_dict


def sample_images_common(
pipe_class,
accelerator: Accelerator,
Expand All @@ -4654,6 +4653,7 @@
"""
StableDiffusionLongPromptWeightingPipelineの改造版を使うようにしたので、clip skipおよびプロンプトの重みづけに対応した
"""

if steps == 0:
if not args.sample_at_first:
return
Expand All @@ -4668,13 +4668,15 @@
if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch
return

distributed_state = PartialState() #testing implementation of multi gpu distributed inference

print(f"\ngenerating sample images at step / サンプル画像生成 ステップ: {steps}")
if not os.path.isfile(args.sample_prompts):
print(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}")
return

org_vae_device = vae.device # CPUにいるはず
vae.to(device)
vae.to(distributed_state.device)

# unwrap unet and text_encoder(s)
unet = accelerator.unwrap_model(unet)
Expand All @@ -4700,12 +4702,11 @@
with open(args.sample_prompts, "r", encoding="utf-8") as f:
prompts = json.load(f)

schedulers: dict = {}
# schedulers: dict = {} cannot find where this is used
default_scheduler = get_my_scheduler(
sample_sampler=args.sample_sampler,
v_parameterization=args.v_parameterization,
)
schedulers[args.sample_sampler] = default_scheduler

pipeline = pipe_class(
text_encoder=text_encoder,
Expand All @@ -4718,114 +4719,135 @@
requires_safety_checker=False,
clip_skip=args.clip_skip,
)
pipeline.to(device)

pipeline.to(distributed_state.device)
save_dir = args.output_dir + "/sample"
os.makedirs(save_dir, exist_ok=True)

# Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processess available (number of devices available)

Check warning on line 4726 in library/train_util.py

View workflow job for this annotation

GitHub Actions / build

"processess" should be "processes".
# prompt_dicts are assigned to lists based on order of processes, to attempt to time the image creation time to match enum order. Probably only works when steps and sampler are identical.
per_process_prompts = generate_per_device_prompt_list(prompts, num_of_processes = distributed_state.num_processes, prompt_replacement = prompt_replacement)

rng_state = torch.get_rng_state()
cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None
# True random sample image generation
torch.seed()
torch.cuda.seed()

with torch.no_grad():
# with accelerator.autocast():
for i, prompt_dict in enumerate(prompts):
if not accelerator.is_main_process:
continue

if isinstance(prompt_dict, str):
prompt_dict = line_to_prompt_dict(prompt_dict)

assert isinstance(prompt_dict, dict)
negative_prompt = prompt_dict.get("negative_prompt")
sample_steps = prompt_dict.get("sample_steps", 30)
width = prompt_dict.get("width", 512)
height = prompt_dict.get("height", 512)
scale = prompt_dict.get("scale", 7.5)
seed = prompt_dict.get("seed")
controlnet_image = prompt_dict.get("controlnet_image")
prompt: str = prompt_dict.get("prompt", "")
sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler)

if seed is not None:
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

scheduler = schedulers.get(sampler_name)
if scheduler is None:
scheduler = get_my_scheduler(
sample_sampler=sampler_name,
v_parameterization=args.v_parameterization,
)
schedulers[sampler_name] = scheduler
pipeline.scheduler = scheduler

if prompt_replacement is not None:
prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
if negative_prompt is not None:
negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1])

if controlnet_image is not None:
controlnet_image = Image.open(controlnet_image).convert("RGB")
controlnet_image = controlnet_image.resize((width, height), Image.LANCZOS)

height = max(64, height - height % 8) # round to divisible by 8
width = max(64, width - width % 8) # round to divisible by 8
print(f"prompt: {prompt}")
print(f"negative_prompt: {negative_prompt}")
print(f"height: {height}")
print(f"width: {width}")
print(f"sample_steps: {sample_steps}")
print(f"scale: {scale}")
print(f"sample_sampler: {sampler_name}")
if seed is not None:
print(f"seed: {seed}")
with accelerator.autocast():
latents = pipeline(
prompt=prompt,
height=height,
width=width,
num_inference_steps=sample_steps,
guidance_scale=scale,
negative_prompt=negative_prompt,
controlnet=controlnet,
controlnet_image=controlnet_image,
)
with distributed_state.split_between_processes(per_process_prompts) as prompt_dict_lists:
for prompt_dict in prompt_dict_lists[0]:
sample_image_inference(accelerator, args, pipeline, save_dir, prompt_dict, epoch, steps, controlnet=controlnet)

image = pipeline.latents_to_image(latents)[0]

ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime())
num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}"
seed_suffix = "" if seed is None else f"_{seed}"
img_filename = (
f"{'' if args.output_name is None else args.output_name + '_'}{ts_str}_{num_suffix}_{i:02d}{seed_suffix}.png"
)

image.save(os.path.join(save_dir, img_filename))

# wandb有効時のみログを送信
try:
wandb_tracker = accelerator.get_tracker("wandb")
try:
import wandb
except ImportError: # 事前に一度確認するのでここはエラー出ないはず
raise ImportError("No wandb / wandb がインストールされていないようです")

wandb_tracker.log({f"sample_{i}": wandb.Image(image)})
except: # wandb 無効時
pass

# clear pipeline and cache to reduce vram usage
del pipeline
torch.cuda.empty_cache()

with torch.cuda.device(torch.cuda.current_device()):
torch.cuda.empty_cache()

torch.set_rng_state(rng_state)
if cuda_rng_state is not None:
torch.cuda.set_rng_state(cuda_rng_state)
vae.to(org_vae_device)

def generate_per_device_prompt_list(prompts, num_of_processes, prompt_replacement=None):

# Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processess available (number of devices available)

Check warning on line 4755 in library/train_util.py

View workflow job for this annotation

GitHub Actions / build

"processess" should be "processes".
# prompt_dicts are assigned to lists based on order of processes, to attempt to time the image creation time to match enum order. Probably only works when steps and sampler are identical.
per_process_prompts = [[] for i in range(num_of_processes)]
for i, prompt in enumerate(prompts):
if isinstance(prompt, str):
prompt = line_to_prompt_dict(prompt)
assert isinstance(prompt, dict)
prompt.pop("subset", None) # Clean up subset key
prompt["enum"] = i
# Adds an enumerator to the dict based on prompt position. Used later to name image files. Also cleanup of extra data in original prompt dict.
if prompt_replacement is not None:
prompt["prompt"] = prompt["prompt"].replace(prompt_replacement[0], prompt_replacement[1])
if prompt["negative_prompt"] is not None:
prompt["negative_prompt"] = prompt["negative_prompt"].replace(prompt_replacement[0], prompt_replacement[1])
# Refactor prompt replacement to here in order to simplify sample_image_inference function.
per_process_prompts[i % num_of_processes].append(prompt)
return per_process_prompts

def sample_image_inference(accelerator: Accelerator, args: argparse.Namespace, pipeline, save_dir, prompt_dict, epoch, steps, controlnet=None):
assert isinstance(prompt_dict, dict)
negative_prompt = prompt_dict.get("negative_prompt")
sample_steps = prompt_dict.get("sample_steps", 30)
width = prompt_dict.get("width", 512)
height = prompt_dict.get("height", 512)
scale = prompt_dict.get("scale", 7.5)
seed = prompt_dict.get("seed")
controlnet_image = prompt_dict.get("controlnet_image")
prompt: str = prompt_dict.get("prompt", "")
sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler)

if seed is not None:
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

scheduler = get_my_scheduler(
sample_sampler=sampler_name,
v_parameterization=args.v_parameterization,
)
pipeline.scheduler = scheduler

if controlnet_image is not None:
controlnet_image = Image.open(controlnet_image).convert("RGB")
controlnet_image = controlnet_image.resize((width, height), Image.LANCZOS)

height = max(64, height - height % 8) # round to divisible by 8
width = max(64, width - width % 8) # round to divisible by 8
print(f"\nprompt: {prompt}")
print(f"negative_prompt: {negative_prompt}")
print(f"height: {height}")
print(f"width: {width}")
print(f"sample_steps: {sample_steps}")
print(f"scale: {scale}")
print(f"sample_sampler: {sampler_name}")
if seed is not None:
print(f"seed: {seed}")
with accelerator.autocast():
latents = pipeline(
prompt=prompt,
height=height,
width=width,
num_inference_steps=sample_steps,
guidance_scale=scale,
negative_prompt=negative_prompt,
controlnet=controlnet,
controlnet_image=controlnet_image,
)
image = pipeline.latents_to_image(latents)[0]
# adding accelerator.wait_for_everyone() here should sync up and ensure that sample images are saved in the same order as the original prompt list
ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime())
num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}"
seed_suffix = "" if seed is None else f"_{seed}"
i: int = prompt_dict["enum"]
img_filename = (
f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png"
)

image.save(os.path.join(save_dir, img_filename))
if seed is not None:
torch.seed()
torch.cuda.seed()
# wandb有効時のみログを送信
try:
wandb_tracker = accelerator.get_tracker("wandb")
try:
import wandb
except ImportError: # 事前に一度確認するのでここはエラー出ないはず
raise ImportError("No wandb / wandb がインストールされていないようです")

wandb_tracker.log({f"sample_{i}": wandb.Image(image)})
except: # wandb 無効時
pass
# endregion




# region 前処理用


Expand Down
Loading