Skip to content

Commit

Permalink
[spec decode] [4/N] Move update_flash_attn_metadata to attn backend (#…
Browse files Browse the repository at this point in the history
…7571)

Co-authored-by: Cody Yu <hao.yu.cody@gmail.com>
  • Loading branch information
SolitaryThinker and comaniac authored Aug 16, 2024
1 parent 855866c commit f366f63
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 33 deletions.
3 changes: 3 additions & 0 deletions vllm/attention/backends/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ def copy_blocks(
) -> None:
raise NotImplementedError

def advance_step(self, num_seqs: int, num_queries: int):
raise NotImplementedError


@dataclass
class AttentionMetadata:
Expand Down
45 changes: 45 additions & 0 deletions vllm/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,51 @@ def decode_metadata(self) -> Optional["FlashAttentionMetadata"]:
)
return self._cached_decode_metadata

def advance_step(self, num_seqs: int, num_queries: int):
"""
Update metadata in-place to advance one decode step.
"""
# GPU in-place update is currently called separately through
# custom_ops.advance_step(). See draft_model_runner. TODO(will): Move
# this logic to the backend.

# When using cudagraph, the num_seqs is padded to the next captured
# batch sized, but num_queries tracks the actual number of requests in
# the batch. For --enforce-eager mode, num_seqs == num_queries
if num_seqs != num_queries:
assert num_seqs > num_queries
assert self.use_cuda_graph

assert self.num_prefills == 0
assert self.num_prefill_tokens == 0
assert self.num_decode_tokens == num_seqs
assert self.slot_mapping.shape == (num_seqs, )

assert self.seq_lens is not None
assert len(self.seq_lens) == num_seqs
assert self.seq_lens_tensor is not None
assert self.seq_lens_tensor.shape == (num_seqs, )
assert self.max_query_len == 1
assert self.max_prefill_seq_len == 0
assert self.max_decode_seq_len == max(self.seq_lens)

assert self.query_start_loc is not None
assert self.query_start_loc.shape == (num_queries + 1, )
assert self.seq_start_loc is not None
assert self.seq_start_loc.shape == (num_seqs + 1, )

assert self.context_lens_tensor is not None
assert self.context_lens_tensor.shape == (num_queries, )

assert self.block_tables is not None
assert self.block_tables.shape[0] == num_seqs

# Update query lengths. Note that we update only queries and not seqs,
# since tensors may be padded due to captured cuda graph batch size
for i in range(num_queries):
self.seq_lens[i] += 1
self.max_decode_seq_len = max(self.seq_lens)


class FlashAttentionMetadataBuilder(
AttentionMetadataBuilder[FlashAttentionMetadata]):
Expand Down
34 changes: 1 addition & 33 deletions vllm/spec_decode/draft_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,38 +97,6 @@ def __init__(
self.flashinfer_prefill_workspace_buffer = None
self.flashinfer_prefill_wrapper = None

def _update_flash_attn_metadata(self, attn_metadata, num_seqs,
num_queries):
assert isinstance(attn_metadata, FlashAttentionMetadata)

if num_seqs != num_queries:
assert num_seqs > num_queries
assert attn_metadata.use_cuda_graph

assert attn_metadata.num_prefills == 0
assert attn_metadata.num_prefill_tokens == 0
assert attn_metadata.num_decode_tokens == num_seqs
assert attn_metadata.slot_mapping.shape == (num_seqs, )

assert len(attn_metadata.seq_lens) == num_seqs
assert attn_metadata.seq_lens_tensor.shape == (num_seqs, )
assert attn_metadata.max_query_len == 1
assert attn_metadata.max_prefill_seq_len == 0
assert attn_metadata.max_decode_seq_len == max(attn_metadata.seq_lens)

assert attn_metadata.query_start_loc.shape == (num_queries + 1, )
assert attn_metadata.seq_start_loc.shape == (num_seqs + 1, )

assert attn_metadata.context_lens_tensor.shape == (num_queries, )

assert attn_metadata.block_tables.shape[0] == num_seqs

# Update query lengths. Note that we update only queries and not seqs,
# since tensors may be padded due to captured cuda graph batch size
for i in range(num_queries):
attn_metadata.seq_lens[i] += 1
attn_metadata.max_decode_seq_len = max(attn_metadata.seq_lens)

def _update_sampling_metadata(self, sampling_metadata, num_seqs,
num_queries):

Expand Down Expand Up @@ -166,7 +134,7 @@ def _gpu_advance_step(
# Update attn_metadata
attn_metadata = model_input.attn_metadata
assert isinstance(attn_metadata, FlashAttentionMetadata)
self._update_flash_attn_metadata(attn_metadata, num_seqs, num_queries)
attn_metadata.advance_step(num_seqs, num_queries)

# Update GPU tensors
ops.advance_step(num_seqs=num_seqs,
Expand Down

0 comments on commit f366f63

Please sign in to comment.