Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feat] Distrifusion Acceleration Support for Diffusion Inference #5895

Merged
merged 10 commits into from
Jul 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion colossalai/inference/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@


## 📌 Introduction
ColossalAI-Inference is a module which offers acceleration to the inference execution of Transformers models, especially LLMs. In ColossalAI-Inference, we leverage high-performance kernels, KV cache, paged attention, continous batching and other techniques to accelerate the inference of LLMs. We also provide simple and unified APIs for the sake of user-friendliness. [[blog]](https://hpc-ai.com/blog/colossal-inference)
ColossalAI-Inference is a module which offers acceleration to the inference execution of Transformers models, especially LLMs and DiT Diffusion Models. In ColossalAI-Inference, we leverage high-performance kernels, KV cache, paged attention, continous batching and other techniques to accelerate the inference of LLMs. We also provide simple and unified APIs for the sake of user-friendliness. [[blog]](https://hpc-ai.com/blog/colossal-inference)

<p align="center">
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/inference/colossal-inference-v1-1.png" width=1000/>
Expand Down Expand Up @@ -310,4 +310,14 @@ If you wish to cite relevant research papars, you can find the reference below.
journal={arXiv},
year={2023}
}

# Distrifusion
@InProceedings{Li_2024_CVPR,
author={Li, Muyang and Cai, Tianle and Cao, Jiaxin and Zhang, Qinsheng and Cai, Han and Bai, Junjie and Jia, Yangqing and Li, Kai and Han, Song},
title={DistriFusion: Distributed Parallel Inference for High-Resolution Diffusion Models},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
month={June},
year={2024},
pages={7183-7193}
}
```
16 changes: 16 additions & 0 deletions colossalai/inference/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ class InferenceConfig(RPC_PARAM):
enable_streamingllm(bool): Whether to use StreamingLLM, the relevant algorithms refer to the paper at https://arxiv.org/pdf/2309.17453 for implementation.
start_token_size(int): The size of the start tokens, when using StreamingLLM.
generated_token_size(int): The size of the generated tokens, When using StreamingLLM.
patched_parallelism_size(int): Patched Parallelism Size, When using Distrifusion
"""

# NOTE: arrange configs according to their importance and frequency of usage
Expand Down Expand Up @@ -245,6 +246,11 @@ class InferenceConfig(RPC_PARAM):
start_token_size: int = 4
generated_token_size: int = 512

# Acceleration for Diffusion Model(PipeFusion or Distrifusion)
patched_parallelism_size: int = 1 # for distrifusion
# pipeFusion_m_size: int = 1 # for pipefusion
# pipeFusion_n_size: int = 1 # for pipefusion

def __post_init__(self):
self.max_context_len_to_capture = self.max_input_len + self.max_output_len
self._verify_config()
Expand Down Expand Up @@ -288,6 +294,14 @@ def _verify_config(self) -> None:
# Thereafter, we swap out tokens in units of blocks, and always swapping out the second block when the generated tokens exceeded the limit.
self.start_token_size = self.block_size

# check Distrifusion
# TODO(@lry89757) need more detailed check
if self.patched_parallelism_size > 1:
# self.use_patched_parallelism = True
self.tp_size = (
self.patched_parallelism_size
) # this is not a real tp, because some annoying check, so we have to set this to patched_parallelism_size

# check prompt template
if self.prompt_template is None:
return
Expand Down Expand Up @@ -324,6 +338,7 @@ def to_model_shard_inference_config(self) -> "ModelShardInferenceConfig":
use_cuda_kernel=self.use_cuda_kernel,
use_spec_dec=self.use_spec_dec,
use_flash_attn=use_flash_attn,
patched_parallelism_size=self.patched_parallelism_size,
)
return model_inference_config

Expand Down Expand Up @@ -396,6 +411,7 @@ class ModelShardInferenceConfig:
use_cuda_kernel: bool = False
use_spec_dec: bool = False
use_flash_attn: bool = False
patched_parallelism_size: int = 1 # for diffusion model, Distrifusion Technique
LRY89757 marked this conversation as resolved.
Show resolved Hide resolved


@dataclass
Expand Down
2 changes: 1 addition & 1 deletion colossalai/inference/core/diffusion_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from colossalai.accelerator import get_accelerator
from colossalai.cluster import ProcessGroupMesh
from colossalai.inference.config import DiffusionGenerationConfig, InferenceConfig, ModelShardInferenceConfig
from colossalai.inference.modeling.models.diffusion import DiffusionPipe
from colossalai.inference.modeling.layers.diffusion import DiffusionPipe
from colossalai.inference.modeling.policy import model_policy_map
from colossalai.inference.struct import DiffusionSequence
from colossalai.inference.utils import get_model_size, get_model_type
Expand Down
Loading
Loading