Skip to content

Commit

Permalink
[torch.compile] use empty tensor instead of None for profiling (#8875)
Browse files Browse the repository at this point in the history
  • Loading branch information
youkaichao authored Sep 27, 2024
1 parent 8df2dc3 commit a9b15c6
Show file tree
Hide file tree
Showing 15 changed files with 84 additions and 32 deletions.
8 changes: 6 additions & 2 deletions tests/kernels/test_encoder_decoder_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,9 @@ class that Attention will automatically select when it is constructed.
)
if test_pt.num_blocks is None or test_pt.num_heads is None:
# Caller does not require a KV cache
return TestResources(scale, attn_backend, attn, None)
return TestResources(
scale, attn_backend, attn,
torch.tensor([], dtype=torch.float32, device=CUDA_DEVICE))

# Construct KV cache
kv_cache = make_kv_cache(test_pt.num_blocks,
Expand Down Expand Up @@ -620,7 +622,9 @@ def _run_encoder_attention_test(
return attn.forward(packed_qkv.query,
packed_qkv.key,
packed_qkv.value,
None,
torch.tensor([],
dtype=torch.float32,
device=packed_qkv.query.device),
attn_metadata,
attn_type=attn_type)

Expand Down
6 changes: 4 additions & 2 deletions vllm/attention/backends/blocksparse_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,8 @@ def forward(
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
NOTE: kv_cache will be an empty tensor with shape [0]
for profiling run.
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
Expand All @@ -373,7 +375,7 @@ def forward(
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)

if kv_cache is not None:
if kv_cache.numel() > 0:
key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size)

Expand All @@ -399,7 +401,7 @@ def forward(
# When block_tables are not filled, it means q and k are the
# prompt, and they have the same length.

assert kv_cache is None \
assert kv_cache.numel() == 0 \
or prefill_meta.block_tables is None \
or prefill_meta.block_tables.numel() == 0, \
"Does not support prefix-enabled attention."
Expand Down
6 changes: 4 additions & 2 deletions vllm/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -665,6 +665,8 @@ def forward(
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size]
NOTE: kv_cache will be an empty tensor with shape [0]
for profiling run.
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
Expand All @@ -685,7 +687,7 @@ def forward(
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)

if kv_cache is not None:
if kv_cache.numel() > 0:
key_cache = kv_cache[0]
value_cache = kv_cache[1]

Expand Down Expand Up @@ -722,7 +724,7 @@ def forward(

if prefill_meta := attn_metadata.prefill_metadata:
# Prompt run.
if (kv_cache is None or prefill_meta.block_tables is None
if (kv_cache.numel() == 0 or prefill_meta.block_tables is None
or prefill_meta.block_tables.numel() == 0):
# normal attention
# When block_tables are not filled, it means q and k are the
Expand Down
6 changes: 3 additions & 3 deletions vllm/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,7 +746,7 @@ def forward(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: Optional[torch.Tensor],
kv_cache: torch.Tensor,
attn_metadata: FlashInferMetadata,
k_scale: float = 1.0,
v_scale: float = 1.0,
Expand All @@ -770,7 +770,7 @@ def forward(
if attn_metadata.num_decode_tokens > 0:
assert attn_metadata.num_prefill_tokens == 0, (
"Chunked prefill is not supported with flashinfer yet.")
if kv_cache is not None:
if kv_cache.numel() > 0:
# Use the same reshape and cache kernel as flash attention.
ops.reshape_and_cache_flash(
key,
Expand All @@ -796,7 +796,7 @@ def forward(
# when kv_cache is not provided.
# This happens when vllm runs the profiling to
# determine the number of blocks.
if kv_cache is None:
if kv_cache.numel() == 0:
output = torch.ops.vllm.flash_attn_varlen_func(
q=query,
k=key,
Expand Down
9 changes: 6 additions & 3 deletions vllm/attention/backends/ipex_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def forward(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: Optional[torch.Tensor],
kv_cache: torch.Tensor,
attn_metadata: IpexAttnMetadata, # type: ignore
k_scale: float = 1.0,
v_scale: float = 1.0,
Expand All @@ -180,6 +180,8 @@ def forward(
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
NOTE: kv_cache will be an empty tensor with shape [0]
for profiling run.
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
Expand All @@ -196,7 +198,7 @@ def forward(
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)

if kv_cache is not None:
if kv_cache.numel() > 0:
key_cache, value_cache = self.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size)
ipex_ops.reshape_and_cache(
Expand All @@ -212,7 +214,8 @@ def forward(

if attn_metadata.is_prompt:
assert attn_metadata.seq_lens is not None
if (kv_cache is None or attn_metadata.block_tables.numel() == 0):
if (kv_cache.numel() == 0
or attn_metadata.block_tables.numel() == 0):
if self.num_kv_heads != self.num_heads:
key = key.repeat_interleave(self.num_queries_per_kv, dim=1)
value = value.repeat_interleave(self.num_queries_per_kv,
Expand Down
12 changes: 7 additions & 5 deletions vllm/attention/backends/pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def forward(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: Tuple[Optional[torch.Tensor], Optional[torch.Tensor]],
kv_cache: Tuple[torch.Tensor, torch.Tensor],
attn_metadata: PallasMetadata,
k_scale: float = 1.0,
v_scale: float = 1.0,
Expand All @@ -155,8 +155,10 @@ def forward(
query: shape = [batch_size, seq_len, num_heads * head_size]
key: shape = [batch_size, seq_len, num_kv_heads * head_size]
value: shape = [batch_size, seq_len, num_kv_heads * head_size]
key_cache = [num_kv_heads, num_blocks, block_size, head_size]
value_cache = [num_kv_heads, num_blocks, block_size, head_size]
kv_cache[0] = [num_kv_heads, num_blocks, block_size, head_size]
kv_cache[1] = [num_kv_heads, num_blocks, block_size, head_size]
NOTE: kv_cache[0] and kv_cache[1] will be an empty tensor
with shape [0] for profiling run.
attn_metadata: Metadata for attention.
Returns:
shape = [batch_size, seq_len, num_heads * head_size]
Expand All @@ -173,7 +175,7 @@ def forward(
value = value.view(batch_size, seq_len, self.num_kv_heads,
self.head_size)

if kv_cache[0] is not None:
if kv_cache[0].numel() > 0:
slot_mapping = attn_metadata.slot_mapping
key_cache, value_cache = kv_cache
write_to_kv_cache(key, value, key_cache, value_cache, slot_mapping)
Expand Down Expand Up @@ -205,7 +207,7 @@ def forward(
output = output.permute(0, 2, 1, 3)
else:
# Decoding run.
assert kv_cache is not None
assert kv_cache[0].numel() > 0

pages_per_compute_block = 16 # TODO(woosuk): Tune this value.
if self.megacore_mode == "batch" and batch_size % 2 != 0:
Expand Down
6 changes: 4 additions & 2 deletions vllm/attention/backends/rocm_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,8 @@ def forward(
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
NOTE: kv_cache will be an empty tensor with shape [0]
for profiling run.
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
Expand All @@ -412,7 +414,7 @@ def forward(
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)

if kv_cache is not None:
if kv_cache.numel() > 0:
key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size)

Expand Down Expand Up @@ -449,7 +451,7 @@ def forward(
if prefill_meta := attn_metadata.prefill_metadata:
# Prompt run.
assert prefill_meta.seq_lens is not None
if kv_cache is None or prefill_meta.block_tables.numel() == 0:
if kv_cache.numel() == 0 or prefill_meta.block_tables.numel() == 0:
# triton attention
# When block_tables are not filled, it means q and k are the
# prompt, and they have the same length.
Expand Down
9 changes: 6 additions & 3 deletions vllm/attention/backends/torch_sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def forward(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: Optional[torch.Tensor],
kv_cache: torch.Tensor,
attn_metadata: TorchSDPAMetadata, # type: ignore
k_scale: float = 1.0,
v_scale: float = 1.0,
Expand All @@ -164,6 +164,8 @@ def forward(
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
NOTE: kv_cache will be an empty tensor with shape [0]
for profiling run.
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
Expand All @@ -180,7 +182,7 @@ def forward(
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)

if kv_cache is not None:
if kv_cache.numel() > 0:
key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size)
PagedAttention.write_to_paged_cache(key, value, key_cache,
Expand All @@ -191,7 +193,8 @@ def forward(

if attn_metadata.is_prompt:
assert attn_metadata.seq_lens is not None
if (kv_cache is None or attn_metadata.block_tables.numel() == 0):
if (kv_cache.numel() == 0
or attn_metadata.block_tables.numel() == 0):
if self.num_kv_heads != self.num_heads:
key = key.repeat_interleave(self.num_queries_per_kv, dim=1)
value = value.repeat_interleave(self.num_queries_per_kv,
Expand Down
8 changes: 5 additions & 3 deletions vllm/attention/backends/xformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,7 @@ def forward(
query: torch.Tensor,
key: Optional[torch.Tensor],
value: Optional[torch.Tensor],
kv_cache: Optional[torch.Tensor],
kv_cache: torch.Tensor,
attn_metadata: "XFormersMetadata",
k_scale: float = 1.0,
v_scale: float = 1.0,
Expand Down Expand Up @@ -489,6 +489,8 @@ def forward(
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
NOTE: kv_cache will be an empty tensor with shape [0]
for profiling run.
attn_metadata: Metadata for attention.
attn_type: Select attention type, between encoder attention,
decoder self-attention, or encoder/decoder cross-
Expand Down Expand Up @@ -522,7 +524,7 @@ def forward(
# which KV cache memory-mapping & which
# seqlen datastructures we utilize

if (attn_type != AttentionType.ENCODER and kv_cache is not None):
if (attn_type != AttentionType.ENCODER and kv_cache.numel() > 0):
# KV-cache during decoder-self- or
# encoder-decoder-cross-attention, but not
# during encoder attention.
Expand Down Expand Up @@ -588,7 +590,7 @@ def forward(

if prefill_meta := attn_metadata.prefill_metadata:
# Prompt run.
if kv_cache is None or prefill_meta.block_tables.numel() == 0:
if kv_cache.numel() == 0 or prefill_meta.block_tables.numel() == 0:
# normal attention.
# block tables are empty if the prompt does not have a cached
# prefix.
Expand Down
8 changes: 7 additions & 1 deletion vllm/worker/embedding_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,13 @@ def execute_model(
model_executable = self.model

num_layers = self.model_config.get_num_layers(self.parallel_config)
kv_caches = [None] * num_layers
# use an empty tensor instead of `None`` to force Dynamo to pass
# it by reference, rather by specializing on the value ``None``.
# the `dtype` argument does not matter, and we use `float32` as
# a placeholder (it has wide hardware support).
kv_caches = [
torch.tensor([], dtype=torch.float32, device=self.device)
] * num_layers

execute_model_kwargs = {
"input_ids":
Expand Down
8 changes: 7 additions & 1 deletion vllm/worker/enc_dec_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,13 @@ def profile_run(self) -> None:

# Run the model with the dummy inputs.
num_layers = self.model_config.get_num_layers(self.parallel_config)
kv_caches = [None] * num_layers
# use an empty tensor instead of `None`` to force Dynamo to pass
# it by reference, rather by specializing on the value ``None``.
# the `dtype` argument does not matter, and we use `float32` as
# a placeholder (it has wide hardware support).
kv_caches = [
torch.tensor([], dtype=torch.float32, device=self.device)
] * num_layers
finished_requests_ids = [seq.request_id for seq in seqs]
model_input = self.prepare_model_input(
seqs, finished_requests_ids=finished_requests_ids)
Expand Down
8 changes: 7 additions & 1 deletion vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1223,7 +1223,13 @@ def profile_run(self) -> None:

# Run the model with the dummy inputs.
num_layers = self.model_config.get_num_layers(self.parallel_config)
kv_caches = [None] * num_layers
# use an empty tensor instead of `None`` to force Dynamo to pass
# it by reference, rather by specializing on the value ``None``.
# the `dtype` argument does not matter, and we use `float32` as
# a placeholder (it has wide hardware support).
kv_caches = [
torch.tensor([], dtype=torch.float32, device=self.device)
] * num_layers
finished_requests_ids = [seq.request_id for seq in seqs]
model_input = self.prepare_model_input(
seqs, finished_requests_ids=finished_requests_ids)
Expand Down
4 changes: 2 additions & 2 deletions vllm/worker/tpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,7 +714,7 @@ def forward(
t: torch.Tensor,
p: torch.Tensor,
num_samples: int,
kv_caches: List[Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]],
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
) -> torch.Tensor:
"""Executes the forward pass of the model and samples the next token.
Expand Down Expand Up @@ -745,7 +745,7 @@ def forward(
)

# Skip this in memory profiling at initialization.
if kv_caches[0][0] is not None:
if kv_caches[0][0].numel() > 0:
# index_copy_(slot_mapping) only works when the inserted dimension
# is 0. However, the KV cache in the Pallas backend has the shape
# [num_kv_heads, num_blocks, block_size, head_size]. To make it
Expand Down
10 changes: 9 additions & 1 deletion vllm/worker/tpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,15 @@ def determine_num_available_blocks(self) -> Tuple[int, int]:
head_size = self.model_config.get_head_size()
num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config)

kv_caches = [(None, None) for _ in range(num_layers)]
# use an empty tensor instead of `None`` to force Dynamo to pass
# it by reference, rather by specializing on the value ``None``.
# the `dtype` argument does not matter, and we use `float32` as
# a placeholder (it has wide hardware support).
kv_caches = [(torch.tensor([], dtype=torch.float32,
device=self.device),
torch.tensor([], dtype=torch.float32,
device=self.device))
for _ in range(num_layers)]
self.model_runner._dummy_run(
batch_size=1,
seq_len=self.scheduler_config.max_num_batched_tokens,
Expand Down
8 changes: 7 additions & 1 deletion vllm/worker/xpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,13 @@ def profile_run(self) -> None:

# Run the model with the dummy inputs.
num_layers = self.model_config.get_num_layers(self.parallel_config)
kv_caches = [None] * num_layers
# use an empty tensor instead of `None`` to force Dynamo to pass
# it by reference, rather by specializing on the value ``None``.
# the `dtype` argument does not matter, and we use `float32` as
# a placeholder (it has wide hardware support).
kv_caches = [
torch.tensor([], dtype=torch.float32, device=self.device)
] * num_layers
finished_requests_ids = [seq.request_id for seq in seqs]
model_input = self.prepare_model_input(
seqs, finished_requests_ids=finished_requests_ids)
Expand Down

0 comments on commit a9b15c6

Please sign in to comment.