Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] cuda graph support and refactor non-functional api #5434

Merged
33 changes: 31 additions & 2 deletions colossalai/inference/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

logger = logging.Logger(__name__)


_DTYPE_MAPPING = {
"fp16": torch.float16,
"bf16": torch.bfloat16,
Expand All @@ -23,13 +22,37 @@

_ALLOWED_DTYPES = [torch.float16, torch.bfloat16, torch.float32]


_DEFAULT_PROMPT_TEMPLATES = {
"llama": "[INST] <<SYS>>\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n<</SYS>>\n{input_text}[/INST]",
"vicuna": "USER: {input_text}\n\nASSISTANT: ",
}


@dataclass
class InputMetaData:
"""The input info for a single step

Args:
block_tables (torch.Tensor, optional): Sequences' BlockTables Defaults to None.
sequence_lengths (torch.Tensor): A tensor containing sequence lengths.
fd_inter_tensor (torch.Tensor, optional): A tensor representing intermediate data for flash decoding. Defaults to None.
batch_size (int, optional): The current batch size. Defaults to 64.
is_prompts (bool, optional): Indicates whether prefill or decoding. Defaults to False(decoding).
use_cuda_graph (bool, optional): Indicates whether to use the CUDA graph. Defaults to False.
kv_seq_len (int, optional): Key-value sequence length. Defaults to 512.
head_dim (int, optional): Head dimension. Defaults to 32.
"""

block_tables: torch.Tensor = None
sequence_lengths: torch.Tensor = None
fd_inter_tensor: torch.Tensor = None
batch_size: int = 64 # current_batch_size
is_prompts: bool = False
use_cuda_graph: bool = False
kv_seq_len: int = 512
head_dim: int = 32


@dataclass
class InferenceConfig:
"""The inference configuration.
Expand All @@ -55,6 +78,8 @@ class InferenceConfig:
pp_size (int): Pipeline parallel size, defaults to 1.
micro_batch_size (int): the micro batch size, defaults to 1. Only useful when `pp_size` > 1.
micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages.
use_cuda_graph (bool): Whether to enforce CUDA graph execution. If False, we will disable CUDA graph and always execute the model in eager mode. If True, we will use eager execution in hybrid.
max_context_len_to_capture (int)

"""

Expand Down Expand Up @@ -90,6 +115,10 @@ class InferenceConfig:
micro_batch_size: int = 1
micro_batch_buffer_size: int = None

# cuda_graph
use_cuda_graph: bool = False
max_context_len_to_capture: int = max_input_len * max_output_len

def __post_init__(self):
self._verify_config()

Expand Down
141 changes: 130 additions & 11 deletions colossalai/inference/core/engine.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
import copy
import time
from itertools import count
from typing import List, Optional, Union
from typing import Dict, List, Optional, Tuple, Union

import numpy as np
import torch
import torch.nn as nn
from transformers import GenerationConfig, PreTrainedTokenizer, PreTrainedTokenizerFast

from colossalai.cluster import ProcessGroupMesh
from colossalai.inference.config import InferenceConfig
from colossalai.inference.batch_bucket import BatchBucket
from colossalai.inference.config import InferenceConfig, InputMetaData
from colossalai.inference.graph_runner import CUDAGraphRunner
from colossalai.inference.modeling.policy import model_policy_map
from colossalai.inference.struct import Sequence
from colossalai.logging import get_dist_logger
Expand Down Expand Up @@ -81,11 +85,89 @@ def __init__(
self.logger = get_dist_logger(__name__)

self.request_handler = RequestHandler(self.inference_config, self.model_config)
self.k_cahce, self.v_cache = self.request_handler.get_kvcache()
self.k_cache, self.v_cache = self.request_handler.get_kvcache()
# DISCUSS maybe move this into batch info?

self.counter = count()

self.use_cuda_graph = self.inference_config.use_cuda_graph
if self.use_cuda_graph:
self.graph_runners: Dict[int, CUDAGraphRunner] = {}
self.graph_memory_pool = None # Set during graph capture.
if verbose:
self.logger.info("Colossal AI CUDA Graph Capture on")

self.capture_model(self.k_cache, self.v_cache)

@torch.inference_mode()
def capture_model(self, k_cache: torch.Tensor, v_cache: torch.Tensor):
assert self.use_cuda_graph, "please turn on the cuda graph"

if self.verbose:
self.logger.info("Colossal AI CUDA Graph Capture begin")

t_capture_begin = time.perf_counter()

_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [8 * i for i in range(1, 33)]

block_size = self.inference_config.block_size

# Prepare dummy inputs. These will be reused for all batch sizes.
max_batch_size = max(_BATCH_SIZES_TO_CAPTURE)

max_context_len_to_capture = self.inference_config.max_context_len_to_capture
max_num_blocks = (max_context_len_to_capture + block_size - 1) // block_size
input_tokens = torch.zeros(max_batch_size, 1, dtype=torch.long).cuda()
self.graph_block_tables = np.zeros((max(_BATCH_SIZES_TO_CAPTURE), max_num_blocks), dtype=np.int32)
block_tables = torch.from_numpy(self.graph_block_tables).cuda()
max_num_seqs = self.inference_config.max_batch_size
batch_size_capture_list = [bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= max_num_seqs]

# NOTE: Capturing the largest batch size first may help reduce the
# memory usage of CUDA graph.
for batch_size in reversed(batch_size_capture_list[-1:]):
batch_bucket_for_capture = copy.deepcopy(self.request_handler.running_bb)
batch_bucket_for_capture.fd_interm_tensor = self.request_handler.running_bb.fd_interm_tensor

if self.verbose:
self.logger.info(f"batch size {batch_size} graph capturing")

# generate dummy input
for i in range(batch_size):
sequence = Sequence(
i,
None,
input_tokens[i],
block_size,
None,
self.tokenizer.eos_token_id,
self.tokenizer.pad_token_id,
self.inference_config.max_output_len,
)
sequence.output_token_id = [0] # only capture the graph of decoding
batch_bucket_for_capture.add_seq(sequence, alloc_block_table=block_tables[i])

input_data = self.prepare_input(batch_bucket_for_capture)

input_tokens_ids, output_tensor, inputmetadata = input_data

graph_runner = CUDAGraphRunner(self.model)
graph_runner.capture(
input_tokens_ids,
output_tensor,
inputmetadata,
k_caches=k_cache,
v_caches=v_cache,
memory_pool=self.graph_memory_pool,
)
self.graph_memory_pool = graph_runner.graph.pool()
self.graph_runners[batch_size] = graph_runner

t_capture_end = time.perf_counter()

if self.verbose:
self.logger.info(f"CUDA Graph capture time: {t_capture_end - t_capture_begin} s")

def _verify_config(self) -> None:
"""
Verify the input config
Expand Down Expand Up @@ -278,26 +360,63 @@ def add_request(
)
self.request_handler.add_sequence(sequence)

def prepare_input(self, batch: BatchBucket) -> Tuple[torch.Tensor, torch.Tensor, InputMetaData]:
input_ids = batch.get_1D_inputs()

sequence_lengths = batch.get_sequence_lengths()
if batch.is_prompts:
output_tensor = torch.zeros(
(sequence_lengths.sum().item(), batch.num_heads * batch.head_dim),
dtype=batch.dtype,
device=batch.device,
)
else:
output_tensor = torch.zeros(
(batch.current_batch_size, batch.num_heads * batch.head_dim), dtype=batch.dtype, device=batch.device
)

# only when we have the graph for specific decoding batch size can we use the cuda graph for inference
use_cuda_graph = False
if self.use_cuda_graph and not batch.is_prompts and batch.current_batch_size in self.graph_runners.keys():
use_cuda_graph = True

input_meta_data = InputMetaData(
block_tables=batch.get_block_table_tensor(),
sequence_lengths=sequence_lengths,
fd_inter_tensor=batch.fd_inter_tensor,
batch_size=batch.current_batch_size,
is_prompts=batch.is_prompts,
use_cuda_graph=use_cuda_graph,
kv_seq_len=sequence_lengths.max().item(),
head_dim=batch.head_dim,
)

return input_ids, output_tensor, input_meta_data

def step(self) -> List[str]:
"""
In each step, do the follows:
1. Run RequestHandler.schedule() and get the batch used for inference.
2. Run model to generate the next token
3. Update waiting list and running list in RequestHandler and get finished sequences.
4. Decode and return finished sequences.
2. Get the input, inputinfo and output placeholder from the batchbucket
3. Run model to generate the next token
4. Update waiting list and running list in RequestHandler and get finished sequences.
5. Decode and return finished sequences.

Returns:
List[str]: Decoded finished sequences generated by one step.
"""

batch = self.request_handler.schedule()

input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch)

if input_meta_data.use_cuda_graph:
model_executable = self.graph_runners[input_meta_data.batch_size]
else:
model_executable = self.model

# TODO: padding_id is used for generating attn_mask and will be removed if nopad version is supported.
logits = self.model(
batch,
self.k_cahce,
self.v_cache,
)
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)

if self.inference_config.pad_input:
logits = logits[:, -1, :]
Expand Down
92 changes: 92 additions & 0 deletions colossalai/inference/graph_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
from typing import Dict, List

import torch
from torch import nn

from colossalai.inference.config import InputMetaData
from colossalai.logging import get_dist_logger


class CUDAGraphRunner:
def __init__(self, model: nn.Module):
self.model = model
self.graph = None
self.input_buffers: Dict[str, torch.Tensor] = {}
self.output_buffers: Dict[str, torch.Tensor] = {}
self.logger = get_dist_logger(__name__)

def capture(
self,
input_tokens_ids: torch.Tensor,
output_tensor: torch.Tensor,
inputmetadata: InputMetaData,
k_caches: List[torch.Tensor] = None,
v_caches: List[torch.Tensor] = None,
memory_pool=None,
) -> None:
assert self.graph is None

# run kernel once to cache the kernel, avoid stream capture error
hidden_states = self.model(
# batch,
input_tokens_ids,
output_tensor,
inputmetadata,
k_caches,
v_caches,
)
torch.cuda.synchronize()

# Capture the graph.
# self.logger.info(f"begin capture model...")
self.graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self.graph, pool=memory_pool):
hidden_states = self.model(
# batch,
input_tokens_ids,
output_tensor,
inputmetadata,
k_caches,
v_caches,
)
torch.cuda.synchronize()

# Save the input and output buffers, because replay always uses the same virtual memory space
self.input_buffers = {
# "batch": batch,
"input_tokens_ids": input_tokens_ids,
"output_tensor": output_tensor,
"block_tables": inputmetadata.block_tables,
"sequence_lengths": inputmetadata.sequence_lengths,
"k_caches": k_caches,
"v_caches": v_caches,
}
self.output_buffers = {"logits": hidden_states}
return

def forward(
self,
input_tokens_ids: torch.Tensor,
output_tensor: torch.Tensor,
inputmetadata: InputMetaData,
k_caches: List[torch.Tensor] = None,
v_caches: List[torch.Tensor] = None,
) -> torch.Tensor:
# Copy the input tensors to the input buffers.
self.input_buffers["input_tokens_ids"].copy_(input_tokens_ids, non_blocking=True)
self.input_buffers["output_tensor"].copy_(output_tensor, non_blocking=True)
self.input_buffers["block_tables"].copy_(inputmetadata.block_tables, non_blocking=True)
self.input_buffers["sequence_lengths"].copy_(inputmetadata.sequence_lengths, non_blocking=True)

# KV caches are fixed tensors, so we don't need to copy them.
# self.input_buffers["k_caches"].copy_(k_caches, non_blocking=True)
# self.input_buffers["v_caches"].copy_(v_caches, non_blocking=True)

# Run the graph.
self.graph.replay()

# Return the output tensor.
return self.output_buffers["logits"]

def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)
Loading