From 07d6649c9d8eb999278e81a87c692f8ac51c9e3c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=92=8B=E7=A1=95?= Date: Thu, 10 Oct 2024 10:57:17 +0800 Subject: [PATCH] Improve the performance and suitable for NPU --- .../text_to_image/train_text_to_image_sdxl.py | 21 ++++++++++++------- src/diffusers/models/attention_processor.py | 5 +++-- 2 files changed, 17 insertions(+), 9 deletions(-) diff --git a/examples/text_to_image/train_text_to_image_sdxl.py b/examples/text_to_image/train_text_to_image_sdxl.py index 2ca511c857ae6..05f46721e3358 100644 --- a/examples/text_to_image/train_text_to_image_sdxl.py +++ b/examples/text_to_image/train_text_to_image_sdxl.py @@ -59,7 +59,9 @@ logger = get_logger(__name__) if is_torch_npu_available(): + import torch_npu torch.npu.config.allow_internal_format = False + torch.npu.set_compile_mode(jit_compile=False) DATASET_NAME_MAPPING = { "lambdalabs/naruto-blip-captions": ("image", "text"), @@ -531,7 +533,7 @@ def encode_prompt(batch, text_encoders, tokenizers, proportion_empty_prompts, ca return {"prompt_embeds": prompt_embeds.cpu(), "pooled_prompt_embeds": pooled_prompt_embeds.cpu()} -def compute_vae_encodings(batch, vae): +def compute_vae_encodings(batch, accelerator, vae): images = batch.pop("pixel_values") pixel_values = torch.stack(list(images)) pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() @@ -540,7 +542,7 @@ def compute_vae_encodings(batch, vae): with torch.no_grad(): model_input = vae.encode(pixel_values).latent_dist.sample() model_input = model_input * vae.config.scaling_factor - return {"model_input": model_input.cpu()} + return {"model_input": accelerator.gather(model_input)} def generate_timestep_weights(args, num_timesteps): @@ -910,7 +912,7 @@ def preprocess_train(examples): proportion_empty_prompts=args.proportion_empty_prompts, caption_column=args.caption_column, ) - compute_vae_encodings_fn = functools.partial(compute_vae_encodings, vae=vae) + compute_vae_encodings_fn = functools.partial(compute_vae_encodings, accelerator=accelerator, vae=vae) with accelerator.main_process_first(): from datasets.fingerprint import Hasher @@ -935,7 +937,10 @@ def preprocess_train(examples): del compute_vae_encodings_fn, compute_embeddings_fn, text_encoder_one, text_encoder_two del text_encoders, tokenizers, vae gc.collect() - torch.cuda.empty_cache() + if is_torch_npu_available(): + torch_npu.npu.empty_cache() + else: + torch.cuda.empty_cache() def collate_fn(examples): model_input = torch.stack([torch.tensor(example["model_input"]) for example in examples]) @@ -1091,8 +1096,7 @@ def compute_time_ids(original_size, crops_coords_top_left): # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids target_size = (args.resolution, args.resolution) add_time_ids = list(original_size + crops_coords_top_left + target_size) - add_time_ids = torch.tensor([add_time_ids]) - add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype) + add_time_ids = torch.tensor([add_time_ids], device=accelerator.device, dtype=weight_dtype) return add_time_ids add_time_ids = torch.cat( @@ -1261,7 +1265,10 @@ def compute_time_ids(original_size, crops_coords_top_left): ) del pipeline - torch.cuda.empty_cache() + if is_torch_npu_available(): + torch_npu.npu.empty_cache() + else: + torch.cuda.empty_cache() if args.use_ema: # Switch back to the original UNet parameters. diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 9f9bc5a46e10d..b49512d30eb9a 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2274,8 +2274,7 @@ def __call__( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - hidden_states = hidden_states.to(query.dtype) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim).to(query.dtype) # linear proj hidden_states = attn.to_out[0](hidden_states) @@ -4277,6 +4276,7 @@ def __init__(self): CROSS_ATTENTION_PROCESSORS = ( AttnProcessor, AttnProcessor2_0, + AttnProcessorNPU, XFormersAttnProcessor, SlicedAttnProcessor, IPAdapterAttnProcessor, @@ -4286,6 +4286,7 @@ def __init__(self): AttentionProcessor = Union[ AttnProcessor, AttnProcessor2_0, + AttnProcessorNPU, FusedAttnProcessor2_0, XFormersAttnProcessor, SlicedAttnProcessor,