Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feat] Refactor Engine and Diffusion Model(PixArtAlpha/StableDiffusion3) Support #5838

Merged
merged 9 commits into from
Jul 8, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 47 additions & 1 deletion colossalai/inference/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass, fields
from typing import Any, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Union

import torch
from transformers.generation import GenerationConfig
Expand Down Expand Up @@ -393,3 +393,49 @@ class ModelShardInferenceConfig:
use_cuda_kernel: bool = False
use_spec_dec: bool = False
use_flash_attn: bool = False


@dataclass
class DiffusionGenerationConfig:
"""
Param for diffusion model forward
"""

prompt_2: Optional[Union[str, List[str]]] = None
prompt_3: Optional[Union[str, List[str]]] = None
height: Optional[int] = None
width: Optional[int] = None
num_inference_steps: int = None
timesteps: List[int] = None
guidance_scale: float = None
negative_prompt: Optional[Union[str, List[str]]] = (
None # NOTE(@lry89757) in pixart default to "", in sd3 default to None
)
negative_prompt_2: Optional[Union[str, List[str]]] = None
negative_prompt_3: Optional[Union[str, List[str]]] = None
num_images_per_prompt: Optional[int] = None
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None
latents: Optional[torch.FloatTensor] = None
prompt_embeds: Optional[torch.FloatTensor] = None
negative_prompt_embeds: Optional[torch.FloatTensor] = None
pooled_prompt_embeds: Optional[torch.FloatTensor] = None
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None
output_type: Optional[str] = None # "pil"
return_dict: bool = None
joint_attention_kwargs: Optional[Dict[str, Any]] = None
clip_skip: Optional[int] = None
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None
callback_on_step_end_tensor_inputs: List[str] = None

def to_dict(self) -> Dict[str, Any]:
# NOTE(@lry89757) Only return the dict that not the default value None
result = {}
for field in fields(self):
value = getattr(self, field.name)
if value is not None:
result[field.name] = value
return result

@classmethod
def from_kwargs(cls, **kwargs) -> "DiffusionGenerationConfig":
return cls(**kwargs)
86 changes: 86 additions & 0 deletions colossalai/inference/core/base_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from abc import ABC, abstractmethod

import torch
import torch.nn as nn

from colossalai.cluster import ProcessGroupMesh
from colossalai.inference.config import ModelShardInferenceConfig
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.shardformer.policies.base_policy import Policy


class BaseEngine(ABC):
@abstractmethod
def __init__(self, model_or_path, inference_config=None, verbose=False, model_policy=None):
pass

@abstractmethod
def init_model(self, model_or_path, model_policy=None, model_shard_infer_config=None):
"""
Init Model for Engine
"""

@abstractmethod
def generate(self, request_ids=None, prompts=None, generation_config=None, **kwargs):
pass

@abstractmethod
def add_request(self, prompts, request_ids=None, **kwargs):
"""
add new request to Engine
"""

@abstractmethod
def step(self):
"""
perform one new step forward
"""

@abstractmethod
def _verify_args(self):
"""
verify the parameters and members of class
"""

@torch.inference_mode()
def capture_model(self):
"use cuda graph to capture model"
return NotImplementedError("This method should be implemented by subclasses")

def _shardformer(
self,
model: nn.Module,
model_policy: Policy,
model_shard_infer_config: ModelShardInferenceConfig = None,
stage_manager: PipelineStageManager = None,
tp_group: ProcessGroupMesh = None,
**kwargs,
) -> nn.Module:
"""
Initialize ShardConfig and replace the model with shardformer.

Args:
model (nn.Module): Path or nn.Module of this model.
model_policy (Policy): The policy to shardformer model which is determined by the model type.
stage_manager (PipelineStageManager, optional): Used to manage pipeline stages. Defaults to None.
tp_group (ProcessGroupMesh, optional): Used to manage the process TP group mesh. Defaults to None.

Returns:
nn.Module: The model optimized by Shardformer.
"""

shardconfig = ShardConfig(
tensor_parallel_process_group=tp_group,
pipeline_stage_manager=stage_manager,
enable_tensor_parallelism=(self.inference_config.tp_size > 1),
enable_fused_normalization=False,
enable_all_optimization=False,
enable_flash_attention=False,
enable_jit_fused=False,
enable_sequence_parallelism=False,
extra_kwargs={"model_shard_infer_config": model_shard_infer_config, **kwargs},
)
shardformer = ShardFormer(shard_config=shardconfig)
shard_model, _ = shardformer.optimize(model, model_policy)
return shard_model
200 changes: 200 additions & 0 deletions colossalai/inference/core/diffusion_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
from itertools import count
from typing import List, Tuple, Type, Union

import numpy as np
import PIL.Image
import torch
import torch.nn as nn
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from torch import distributed as dist

from colossalai.accelerator import get_accelerator
from colossalai.cluster import ProcessGroupMesh
from colossalai.inference.config import DiffusionGenerationConfig, InferenceConfig, ModelShardInferenceConfig
from colossalai.inference.modeling.models.diffusion import DiffusionPipe
from colossalai.inference.modeling.policy import model_policy_map
from colossalai.inference.struct import DiffusionSequence
from colossalai.inference.utils import get_model_size, get_model_type
from colossalai.logging import get_dist_logger
from colossalai.shardformer.policies.base_policy import Policy

from .base_engine import BaseEngine
from .request_handler import NaiveRequestHandler

PP_AXIS, TP_AXIS = 0, 1


class DiffusionEngine(BaseEngine):
def __init__(
self,
model_or_path: DiffusionPipeline | str,
inference_config: InferenceConfig = None,
verbose: bool = False,
model_policy: Policy | type[Policy] = None,
) -> None:
self.inference_config = inference_config
self.dtype = inference_config.dtype
self.high_precision = inference_config.high_precision

self.verbose = verbose
self.logger = get_dist_logger(__name__)
self.model_shard_infer_config = inference_config.to_model_shard_inference_config()

self.model_type = get_model_type(model_or_path=model_or_path)

self.init_model(model_or_path, model_policy, self.model_shard_infer_config)

self.request_handler = NaiveRequestHandler()

self.counter = count()

self._verify_args()

def _verify_args(self) -> None:
assert isinstance(self.model, DiffusionPipe), "model must be DiffusionPipe"

def init_model(
self,
model_or_path: Union[str, nn.Module, DiffusionPipeline],
model_policy: Union[Policy, Type[Policy]] = None,
model_shard_infer_config: ModelShardInferenceConfig = None,
):
"""
Shard model or/and Load weight

Args:
model_or_path Union[nn.Module, str]: path to the checkpoint or model of transformer format.
model_policy (Policy): the policy to replace the model.
model_inference_config: the configuration for modeling initialization when inference.
model_shard_infer_config (ModelShardInferenceConfig): the configuration for init of module when inference.
"""
if isinstance(model_or_path, str):
model = DiffusionPipeline.from_pretrained(model_or_path, torch_dtype=self.dtype)
policy_map_key = model.__class__.__name__
model = DiffusionPipe(model)
elif isinstance(model_or_path, DiffusionPipeline):
policy_map_key = model_or_path.__class__.__name__
model = DiffusionPipe(model_or_path)
else:
self.logger.error(f"model_or_path support only str or DiffusionPipeline currently!")

torch.cuda.empty_cache()
init_gpu_memory = torch.cuda.mem_get_info()[0]

self.device = get_accelerator().get_current_device()
if self.verbose:
self.logger.info(f"the device is {self.device}")

if self.verbose:
self.logger.info(
f"Before the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(model)} GB, model's device is: {model.device}"
)

if model_policy is None:
model_policy = model_policy_map.get(policy_map_key)

if not isinstance(model_policy, Policy):
try:
model_policy = model_policy()
except Exception as e:
raise ValueError(f"Unable to instantiate model policy: {e}")

assert isinstance(model_policy, Policy), f"Invalid type of model policy: {type(model_policy)}"
pg_mesh = ProcessGroupMesh(self.inference_config.pp_size, self.inference_config.tp_size)
tp_group = pg_mesh.get_group_along_axis(TP_AXIS)

self.model = self._shardformer(
model,
model_policy,
model_shard_infer_config,
None,
tp_group=tp_group,
)

self.model = model.to(self.device)

if self.verbose:
self.logger.info(
f"After the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(self.model)} GB, model's device is: {model.device}"
)

free_gpu_memory, _ = torch.cuda.mem_get_info()
peak_memory = init_gpu_memory - free_gpu_memory
if self.verbose:
self.logger.info(
f"Rank [{dist.get_rank()}], Model Weight Max Occupy {peak_memory / (1024 ** 3)} GB, Model size: {get_model_size(self.model)} GB"
)

def generate(
self,
request_ids: Union[List[int], int] = None,
prompts: Union[List[str], str] = None,
generation_config: DiffusionGenerationConfig = None,
**kwargs,
) -> Union[List[Union[str, List[PIL.Image.Image], np.ndarray]], Tuple[List[str], List[List[int]]]]:
""" """
gen_config_dict = generation_config.to_dict() if generation_config is not None else {}
prompts = [prompts] if isinstance(prompts, str) else prompts
request_ids = [request_ids] if isinstance(request_ids, int) else request_ids

with torch.inference_mode():
if prompts is not None:
self.add_request(
request_ids=request_ids,
prompts=prompts,
**gen_config_dict,
**kwargs,
)

output_reqs_list = []

# intuition: If user provide a generation config, we should replace the existing one.
if generation_config is not None:
self.generation_config = generation_config
self.generation_config_dict = gen_config_dict

while self.request_handler.check_unfinished_reqs():
output_reqs_list += self.step()

return output_reqs_list

def add_request(
self,
prompts: Union[List[str], str],
request_ids: Union[List[int], int] = None,
**kwargs,
):
if request_ids is not None and not isinstance(request_ids, list):
request_ids = [request_ids]

if not isinstance(prompts, list):
prompts = [prompts]

generation_config = DiffusionGenerationConfig.from_kwargs(**kwargs)
prompts_num = len(prompts)
for i in range(prompts_num):
if request_ids:
assert isinstance(
request_ids[0], int
), f"The request_id type must be int, but got {type(request_ids[0])}"
assert len(request_ids) == prompts_num
request_id = request_ids[i]
else:
request_id = next(self.counter)

seq = DiffusionSequence(request_id=request_id, prompt=prompts[i], generation_config=generation_config)

self.request_handler.add_sequence(seq)

def step(self) -> List[PIL.Image.Image]:
"""
In each step, do the follows:
1. Run RequestHandler.schedule() and get the batch used for inference.
2. run forward to get List[Image]
Returns:
List[PIL.Image.Image]: Image Generated by one step.
"""

input = self.request_handler.schedule()
ret = self.model(prompt=input.prompt, **input.generation_config.to_dict())
return ret
Loading
Loading