Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

extend cuda graph size for H200 #7894

Merged
merged 6 commits into from
Aug 29, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 31 additions & 7 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,14 @@

LORA_WARMUP_RANK = 8
_BATCH_SIZE_ALIGNMENT = 8
# Capture graphs for token size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256.
# all the token sizes that **can** be captured by cudagraph.
# they can be arbitrarily large.
# currently it includes: 1, 2, 4, 8, 16, 24, 32, 40, ..., 8192.
# the actual sizes to capture will be determined by the model,
# depending on the model's max_num_seqs.
# NOTE: _get_graph_batch_size needs to be updated if this list is changed.
_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [
_BATCH_SIZE_ALIGNMENT * i for i in range(1, 33)
_BATCH_SIZE_ALIGNMENT * i for i in range(1, 1025)
]
_NUM_WARMUP_ITERS = 2

Expand Down Expand Up @@ -659,7 +663,7 @@ def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata):
def _use_captured_graph(self, batch_size: int,
max_decode_seq_len: int) -> bool:
return (self.decode_only and not self.runner.model_config.enforce_eager
and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1]
and batch_size <= self.runner.max_batchsize_to_capture
and max_decode_seq_len <= self.runner.max_seq_len_to_capture)

def build(self) -> ModelInputForGPU:
Expand Down Expand Up @@ -845,6 +849,8 @@ def __init__(
self.sliding_window = model_config.get_sliding_window()
self.block_size = cache_config.block_size
self.max_seq_len_to_capture = self.model_config.max_seq_len_to_capture
self.max_batchsize_to_capture = _get_max_graph_batch_size(
self.scheduler_config.max_num_seqs)

self.graph_runners: List[Dict[int, CUDAGraphRunner]] = [
{} for _ in range(self.parallel_config.pipeline_parallel_size)
Expand All @@ -862,7 +868,7 @@ def __init__(
# The shape of the cached block table will be
# (max batch size to capture, max context len to capture / block size).
self.graph_block_tables = np.zeros(
(max(_BATCH_SIZES_TO_CAPTURE), self.get_max_block_per_batch()),
(self.max_batchsize_to_capture, self.get_max_block_per_batch()),
dtype=np.int32)
num_attn_heads = self.model_config.get_num_attention_heads(
self.parallel_config)
Expand Down Expand Up @@ -1217,7 +1223,7 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None:
start_time = time.perf_counter()

# Prepare dummy inputs. These will be reused for all batch sizes.
max_batch_size = max(_BATCH_SIZES_TO_CAPTURE)
max_batch_size = self.max_batchsize_to_capture
input_tokens = torch.zeros(max_batch_size, dtype=torch.long).cuda()
input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda()

Expand Down Expand Up @@ -1245,8 +1251,7 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None:
None
] * self.parallel_config.pipeline_parallel_size

graph_batch_size = _get_graph_batch_size(
self.scheduler_config.max_num_seqs)
graph_batch_size = self.max_batchsize_to_capture
batch_size_capture_list = [
bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size
]
Expand Down Expand Up @@ -1672,3 +1677,22 @@ def _get_graph_batch_size(batch_size: int) -> int:
else:
return ((batch_size + _BATCH_SIZE_ALIGNMENT - 1) //
_BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT)


def _get_max_graph_batch_size(max_num_seqs: int) -> int:
"""
max_num_seqs: Maximum number of sequences in a batch.
_BATCH_SIZES_TO_CAPTURE: all the sizes that we want to capture.

pad the max_num_seqs if necessary by calling _get_graph_batch_size,
which will deal with some edge cases like 1, 2, 4.

if the padded size is in _BATCH_SIZES_TO_CAPTURE, return the padded size.
if not, it means the padded size is larger than the largest size in
_BATCH_SIZES_TO_CAPTURE, return the largest size in _BATCH_SIZES_TO_CAPTURE.
"""
padded_size = _get_graph_batch_size(max_num_seqs)
if padded_size in _BATCH_SIZES_TO_CAPTURE:
return padded_size
assert padded_size > _BATCH_SIZES_TO_CAPTURE[-1]
return _BATCH_SIZES_TO_CAPTURE[-1]
Loading