From 1867dcd3262875d48f9494e4ee897f5201ee5365 Mon Sep 17 00:00:00 2001 From: Runyu Lu Date: Tue, 9 Jul 2024 07:52:42 +0000 Subject: [PATCH 01/10] Distrifusion Support source --- colossalai/inference/config.py | 17 + colossalai/inference/core/diffusion_engine.py | 2 +- .../modeling/{models => layers}/diffusion.py | 0 .../inference/modeling/layers/distrifusion.py | 591 ++++++++++++++++++ .../inference/modeling/models/pixart_alpha.py | 2 +- .../modeling/models/stablediffusion3.py | 2 +- .../inference/modeling/policy/pixart_alpha.py | 45 +- .../modeling/policy/stablediffusion3.py | 50 +- examples/inference/stable_diffusion/README.md | 6 + .../stable_diffusion/benchmark_sd3.py | 1 + .../stable_diffusion/compute_metric.py | 79 +++ .../stable_diffusion/sd3_generation.py | 14 +- 12 files changed, 799 insertions(+), 10 deletions(-) rename colossalai/inference/modeling/{models => layers}/diffusion.py (100%) create mode 100644 colossalai/inference/modeling/layers/distrifusion.py create mode 100644 examples/inference/stable_diffusion/README.md create mode 100644 examples/inference/stable_diffusion/benchmark_sd3.py create mode 100644 examples/inference/stable_diffusion/compute_metric.py diff --git a/colossalai/inference/config.py b/colossalai/inference/config.py index 1beb86874826..0a4722e755c3 100644 --- a/colossalai/inference/config.py +++ b/colossalai/inference/config.py @@ -245,6 +245,13 @@ class InferenceConfig(RPC_PARAM): start_token_size: int = 4 generated_token_size: int = 512 + # Acceleration for Diffusion Model(PipeFusion or Distrifusion) + # use_patched_parallelism : bool = False + patched_parallelism_size: int = 1 # for distrifusion + # use_pipefusion : bool = False + pipeFusion_m_size: int = 1 # for pipefusion + pipeFusion_n_size: int = 1 # for pipefusion + def __post_init__(self): self.max_context_len_to_capture = self.max_input_len + self.max_output_len self._verify_config() @@ -288,6 +295,14 @@ def _verify_config(self) -> None: # Thereafter, we swap out tokens in units of blocks, and always swapping out the second block when the generated tokens exceeded the limit. self.start_token_size = self.block_size + # check Distrifusion + # (TODO@lry897575) need more detailed check + if self.patched_parallelism_size > 1: + # self.use_patched_parallelism = True + self.tp_size = ( + self.patched_parallelism_size + ) # this is not a real tp, because some annoying check, so we have to set this to patched_parallelism_size + # check prompt template if self.prompt_template is None: return @@ -324,6 +339,7 @@ def to_model_shard_inference_config(self) -> "ModelShardInferenceConfig": use_cuda_kernel=self.use_cuda_kernel, use_spec_dec=self.use_spec_dec, use_flash_attn=use_flash_attn, + patched_parallelism_size=self.patched_parallelism_size, ) return model_inference_config @@ -396,6 +412,7 @@ class ModelShardInferenceConfig: use_cuda_kernel: bool = False use_spec_dec: bool = False use_flash_attn: bool = False + patched_parallelism_size: int = 1 # for diffusion model, Distrifusion Technique @dataclass diff --git a/colossalai/inference/core/diffusion_engine.py b/colossalai/inference/core/diffusion_engine.py index 75b9889bf28d..8bed508cba55 100644 --- a/colossalai/inference/core/diffusion_engine.py +++ b/colossalai/inference/core/diffusion_engine.py @@ -11,7 +11,7 @@ from colossalai.accelerator import get_accelerator from colossalai.cluster import ProcessGroupMesh from colossalai.inference.config import DiffusionGenerationConfig, InferenceConfig, ModelShardInferenceConfig -from colossalai.inference.modeling.models.diffusion import DiffusionPipe +from colossalai.inference.modeling.layers.diffusion import DiffusionPipe from colossalai.inference.modeling.policy import model_policy_map from colossalai.inference.struct import DiffusionSequence from colossalai.inference.utils import get_model_size, get_model_type diff --git a/colossalai/inference/modeling/models/diffusion.py b/colossalai/inference/modeling/layers/diffusion.py similarity index 100% rename from colossalai/inference/modeling/models/diffusion.py rename to colossalai/inference/modeling/layers/diffusion.py diff --git a/colossalai/inference/modeling/layers/distrifusion.py b/colossalai/inference/modeling/layers/distrifusion.py new file mode 100644 index 000000000000..8055dc252d82 --- /dev/null +++ b/colossalai/inference/modeling/layers/distrifusion.py @@ -0,0 +1,591 @@ +# Code refer and adapted from: +# https://github.com/PipeFusion/PipeFusion +# https://github.com/huggingface/diffusers/blob/v0.29.0-release/src/diffusers + +import inspect +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.distributed as dist +import torch.nn.functional as F +from diffusers.models import attention_processor +from diffusers.models.attention import Attention +from diffusers.models.embeddings import PatchEmbed, get_2d_sincos_pos_embed +from diffusers.models.transformers.transformer_2d import Transformer2DModel +from diffusers.models.transformers.transformer_sd3 import SD3Transformer2DModel +from torch import nn +from torch.distributed import ProcessGroup + +from colossalai.inference.config import ModelShardInferenceConfig +from colossalai.logging import get_dist_logger +from colossalai.shardformer.layer.parallel_module import ParallelModule +from colossalai.utils import get_current_device + +try: + from flash_attn import flash_attn_func + + HAS_FLASH_ATTN = True +except ImportError: + HAS_FLASH_ATTN = False + + +logger = get_dist_logger(__name__) + + +# adapted from https://github.com/huggingface/diffusers/blob/v0.29.0-release/src/diffusers/models/transformers/transformer_2d.py +def PixArtAlphaTransformer2DModel_forward( + self: Transformer2DModel, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + timestep: Optional[torch.LongTensor] = None, + added_cond_kwargs: Dict[str, torch.Tensor] = None, + class_labels: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + return_dict: bool = True, +): + assert hasattr( + self, "patched_parallel_size" + ), "please check your policy, `Transformer2DModel` Must have attribute `patched_parallel_size`" + + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. + # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. + # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. + # expects mask of shape: + # [batch, key_tokens] + # adds singleton query_tokens dimension: + # [batch, 1, key_tokens] + # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) + if attention_mask is not None and attention_mask.ndim == 2: + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) + attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: + encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + # 1. Input + assert self.is_input_patches == True + height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size + hidden_states, encoder_hidden_states, timestep, embedded_timestep = self._operate_on_patched_inputs( + hidden_states, encoder_hidden_states, timestep, added_cond_kwargs + ) + + # 2. Blocks + for block in self.transformer_blocks: + hidden_states = block( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=class_labels, + ) + + # 3. Output + output = self._get_output_for_patched_inputs( + hidden_states=hidden_states, + timestep=timestep, + class_labels=class_labels, + embedded_timestep=embedded_timestep, + height=height // self.patched_parallel_size, + width=width, + ) + + # enable Distrifusion Optimization + if hasattr(self, "patched_parallel_size") is None: + from torch import distributed as dist + + if getattr(self, "output_buffer", None): + self.output_buffer = torch.empty_like(output) + if getattr(self, "buffer_list", None) is None: + self.buffer_list = [torch.empty_like(output) for _ in range(self.patched_parallel_size)] + output = output.contiguous() + dist.all_gather(self.buffer_list, output, async_op=False) + torch.cat(self.buffer_list, dim=2, out=self.output_buffer) + output = self.output_buffer + + return (output,) + + +# adapted from https://github.com/huggingface/diffusers/blob/v0.29.0-release/src/diffusers/models/transformers/transformer_sd3.py +def SD3Transformer2DModel_forward( + self: SD3Transformer2DModel, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor = None, + pooled_projections: torch.FloatTensor = None, + timestep: torch.LongTensor = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = True, +) -> Union[torch.FloatTensor]: + + assert hasattr( + self, "patched_parallel_size" + ), "please check your policy, `Transformer2DModel` Must have attribute `patched_parallel_size`" + + height, width = hidden_states.shape[-2:] + + hidden_states = self.pos_embed(hidden_states) # takes care of adding positional embeddings too. + temb = self.time_text_embed(timestep, pooled_projections) + encoder_hidden_states = self.context_embedder(encoder_hidden_states) + + for block in self.transformer_blocks: + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb + ) + + hidden_states = self.norm_out(hidden_states, temb) + hidden_states = self.proj_out(hidden_states) + + # unpatchify + patch_size = self.config.patch_size + height = height // patch_size // self.patched_parallel_size + width = width // patch_size + + hidden_states = hidden_states.reshape( + shape=(hidden_states.shape[0], height, width, patch_size, patch_size, self.out_channels) + ) + hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) + output = hidden_states.reshape( + shape=(hidden_states.shape[0], self.out_channels, height * patch_size, width * patch_size) + ) + + # enable Distrifusion Optimization + if hasattr(self, "patched_parallel_size"): + from torch import distributed as dist + + if getattr(self, "output_buffer", None) is None: + self.output_buffer = torch.empty_like(output) + if getattr(self, "buffer_list", None) is None: + self.buffer_list = [torch.empty_like(output) for _ in range(self.patched_parallel_size)] + output = output.contiguous() + dist.all_gather(self.buffer_list, output, async_op=False) + torch.cat(self.buffer_list, dim=2, out=self.output_buffer) + output = self.output_buffer + + return (output,) + + +# Code adapted from: https://github.com/PipeFusion/PipeFusion/blob/main/pipefuser/modules/dit/patch_parallel/patchembed.py +class DistrifusionPatchEmbed(ParallelModule): + def __init__( + self, + module: PatchEmbed, + process_group: Union[ProcessGroup, List[ProcessGroup]], + model_shard_infer_config: ModelShardInferenceConfig = None, + ): + super().__init__() + self.module = module + self.rank = dist.get_rank(group=process_group) + self.patched_parallelism_size = model_shard_infer_config.patched_parallelism_size + + @staticmethod + def from_native_module(module: PatchEmbed, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs): + model_shard_infer_config = kwargs.get("model_shard_infer_config", None) + distrifusion_embed = DistrifusionPatchEmbed( + module, process_group, model_shard_infer_config=model_shard_infer_config + ) + return distrifusion_embed + + def forward(self, latent): + module = self.module + if module.pos_embed_max_size is not None: + height, width = latent.shape[-2:] + else: + height, width = latent.shape[-2] // module.patch_size, latent.shape[-1] // module.patch_size + + latent = module.proj(latent) + if module.flatten: + latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC + if module.layer_norm: + latent = module.norm(latent) + if module.pos_embed is None: + return latent.to(latent.dtype) + # Interpolate or crop positional embeddings as needed + if module.pos_embed_max_size: + pos_embed = module.cropped_pos_embed(height, width) + else: + if module.height != height or module.width != width: + pos_embed = get_2d_sincos_pos_embed( + embed_dim=module.pos_embed.shape[-1], + grid_size=(height, width), + base_size=module.base_size, + interpolation_scale=module.interpolation_scale, + ) + pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).to(latent.device) + else: + pos_embed = module.pos_embed + + b, c, h = pos_embed.shape + pos_embed = pos_embed.view(b, self.patched_parallelism_size, -1, h)[:, self.rank] + + return (latent + pos_embed).to(latent.dtype) + + +# Code adapted from: https://github.com/PipeFusion/PipeFusion +class DistrifusionConv2D(ParallelModule): + + def __init__( + self, + module: nn.Conv2d, + process_group: Union[ProcessGroup, List[ProcessGroup]], + model_shard_infer_config: ModelShardInferenceConfig = None, + ): + super().__init__() + self.module = module + self.rank = dist.get_rank(group=process_group) + self.patched_parallelism_size = model_shard_infer_config.patched_parallelism_size + + @staticmethod + def from_native_module(module: nn.Conv2d, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs): + model_shard_infer_config = kwargs.get("model_shard_infer_config", None) + distrifusion_conv = DistrifusionConv2D(module, process_group, model_shard_infer_config=model_shard_infer_config) + return distrifusion_conv + + def sliced_forward(self, x: torch.Tensor) -> torch.Tensor: + + b, c, h, w = x.shape + + stride = self.module.stride[0] + padding = self.module.padding[0] + + output_h = x.shape[2] // stride // self.patched_parallelism_size + idx = dist.get_rank() + h_begin = output_h * idx * stride - padding + h_end = output_h * (idx + 1) * stride + padding + final_padding = [padding, padding, 0, 0] + if h_begin < 0: + h_begin = 0 + final_padding[2] = padding + if h_end > h: + h_end = h + final_padding[3] = padding + sliced_input = x[:, :, h_begin:h_end, :] + padded_input = F.pad(sliced_input, final_padding, mode="constant") + return F.conv2d( + padded_input, + self.module.weight, + self.module.bias, + stride=stride, + padding="valid", + ) + + def forward(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + output = self.sliced_forward(input) + return output + + +# Code adapted from: https://github.com/huggingface/diffusers/blob/v0.29.0-release/src/diffusers/models/attention_processor.py +class Distrifusion_FusedAttention(ParallelModule): + + def __init__( + self, + module: attention_processor.Attention, + process_group: Union[ProcessGroup, List[ProcessGroup]], + model_shard_infer_config: ModelShardInferenceConfig = None, + ): + super().__init__() + self.counter = 0 + self.module = module + self.buffer_list = None + self.kv_buffer_idx = dist.get_rank(group=process_group) + self.patched_parallelism_size = model_shard_infer_config.patched_parallelism_size + self.handle = None + self.process_group = process_group + self.warm_step = 5 # for warmup + + @staticmethod + def from_native_module( + module: attention_processor.Attention, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs + ) -> ParallelModule: + model_shard_infer_config = kwargs.get("model_shard_infer_config", None) + return Distrifusion_FusedAttention( + module=module, process_group=process_group, model_shard_infer_config=model_shard_infer_config + ) + + def _forward( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor = None, + attention_mask: Optional[torch.FloatTensor] = None, + *args, + **kwargs, + ) -> torch.FloatTensor: + residual = hidden_states + + input_ndim = hidden_states.ndim + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + context_input_ndim = encoder_hidden_states.ndim + if context_input_ndim == 4: + batch_size, channel, height, width = encoder_hidden_states.shape + encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size = encoder_hidden_states.shape[0] + + # `sample` projections. + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + kv = torch.cat([key, value], dim=-1) # shape of kv now: (bs, seq_len // parallel_size, dim * 2) + + if self.patched_parallelism_size == 1: + full_kv = kv + else: + if self.buffer_list is None: # buffer not created + full_kv = torch.cat([kv for _ in range(self.patched_parallelism_size)], dim=1) + elif self.counter <= self.warm_step: + # logger.info(f"warmup: {self.counter}") + dist.all_gather( + self.buffer_list, + kv, + group=self.process_group, + async_op=False, + ) + full_kv = torch.cat(self.buffer_list, dim=1) + else: + # logger.info(f"use old kv to infer: {self.counter}") + new_buffer_list = [buffer for buffer in self.buffer_list] + new_buffer_list[self.kv_buffer_idx] = kv + full_kv = torch.cat(new_buffer_list, dim=1) + assert self.handle is None, "we should maintain the kv of last step" + self.handle = dist.all_gather(new_buffer_list, kv, group=self.process_group, async_op=True) + + key, value = torch.split(full_kv, full_kv.shape[-1] // 2, dim=-1) + + # `context` projections. + encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + + # attention + query = torch.cat([query, encoder_hidden_states_query_proj], dim=1) + key = torch.cat([key, encoder_hidden_states_key_proj], dim=1) + value = torch.cat([value, encoder_hidden_states_value_proj], dim=1) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + hidden_states = hidden_states = F.scaled_dot_product_attention( + query, key, value, dropout_p=0.0, is_causal=False + ) # NOTE(@lry89757) for torch >= 2.2, flash attn has been already integrated into scaled_dot_product_attention, https://pytorch.org/blog/pytorch2-2/ + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # Split the attention outputs. + hidden_states, encoder_hidden_states = ( + hidden_states[:, : residual.shape[1]], + hidden_states[:, residual.shape[1] :], + ) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + if not attn.context_pre_only: + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + if context_input_ndim == 4: + encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + return hidden_states, encoder_hidden_states + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + **cross_attention_kwargs, + ) -> torch.Tensor: + + if self.handle is not None: + self.handle.wait() + self.handle = None + + b, l, c = hidden_states.shape + if self.patched_parallelism_size > 1 and self.buffer_list is None: + + kv_shape = (b, l, self.module.to_k.out_features * 2) + self.buffer_list = [ + torch.empty(kv_shape, dtype=hidden_states.dtype, device=get_current_device()) + for _ in range(self.patched_parallelism_size) + ] + + attn_parameters = set(inspect.signature(self.module.processor.__call__).parameters.keys()) + quiet_attn_parameters = {"ip_adapter_masks"} + unused_kwargs = [ + k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters + ] + if len(unused_kwargs) > 0: + logger.warning( + f"cross_attention_kwargs {unused_kwargs} are not expected by {self.module.processor.__class__.__name__} and will be ignored." + ) + cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters} + + output = self._forward( + self.module, + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + + self.counter += 1 + + return output + + +class DistriSelfAttention(ParallelModule): + def __init__( + self, + module: Attention, + process_group: Union[ProcessGroup, List[ProcessGroup]], + model_shard_infer_config: ModelShardInferenceConfig = None, + ): + super().__init__() + self.counter = 0 + self.module = module + self.buffer_list = None + self.kv_buffer_idx = dist.get_rank(group=process_group) + self.patched_parallelism_size = model_shard_infer_config.patched_parallelism_size + self.handle = None + self.process_group = process_group + self.warm_step = 5 # for warmup + + @staticmethod + def from_native_module( + module: Attention, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs + ) -> ParallelModule: + model_shard_infer_config = kwargs.get("model_shard_infer_config", None) + return DistriSelfAttention( + module=module, process_group=process_group, model_shard_infer_config=model_shard_infer_config + ) + + def _forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0): + attn = self.module + assert isinstance(attn, Attention) + + residual = hidden_states + + batch_size, sequence_length, _ = hidden_states.shape + + query = attn.to_q(hidden_states) + + encoder_hidden_states = hidden_states + k = self.module.to_k(encoder_hidden_states) + v = self.module.to_v(encoder_hidden_states) + kv = torch.cat([k, v], dim=-1) # shape of kv now: (bs, seq_len // parallel_size, dim * 2) + + if self.patched_parallelism_size == 1: + full_kv = kv + else: + if self.buffer_list is None: # buffer not created + full_kv = torch.cat([kv for _ in range(self.patched_parallelism_size)], dim=1) + elif self.counter <= self.warm_step: + # logger.info(f"warmup: {self.counter}") + dist.all_gather( + self.buffer_list, + kv, + group=self.process_group, + async_op=False, + ) + full_kv = torch.cat(self.buffer_list, dim=1) + else: + # logger.info(f"use old kv to infer: {self.counter}") + new_buffer_list = [buffer for buffer in self.buffer_list] + new_buffer_list[self.kv_buffer_idx] = kv + full_kv = torch.cat(new_buffer_list, dim=1) + assert self.handle is None, "we should maintain the kv of last step" + self.handle = dist.all_gather(new_buffer_list, kv, group=self.process_group, async_op=True) + + if HAS_FLASH_ATTN: + # flash attn + key, value = torch.split(full_kv, full_kv.shape[-1] // 2, dim=-1) + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim) + key = key.view(batch_size, -1, attn.heads, head_dim) + value = value.view(batch_size, -1, attn.heads, head_dim) + + hidden_states = flash_attn_func(query, key, value, dropout_p=0.0, causal=False) + hidden_states = hidden_states.reshape(batch_size, -1, attn.heads * head_dim).to(query.dtype) + else: + # naive attn + key, value = torch.split(full_kv, full_kv.shape[-1] // 2, dim=-1) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + def forward( + self, + hidden_states: torch.FloatTensor, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + scale: float = 1.0, + *args, + **kwargs, + ) -> torch.FloatTensor: + + # async preallocates memo buffer + if self.handle is not None: + self.handle.wait() + self.handle = None + + b, l, c = hidden_states.shape + if self.patched_parallelism_size > 1 and self.buffer_list is None: + + kv_shape = (b, l, self.module.to_k.out_features * 2) + self.buffer_list = [ + torch.empty(kv_shape, dtype=hidden_states.dtype, device=get_current_device()) + for _ in range(self.patched_parallelism_size) + ] + + # logger.info(f"{self.counter}th step") + + output = self._forward(hidden_states, scale=scale) + + self.counter += 1 + return output diff --git a/colossalai/inference/modeling/models/pixart_alpha.py b/colossalai/inference/modeling/models/pixart_alpha.py index d5774946e365..cc2bee5efd4d 100644 --- a/colossalai/inference/modeling/models/pixart_alpha.py +++ b/colossalai/inference/modeling/models/pixart_alpha.py @@ -14,7 +14,7 @@ from colossalai.logging import get_dist_logger -from .diffusion import DiffusionPipe +from ..layers.diffusion import DiffusionPipe logger = get_dist_logger(__name__) diff --git a/colossalai/inference/modeling/models/stablediffusion3.py b/colossalai/inference/modeling/models/stablediffusion3.py index d1c63a6dc665..b123164039c8 100644 --- a/colossalai/inference/modeling/models/stablediffusion3.py +++ b/colossalai/inference/modeling/models/stablediffusion3.py @@ -4,7 +4,7 @@ import torch from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import retrieve_timesteps -from .diffusion import DiffusionPipe +from ..layers.diffusion import DiffusionPipe # TODO(@lry89757) temporarily image, please support more return output diff --git a/colossalai/inference/modeling/policy/pixart_alpha.py b/colossalai/inference/modeling/policy/pixart_alpha.py index 356056ba73e7..a2b3d651249a 100644 --- a/colossalai/inference/modeling/policy/pixart_alpha.py +++ b/colossalai/inference/modeling/policy/pixart_alpha.py @@ -1,9 +1,16 @@ +from diffusers.models.attention import BasicTransformerBlock +from diffusers.models.transformers.transformer_2d import Transformer2DModel from torch import nn from colossalai.inference.config import RPC_PARAM -from colossalai.inference.modeling.models.diffusion import DiffusionPipe +from colossalai.inference.modeling.layers.diffusion import DiffusionPipe +from colossalai.inference.modeling.layers.distrifusion import ( + DistrifusionConv2D, + DistriSelfAttention, + PixArtAlphaTransformer2DModel_forward, +) from colossalai.inference.modeling.models.pixart_alpha import pixart_alpha_forward -from colossalai.shardformer.policies.base_policy import Policy +from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription class PixArtAlphaInferPolicy(Policy, RPC_PARAM): @@ -12,9 +19,43 @@ def __init__(self) -> None: def module_policy(self): policy = {} + + if self.shard_config.extra_kwargs["model_shard_infer_config"].patched_parallelism_size > 1: + + # policy[DiffusionPipe] = ModulePolicyDescription( + # attribute_replacement={"patched_parallel_size": self.shard_config.extra_kwargs["model_shard_infer_config"].patched_parallelism_size} + # ) + + policy[Transformer2DModel] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="pos_embed.conv", + target_module=DistrifusionConv2D, + kwargs={"model_shard_infer_config": self.shard_config.extra_kwargs["model_shard_infer_config"]}, + ), + ], + attribute_replacement={ + "patched_parallel_size": self.shard_config.extra_kwargs[ + "model_shard_infer_config" + ].patched_parallelism_size + }, + method_replacement={"forward": PixArtAlphaTransformer2DModel_forward}, + ) + + policy[BasicTransformerBlock] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="attn1", + target_module=DistriSelfAttention, + kwargs={"model_shard_infer_config": self.shard_config.extra_kwargs["model_shard_infer_config"]}, + ) + ] + ) + self.append_or_create_method_replacement( description={"forward": pixart_alpha_forward}, policy=policy, target_key=DiffusionPipe ) + return policy def preprocess(self) -> nn.Module: diff --git a/colossalai/inference/modeling/policy/stablediffusion3.py b/colossalai/inference/modeling/policy/stablediffusion3.py index c9877f7dcae6..9a641853b09b 100644 --- a/colossalai/inference/modeling/policy/stablediffusion3.py +++ b/colossalai/inference/modeling/policy/stablediffusion3.py @@ -1,9 +1,17 @@ +from diffusers.models.attention import JointTransformerBlock +from diffusers.models.transformers import SD3Transformer2DModel from torch import nn from colossalai.inference.config import RPC_PARAM -from colossalai.inference.modeling.models.diffusion import DiffusionPipe +from colossalai.inference.modeling.layers.diffusion import DiffusionPipe +from colossalai.inference.modeling.layers.distrifusion import ( + Distrifusion_FusedAttention, + DistrifusionConv2D, + DistrifusionPatchEmbed, + SD3Transformer2DModel_forward, +) from colossalai.inference.modeling.models.stablediffusion3 import sd3_forward -from colossalai.shardformer.policies.base_policy import Policy +from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription class StableDiffusion3InferPolicy(Policy, RPC_PARAM): @@ -12,6 +20,44 @@ def __init__(self) -> None: def module_policy(self): policy = {} + + if self.shard_config.extra_kwargs["model_shard_infer_config"].patched_parallelism_size > 1: + + # policy[DiffusionPipe] = ModulePolicyDescription( + # attribute_replacement={"patched_parallel_size": self.shard_config.extra_kwargs["model_shard_infer_config"].patched_parallelism_size} + # ) + + policy[SD3Transformer2DModel] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="pos_embed.proj", + target_module=DistrifusionConv2D, + kwargs={"model_shard_infer_config": self.shard_config.extra_kwargs["model_shard_infer_config"]}, + ), + SubModuleReplacementDescription( + suffix="pos_embed", + target_module=DistrifusionPatchEmbed, + kwargs={"model_shard_infer_config": self.shard_config.extra_kwargs["model_shard_infer_config"]}, + ), + ], + attribute_replacement={ + "patched_parallel_size": self.shard_config.extra_kwargs[ + "model_shard_infer_config" + ].patched_parallelism_size + }, + method_replacement={"forward": SD3Transformer2DModel_forward}, + ) + + policy[JointTransformerBlock] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="attn", + target_module=Distrifusion_FusedAttention, + kwargs={"model_shard_infer_config": self.shard_config.extra_kwargs["model_shard_infer_config"]}, + ) + ] + ) + self.append_or_create_method_replacement( description={"forward": sd3_forward}, policy=policy, target_key=DiffusionPipe ) diff --git a/examples/inference/stable_diffusion/README.md b/examples/inference/stable_diffusion/README.md new file mode 100644 index 000000000000..b69fd1cb1c66 --- /dev/null +++ b/examples/inference/stable_diffusion/README.md @@ -0,0 +1,6 @@ +``` +|- sd3_generation.py: an example of how to use Colossalai Inference Engine to generate result by loading Diffusion Model. +|- compute_metric.py: compare the quality of images w/o some acceleration method like Distrifusion +|- benchmark.sh: benchmark the performance of our InferenceEngine +``` +note: compute_metric.py need some dependencies which need `pip install`. diff --git a/examples/inference/stable_diffusion/benchmark_sd3.py b/examples/inference/stable_diffusion/benchmark_sd3.py new file mode 100644 index 000000000000..464090415c47 --- /dev/null +++ b/examples/inference/stable_diffusion/benchmark_sd3.py @@ -0,0 +1 @@ +# TODO diff --git a/examples/inference/stable_diffusion/compute_metric.py b/examples/inference/stable_diffusion/compute_metric.py new file mode 100644 index 000000000000..b1328d14b0a9 --- /dev/null +++ b/examples/inference/stable_diffusion/compute_metric.py @@ -0,0 +1,79 @@ +import argparse +import os + +import numpy as np +import torch +from cleanfid import fid +from PIL import Image +from torch.utils.data import DataLoader, Dataset +from torchmetrics.image import LearnedPerceptualImagePatchSimilarity, PeakSignalNoiseRatio +from torchvision.transforms import Resize +from tqdm import tqdm + + +def read_image(path: str): + """ + input: path + output: tensor (C, H, W) + """ + img = np.asarray(Image.open(path)) + if len(img.shape) == 2: + img = np.repeat(img[:, :, None], 3, axis=2) + img = torch.from_numpy(img).permute(2, 0, 1) + return img + + +class MultiImageDataset(Dataset): + def __init__(self, root0, root1, is_gt=False): + super().__init__() + self.root0 = root0 + self.root1 = root1 + file_names0 = os.listdir(root0) + file_names1 = os.listdir(root1) + + self.image_names0 = sorted([name for name in file_names0 if name.endswith(".png") or name.endswith(".jpg")]) + self.image_names1 = sorted([name for name in file_names1 if name.endswith(".png") or name.endswith(".jpg")]) + self.is_gt = is_gt + assert len(self.image_names0) == len(self.image_names1) + + def __len__(self): + return len(self.image_names0) + + def __getitem__(self, idx): + img0 = read_image(os.path.join(self.root0, self.image_names0[idx])) + if self.is_gt: + # resize to 1024 x 1024 + img0 = Resize((1024, 1024))(img0) + img1 = read_image(os.path.join(self.root1, self.image_names1[idx])) + + batch_list = [img0, img1] + return batch_list + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--batch_size", type=int, default=64) + parser.add_argument("--num_workers", type=int, default=8) + parser.add_argument("--is_gt", action="store_true") + parser.add_argument("--input_root0", type=str, required=True) + parser.add_argument("--input_root1", type=str, required=True) + args = parser.parse_args() + + psnr = PeakSignalNoiseRatio(data_range=(0, 1), reduction="elementwise_mean", dim=(1, 2, 3)).to("cuda") + lpips = LearnedPerceptualImagePatchSimilarity(normalize=True).to("cuda") + + dataset = MultiImageDataset(args.input_root0, args.input_root1, is_gt=args.is_gt) + dataloader = DataLoader(dataset, batch_size=args.batch_size, num_workers=args.num_workers) + + progress_bar = tqdm(dataloader) + with torch.inference_mode(): + for i, batch in enumerate(progress_bar): + batch = [img.to("cuda") / 255 for img in batch] + batch_size = batch[0].shape[0] + psnr.update(batch[0], batch[1]) + lpips.update(batch[0], batch[1]) + fid_score = fid.compute_fid(args.input_root0, args.input_root1) + + print("PSNR:", psnr.compute().item()) + print("LPIPS:", lpips.compute().item()) + print("FID:", fid_score) diff --git a/examples/inference/stable_diffusion/sd3_generation.py b/examples/inference/stable_diffusion/sd3_generation.py index fe989eed7c2d..daeabeefeaf7 100644 --- a/examples/inference/stable_diffusion/sd3_generation.py +++ b/examples/inference/stable_diffusion/sd3_generation.py @@ -1,7 +1,9 @@ import argparse from diffusers import PixArtAlphaPipeline, StableDiffusion3Pipeline -from torch import bfloat16, float16, float32 +from torch import bfloat16 +from torch import distributed as dist +from torch import float16, float32 import colossalai from colossalai.cluster import DistCoordinator @@ -43,6 +45,7 @@ def infer(args): max_batch_size=args.max_batch_size, tp_size=args.tp_size, use_cuda_kernel=args.use_cuda_kernel, + patched_parallelism_size=dist.get_world_size(), ) engine = InferenceEngine(model, inference_config=inference_config, model_policy=POLICY_CLS(), verbose=True) @@ -51,12 +54,17 @@ def infer(args): # ============================== coordinator.print_on_master(f"Generating...") out = engine.generate(prompts=[args.prompt], generation_config=DiffusionGenerationConfig())[0] - out.save("cat.jpg") + if dist.get_rank() == 0: + out.save(f"cat_parallel_size{dist.get_world_size()}.jpg") coordinator.print_on_master(out) # colossalai run --nproc_per_node 1 examples/inference/stable_diffusion/sd3_generation.py -m MODEL_PATH -# colossalai run --nproc_per_node 1 examples/inference/stable_diffusion/sd3_generation.py -m "stabilityai/stable-diffusion-3-medium-diffusers" --tp_size 1 +# CUDA_VISIBLE_DEVICES_set_n_least_memory_usage 1 && colossalai run --nproc_per_node 1 examples/inference/stable_diffusion/sd3_generation.py -m "stabilityai/stable-diffusion-3-medium-diffusers" --tp_size 1 +# CUDA_VISIBLE_DEVICES_set_n_least_memory_usage 2 && colossalai run --nproc_per_node 2 examples/inference/stable_diffusion/sd3_generation.py -m "stabilityai/stable-diffusion-3-medium-diffusers" --tp_size 1 + +# CUDA_VISIBLE_DEVICES_set_n_least_memory_usage 1 && colossalai run --nproc_per_node 1 examples/inference/stable_diffusion/sd3_generation.py -m "PixArt-alpha/PixArt-XL-2-1024-MS" --tp_size 1 +# CUDA_VISIBLE_DEVICES_set_n_least_memory_usage 2 && colossalai run --nproc_per_node 2 examples/inference/stable_diffusion/sd3_generation.py -m "PixArt-alpha/PixArt-XL-2-1024-MS" --tp_size 1 if __name__ == "__main__": From 899617ab4bbf466c4078075682c529403e1e926a Mon Sep 17 00:00:00 2001 From: Runyu Lu Date: Fri, 12 Jul 2024 07:50:02 +0000 Subject: [PATCH 02/10] comp comm overlap optimization --- .../inference/modeling/layers/distrifusion.py | 27 ++++++++++++------- .../inference/modeling/policy/pixart_alpha.py | 6 ++++- 2 files changed, 23 insertions(+), 10 deletions(-) diff --git a/colossalai/inference/modeling/layers/distrifusion.py b/colossalai/inference/modeling/layers/distrifusion.py index 8055dc252d82..03015ab4110c 100644 --- a/colossalai/inference/modeling/layers/distrifusion.py +++ b/colossalai/inference/modeling/layers/distrifusion.py @@ -1,6 +1,6 @@ # Code refer and adapted from: -# https://github.com/PipeFusion/PipeFusion # https://github.com/huggingface/diffusers/blob/v0.29.0-release/src/diffusers +# https://github.com/PipeFusion/PipeFusion import inspect from typing import Any, Dict, List, Optional, Tuple, Union @@ -234,7 +234,7 @@ def forward(self, latent): return (latent + pos_embed).to(latent.dtype) -# Code adapted from: https://github.com/PipeFusion/PipeFusion +# Code adapted from: https://github.com/PipeFusion/PipeFusion/blob/main/pipefuser/modules/dit/patch_parallel/conv2d.py class DistrifusionConv2D(ParallelModule): def __init__( @@ -455,12 +455,14 @@ def forward( return output +# Code adapted from: https://github.com/PipeFusion/PipeFusion/blob/main/pipefuser/modules/dit/patch_parallel/attn.py class DistriSelfAttention(ParallelModule): def __init__( self, module: Attention, process_group: Union[ProcessGroup, List[ProcessGroup]], model_shard_infer_config: ModelShardInferenceConfig = None, + async_nccl_stream: torch.cuda.Stream = None, ): super().__init__() self.counter = 0 @@ -470,6 +472,7 @@ def __init__( self.patched_parallelism_size = model_shard_infer_config.patched_parallelism_size self.handle = None self.process_group = process_group + self.async_nccl_stream = async_nccl_stream self.warm_step = 5 # for warmup @staticmethod @@ -477,8 +480,12 @@ def from_native_module( module: Attention, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs ) -> ParallelModule: model_shard_infer_config = kwargs.get("model_shard_infer_config", None) + async_nccl_stream = kwargs.get("async_nccl_stream", None) return DistriSelfAttention( - module=module, process_group=process_group, model_shard_infer_config=model_shard_infer_config + module=module, + process_group=process_group, + model_shard_infer_config=model_shard_infer_config, + async_nccl_stream=async_nccl_stream, ) def _forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0): @@ -512,11 +519,15 @@ def _forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0): full_kv = torch.cat(self.buffer_list, dim=1) else: # logger.info(f"use old kv to infer: {self.counter}") - new_buffer_list = [buffer for buffer in self.buffer_list] - new_buffer_list[self.kv_buffer_idx] = kv - full_kv = torch.cat(new_buffer_list, dim=1) + self.buffer_list[self.kv_buffer_idx].copy_(kv) + full_kv = torch.cat(self.buffer_list, dim=1) assert self.handle is None, "we should maintain the kv of last step" - self.handle = dist.all_gather(new_buffer_list, kv, group=self.process_group, async_op=True) + with torch.cuda.stream( + self.async_nccl_stream + ): # NOTE(@LRY89757) implementation async op of torch.distributed.all_gather to ensure efficient overlap of comp and comm + self.handle = dist.all_gather(self.buffer_list, kv, group=self.process_group, async_op=False) + self.handle = torch.cuda.Event() + self.handle.record(self.async_nccl_stream) if HAS_FLASH_ATTN: # flash attn @@ -583,8 +594,6 @@ def forward( for _ in range(self.patched_parallelism_size) ] - # logger.info(f"{self.counter}th step") - output = self._forward(hidden_states, scale=scale) self.counter += 1 diff --git a/colossalai/inference/modeling/policy/pixart_alpha.py b/colossalai/inference/modeling/policy/pixart_alpha.py index a2b3d651249a..c0bbb7a1863b 100644 --- a/colossalai/inference/modeling/policy/pixart_alpha.py +++ b/colossalai/inference/modeling/policy/pixart_alpha.py @@ -1,3 +1,4 @@ +import torch from diffusers.models.attention import BasicTransformerBlock from diffusers.models.transformers.transformer_2d import Transformer2DModel from torch import nn @@ -47,7 +48,10 @@ def module_policy(self): SubModuleReplacementDescription( suffix="attn1", target_module=DistriSelfAttention, - kwargs={"model_shard_infer_config": self.shard_config.extra_kwargs["model_shard_infer_config"]}, + kwargs={ + "model_shard_infer_config": self.shard_config.extra_kwargs["model_shard_infer_config"], + "async_nccl_stream": torch.cuda.Stream(), + }, ) ] ) From 1049ad10090a8cc58926fd7a8ed7409eee7beaa4 Mon Sep 17 00:00:00 2001 From: Runyu Lu Date: Fri, 12 Jul 2024 07:51:26 +0000 Subject: [PATCH 03/10] sd3 benchmark --- .../stable_diffusion/benchmark_sd3.py | 178 +++++++++++++++++- 1 file changed, 177 insertions(+), 1 deletion(-) diff --git a/examples/inference/stable_diffusion/benchmark_sd3.py b/examples/inference/stable_diffusion/benchmark_sd3.py index 464090415c47..8cf7615c76d6 100644 --- a/examples/inference/stable_diffusion/benchmark_sd3.py +++ b/examples/inference/stable_diffusion/benchmark_sd3.py @@ -1 +1,177 @@ -# TODO +import argparse +import time +from contextlib import nullcontext +from typing import List, Union + +import torch +import torch.distributed as dist +from diffusers import DiffusionPipeline + +import colossalai +from colossalai.inference.config import DiffusionGenerationConfig, InferenceConfig +from colossalai.inference.core.engine import InferenceEngine +from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn + +GIGABYTE = 1024**3 +MEGABYTE = 1024 * 1024 + +_DTYPE_MAPPING = { + "fp16": torch.float16, + "bf16": torch.bfloat16, + "fp32": torch.float32, +} + + +def benchmark_colossalai(rank, world_size, port, args): + if isinstance(args.width, int): + width_list = [args.width] + else: + width_list = args.width + + if isinstance(args.height, int): + height_list = [args.height] + else: + height_list = args.height + + assert len(width_list) == len(height_list) + + colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + from colossalai.cluster.dist_coordinator import DistCoordinator + + coordinator = DistCoordinator() + + inference_config = InferenceConfig( + dtype=args.dtype, + patched_parallelism_size=args.patched_parallel_size, + ) + engine = InferenceEngine(args.model, inference_config=inference_config, verbose=False) + + # warmup + for i in range(args.n_warm_up_steps): + engine.generate( + prompts=["hello world"], + generation_config=DiffusionGenerationConfig( + num_inference_steps=args.num_inference_steps, height=1024, width=1024 + ), + ) + + ctx = ( + torch.profiler.profile( + record_shapes=True, + with_stack=True, + with_modules=True, + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + ) + if args.profile + else nullcontext() + ) + + for h, w in zip(height_list, width_list): + with ctx as prof: + start = time.perf_counter() + for i in range(args.n_repeat_times): + engine.generate( + prompts=["hello world"], + generation_config=DiffusionGenerationConfig( + num_inference_steps=args.num_inference_steps, height=h, width=w + ), + ) + end = time.perf_counter() + coordinator.print_on_master( + f"[ColossalAI]avg generation time for h({h})xw({w}) is {(end - start) / args.n_repeat_times:.2f}s" + ) + if args.profile: + file = f"examples/inference/stable_diffusion/benchmark_bs{args.batch_size}_pps{args.patched_parallel_size}_steps{args.num_inference_steps}_height{h}_width{w}_dtype{args.dtype}_warmup{args.n_warm_up_steps}_repeat{args.n_repeat_times}_profile{args.profile}_model{args.model.split('/')[-1]}_mode{args.mode}_rank_{dist.get_rank()}.json" + prof.export_chrome_trace(file) + + +def benchmark_diffusers(args): + if isinstance(args.width, int): + width_list = [args.width] + else: + width_list = args.width + + if isinstance(args.height, int): + height_list = [args.height] + else: + height_list = args.height + + assert len(width_list) == len(height_list) + + model = DiffusionPipeline.from_pretrained(args.model, torch_dtype=_DTYPE_MAPPING[args.dtype]).to("cuda") + + # warmup + for i in range(args.n_warm_up_steps): + model(prompt="hello world", num_inference_steps=args.num_inference_steps, height=1024, width=1024) + + ctx = ( + torch.profiler.profile( + record_shapes=True, + with_stack=True, + with_modules=True, + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + ) + if args.profile + else nullcontext() + ) + + for h, w in zip(height_list, width_list): + with ctx as prof: + start = time.perf_counter() + for i in range(args.n_repeat_times): + model(prompt="hello world", num_inference_steps=args.num_inference_steps, height=h, width=w) + end = time.perf_counter() + print(f"[ColossalAI]avg generation time for h({h})xw({w}) is {(end - start) / args.n_repeat_times:.2f}s") + if args.profile: + file = f"examples/inference/stable_diffusion/benchmark_bs{args.batch_size}_pps{args.patched_parallel_size}_steps{args.num_inference_steps}_height{h}_width{w}_dtype{args.dtype}_warmup{args.n_warm_up_steps}_repeat{args.n_repeat_times}_profile{args.profile}_model{args.model.split('/')[-1]}_mode{args.mode}.json" + prof.export_chrome_trace(file) + + +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def benchmark(args): + if args.mode == "colossalai": + spawn(benchmark_colossalai, nprocs=args.patched_parallel_size, args=args) + elif args.mode == "diffusers": + benchmark_diffusers(args) + + +# CUDA_VISIBLE_DEVICES_set_n_least_memory_usage 2 && python examples/inference/stable_diffusion/benchmark_sd3.py -m "PixArt-alpha/PixArt-XL-2-1024-MS" -p 2 --mode colossalai +# CUDA_VISIBLE_DEVICES_set_n_least_memory_usage 1 && python examples/inference/stable_diffusion/benchmark_sd3.py -m "PixArt-alpha/PixArt-XL-2-1024-MS" --mode diffusers + +# enable profiler +# CUDA_VISIBLE_DEVICES_set_n_least_memory_usage 2 && python examples/inference/stable_diffusion/benchmark_sd3.py -m "PixArt-alpha/PixArt-XL-2-1024-MS" -p 2 --mode colossalai --n_warm_up_steps 3 --n_repeat_times 1 --profile --num_inference_steps 20 +# CUDA_VISIBLE_DEVICES_set_n_least_memory_usage 1 && python examples/inference/stable_diffusion/benchmark_sd3.py -m "PixArt-alpha/PixArt-XL-2-1024-MS" --mode diffusers --n_warm_up_steps 3 --n_repeat_times 1 --profile --num_inference_steps 20 + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("-b", "--batch_size", type=int, default=8, help="batch size") + parser.add_argument("-p", "--patched_parallel_size", type=int, default=1, help="Patched Parallelism size") + parser.add_argument("-n", "--num_inference_steps", type=int, default=50, help="num of inference steps") + parser.add_argument("-H", "--height", type=Union[int, List[int]], default=[1024, 2048, 3840], help="Height List") + parser.add_argument("-w", "--width", type=Union[int, List[int]], default=[1024, 2048, 3840], help="Width List") + parser.add_argument("--dtype", type=str, default="fp16", help="data type", choices=["fp16", "fp32", "bf16"]) + parser.add_argument("--n_warm_up_steps", type=int, default=3, help="warm up times") + parser.add_argument("--n_repeat_times", type=int, default=5, help="repeat times") + parser.add_argument("--profile", default=False, action="store_true", help="enable torch profiler") + parser.add_argument( + "-m", + "--model", + default="stabilityai/stable-diffusion-3-medium-diffusers", + help="the type of model", + # choices=["stabilityai/stable-diffusion-3-medium-diffusers", "PixArt-alpha/PixArt-XL-2-1024-MS"], + ) + parser.add_argument( + "--mode", + default="colossalai", + # choices=["colossalai", "diffusers"], + help="decide which inference framework to run", + ) + args = parser.parse_args() + benchmark(args) From 86dcc7158cd03c04a937e81cdd0da4293cc740b1 Mon Sep 17 00:00:00 2001 From: Runyu Lu Date: Tue, 16 Jul 2024 07:58:20 +0000 Subject: [PATCH 04/10] pixart distrifusion bug fix --- .../inference/modeling/layers/distrifusion.py | 82 ++++++++++++++----- .../inference/modeling/policy/pixart_alpha.py | 12 ++- .../stable_diffusion/benchmark_sd3.py | 29 +++++-- 3 files changed, 93 insertions(+), 30 deletions(-) diff --git a/colossalai/inference/modeling/layers/distrifusion.py b/colossalai/inference/modeling/layers/distrifusion.py index 03015ab4110c..f93f8bd545c0 100644 --- a/colossalai/inference/modeling/layers/distrifusion.py +++ b/colossalai/inference/modeling/layers/distrifusion.py @@ -11,7 +11,7 @@ from diffusers.models import attention_processor from diffusers.models.attention import Attention from diffusers.models.embeddings import PatchEmbed, get_2d_sincos_pos_embed -from diffusers.models.transformers.transformer_2d import Transformer2DModel +from diffusers.models.transformers.pixart_transformer_2d import PixArtTransformer2DModel from diffusers.models.transformers.transformer_sd3 import SD3Transformer2DModel from torch import nn from torch.distributed import ProcessGroup @@ -34,7 +34,7 @@ # adapted from https://github.com/huggingface/diffusers/blob/v0.29.0-release/src/diffusers/models/transformers/transformer_2d.py def PixArtAlphaTransformer2DModel_forward( - self: Transformer2DModel, + self: PixArtTransformer2DModel, hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None, timestep: Optional[torch.LongTensor] = None, @@ -76,11 +76,20 @@ def PixArtAlphaTransformer2DModel_forward( encoder_attention_mask = encoder_attention_mask.unsqueeze(1) # 1. Input - assert self.is_input_patches == True - height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size - hidden_states, encoder_hidden_states, timestep, embedded_timestep = self._operate_on_patched_inputs( - hidden_states, encoder_hidden_states, timestep, added_cond_kwargs + batch_size = hidden_states.shape[0] + height, width = ( + hidden_states.shape[-2] // self.config.patch_size, + hidden_states.shape[-1] // self.config.patch_size, ) + hidden_states = self.pos_embed(hidden_states) + + timestep, embedded_timestep = self.adaln_single( + timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype + ) + + if self.caption_projection is not None: + encoder_hidden_states = self.caption_projection(encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1]) # 2. Blocks for block in self.transformer_blocks: @@ -95,20 +104,41 @@ def PixArtAlphaTransformer2DModel_forward( ) # 3. Output - output = self._get_output_for_patched_inputs( - hidden_states=hidden_states, - timestep=timestep, - class_labels=class_labels, - embedded_timestep=embedded_timestep, - height=height // self.patched_parallel_size, - width=width, + shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None].to(self.scale_shift_table.device)).chunk( + 2, dim=1 + ) + hidden_states = self.norm_out(hidden_states) + # Modulation + hidden_states = hidden_states * (1 + scale.to(hidden_states.device)) + shift.to(hidden_states.device) + hidden_states = self.proj_out(hidden_states) + hidden_states = hidden_states.squeeze(1) + + # unpatchify + hidden_states = hidden_states.reshape( + shape=( + -1, + height // self.patched_parallel_size, + width, + self.config.patch_size, + self.config.patch_size, + self.out_channels, + ) + ) + hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) + output = hidden_states.reshape( + shape=( + -1, + self.out_channels, + height // self.patched_parallel_size * self.config.patch_size, + width * self.config.patch_size, + ) ) # enable Distrifusion Optimization - if hasattr(self, "patched_parallel_size") is None: + if hasattr(self, "patched_parallel_size"): from torch import distributed as dist - if getattr(self, "output_buffer", None): + if getattr(self, "output_buffer", None) is None: self.output_buffer = torch.empty_like(output) if getattr(self, "buffer_list", None) is None: self.buffer_list = [torch.empty_like(output) for _ in range(self.patched_parallel_size)] @@ -473,7 +503,9 @@ def __init__( self.handle = None self.process_group = process_group self.async_nccl_stream = async_nccl_stream - self.warm_step = 5 # for warmup + self.launch_event = torch.cuda.Event() + self.sync_event = torch.cuda.Event() + self.warm_step = 3 # for warmup @staticmethod def from_native_module( @@ -481,9 +513,10 @@ def from_native_module( ) -> ParallelModule: model_shard_infer_config = kwargs.get("model_shard_infer_config", None) async_nccl_stream = kwargs.get("async_nccl_stream", None) + proc_group = kwargs.get("async_nccl_group", process_group) return DistriSelfAttention( module=module, - process_group=process_group, + process_group=proc_group, model_shard_infer_config=model_shard_infer_config, async_nccl_stream=async_nccl_stream, ) @@ -525,9 +558,15 @@ def _forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0): with torch.cuda.stream( self.async_nccl_stream ): # NOTE(@LRY89757) implementation async op of torch.distributed.all_gather to ensure efficient overlap of comp and comm - self.handle = dist.all_gather(self.buffer_list, kv, group=self.process_group, async_op=False) - self.handle = torch.cuda.Event() - self.handle.record(self.async_nccl_stream) + self.handle = dist.all_gather(self.buffer_list, kv, group=self.process_group, async_op=True) + self.launch_event.record(self.async_nccl_stream) + if self.handle is not None: + self.handle.wait() + self.handle = None + self.sync_event.record(self.async_nccl_stream) + + if self.handle is not None: + self.launch_event.wait() if HAS_FLASH_ATTN: # flash attn @@ -582,7 +621,8 @@ def forward( # async preallocates memo buffer if self.handle is not None: - self.handle.wait() + self.sync_event.wait() + # torch.cuda.default_stream().wait_stream(self.async_nccl_stream) self.handle = None b, l, c = hidden_states.shape diff --git a/colossalai/inference/modeling/policy/pixart_alpha.py b/colossalai/inference/modeling/policy/pixart_alpha.py index c0bbb7a1863b..01922bbbc18f 100644 --- a/colossalai/inference/modeling/policy/pixart_alpha.py +++ b/colossalai/inference/modeling/policy/pixart_alpha.py @@ -1,12 +1,13 @@ import torch from diffusers.models.attention import BasicTransformerBlock -from diffusers.models.transformers.transformer_2d import Transformer2DModel +from diffusers.models.transformers.pixart_transformer_2d import PixArtTransformer2DModel from torch import nn from colossalai.inference.config import RPC_PARAM from colossalai.inference.modeling.layers.diffusion import DiffusionPipe from colossalai.inference.modeling.layers.distrifusion import ( DistrifusionConv2D, + DistrifusionPatchEmbed, DistriSelfAttention, PixArtAlphaTransformer2DModel_forward, ) @@ -27,13 +28,18 @@ def module_policy(self): # attribute_replacement={"patched_parallel_size": self.shard_config.extra_kwargs["model_shard_infer_config"].patched_parallelism_size} # ) - policy[Transformer2DModel] = ModulePolicyDescription( + policy[PixArtTransformer2DModel] = ModulePolicyDescription( sub_module_replacement=[ SubModuleReplacementDescription( - suffix="pos_embed.conv", + suffix="pos_embed.proj", target_module=DistrifusionConv2D, kwargs={"model_shard_infer_config": self.shard_config.extra_kwargs["model_shard_infer_config"]}, ), + SubModuleReplacementDescription( + suffix="pos_embed", + target_module=DistrifusionPatchEmbed, + kwargs={"model_shard_infer_config": self.shard_config.extra_kwargs["model_shard_infer_config"]}, + ), ], attribute_replacement={ "patched_parallel_size": self.shard_config.extra_kwargs[ diff --git a/examples/inference/stable_diffusion/benchmark_sd3.py b/examples/inference/stable_diffusion/benchmark_sd3.py index 8cf7615c76d6..f9e3228484d2 100644 --- a/examples/inference/stable_diffusion/benchmark_sd3.py +++ b/examples/inference/stable_diffusion/benchmark_sd3.py @@ -22,6 +22,11 @@ } +def log_generation_time(log_message, log_file): + with open(log_file, "a") as f: + f.write(log_message) + + def benchmark_colossalai(rank, world_size, port, args): if isinstance(args.width, int): width_list = [args.width] @@ -80,9 +85,13 @@ def benchmark_colossalai(rank, world_size, port, args): ), ) end = time.perf_counter() - coordinator.print_on_master( - f"[ColossalAI]avg generation time for h({h})xw({w}) is {(end - start) / args.n_repeat_times:.2f}s" - ) + log_msg = f"[ColossalAI]avg generation time for h({h})xw({w}) is {(end - start) / args.n_repeat_times:.2f}s" + coordinator.print_on_master(log_msg) + if args.log: + log_file = f"examples/inference/stable_diffusion/benchmark_bs{args.batch_size}_pps{args.patched_parallel_size}_steps{args.num_inference_steps}_height{h}_width{w}_dtype{args.dtype}_profile{args.profile}_model{args.model.split('/')[-1]}_mode{args.mode}.log" + if dist.get_rank() == 0: + log_generation_time(log_message=log_msg, log_file=log_file) + if args.profile: file = f"examples/inference/stable_diffusion/benchmark_bs{args.batch_size}_pps{args.patched_parallel_size}_steps{args.num_inference_steps}_height{h}_width{w}_dtype{args.dtype}_warmup{args.n_warm_up_steps}_repeat{args.n_repeat_times}_profile{args.profile}_model{args.model.split('/')[-1]}_mode{args.mode}_rank_{dist.get_rank()}.json" prof.export_chrome_trace(file) @@ -127,7 +136,12 @@ def benchmark_diffusers(args): for i in range(args.n_repeat_times): model(prompt="hello world", num_inference_steps=args.num_inference_steps, height=h, width=w) end = time.perf_counter() - print(f"[ColossalAI]avg generation time for h({h})xw({w}) is {(end - start) / args.n_repeat_times:.2f}s") + log_msg = f"[Diffusers]avg generation time for h({h})xw({w}) is {(end - start) / args.n_repeat_times:.2f}s" + print(log_msg) + if args.log: + log_file = f"examples/inference/stable_diffusion/benchmark_bs{args.batch_size}_pps{args.patched_parallel_size}_steps{args.num_inference_steps}_height{h}_width{w}_dtype{args.dtype}_profile{args.profile}_model{args.model.split('/')[-1]}_mode{args.mode}.log" + log_generation_time(log_message=log_msg, log_file=log_file) + if args.profile: file = f"examples/inference/stable_diffusion/benchmark_bs{args.batch_size}_pps{args.patched_parallel_size}_steps{args.num_inference_steps}_height{h}_width{w}_dtype{args.dtype}_warmup{args.n_warm_up_steps}_repeat{args.n_repeat_times}_profile{args.profile}_model{args.model.split('/')[-1]}_mode{args.mode}.json" prof.export_chrome_trace(file) @@ -142,8 +156,10 @@ def benchmark(args): benchmark_diffusers(args) -# CUDA_VISIBLE_DEVICES_set_n_least_memory_usage 2 && python examples/inference/stable_diffusion/benchmark_sd3.py -m "PixArt-alpha/PixArt-XL-2-1024-MS" -p 2 --mode colossalai -# CUDA_VISIBLE_DEVICES_set_n_least_memory_usage 1 && python examples/inference/stable_diffusion/benchmark_sd3.py -m "PixArt-alpha/PixArt-XL-2-1024-MS" --mode diffusers +# CUDA_VISIBLE_DEVICES_set_n_least_memory_usage 2 && python examples/inference/stable_diffusion/benchmark_sd3.py -m "PixArt-alpha/PixArt-XL-2-1024-MS" -p 2 --mode colossalai --log +# CUDA_VISIBLE_DEVICES_set_n_least_memory_usage 4 && python examples/inference/stable_diffusion/benchmark_sd3.py -m "PixArt-alpha/PixArt-XL-2-1024-MS" -p 4 --mode colossalai --log +# CUDA_VISIBLE_DEVICES_set_n_least_memory_usage 8 && python examples/inference/stable_diffusion/benchmark_sd3.py -m "PixArt-alpha/PixArt-XL-2-1024-MS" -p 8 --mode colossalai --log +# CUDA_VISIBLE_DEVICES_set_n_least_memory_usage 1 && python examples/inference/stable_diffusion/benchmark_sd3.py -m "PixArt-alpha/PixArt-XL-2-1024-MS" --mode diffusers --log # enable profiler # CUDA_VISIBLE_DEVICES_set_n_least_memory_usage 2 && python examples/inference/stable_diffusion/benchmark_sd3.py -m "PixArt-alpha/PixArt-XL-2-1024-MS" -p 2 --mode colossalai --n_warm_up_steps 3 --n_repeat_times 1 --profile --num_inference_steps 20 @@ -160,6 +176,7 @@ def benchmark(args): parser.add_argument("--n_warm_up_steps", type=int, default=3, help="warm up times") parser.add_argument("--n_repeat_times", type=int, default=5, help="repeat times") parser.add_argument("--profile", default=False, action="store_true", help="enable torch profiler") + parser.add_argument("--log", default=False, action="store_true", help="enable torch profiler") parser.add_argument( "-m", "--model", From 7868216c0716748de67216652c79e5a72bec642f Mon Sep 17 00:00:00 2001 From: Runyu Lu Date: Wed, 17 Jul 2024 07:24:08 +0000 Subject: [PATCH 05/10] sd3 bug fix and benchmark --- colossalai/inference/config.py | 2 +- .../inference/modeling/layers/distrifusion.py | 36 ++-- .../inference/modeling/policy/pixart_alpha.py | 4 - .../modeling/policy/stablediffusion3.py | 10 +- examples/inference/stable_diffusion/README.md | 2 +- .../stable_diffusion/benchmark_sd3.py | 190 ++++++++---------- .../stable_diffusion/run_benchmark.sh | 43 ++++ 7 files changed, 149 insertions(+), 138 deletions(-) create mode 100644 examples/inference/stable_diffusion/run_benchmark.sh diff --git a/colossalai/inference/config.py b/colossalai/inference/config.py index 0a4722e755c3..2864d958eaee 100644 --- a/colossalai/inference/config.py +++ b/colossalai/inference/config.py @@ -296,7 +296,7 @@ def _verify_config(self) -> None: self.start_token_size = self.block_size # check Distrifusion - # (TODO@lry897575) need more detailed check + # TODO(@lry89757) need more detailed check if self.patched_parallelism_size > 1: # self.use_patched_parallelism = True self.tp_size = ( diff --git a/colossalai/inference/modeling/layers/distrifusion.py b/colossalai/inference/modeling/layers/distrifusion.py index f93f8bd545c0..7a73d84390e8 100644 --- a/colossalai/inference/modeling/layers/distrifusion.py +++ b/colossalai/inference/modeling/layers/distrifusion.py @@ -325,6 +325,7 @@ def __init__( module: attention_processor.Attention, process_group: Union[ProcessGroup, List[ProcessGroup]], model_shard_infer_config: ModelShardInferenceConfig = None, + async_nccl_stream: torch.cuda.Stream = None, ): super().__init__() self.counter = 0 @@ -334,6 +335,7 @@ def __init__( self.patched_parallelism_size = model_shard_infer_config.patched_parallelism_size self.handle = None self.process_group = process_group + self.async_nccl_stream = async_nccl_stream self.warm_step = 5 # for warmup @staticmethod @@ -341,8 +343,12 @@ def from_native_module( module: attention_processor.Attention, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs ) -> ParallelModule: model_shard_infer_config = kwargs.get("model_shard_infer_config", None) + async_nccl_stream = kwargs.get("async_nccl_stream", None) return Distrifusion_FusedAttention( - module=module, process_group=process_group, model_shard_infer_config=model_shard_infer_config + module=module, + process_group=process_group, + model_shard_infer_config=model_shard_infer_config, + async_nccl_stream=async_nccl_stream, ) def _forward( @@ -390,11 +396,11 @@ def _forward( full_kv = torch.cat(self.buffer_list, dim=1) else: # logger.info(f"use old kv to infer: {self.counter}") - new_buffer_list = [buffer for buffer in self.buffer_list] - new_buffer_list[self.kv_buffer_idx] = kv - full_kv = torch.cat(new_buffer_list, dim=1) + self.buffer_list[self.kv_buffer_idx].copy_(kv) + full_kv = torch.cat(self.buffer_list, dim=1) assert self.handle is None, "we should maintain the kv of last step" - self.handle = dist.all_gather(new_buffer_list, kv, group=self.process_group, async_op=True) + with torch.cuda.stream(self.async_nccl_stream): # NOTE(@lry89757) make nccl kernels' list more neat + self.handle = dist.all_gather(self.buffer_list, kv, group=self.process_group, async_op=True) key, value = torch.split(full_kv, full_kv.shape[-1] // 2, dim=-1) @@ -503,8 +509,6 @@ def __init__( self.handle = None self.process_group = process_group self.async_nccl_stream = async_nccl_stream - self.launch_event = torch.cuda.Event() - self.sync_event = torch.cuda.Event() self.warm_step = 3 # for warmup @staticmethod @@ -513,10 +517,9 @@ def from_native_module( ) -> ParallelModule: model_shard_infer_config = kwargs.get("model_shard_infer_config", None) async_nccl_stream = kwargs.get("async_nccl_stream", None) - proc_group = kwargs.get("async_nccl_group", process_group) return DistriSelfAttention( module=module, - process_group=proc_group, + process_group=process_group, model_shard_infer_config=model_shard_infer_config, async_nccl_stream=async_nccl_stream, ) @@ -555,18 +558,8 @@ def _forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0): self.buffer_list[self.kv_buffer_idx].copy_(kv) full_kv = torch.cat(self.buffer_list, dim=1) assert self.handle is None, "we should maintain the kv of last step" - with torch.cuda.stream( - self.async_nccl_stream - ): # NOTE(@LRY89757) implementation async op of torch.distributed.all_gather to ensure efficient overlap of comp and comm + with torch.cuda.stream(self.async_nccl_stream): # NOTE(@lry89757) make nccl kernels' list more neat self.handle = dist.all_gather(self.buffer_list, kv, group=self.process_group, async_op=True) - self.launch_event.record(self.async_nccl_stream) - if self.handle is not None: - self.handle.wait() - self.handle = None - self.sync_event.record(self.async_nccl_stream) - - if self.handle is not None: - self.launch_event.wait() if HAS_FLASH_ATTN: # flash attn @@ -621,8 +614,7 @@ def forward( # async preallocates memo buffer if self.handle is not None: - self.sync_event.wait() - # torch.cuda.default_stream().wait_stream(self.async_nccl_stream) + self.handle.wait() self.handle = None b, l, c = hidden_states.shape diff --git a/colossalai/inference/modeling/policy/pixart_alpha.py b/colossalai/inference/modeling/policy/pixart_alpha.py index 01922bbbc18f..0ff2f08401d9 100644 --- a/colossalai/inference/modeling/policy/pixart_alpha.py +++ b/colossalai/inference/modeling/policy/pixart_alpha.py @@ -24,10 +24,6 @@ def module_policy(self): if self.shard_config.extra_kwargs["model_shard_infer_config"].patched_parallelism_size > 1: - # policy[DiffusionPipe] = ModulePolicyDescription( - # attribute_replacement={"patched_parallel_size": self.shard_config.extra_kwargs["model_shard_infer_config"].patched_parallelism_size} - # ) - policy[PixArtTransformer2DModel] = ModulePolicyDescription( sub_module_replacement=[ SubModuleReplacementDescription( diff --git a/colossalai/inference/modeling/policy/stablediffusion3.py b/colossalai/inference/modeling/policy/stablediffusion3.py index 9a641853b09b..1a9e3c41eda3 100644 --- a/colossalai/inference/modeling/policy/stablediffusion3.py +++ b/colossalai/inference/modeling/policy/stablediffusion3.py @@ -1,3 +1,4 @@ +import torch from diffusers.models.attention import JointTransformerBlock from diffusers.models.transformers import SD3Transformer2DModel from torch import nn @@ -23,10 +24,6 @@ def module_policy(self): if self.shard_config.extra_kwargs["model_shard_infer_config"].patched_parallelism_size > 1: - # policy[DiffusionPipe] = ModulePolicyDescription( - # attribute_replacement={"patched_parallel_size": self.shard_config.extra_kwargs["model_shard_infer_config"].patched_parallelism_size} - # ) - policy[SD3Transformer2DModel] = ModulePolicyDescription( sub_module_replacement=[ SubModuleReplacementDescription( @@ -53,7 +50,10 @@ def module_policy(self): SubModuleReplacementDescription( suffix="attn", target_module=Distrifusion_FusedAttention, - kwargs={"model_shard_infer_config": self.shard_config.extra_kwargs["model_shard_infer_config"]}, + kwargs={ + "model_shard_infer_config": self.shard_config.extra_kwargs["model_shard_infer_config"], + "async_nccl_stream": torch.cuda.Stream(), + }, ) ] ) diff --git a/examples/inference/stable_diffusion/README.md b/examples/inference/stable_diffusion/README.md index b69fd1cb1c66..0f1afc805bec 100644 --- a/examples/inference/stable_diffusion/README.md +++ b/examples/inference/stable_diffusion/README.md @@ -1,6 +1,6 @@ ``` |- sd3_generation.py: an example of how to use Colossalai Inference Engine to generate result by loading Diffusion Model. |- compute_metric.py: compare the quality of images w/o some acceleration method like Distrifusion -|- benchmark.sh: benchmark the performance of our InferenceEngine +|- run_benchmark.sh: benchmark the performance of our InferenceEngine ``` note: compute_metric.py need some dependencies which need `pip install`. diff --git a/examples/inference/stable_diffusion/benchmark_sd3.py b/examples/inference/stable_diffusion/benchmark_sd3.py index f9e3228484d2..ce8cd0fb2e93 100644 --- a/examples/inference/stable_diffusion/benchmark_sd3.py +++ b/examples/inference/stable_diffusion/benchmark_sd3.py @@ -1,7 +1,7 @@ import argparse +import json import time from contextlib import nullcontext -from typing import List, Union import torch import torch.distributed as dist @@ -22,37 +22,14 @@ } -def log_generation_time(log_message, log_file): +def log_generation_time(log_data, log_file): with open(log_file, "a") as f: - f.write(log_message) + json.dump(log_data, f, indent=2) + f.write("\n") -def benchmark_colossalai(rank, world_size, port, args): - if isinstance(args.width, int): - width_list = [args.width] - else: - width_list = args.width - - if isinstance(args.height, int): - height_list = [args.height] - else: - height_list = args.height - - assert len(width_list) == len(height_list) - - colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") - from colossalai.cluster.dist_coordinator import DistCoordinator - - coordinator = DistCoordinator() - - inference_config = InferenceConfig( - dtype=args.dtype, - patched_parallelism_size=args.patched_parallel_size, - ) - engine = InferenceEngine(args.model, inference_config=inference_config, verbose=False) - - # warmup - for i in range(args.n_warm_up_steps): +def warmup(engine, args): + for _ in range(args.n_warm_up_steps): engine.generate( prompts=["hello world"], generation_config=DiffusionGenerationConfig( @@ -60,7 +37,9 @@ def benchmark_colossalai(rank, world_size, port, args): ), ) - ctx = ( + +def profile_context(args): + return ( torch.profiler.profile( record_shapes=True, with_stack=True, @@ -74,10 +53,51 @@ def benchmark_colossalai(rank, world_size, port, args): else nullcontext() ) - for h, w in zip(height_list, width_list): - with ctx as prof: + +def log_and_profile(h, w, avg_time, log_msg, args, model_name, mode, prof=None): + log_data = { + "mode": mode, + "model": model_name, + "batch_size": args.batch_size, + "patched_parallel_size": args.patched_parallel_size, + "num_inference_steps": args.num_inference_steps, + "height": h, + "width": w, + "dtype": args.dtype, + "profile": args.profile, + "n_warm_up_steps": args.n_warm_up_steps, + "n_repeat_times": args.n_repeat_times, + "avg_generation_time": avg_time, + "log_message": log_msg, + } + + if args.log: + log_file = f"examples/inference/stable_diffusion/benchmark_{model_name}_{mode}.json" + log_generation_time(log_data=log_data, log_file=log_file) + + if args.profile: + file = f"examples/inference/stable_diffusion/benchmark_{model_name}_{mode}_prof.json" + prof.export_chrome_trace(file) + + +def benchmark_colossalai(rank, world_size, port, args): + colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + from colossalai.cluster.dist_coordinator import DistCoordinator + + coordinator = DistCoordinator() + + inference_config = InferenceConfig( + dtype=args.dtype, + patched_parallelism_size=args.patched_parallel_size, + ) + engine = InferenceEngine(args.model, inference_config=inference_config, verbose=False) + + warmup(engine, args) + + for h, w in zip(args.height, args.width): + with profile_context(args) as prof: start = time.perf_counter() - for i in range(args.n_repeat_times): + for _ in range(args.n_repeat_times): engine.generate( prompts=["hello world"], generation_config=DiffusionGenerationConfig( @@ -85,66 +105,33 @@ def benchmark_colossalai(rank, world_size, port, args): ), ) end = time.perf_counter() - log_msg = f"[ColossalAI]avg generation time for h({h})xw({w}) is {(end - start) / args.n_repeat_times:.2f}s" + + avg_time = (end - start) / args.n_repeat_times + log_msg = f"[ColossalAI]avg generation time for h({h})xw({w}) is {avg_time:.2f}s" coordinator.print_on_master(log_msg) - if args.log: - log_file = f"examples/inference/stable_diffusion/benchmark_bs{args.batch_size}_pps{args.patched_parallel_size}_steps{args.num_inference_steps}_height{h}_width{w}_dtype{args.dtype}_profile{args.profile}_model{args.model.split('/')[-1]}_mode{args.mode}.log" - if dist.get_rank() == 0: - log_generation_time(log_message=log_msg, log_file=log_file) - if args.profile: - file = f"examples/inference/stable_diffusion/benchmark_bs{args.batch_size}_pps{args.patched_parallel_size}_steps{args.num_inference_steps}_height{h}_width{w}_dtype{args.dtype}_warmup{args.n_warm_up_steps}_repeat{args.n_repeat_times}_profile{args.profile}_model{args.model.split('/')[-1]}_mode{args.mode}_rank_{dist.get_rank()}.json" - prof.export_chrome_trace(file) + if dist.get_rank() == 0: + log_and_profile(h, w, avg_time, log_msg, args, args.model.split("/")[-1], "colossalai", prof=prof) def benchmark_diffusers(args): - if isinstance(args.width, int): - width_list = [args.width] - else: - width_list = args.width - - if isinstance(args.height, int): - height_list = [args.height] - else: - height_list = args.height - - assert len(width_list) == len(height_list) - model = DiffusionPipeline.from_pretrained(args.model, torch_dtype=_DTYPE_MAPPING[args.dtype]).to("cuda") - # warmup - for i in range(args.n_warm_up_steps): + for _ in range(args.n_warm_up_steps): model(prompt="hello world", num_inference_steps=args.num_inference_steps, height=1024, width=1024) - ctx = ( - torch.profiler.profile( - record_shapes=True, - with_stack=True, - with_modules=True, - activities=[ - torch.profiler.ProfilerActivity.CPU, - torch.profiler.ProfilerActivity.CUDA, - ], - ) - if args.profile - else nullcontext() - ) - - for h, w in zip(height_list, width_list): - with ctx as prof: + for h, w in zip(args.height, args.width): + with profile_context(args) as prof: start = time.perf_counter() - for i in range(args.n_repeat_times): + for _ in range(args.n_repeat_times): model(prompt="hello world", num_inference_steps=args.num_inference_steps, height=h, width=w) end = time.perf_counter() - log_msg = f"[Diffusers]avg generation time for h({h})xw({w}) is {(end - start) / args.n_repeat_times:.2f}s" + + avg_time = (end - start) / args.n_repeat_times + log_msg = f"[Diffusers]avg generation time for h({h})xw({w}) is {avg_time:.2f}s" print(log_msg) - if args.log: - log_file = f"examples/inference/stable_diffusion/benchmark_bs{args.batch_size}_pps{args.patched_parallel_size}_steps{args.num_inference_steps}_height{h}_width{w}_dtype{args.dtype}_profile{args.profile}_model{args.model.split('/')[-1]}_mode{args.mode}.log" - log_generation_time(log_message=log_msg, log_file=log_file) - if args.profile: - file = f"examples/inference/stable_diffusion/benchmark_bs{args.batch_size}_pps{args.patched_parallel_size}_steps{args.num_inference_steps}_height{h}_width{w}_dtype{args.dtype}_warmup{args.n_warm_up_steps}_repeat{args.n_repeat_times}_profile{args.profile}_model{args.model.split('/')[-1]}_mode{args.mode}.json" - prof.export_chrome_trace(file) + log_and_profile(h, w, avg_time, log_msg, args, args.model.split("/")[-1], "diffusers", prof) @rerun_if_address_is_in_use() @@ -156,39 +143,32 @@ def benchmark(args): benchmark_diffusers(args) -# CUDA_VISIBLE_DEVICES_set_n_least_memory_usage 2 && python examples/inference/stable_diffusion/benchmark_sd3.py -m "PixArt-alpha/PixArt-XL-2-1024-MS" -p 2 --mode colossalai --log -# CUDA_VISIBLE_DEVICES_set_n_least_memory_usage 4 && python examples/inference/stable_diffusion/benchmark_sd3.py -m "PixArt-alpha/PixArt-XL-2-1024-MS" -p 4 --mode colossalai --log -# CUDA_VISIBLE_DEVICES_set_n_least_memory_usage 8 && python examples/inference/stable_diffusion/benchmark_sd3.py -m "PixArt-alpha/PixArt-XL-2-1024-MS" -p 8 --mode colossalai --log -# CUDA_VISIBLE_DEVICES_set_n_least_memory_usage 1 && python examples/inference/stable_diffusion/benchmark_sd3.py -m "PixArt-alpha/PixArt-XL-2-1024-MS" --mode diffusers --log +""" +# enable log +python examples/inference/stable_diffusion/benchmark_sd3.py -m "PixArt-alpha/PixArt-XL-2-1024-MS" -p 2 --mode colossalai --log +python examples/inference/stable_diffusion/benchmark_sd3.py -m "PixArt-alpha/PixArt-XL-2-1024-MS" --mode diffusers --log # enable profiler -# CUDA_VISIBLE_DEVICES_set_n_least_memory_usage 2 && python examples/inference/stable_diffusion/benchmark_sd3.py -m "PixArt-alpha/PixArt-XL-2-1024-MS" -p 2 --mode colossalai --n_warm_up_steps 3 --n_repeat_times 1 --profile --num_inference_steps 20 -# CUDA_VISIBLE_DEVICES_set_n_least_memory_usage 1 && python examples/inference/stable_diffusion/benchmark_sd3.py -m "PixArt-alpha/PixArt-XL-2-1024-MS" --mode diffusers --n_warm_up_steps 3 --n_repeat_times 1 --profile --num_inference_steps 20 +python examples/inference/stable_diffusion/benchmark_sd3.py -m "stabilityai/stable-diffusion-3-medium-diffusers" -p 2 --mode colossalai --n_warm_up_steps 3 --n_repeat_times 1 --profile --num_inference_steps 20 +python examples/inference/stable_diffusion/benchmark_sd3.py -m "PixArt-alpha/PixArt-XL-2-1024-MS" -p 2 --mode colossalai --n_warm_up_steps 3 --n_repeat_times 1 --profile --num_inference_steps 20 +python examples/inference/stable_diffusion/benchmark_sd3.py -m "PixArt-alpha/PixArt-XL-2-1024-MS" --mode diffusers --n_warm_up_steps 3 --n_repeat_times 1 --profile --num_inference_steps 20 +""" if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("-b", "--batch_size", type=int, default=8, help="batch size") + parser.add_argument("-b", "--batch_size", type=int, default=1, help="Batch size") parser.add_argument("-p", "--patched_parallel_size", type=int, default=1, help="Patched Parallelism size") - parser.add_argument("-n", "--num_inference_steps", type=int, default=50, help="num of inference steps") - parser.add_argument("-H", "--height", type=Union[int, List[int]], default=[1024, 2048, 3840], help="Height List") - parser.add_argument("-w", "--width", type=Union[int, List[int]], default=[1024, 2048, 3840], help="Width List") - parser.add_argument("--dtype", type=str, default="fp16", help="data type", choices=["fp16", "fp32", "bf16"]) - parser.add_argument("--n_warm_up_steps", type=int, default=3, help="warm up times") - parser.add_argument("--n_repeat_times", type=int, default=5, help="repeat times") - parser.add_argument("--profile", default=False, action="store_true", help="enable torch profiler") - parser.add_argument("--log", default=False, action="store_true", help="enable torch profiler") - parser.add_argument( - "-m", - "--model", - default="stabilityai/stable-diffusion-3-medium-diffusers", - help="the type of model", - # choices=["stabilityai/stable-diffusion-3-medium-diffusers", "PixArt-alpha/PixArt-XL-2-1024-MS"], - ) + parser.add_argument("-n", "--num_inference_steps", type=int, default=50, help="Number of inference steps") + parser.add_argument("-H", "--height", type=int, nargs="+", default=[1024], help="Height list") + parser.add_argument("-w", "--width", type=int, nargs="+", default=[1024], help="Width list") + parser.add_argument("--dtype", type=str, default="fp16", choices=["fp16", "fp32", "bf16"], help="Data type") + parser.add_argument("--n_warm_up_steps", type=int, default=3, help="Number of warm up steps") + parser.add_argument("--n_repeat_times", type=int, default=5, help="Number of repeat times") + parser.add_argument("--profile", default=False, action="store_true", help="Enable torch profiler") + parser.add_argument("--log", default=False, action="store_true", help="Enable logging") + parser.add_argument("-m", "--model", default="stabilityai/stable-diffusion-3-medium-diffusers", help="Model path") parser.add_argument( - "--mode", - default="colossalai", - # choices=["colossalai", "diffusers"], - help="decide which inference framework to run", + "--mode", default="colossalai", choices=["colossalai", "diffusers"], help="Inference framework mode" ) args = parser.parse_args() benchmark(args) diff --git a/examples/inference/stable_diffusion/run_benchmark.sh b/examples/inference/stable_diffusion/run_benchmark.sh new file mode 100644 index 000000000000..63e346f35c82 --- /dev/null +++ b/examples/inference/stable_diffusion/run_benchmark.sh @@ -0,0 +1,43 @@ +#!/bin/bash + +models=("PixArt-alpha/PixArt-XL-2-1024-MS" "stabilityai/stable-diffusion-3-medium-diffusers") +# parallelism=(1 2 4 8) +parallelism=(1 2 4) +resolutions=(1024 2048 3840) +modes=("colossalai" "diffusers") + +CUDA_VISIBLE_DEVICES_set_n_least_memory_usage() { + local n=${1:-"9999"} + echo "GPU Memory Usage:" + local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv \ + | tail -n +2 \ + | nl -v 0 \ + | tee /dev/tty \ + | sort -g -k 2 \ + | awk '{print $1}' \ + | head -n $n) + export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g') + echo "Now CUDA_VISIBLE_DEVICES is set to:" + echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" +} + +for model in "${models[@]}"; do + for p in "${parallelism[@]}"; do + for resolution in "${resolutions[@]}"; do + for mode in "${modes[@]}"; do + if [[ "$mode" == "colossalai" && "$p" == 1 ]]; then + continue + fi + if [[ "$mode" == "diffusers" && "$p" != 1 ]]; then + continue + fi + CUDA_VISIBLE_DEVICES_set_n_least_memory_usage $p + + cmd="python examples/inference/stable_diffusion/benchmark_sd3.py -m \"$model\" -p $p --mode $mode --log -H $resolution -w $resolution" + + echo "Executing: $cmd" + eval $cmd + done + done + done +done From b0dba97204cb431af60d02c60438378bd0dbb06e Mon Sep 17 00:00:00 2001 From: Runyu Lu Date: Wed, 17 Jul 2024 09:22:39 +0000 Subject: [PATCH 06/10] generation bug fix --- colossalai/inference/config.py | 6 ++---- .../inference/modeling/layers/distrifusion.py | 14 ++------------ .../inference/modeling/policy/pixart_alpha.py | 2 -- .../modeling/policy/stablediffusion3.py | 2 -- .../stable_diffusion/sd3_generation.py | 18 ++++++++---------- 5 files changed, 12 insertions(+), 30 deletions(-) diff --git a/colossalai/inference/config.py b/colossalai/inference/config.py index 2864d958eaee..24ec642708f0 100644 --- a/colossalai/inference/config.py +++ b/colossalai/inference/config.py @@ -246,11 +246,9 @@ class InferenceConfig(RPC_PARAM): generated_token_size: int = 512 # Acceleration for Diffusion Model(PipeFusion or Distrifusion) - # use_patched_parallelism : bool = False patched_parallelism_size: int = 1 # for distrifusion - # use_pipefusion : bool = False - pipeFusion_m_size: int = 1 # for pipefusion - pipeFusion_n_size: int = 1 # for pipefusion + # pipeFusion_m_size: int = 1 # for pipefusion + # pipeFusion_n_size: int = 1 # for pipefusion def __post_init__(self): self.max_context_len_to_capture = self.max_input_len + self.max_output_len diff --git a/colossalai/inference/modeling/layers/distrifusion.py b/colossalai/inference/modeling/layers/distrifusion.py index 7a73d84390e8..fcf29a3e7267 100644 --- a/colossalai/inference/modeling/layers/distrifusion.py +++ b/colossalai/inference/modeling/layers/distrifusion.py @@ -325,7 +325,6 @@ def __init__( module: attention_processor.Attention, process_group: Union[ProcessGroup, List[ProcessGroup]], model_shard_infer_config: ModelShardInferenceConfig = None, - async_nccl_stream: torch.cuda.Stream = None, ): super().__init__() self.counter = 0 @@ -335,7 +334,6 @@ def __init__( self.patched_parallelism_size = model_shard_infer_config.patched_parallelism_size self.handle = None self.process_group = process_group - self.async_nccl_stream = async_nccl_stream self.warm_step = 5 # for warmup @staticmethod @@ -343,12 +341,10 @@ def from_native_module( module: attention_processor.Attention, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs ) -> ParallelModule: model_shard_infer_config = kwargs.get("model_shard_infer_config", None) - async_nccl_stream = kwargs.get("async_nccl_stream", None) return Distrifusion_FusedAttention( module=module, process_group=process_group, model_shard_infer_config=model_shard_infer_config, - async_nccl_stream=async_nccl_stream, ) def _forward( @@ -399,8 +395,7 @@ def _forward( self.buffer_list[self.kv_buffer_idx].copy_(kv) full_kv = torch.cat(self.buffer_list, dim=1) assert self.handle is None, "we should maintain the kv of last step" - with torch.cuda.stream(self.async_nccl_stream): # NOTE(@lry89757) make nccl kernels' list more neat - self.handle = dist.all_gather(self.buffer_list, kv, group=self.process_group, async_op=True) + self.handle = dist.all_gather(self.buffer_list, kv, group=self.process_group, async_op=True) key, value = torch.split(full_kv, full_kv.shape[-1] // 2, dim=-1) @@ -498,7 +493,6 @@ def __init__( module: Attention, process_group: Union[ProcessGroup, List[ProcessGroup]], model_shard_infer_config: ModelShardInferenceConfig = None, - async_nccl_stream: torch.cuda.Stream = None, ): super().__init__() self.counter = 0 @@ -508,7 +502,6 @@ def __init__( self.patched_parallelism_size = model_shard_infer_config.patched_parallelism_size self.handle = None self.process_group = process_group - self.async_nccl_stream = async_nccl_stream self.warm_step = 3 # for warmup @staticmethod @@ -516,12 +509,10 @@ def from_native_module( module: Attention, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs ) -> ParallelModule: model_shard_infer_config = kwargs.get("model_shard_infer_config", None) - async_nccl_stream = kwargs.get("async_nccl_stream", None) return DistriSelfAttention( module=module, process_group=process_group, model_shard_infer_config=model_shard_infer_config, - async_nccl_stream=async_nccl_stream, ) def _forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0): @@ -558,8 +549,7 @@ def _forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0): self.buffer_list[self.kv_buffer_idx].copy_(kv) full_kv = torch.cat(self.buffer_list, dim=1) assert self.handle is None, "we should maintain the kv of last step" - with torch.cuda.stream(self.async_nccl_stream): # NOTE(@lry89757) make nccl kernels' list more neat - self.handle = dist.all_gather(self.buffer_list, kv, group=self.process_group, async_op=True) + self.handle = dist.all_gather(self.buffer_list, kv, group=self.process_group, async_op=True) if HAS_FLASH_ATTN: # flash attn diff --git a/colossalai/inference/modeling/policy/pixart_alpha.py b/colossalai/inference/modeling/policy/pixart_alpha.py index 0ff2f08401d9..1150b2432cc5 100644 --- a/colossalai/inference/modeling/policy/pixart_alpha.py +++ b/colossalai/inference/modeling/policy/pixart_alpha.py @@ -1,4 +1,3 @@ -import torch from diffusers.models.attention import BasicTransformerBlock from diffusers.models.transformers.pixart_transformer_2d import PixArtTransformer2DModel from torch import nn @@ -52,7 +51,6 @@ def module_policy(self): target_module=DistriSelfAttention, kwargs={ "model_shard_infer_config": self.shard_config.extra_kwargs["model_shard_infer_config"], - "async_nccl_stream": torch.cuda.Stream(), }, ) ] diff --git a/colossalai/inference/modeling/policy/stablediffusion3.py b/colossalai/inference/modeling/policy/stablediffusion3.py index 1a9e3c41eda3..db279fb1a7dd 100644 --- a/colossalai/inference/modeling/policy/stablediffusion3.py +++ b/colossalai/inference/modeling/policy/stablediffusion3.py @@ -1,4 +1,3 @@ -import torch from diffusers.models.attention import JointTransformerBlock from diffusers.models.transformers import SD3Transformer2DModel from torch import nn @@ -52,7 +51,6 @@ def module_policy(self): target_module=Distrifusion_FusedAttention, kwargs={ "model_shard_infer_config": self.shard_config.extra_kwargs["model_shard_infer_config"], - "async_nccl_stream": torch.cuda.Stream(), }, ) ] diff --git a/examples/inference/stable_diffusion/sd3_generation.py b/examples/inference/stable_diffusion/sd3_generation.py index daeabeefeaf7..9e146c34b937 100644 --- a/examples/inference/stable_diffusion/sd3_generation.py +++ b/examples/inference/stable_diffusion/sd3_generation.py @@ -1,6 +1,6 @@ import argparse -from diffusers import PixArtAlphaPipeline, StableDiffusion3Pipeline +from diffusers import DiffusionPipeline from torch import bfloat16 from torch import distributed as dist from torch import float16, float32 @@ -9,12 +9,9 @@ from colossalai.cluster import DistCoordinator from colossalai.inference.config import DiffusionGenerationConfig, InferenceConfig from colossalai.inference.core.engine import InferenceEngine -from colossalai.inference.modeling.policy.pixart_alpha import PixArtAlphaInferPolicy -from colossalai.inference.modeling.policy.stablediffusion3 import StableDiffusion3InferPolicy # For Stable Diffusion 3, we'll use the following configuration -MODEL_CLS = [StableDiffusion3Pipeline, PixArtAlphaPipeline][0] -POLICY_CLS = [StableDiffusion3InferPolicy, PixArtAlphaInferPolicy][0] +MODEL_CLS = DiffusionPipeline TORCH_DTYPE_MAP = { "fp16": float16, @@ -47,7 +44,7 @@ def infer(args): use_cuda_kernel=args.use_cuda_kernel, patched_parallelism_size=dist.get_world_size(), ) - engine = InferenceEngine(model, inference_config=inference_config, model_policy=POLICY_CLS(), verbose=True) + engine = InferenceEngine(model, inference_config=inference_config, verbose=True) # ============================== # Generation @@ -60,11 +57,12 @@ def infer(args): # colossalai run --nproc_per_node 1 examples/inference/stable_diffusion/sd3_generation.py -m MODEL_PATH -# CUDA_VISIBLE_DEVICES_set_n_least_memory_usage 1 && colossalai run --nproc_per_node 1 examples/inference/stable_diffusion/sd3_generation.py -m "stabilityai/stable-diffusion-3-medium-diffusers" --tp_size 1 -# CUDA_VISIBLE_DEVICES_set_n_least_memory_usage 2 && colossalai run --nproc_per_node 2 examples/inference/stable_diffusion/sd3_generation.py -m "stabilityai/stable-diffusion-3-medium-diffusers" --tp_size 1 -# CUDA_VISIBLE_DEVICES_set_n_least_memory_usage 1 && colossalai run --nproc_per_node 1 examples/inference/stable_diffusion/sd3_generation.py -m "PixArt-alpha/PixArt-XL-2-1024-MS" --tp_size 1 -# CUDA_VISIBLE_DEVICES_set_n_least_memory_usage 2 && colossalai run --nproc_per_node 2 examples/inference/stable_diffusion/sd3_generation.py -m "PixArt-alpha/PixArt-XL-2-1024-MS" --tp_size 1 +# colossalai run --nproc_per_node 1 examples/inference/stable_diffusion/sd3_generation.py -m "stabilityai/stable-diffusion-3-medium-diffusers" --tp_size 1 +# colossalai run --nproc_per_node 2 examples/inference/stable_diffusion/sd3_generation.py -m "stabilityai/stable-diffusion-3-medium-diffusers" --tp_size 1 + +# colossalai run --nproc_per_node 1 examples/inference/stable_diffusion/sd3_generation.py -m "PixArt-alpha/PixArt-XL-2-1024-MS" --tp_size 1 +# colossalai run --nproc_per_node 2 examples/inference/stable_diffusion/sd3_generation.py -m "PixArt-alpha/PixArt-XL-2-1024-MS" --tp_size 1 if __name__ == "__main__": From 8d23ebd0668734bd0ab745e8cc484c4439098a83 Mon Sep 17 00:00:00 2001 From: Runyu Lu Date: Wed, 17 Jul 2024 09:34:56 +0000 Subject: [PATCH 07/10] naming fix --- colossalai/inference/modeling/layers/distrifusion.py | 4 ++-- colossalai/inference/modeling/policy/stablediffusion3.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/colossalai/inference/modeling/layers/distrifusion.py b/colossalai/inference/modeling/layers/distrifusion.py index fcf29a3e7267..238933c2eba5 100644 --- a/colossalai/inference/modeling/layers/distrifusion.py +++ b/colossalai/inference/modeling/layers/distrifusion.py @@ -318,7 +318,7 @@ def forward(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: # Code adapted from: https://github.com/huggingface/diffusers/blob/v0.29.0-release/src/diffusers/models/attention_processor.py -class Distrifusion_FusedAttention(ParallelModule): +class DistrifusionFusedAttention(ParallelModule): def __init__( self, @@ -341,7 +341,7 @@ def from_native_module( module: attention_processor.Attention, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs ) -> ParallelModule: model_shard_infer_config = kwargs.get("model_shard_infer_config", None) - return Distrifusion_FusedAttention( + return DistrifusionFusedAttention( module=module, process_group=process_group, model_shard_infer_config=model_shard_infer_config, diff --git a/colossalai/inference/modeling/policy/stablediffusion3.py b/colossalai/inference/modeling/policy/stablediffusion3.py index db279fb1a7dd..39b764b92887 100644 --- a/colossalai/inference/modeling/policy/stablediffusion3.py +++ b/colossalai/inference/modeling/policy/stablediffusion3.py @@ -5,8 +5,8 @@ from colossalai.inference.config import RPC_PARAM from colossalai.inference.modeling.layers.diffusion import DiffusionPipe from colossalai.inference.modeling.layers.distrifusion import ( - Distrifusion_FusedAttention, DistrifusionConv2D, + DistrifusionFusedAttention, DistrifusionPatchEmbed, SD3Transformer2DModel_forward, ) @@ -48,7 +48,7 @@ def module_policy(self): sub_module_replacement=[ SubModuleReplacementDescription( suffix="attn", - target_module=Distrifusion_FusedAttention, + target_module=DistrifusionFusedAttention, kwargs={ "model_shard_infer_config": self.shard_config.extra_kwargs["model_shard_infer_config"], }, From 48b8081e9a21c43780a1867451b17fefb5cf41f5 Mon Sep 17 00:00:00 2001 From: Runyu Lu Date: Thu, 18 Jul 2024 09:34:23 +0000 Subject: [PATCH 08/10] add docstring, fix counter and shape error --- colossalai/inference/README.md | 12 ++++++++++- colossalai/inference/config.py | 1 + .../inference/modeling/layers/distrifusion.py | 20 +++++++++++-------- .../stable_diffusion/benchmark_sd3.py | 13 ++++++++---- .../stable_diffusion/run_benchmark.sh | 3 +-- 5 files changed, 34 insertions(+), 15 deletions(-) diff --git a/colossalai/inference/README.md b/colossalai/inference/README.md index 0a9b5293d4a2..76813a4a3495 100644 --- a/colossalai/inference/README.md +++ b/colossalai/inference/README.md @@ -18,7 +18,7 @@ ## 📌 Introduction -ColossalAI-Inference is a module which offers acceleration to the inference execution of Transformers models, especially LLMs. In ColossalAI-Inference, we leverage high-performance kernels, KV cache, paged attention, continous batching and other techniques to accelerate the inference of LLMs. We also provide simple and unified APIs for the sake of user-friendliness. [[blog]](https://hpc-ai.com/blog/colossal-inference) +ColossalAI-Inference is a module which offers acceleration to the inference execution of Transformers models, especially LLMs and DiT Diffusion Models. In ColossalAI-Inference, we leverage high-performance kernels, KV cache, paged attention, continous batching and other techniques to accelerate the inference of LLMs. We also provide simple and unified APIs for the sake of user-friendliness. [[blog]](https://hpc-ai.com/blog/colossal-inference)

@@ -310,4 +310,14 @@ If you wish to cite relevant research papars, you can find the reference below. journal={arXiv}, year={2023} } + +# Distrifusion +@InProceedings{Li_2024_CVPR, + author={Li, Muyang and Cai, Tianle and Cao, Jiaxin and Zhang, Qinsheng and Cai, Han and Bai, Junjie and Jia, Yangqing and Li, Kai and Han, Song}, + title={DistriFusion: Distributed Parallel Inference for High-Resolution Diffusion Models}, + booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, + month={June}, + year={2024}, + pages={7183-7193} +} ``` diff --git a/colossalai/inference/config.py b/colossalai/inference/config.py index 24ec642708f0..072ddbcfd298 100644 --- a/colossalai/inference/config.py +++ b/colossalai/inference/config.py @@ -186,6 +186,7 @@ class InferenceConfig(RPC_PARAM): enable_streamingllm(bool): Whether to use StreamingLLM, the relevant algorithms refer to the paper at https://arxiv.org/pdf/2309.17453 for implementation. start_token_size(int): The size of the start tokens, when using StreamingLLM. generated_token_size(int): The size of the generated tokens, When using StreamingLLM. + patched_parallelism_size(int): Patched Parallelism Size, When using Distrifusion """ # NOTE: arrange configs according to their importance and frequency of usage diff --git a/colossalai/inference/modeling/layers/distrifusion.py b/colossalai/inference/modeling/layers/distrifusion.py index 238933c2eba5..ea97cceefac9 100644 --- a/colossalai/inference/modeling/layers/distrifusion.py +++ b/colossalai/inference/modeling/layers/distrifusion.py @@ -138,9 +138,9 @@ def PixArtAlphaTransformer2DModel_forward( if hasattr(self, "patched_parallel_size"): from torch import distributed as dist - if getattr(self, "output_buffer", None) is None: + if (getattr(self, "output_buffer", None) is None) or (self.output_buffer.shape != output.shape): self.output_buffer = torch.empty_like(output) - if getattr(self, "buffer_list", None) is None: + if (getattr(self, "buffer_list", None) is None) or (self.buffer_list[0].shape != output.shape): self.buffer_list = [torch.empty_like(output) for _ in range(self.patched_parallel_size)] output = output.contiguous() dist.all_gather(self.buffer_list, output, async_op=False) @@ -196,9 +196,9 @@ def SD3Transformer2DModel_forward( if hasattr(self, "patched_parallel_size"): from torch import distributed as dist - if getattr(self, "output_buffer", None) is None: + if (getattr(self, "output_buffer", None) is None) or (self.output_buffer.shape != output.shape): self.output_buffer = torch.empty_like(output) - if getattr(self, "buffer_list", None) is None: + if (getattr(self, "buffer_list", None) is None) or (self.buffer_list[0].shape != output.shape): self.buffer_list = [torch.empty_like(output) for _ in range(self.patched_parallel_size)] output = output.contiguous() dist.all_gather(self.buffer_list, output, async_op=False) @@ -454,14 +454,16 @@ def forward( self.handle = None b, l, c = hidden_states.shape - if self.patched_parallelism_size > 1 and self.buffer_list is None: + kv_shape = (b, l, self.module.to_k.out_features * 2) + if self.patched_parallelism_size > 1 and (self.buffer_list is None or self.buffer_list[0].shape != kv_shape): - kv_shape = (b, l, self.module.to_k.out_features * 2) self.buffer_list = [ torch.empty(kv_shape, dtype=hidden_states.dtype, device=get_current_device()) for _ in range(self.patched_parallelism_size) ] + self.counter = 0 + attn_parameters = set(inspect.signature(self.module.processor.__call__).parameters.keys()) quiet_attn_parameters = {"ip_adapter_masks"} unused_kwargs = [ @@ -608,14 +610,16 @@ def forward( self.handle = None b, l, c = hidden_states.shape - if self.patched_parallelism_size > 1 and self.buffer_list is None: + kv_shape = (b, l, self.module.to_k.out_features * 2) + if self.patched_parallelism_size > 1 and (self.buffer_list is None or self.buffer_list[0].shape != kv_shape): - kv_shape = (b, l, self.module.to_k.out_features * 2) self.buffer_list = [ torch.empty(kv_shape, dtype=hidden_states.dtype, device=get_current_device()) for _ in range(self.patched_parallelism_size) ] + self.counter = 0 + output = self._forward(hidden_states, scale=scale) self.counter += 1 diff --git a/examples/inference/stable_diffusion/benchmark_sd3.py b/examples/inference/stable_diffusion/benchmark_sd3.py index ce8cd0fb2e93..19db57c33c82 100644 --- a/examples/inference/stable_diffusion/benchmark_sd3.py +++ b/examples/inference/stable_diffusion/benchmark_sd3.py @@ -33,7 +33,7 @@ def warmup(engine, args): engine.generate( prompts=["hello world"], generation_config=DiffusionGenerationConfig( - num_inference_steps=args.num_inference_steps, height=1024, width=1024 + num_inference_steps=args.num_inference_steps, height=args.height[0], width=args.width[0] ), ) @@ -118,7 +118,12 @@ def benchmark_diffusers(args): model = DiffusionPipeline.from_pretrained(args.model, torch_dtype=_DTYPE_MAPPING[args.dtype]).to("cuda") for _ in range(args.n_warm_up_steps): - model(prompt="hello world", num_inference_steps=args.num_inference_steps, height=1024, width=1024) + model( + prompt="hello world", + num_inference_steps=args.num_inference_steps, + height=args.height[0], + width=args.width[0], + ) for h, w in zip(args.height, args.width): with profile_context(args) as prof: @@ -159,8 +164,8 @@ def benchmark(args): parser.add_argument("-b", "--batch_size", type=int, default=1, help="Batch size") parser.add_argument("-p", "--patched_parallel_size", type=int, default=1, help="Patched Parallelism size") parser.add_argument("-n", "--num_inference_steps", type=int, default=50, help="Number of inference steps") - parser.add_argument("-H", "--height", type=int, nargs="+", default=[1024], help="Height list") - parser.add_argument("-w", "--width", type=int, nargs="+", default=[1024], help="Width list") + parser.add_argument("-H", "--height", type=int, nargs="+", default=[1024, 2048], help="Height list") + parser.add_argument("-w", "--width", type=int, nargs="+", default=[1024, 2048], help="Width list") parser.add_argument("--dtype", type=str, default="fp16", choices=["fp16", "fp32", "bf16"], help="Data type") parser.add_argument("--n_warm_up_steps", type=int, default=3, help="Number of warm up steps") parser.add_argument("--n_repeat_times", type=int, default=5, help="Number of repeat times") diff --git a/examples/inference/stable_diffusion/run_benchmark.sh b/examples/inference/stable_diffusion/run_benchmark.sh index 63e346f35c82..f3e45a335219 100644 --- a/examples/inference/stable_diffusion/run_benchmark.sh +++ b/examples/inference/stable_diffusion/run_benchmark.sh @@ -1,8 +1,7 @@ #!/bin/bash models=("PixArt-alpha/PixArt-XL-2-1024-MS" "stabilityai/stable-diffusion-3-medium-diffusers") -# parallelism=(1 2 4 8) -parallelism=(1 2 4) +parallelism=(1 2 4 8) resolutions=(1024 2048 3840) modes=("colossalai" "diffusers") From b467d34c7e1121a8435845177c79429caeb9f95c Mon Sep 17 00:00:00 2001 From: Runyu Lu Date: Thu, 18 Jul 2024 09:42:57 +0000 Subject: [PATCH 09/10] add reference --- examples/inference/stable_diffusion/compute_metric.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/inference/stable_diffusion/compute_metric.py b/examples/inference/stable_diffusion/compute_metric.py index b1328d14b0a9..14c92501b66d 100644 --- a/examples/inference/stable_diffusion/compute_metric.py +++ b/examples/inference/stable_diffusion/compute_metric.py @@ -1,3 +1,4 @@ +# Code from https://github.com/mit-han-lab/distrifuser/blob/main/scripts/compute_metrics.py import argparse import os From 69067de5aabb85d472e98747275cbc172760a48f Mon Sep 17 00:00:00 2001 From: Runyu Lu Date: Thu, 18 Jul 2024 09:55:38 +0000 Subject: [PATCH 10/10] readme and requirement --- examples/inference/stable_diffusion/README.md | 20 +++++++++++++++++-- .../stable_diffusion/requirements.txt | 3 +++ 2 files changed, 21 insertions(+), 2 deletions(-) create mode 100644 examples/inference/stable_diffusion/requirements.txt diff --git a/examples/inference/stable_diffusion/README.md b/examples/inference/stable_diffusion/README.md index 0f1afc805bec..c11b9804392c 100644 --- a/examples/inference/stable_diffusion/README.md +++ b/examples/inference/stable_diffusion/README.md @@ -1,6 +1,22 @@ +## File Structure ``` |- sd3_generation.py: an example of how to use Colossalai Inference Engine to generate result by loading Diffusion Model. |- compute_metric.py: compare the quality of images w/o some acceleration method like Distrifusion -|- run_benchmark.sh: benchmark the performance of our InferenceEngine +|- benchmark_sd3.py: benchmark the performance of our InferenceEngine +|- run_benchmark.sh: run benchmark command +``` +Note: compute_metric.py need some dependencies which need `pip install -r requirements.txt`, `requirements.txt` is in `examples/inference/stable_diffusion/` + +## Run Inference + +The provided example `sd3_generation.py` is an example to configure, initialize the engine, and run inference on provided model. We've added `DiffusionPipeline` as model class, and the script is good to run inference with StableDiffusion 3. + +For a basic setting, you could run the example by: +```bash +colossalai run --nproc_per_node 1 sd3_generation.py -m PATH_MODEL -p "hello world" +``` + +Run multi-GPU inference (Patched Parallelism), as in the following example using 2 GPUs: +```bash +colossalai run --nproc_per_node 2 sd3_generation.py -m PATH_MODEL ``` -note: compute_metric.py need some dependencies which need `pip install`. diff --git a/examples/inference/stable_diffusion/requirements.txt b/examples/inference/stable_diffusion/requirements.txt new file mode 100644 index 000000000000..c4e74162dfb5 --- /dev/null +++ b/examples/inference/stable_diffusion/requirements.txt @@ -0,0 +1,3 @@ +torchvision +torchmetrics +cleanfid