Skip to content

Commit

Permalink
extend cuda graph size for H200 (#7894)
Browse files Browse the repository at this point in the history
Co-authored-by: youkaichao <youkaichao@126.com>
  • Loading branch information
kushanam and youkaichao authored Aug 29, 2024
1 parent 6b34215 commit c334b18
Showing 1 changed file with 31 additions and 7 deletions.
38 changes: 31 additions & 7 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,14 @@

LORA_WARMUP_RANK = 8
_BATCH_SIZE_ALIGNMENT = 8
# Capture graphs for token size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256.
# all the token sizes that **can** be captured by cudagraph.
# they can be arbitrarily large.
# currently it includes: 1, 2, 4, 8, 16, 24, 32, 40, ..., 8192.
# the actual sizes to capture will be determined by the model,
# depending on the model's max_num_seqs.
# NOTE: _get_graph_batch_size needs to be updated if this list is changed.
_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [
_BATCH_SIZE_ALIGNMENT * i for i in range(1, 33)
_BATCH_SIZE_ALIGNMENT * i for i in range(1, 1025)
]
_NUM_WARMUP_ITERS = 2

Expand Down Expand Up @@ -660,7 +664,7 @@ def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata):
def _use_captured_graph(self, batch_size: int,
max_decode_seq_len: int) -> bool:
return (self.decode_only and not self.runner.model_config.enforce_eager
and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1]
and batch_size <= self.runner.max_batchsize_to_capture
and max_decode_seq_len <= self.runner.max_seq_len_to_capture)

def build(self) -> ModelInputForGPU:
Expand Down Expand Up @@ -846,6 +850,8 @@ def __init__(
self.sliding_window = model_config.get_sliding_window()
self.block_size = cache_config.block_size
self.max_seq_len_to_capture = self.model_config.max_seq_len_to_capture
self.max_batchsize_to_capture = _get_max_graph_batch_size(
self.scheduler_config.max_num_seqs)

self.graph_runners: List[Dict[int, CUDAGraphRunner]] = [
{} for _ in range(self.parallel_config.pipeline_parallel_size)
Expand All @@ -863,7 +869,7 @@ def __init__(
# The shape of the cached block table will be
# (max batch size to capture, max context len to capture / block size).
self.graph_block_tables = np.zeros(
(max(_BATCH_SIZES_TO_CAPTURE), self.get_max_block_per_batch()),
(self.max_batchsize_to_capture, self.get_max_block_per_batch()),
dtype=np.int32)
num_attn_heads = self.model_config.get_num_attention_heads(
self.parallel_config)
Expand Down Expand Up @@ -1218,7 +1224,7 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None:
start_time = time.perf_counter()

# Prepare dummy inputs. These will be reused for all batch sizes.
max_batch_size = max(_BATCH_SIZES_TO_CAPTURE)
max_batch_size = self.max_batchsize_to_capture
input_tokens = torch.zeros(max_batch_size, dtype=torch.long).cuda()
input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda()

Expand Down Expand Up @@ -1246,8 +1252,7 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None:
None
] * self.parallel_config.pipeline_parallel_size

graph_batch_size = _get_graph_batch_size(
self.scheduler_config.max_num_seqs)
graph_batch_size = self.max_batchsize_to_capture
batch_size_capture_list = [
bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size
]
Expand Down Expand Up @@ -1673,3 +1678,22 @@ def _get_graph_batch_size(batch_size: int) -> int:
else:
return ((batch_size + _BATCH_SIZE_ALIGNMENT - 1) //
_BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT)


def _get_max_graph_batch_size(max_num_seqs: int) -> int:
"""
max_num_seqs: Maximum number of sequences in a batch.
_BATCH_SIZES_TO_CAPTURE: all the sizes that we want to capture.
pad the max_num_seqs if necessary by calling _get_graph_batch_size,
which will deal with some edge cases like 1, 2, 4.
if the padded size is in _BATCH_SIZES_TO_CAPTURE, return the padded size.
if not, it means the padded size is larger than the largest size in
_BATCH_SIZES_TO_CAPTURE, return the largest size in _BATCH_SIZES_TO_CAPTURE.
"""
padded_size = _get_graph_batch_size(max_num_seqs)
if padded_size in _BATCH_SIZES_TO_CAPTURE:
return padded_size
assert padded_size > _BATCH_SIZES_TO_CAPTURE[-1]
return _BATCH_SIZES_TO_CAPTURE[-1]

0 comments on commit c334b18

Please sign in to comment.