-
Notifications
You must be signed in to change notification settings - Fork 4.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Feat] Diffusion Model(PixArtAlpha/StableDiffusion3) Support (#5838)
* Diffusion Model Inference support * Stable Diffusion 3 Support * pixartalpha support
- Loading branch information
Showing
16 changed files
with
1,860 additions
and
740 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
from abc import ABC, abstractmethod | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
||
from colossalai.cluster import ProcessGroupMesh | ||
from colossalai.inference.config import ModelShardInferenceConfig | ||
from colossalai.pipeline.stage_manager import PipelineStageManager | ||
from colossalai.shardformer import ShardConfig, ShardFormer | ||
from colossalai.shardformer.policies.base_policy import Policy | ||
|
||
|
||
class BaseEngine(ABC): | ||
@abstractmethod | ||
def __init__(self, model_or_path, inference_config=None, verbose=False, model_policy=None): | ||
pass | ||
|
||
@abstractmethod | ||
def init_model(self, model_or_path, model_policy=None, model_shard_infer_config=None): | ||
""" | ||
Init Model for Engine | ||
""" | ||
|
||
@abstractmethod | ||
def generate(self, request_ids=None, prompts=None, generation_config=None, **kwargs): | ||
""" | ||
Generate ouptput for coming requests | ||
""" | ||
|
||
@abstractmethod | ||
def add_request(self, prompts, request_ids=None, **kwargs): | ||
""" | ||
Add new request to Engine | ||
""" | ||
|
||
@abstractmethod | ||
def step(self): | ||
""" | ||
Perform one new step forward | ||
""" | ||
|
||
@abstractmethod | ||
def _verify_args(self): | ||
""" | ||
Verify the parameters and members of class | ||
""" | ||
|
||
@torch.inference_mode() | ||
def capture_model(self): | ||
""" | ||
Use cuda graph to capture model | ||
""" | ||
return NotImplementedError("This method should be implemented by subclasses") | ||
|
||
def _shardformer( | ||
self, | ||
model: nn.Module, | ||
model_policy: Policy, | ||
model_shard_infer_config: ModelShardInferenceConfig = None, | ||
stage_manager: PipelineStageManager = None, | ||
tp_group: ProcessGroupMesh = None, | ||
**kwargs, | ||
) -> nn.Module: | ||
""" | ||
Initialize ShardConfig and replace the model with shardformer. | ||
Args: | ||
model (nn.Module): Path or nn.Module of this model. | ||
model_policy (Policy): The policy to shardformer model which is determined by the model type. | ||
stage_manager (PipelineStageManager, optional): Used to manage pipeline stages. Defaults to None. | ||
tp_group (ProcessGroupMesh, optional): Used to manage the process TP group mesh. Defaults to None. | ||
Returns: | ||
nn.Module: The model optimized by Shardformer. | ||
""" | ||
|
||
shardconfig = ShardConfig( | ||
tensor_parallel_process_group=tp_group, | ||
pipeline_stage_manager=stage_manager, | ||
enable_tensor_parallelism=(self.inference_config.tp_size > 1), | ||
enable_fused_normalization=False, | ||
enable_all_optimization=False, | ||
enable_flash_attention=False, | ||
enable_jit_fused=False, | ||
enable_sequence_parallelism=False, | ||
extra_kwargs={"model_shard_infer_config": model_shard_infer_config, **kwargs}, | ||
) | ||
shardformer = ShardFormer(shard_config=shardconfig) | ||
shard_model, _ = shardformer.optimize(model, model_policy) | ||
return shard_model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,200 @@ | ||
from itertools import count | ||
from typing import List, Tuple, Type, Union | ||
|
||
import numpy as np | ||
import PIL.Image | ||
import torch | ||
import torch.nn as nn | ||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline | ||
from torch import distributed as dist | ||
|
||
from colossalai.accelerator import get_accelerator | ||
from colossalai.cluster import ProcessGroupMesh | ||
from colossalai.inference.config import DiffusionGenerationConfig, InferenceConfig, ModelShardInferenceConfig | ||
from colossalai.inference.modeling.models.diffusion import DiffusionPipe | ||
from colossalai.inference.modeling.policy import model_policy_map | ||
from colossalai.inference.struct import DiffusionSequence | ||
from colossalai.inference.utils import get_model_size, get_model_type | ||
from colossalai.logging import get_dist_logger | ||
from colossalai.shardformer.policies.base_policy import Policy | ||
|
||
from .base_engine import BaseEngine | ||
from .request_handler import NaiveRequestHandler | ||
|
||
PP_AXIS, TP_AXIS = 0, 1 | ||
|
||
|
||
class DiffusionEngine(BaseEngine): | ||
def __init__( | ||
self, | ||
model_or_path: DiffusionPipeline | str, | ||
inference_config: InferenceConfig = None, | ||
verbose: bool = False, | ||
model_policy: Policy | type[Policy] = None, | ||
) -> None: | ||
self.inference_config = inference_config | ||
self.dtype = inference_config.dtype | ||
self.high_precision = inference_config.high_precision | ||
|
||
self.verbose = verbose | ||
self.logger = get_dist_logger(__name__) | ||
self.model_shard_infer_config = inference_config.to_model_shard_inference_config() | ||
|
||
self.model_type = get_model_type(model_or_path=model_or_path) | ||
|
||
self.init_model(model_or_path, model_policy, self.model_shard_infer_config) | ||
|
||
self.request_handler = NaiveRequestHandler() | ||
|
||
self.counter = count() | ||
|
||
self._verify_args() | ||
|
||
def _verify_args(self) -> None: | ||
assert isinstance(self.model, DiffusionPipe), "model must be DiffusionPipe" | ||
|
||
def init_model( | ||
self, | ||
model_or_path: Union[str, nn.Module, DiffusionPipeline], | ||
model_policy: Union[Policy, Type[Policy]] = None, | ||
model_shard_infer_config: ModelShardInferenceConfig = None, | ||
): | ||
""" | ||
Shard model or/and Load weight | ||
Args: | ||
model_or_path Union[nn.Module, str]: path to the checkpoint or model of transformer format. | ||
model_policy (Policy): the policy to replace the model. | ||
model_inference_config: the configuration for modeling initialization when inference. | ||
model_shard_infer_config (ModelShardInferenceConfig): the configuration for init of module when inference. | ||
""" | ||
if isinstance(model_or_path, str): | ||
model = DiffusionPipeline.from_pretrained(model_or_path, torch_dtype=self.dtype) | ||
policy_map_key = model.__class__.__name__ | ||
model = DiffusionPipe(model) | ||
elif isinstance(model_or_path, DiffusionPipeline): | ||
policy_map_key = model_or_path.__class__.__name__ | ||
model = DiffusionPipe(model_or_path) | ||
else: | ||
self.logger.error(f"model_or_path support only str or DiffusionPipeline currently!") | ||
|
||
torch.cuda.empty_cache() | ||
init_gpu_memory = torch.cuda.mem_get_info()[0] | ||
|
||
self.device = get_accelerator().get_current_device() | ||
if self.verbose: | ||
self.logger.info(f"the device is {self.device}") | ||
|
||
if self.verbose: | ||
self.logger.info( | ||
f"Before the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(model)} GB, model's device is: {model.device}" | ||
) | ||
|
||
if model_policy is None: | ||
model_policy = model_policy_map.get(policy_map_key) | ||
|
||
if not isinstance(model_policy, Policy): | ||
try: | ||
model_policy = model_policy() | ||
except Exception as e: | ||
raise ValueError(f"Unable to instantiate model policy: {e}") | ||
|
||
assert isinstance(model_policy, Policy), f"Invalid type of model policy: {type(model_policy)}" | ||
pg_mesh = ProcessGroupMesh(self.inference_config.pp_size, self.inference_config.tp_size) | ||
tp_group = pg_mesh.get_group_along_axis(TP_AXIS) | ||
|
||
self.model = self._shardformer( | ||
model, | ||
model_policy, | ||
model_shard_infer_config, | ||
None, | ||
tp_group=tp_group, | ||
) | ||
|
||
self.model = model.to(self.device) | ||
|
||
if self.verbose: | ||
self.logger.info( | ||
f"After the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(self.model)} GB, model's device is: {model.device}" | ||
) | ||
|
||
free_gpu_memory, _ = torch.cuda.mem_get_info() | ||
peak_memory = init_gpu_memory - free_gpu_memory | ||
if self.verbose: | ||
self.logger.info( | ||
f"Rank [{dist.get_rank()}], Model Weight Max Occupy {peak_memory / (1024 ** 3)} GB, Model size: {get_model_size(self.model)} GB" | ||
) | ||
|
||
def generate( | ||
self, | ||
request_ids: Union[List[int], int] = None, | ||
prompts: Union[List[str], str] = None, | ||
generation_config: DiffusionGenerationConfig = None, | ||
**kwargs, | ||
) -> Union[List[Union[str, List[PIL.Image.Image], np.ndarray]], Tuple[List[str], List[List[int]]]]: | ||
""" """ | ||
gen_config_dict = generation_config.to_dict() if generation_config is not None else {} | ||
prompts = [prompts] if isinstance(prompts, str) else prompts | ||
request_ids = [request_ids] if isinstance(request_ids, int) else request_ids | ||
|
||
with torch.inference_mode(): | ||
if prompts is not None: | ||
self.add_request( | ||
request_ids=request_ids, | ||
prompts=prompts, | ||
**gen_config_dict, | ||
**kwargs, | ||
) | ||
|
||
output_reqs_list = [] | ||
|
||
# intuition: If user provide a generation config, we should replace the existing one. | ||
if generation_config is not None: | ||
self.generation_config = generation_config | ||
self.generation_config_dict = gen_config_dict | ||
|
||
while self.request_handler.check_unfinished_reqs(): | ||
output_reqs_list += self.step() | ||
|
||
return output_reqs_list | ||
|
||
def add_request( | ||
self, | ||
prompts: Union[List[str], str], | ||
request_ids: Union[List[int], int] = None, | ||
**kwargs, | ||
): | ||
if request_ids is not None and not isinstance(request_ids, list): | ||
request_ids = [request_ids] | ||
|
||
if not isinstance(prompts, list): | ||
prompts = [prompts] | ||
|
||
generation_config = DiffusionGenerationConfig.from_kwargs(**kwargs) | ||
prompts_num = len(prompts) | ||
for i in range(prompts_num): | ||
if request_ids: | ||
assert isinstance( | ||
request_ids[0], int | ||
), f"The request_id type must be int, but got {type(request_ids[0])}" | ||
assert len(request_ids) == prompts_num | ||
request_id = request_ids[i] | ||
else: | ||
request_id = next(self.counter) | ||
|
||
seq = DiffusionSequence(request_id=request_id, prompt=prompts[i], generation_config=generation_config) | ||
|
||
self.request_handler.add_sequence(seq) | ||
|
||
def step(self) -> List[PIL.Image.Image]: | ||
""" | ||
In each step, do the follows: | ||
1. Run RequestHandler.schedule() and get the batch used for inference. | ||
2. run forward to get List[Image] | ||
Returns: | ||
List[PIL.Image.Image]: Image Generated by one step. | ||
""" | ||
|
||
input = self.request_handler.schedule() | ||
ret = self.model(prompt=input.prompt, **input.generation_config.to_dict()) | ||
return ret |
Oops, something went wrong.