Skip to content

Commit

Permalink
Improve the performance and suitable for NPU
Browse files Browse the repository at this point in the history
  • Loading branch information
蒋硕 committed Oct 10, 2024
1 parent 01337da commit 07d6649
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 9 deletions.
21 changes: 14 additions & 7 deletions examples/text_to_image/train_text_to_image_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,9 @@

logger = get_logger(__name__)
if is_torch_npu_available():
import torch_npu
torch.npu.config.allow_internal_format = False
torch.npu.set_compile_mode(jit_compile=False)

DATASET_NAME_MAPPING = {
"lambdalabs/naruto-blip-captions": ("image", "text"),
Expand Down Expand Up @@ -531,7 +533,7 @@ def encode_prompt(batch, text_encoders, tokenizers, proportion_empty_prompts, ca
return {"prompt_embeds": prompt_embeds.cpu(), "pooled_prompt_embeds": pooled_prompt_embeds.cpu()}


def compute_vae_encodings(batch, vae):
def compute_vae_encodings(batch, accelerator, vae):
images = batch.pop("pixel_values")
pixel_values = torch.stack(list(images))
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
Expand All @@ -540,7 +542,7 @@ def compute_vae_encodings(batch, vae):
with torch.no_grad():
model_input = vae.encode(pixel_values).latent_dist.sample()
model_input = model_input * vae.config.scaling_factor
return {"model_input": model_input.cpu()}
return {"model_input": accelerator.gather(model_input)}


def generate_timestep_weights(args, num_timesteps):
Expand Down Expand Up @@ -910,7 +912,7 @@ def preprocess_train(examples):
proportion_empty_prompts=args.proportion_empty_prompts,
caption_column=args.caption_column,
)
compute_vae_encodings_fn = functools.partial(compute_vae_encodings, vae=vae)
compute_vae_encodings_fn = functools.partial(compute_vae_encodings, accelerator=accelerator, vae=vae)
with accelerator.main_process_first():
from datasets.fingerprint import Hasher

Expand All @@ -935,7 +937,10 @@ def preprocess_train(examples):
del compute_vae_encodings_fn, compute_embeddings_fn, text_encoder_one, text_encoder_two
del text_encoders, tokenizers, vae
gc.collect()
torch.cuda.empty_cache()
if is_torch_npu_available():
torch_npu.npu.empty_cache()
else:
torch.cuda.empty_cache()

def collate_fn(examples):
model_input = torch.stack([torch.tensor(example["model_input"]) for example in examples])
Expand Down Expand Up @@ -1091,8 +1096,7 @@ def compute_time_ids(original_size, crops_coords_top_left):
# Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
target_size = (args.resolution, args.resolution)
add_time_ids = list(original_size + crops_coords_top_left + target_size)
add_time_ids = torch.tensor([add_time_ids])
add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype)
add_time_ids = torch.tensor([add_time_ids], device=accelerator.device, dtype=weight_dtype)
return add_time_ids

add_time_ids = torch.cat(
Expand Down Expand Up @@ -1261,7 +1265,10 @@ def compute_time_ids(original_size, crops_coords_top_left):
)

del pipeline
torch.cuda.empty_cache()
if is_torch_npu_available():
torch_npu.npu.empty_cache()
else:
torch.cuda.empty_cache()

if args.use_ema:
# Switch back to the original UNet parameters.
Expand Down
5 changes: 3 additions & 2 deletions src/diffusers/models/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2274,8 +2274,7 @@ def __call__(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)

hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim).to(query.dtype)

# linear proj
hidden_states = attn.to_out[0](hidden_states)
Expand Down Expand Up @@ -4277,6 +4276,7 @@ def __init__(self):
CROSS_ATTENTION_PROCESSORS = (
AttnProcessor,
AttnProcessor2_0,
AttnProcessorNPU,
XFormersAttnProcessor,
SlicedAttnProcessor,
IPAdapterAttnProcessor,
Expand All @@ -4286,6 +4286,7 @@ def __init__(self):
AttentionProcessor = Union[
AttnProcessor,
AttnProcessor2_0,
AttnProcessorNPU,
FusedAttnProcessor2_0,
XFormersAttnProcessor,
SlicedAttnProcessor,
Expand Down

0 comments on commit 07d6649

Please sign in to comment.