From 9aaf14c62e16a7c74b5192a44d01a78125dab2fc Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 3 Oct 2024 12:09:42 -0700 Subject: [PATCH] [misc] add forward context for attention (#9029) --- tests/kernels/test_flash_attn.py | 56 +--- vllm/attention/backends/flash_attn.py | 429 ++++++++++--------------- vllm/attention/backends/flashinfer.py | 4 +- vllm/forward_context.py | 22 ++ vllm/spec_decode/draft_model_runner.py | 22 +- vllm/worker/embedding_model_runner.py | 4 +- vllm/worker/enc_dec_model_runner.py | 24 +- vllm/worker/model_runner.py | 23 +- 8 files changed, 250 insertions(+), 334 deletions(-) create mode 100644 vllm/forward_context.py diff --git a/tests/kernels/test_flash_attn.py b/tests/kernels/test_flash_attn.py index 71f61c19dd951..3e9b4d9a4f8a0 100644 --- a/tests/kernels/test_flash_attn.py +++ b/tests/kernels/test_flash_attn.py @@ -3,9 +3,9 @@ import pytest import torch -import vllm.attention.backends.flash_attn # noqa: F401 -from tests.kernels.utils import opcheck from vllm.utils import seed_everything +from vllm.vllm_flash_attn import (flash_attn_varlen_func, + flash_attn_with_kvcache) NUM_HEADS = [(4, 4), (8, 2), (16, 2)] HEAD_SIZES = [128, 256] @@ -112,10 +112,10 @@ def test_flash_attn_with_paged_kv( (num_seqs, max_num_blocks_per_seq), dtype=torch.int32) - output = torch.ops.vllm.flash_attn_with_kvcache( - decode_query=query.unsqueeze(1), - key_cache=key_cache, - value_cache=value_cache, + output = flash_attn_with_kvcache( + q=query.unsqueeze(1), + k_cache=key_cache, + v_cache=value_cache, softmax_scale=scale, causal=True, block_table=block_tables, @@ -123,25 +123,6 @@ def test_flash_attn_with_paged_kv( softcap=soft_cap if soft_cap is not None else 0, ).squeeze(1) - if num_blocks <= 2048: - test_utils = ["test_faketensor", "test_schema"] - else: - test_utils = ["test_faketensor"] - - opcheck(torch.ops.vllm.flash_attn_with_kvcache, - args=tuple(), - kwargs=dict( - decode_query=query.unsqueeze(1), - key_cache=key_cache, - value_cache=value_cache, - softmax_scale=scale, - causal=True, - block_table=block_tables, - cache_seqlens=kv_lens_tensor, - softcap=soft_cap if soft_cap is not None else 0, - ), - test_utils=test_utils) - ref_output = ref_paged_attn( query=query, key_cache=key_cache, @@ -213,7 +194,7 @@ def test_varlen_with_paged_kv( (num_seqs, max_num_blocks_per_seq), dtype=torch.int32) - output = torch.ops.vllm.flash_attn_varlen_func( + output = flash_attn_varlen_func( q=query, k=key_cache, v=value_cache, @@ -228,29 +209,6 @@ def test_varlen_with_paged_kv( softcap=soft_cap if soft_cap is not None else 0, ) - if num_blocks <= 2048: - test_utils = ["test_faketensor", "test_schema"] - else: - test_utils = ["test_faketensor"] - - opcheck(torch.ops.vllm.flash_attn_varlen_func, - args=tuple(), - kwargs=dict( - q=query, - k=key_cache, - v=value_cache, - cu_seqlens_q=cu_query_lens, - cu_seqlens_k=cu_kv_lens, - max_seqlen_q=max_query_len, - max_seqlen_k=max_kv_len, - softmax_scale=scale, - causal=True, - window_size=window_size, - block_table=block_tables, - softcap=soft_cap if soft_cap is not None else 0, - ), - test_utils=test_utils) - ref_output = ref_paged_attn( query=query, key_cache=key_cache, diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index bb8ab1e3c8c26..bba80262e52d3 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -13,152 +13,15 @@ compute_slot_mapping, compute_slot_mapping_start_idx, is_block_tables_empty) +from vllm.forward_context import get_forward_context from vllm.utils import async_tensor_h2d, make_tensor_with_pad if TYPE_CHECKING: from vllm.worker.model_runner import (ModelInputForGPUBuilder, ModelInputForGPUWithSamplingMetadata) -# yapf: disable -from vllm.vllm_flash_attn import ( - flash_attn_varlen_func as _flash_attn_varlen_func) -from vllm.vllm_flash_attn import ( - flash_attn_with_kvcache as _flash_attn_with_kvcache) - -# yapf: enable - - -@torch.library.custom_op("vllm::flash_attn_varlen_func", mutates_args=[]) -def flash_attn_varlen_func( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - cu_seqlens_q: torch.Tensor, - cu_seqlens_k: torch.Tensor, - max_seqlen_q: int, - max_seqlen_k: int, - softmax_scale: Optional[float] = None, - causal: bool = False, - window_size: Optional[List[int]] = None, - softcap: float = 0.0, - alibi_slopes: Optional[torch.Tensor] = None, - block_table: Optional[torch.Tensor] = None, -) -> torch.Tensor: - # custom op does not support tuple input - real_window_size: Tuple[int, int] - if window_size is None: - real_window_size = (-1, -1) - else: - assert len(window_size) == 2 - real_window_size = (window_size[0], window_size[1]) - return _flash_attn_varlen_func( - q=q, - k=k, - v=v, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - max_seqlen_q=max_seqlen_q, - max_seqlen_k=max_seqlen_k, - softmax_scale=softmax_scale, - causal=causal, - window_size=real_window_size, - softcap=softcap, - alibi_slopes=alibi_slopes, - block_table=block_table, - ) - - -@flash_attn_varlen_func.register_fake # type: ignore -def _( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - cu_seqlens_q: torch.Tensor, - cu_seqlens_k: torch.Tensor, - max_seqlen_q: int, - max_seqlen_k: int, - softmax_scale: Optional[float] = None, - causal: bool = False, - window_size: Optional[List[int]] = None, - softcap: float = 0.0, - alibi_slopes: Optional[torch.Tensor] = None, - block_table: Optional[torch.Tensor] = None, -) -> torch.Tensor: - return torch.empty_like(q) - - -@torch.library.custom_op("vllm::flash_attn_with_kvcache", mutates_args=[]) -def flash_attn_with_kvcache( - decode_query: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - cache_seqlens: Optional[torch.Tensor] = None, - block_table: Optional[torch.Tensor] = None, - softmax_scale: Optional[float] = None, - causal: bool = False, - alibi_slopes: Optional[torch.Tensor] = None, - softcap: float = 0.0, -) -> torch.Tensor: - return _flash_attn_with_kvcache( - decode_query, - key_cache, - value_cache, - cache_seqlens=cache_seqlens, - block_table=block_table, - softmax_scale=softmax_scale, - causal=causal, - alibi_slopes=alibi_slopes, - softcap=softcap, - ) - - -@flash_attn_with_kvcache.register_fake # type: ignore -def _( - decode_query: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - cache_seqlens: Optional[torch.Tensor] = None, - block_table: Optional[torch.Tensor] = None, - softmax_scale: Optional[float] = None, - causal: bool = False, - alibi_slopes: Optional[torch.Tensor] = None, - softcap: float = 0.0, -) -> torch.Tensor: - return torch.empty_like(decode_query) - - -@torch.library.custom_op("vllm::reshape_and_cache_flash", - mutates_args=["kv_cache"]) -def reshape_and_cache_flash( - key: torch.Tensor, - value: torch.Tensor, - kv_cache: torch.Tensor, - slot_mapping: torch.Tensor, - kv_cache_dtype: str, - k_scale: float, - v_scale: float, -) -> None: - """Inductor cannot deal with inplace operations on views. - See https://github.com/pytorch/pytorch/issues/131192 - and https://github.com/pytorch/pytorch/issues/130174 - This is a workaround to hide the view operation from the inductor. - """ - return torch.ops._C_cache_ops.reshape_and_cache_flash( - key, value, kv_cache[0], kv_cache[1], slot_mapping, kv_cache_dtype, - k_scale, v_scale) - - -@reshape_and_cache_flash.register_fake # type: ignore -def _( - key: torch.Tensor, - value: torch.Tensor, - kv_cache: torch.Tensor, - slot_mapping: torch.Tensor, - kv_cache_dtype: str, - k_scale: float, - v_scale: float, -) -> None: - pass +from vllm.vllm_flash_attn import (flash_attn_varlen_func, + flash_attn_with_kvcache) class FlashAttentionBackend(AttentionBackend): @@ -721,118 +584,182 @@ def forward( assert k_scale == 1.0 and v_scale == 1.0, ( "key/v_scale is not supported in FlashAttention.") - num_tokens, hidden_size = query.shape - # Reshape the query, key, and value tensors. - query = query.view(-1, self.num_heads, self.head_size) - key = key.view(-1, self.num_kv_heads, self.head_size) - value = value.view(-1, self.num_kv_heads, self.head_size) - - if kv_cache.numel() > 0: - key_cache = kv_cache[0] - value_cache = kv_cache[1] - - # Reshape the input keys and values and store them in the cache. - # If kv_cache is not provided, the new key and value tensors are - # not cached. This happens during the initial memory profiling run. - torch.ops.vllm.reshape_and_cache_flash( - key, - value, - kv_cache, - attn_metadata.slot_mapping.flatten(), - self.kv_cache_dtype, - k_scale, - v_scale, - ) + output = torch.ops.vllm.unified_flash_attention( + query, + key, + value, + self.num_heads, + self.head_size, + self.num_kv_heads, + kv_cache, + self.kv_cache_dtype, + k_scale, + v_scale, + self.scale, + self.sliding_window, + self.alibi_slopes, + self.logits_soft_cap, + ) - num_prefill_tokens = attn_metadata.num_prefill_tokens - num_decode_tokens = attn_metadata.num_decode_tokens - assert key.shape[0] == num_prefill_tokens + num_decode_tokens, \ - f"key : {key.shape} : #prefill tokens {num_prefill_tokens} : #decode tokens {num_decode_tokens}" # noqa - assert value.shape[0] == num_prefill_tokens + num_decode_tokens, \ - f"value : {value.shape} : #prefill toks {num_prefill_tokens} : #decode toks {num_decode_tokens}" # noqa - - # Query for decode. KV is not needed because it is already cached. - decode_query = query[num_prefill_tokens:] - # QKV for prefill. - query = query[:num_prefill_tokens] - key = key[:num_prefill_tokens] - value = value[:num_prefill_tokens] - - assert query.shape[0] == num_prefill_tokens - assert decode_query.shape[0] == num_decode_tokens - - prefill_output: Optional[torch.Tensor] = None - decode_output: Optional[torch.Tensor] = None - - if prefill_meta := attn_metadata.prefill_metadata: - # Prompt run. - if (kv_cache.numel() == 0 or prefill_meta.block_tables is None - or prefill_meta.block_tables.numel() == 0): - # normal attention - # When block_tables are not filled, it means q and k are the - # prompt, and they have the same length. - prefill_output = torch.ops.vllm.flash_attn_varlen_func( - q=query, - k=key, - v=value, - cu_seqlens_q=prefill_meta.seq_start_loc, - cu_seqlens_k=prefill_meta.seq_start_loc, - max_seqlen_q=prefill_meta.max_prefill_seq_len, - max_seqlen_k=prefill_meta.max_prefill_seq_len, - softmax_scale=self.scale, - causal=True, - window_size=self.sliding_window, - alibi_slopes=self.alibi_slopes, - softcap=self.logits_soft_cap, - ) - else: - # prefix-enabled attention - assert prefill_meta.seq_lens is not None - max_seq_len = max(prefill_meta.seq_lens) - prefill_output = torch.ops.vllm.flash_attn_varlen_func( # noqa - q=query, - k=key_cache, - v=value_cache, - cu_seqlens_q=prefill_meta.query_start_loc, - max_seqlen_q=prefill_meta.max_query_len, - cu_seqlens_k=prefill_meta.seq_start_loc, - max_seqlen_k=max_seq_len, - softmax_scale=self.scale, - causal=True, - alibi_slopes=self.alibi_slopes, - block_table=prefill_meta.block_tables, - softcap=self.logits_soft_cap, - ) - - if decode_meta := attn_metadata.decode_metadata: - # Decoding run. - _, num_head, head_dim = decode_query.shape - decode_query = decode_query.reshape(-1, - decode_meta.decode_query_len, - num_head, head_dim) - decode_output = torch.ops.vllm.flash_attn_with_kvcache( - decode_query, - key_cache, - value_cache, - block_table=decode_meta.block_tables, - cache_seqlens=decode_meta.seq_lens_tensor, - softmax_scale=self.scale, + return output + + +@torch.library.custom_op("vllm::unified_flash_attention", + mutates_args=["kv_cache"]) +def unified_flash_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + num_heads: int, + head_size: int, + num_kv_heads: int, + kv_cache: torch.Tensor, + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + softmax_scale: float, + window_size: Optional[List[int]] = None, + alibi_slopes: Optional[torch.Tensor] = None, + logits_soft_cap: Optional[float] = None, +) -> torch.Tensor: + + current_metadata = get_forward_context() + assert current_metadata is not None + assert isinstance(current_metadata, FlashAttentionMetadata) + attn_metadata: FlashAttentionMetadata = current_metadata + + num_tokens, hidden_size = query.shape + # Reshape the query, key, and value tensors. + query = query.view(-1, num_heads, head_size) + key = key.view(-1, num_kv_heads, head_size) + value = value.view(-1, num_kv_heads, head_size) + + if kv_cache.numel() > 0: + key_cache = kv_cache[0] + value_cache = kv_cache[1] + + # Reshape the input keys and values and store them in the cache. + # If kv_cache is not provided, the new key and value tensors are + # not cached. This happens during the initial memory profiling run. + torch.ops._C_cache_ops.reshape_and_cache_flash( + key, + value, + kv_cache[0], + kv_cache[1], + attn_metadata.slot_mapping.flatten(), + kv_cache_dtype, + k_scale, + v_scale, + ) + + num_prefill_tokens = attn_metadata.num_prefill_tokens + num_decode_tokens = attn_metadata.num_decode_tokens + assert key.shape[0] == num_prefill_tokens + num_decode_tokens, \ + f"key : {key.shape} : #prefill tokens {num_prefill_tokens} : #decode tokens {num_decode_tokens}" # noqa + assert value.shape[0] == num_prefill_tokens + num_decode_tokens, \ + f"value : {value.shape} : #prefill toks {num_prefill_tokens} : #decode toks {num_decode_tokens}" # noqa + + # Query for decode. KV is not needed because it is already cached. + decode_query = query[num_prefill_tokens:] + # QKV for prefill. + query = query[:num_prefill_tokens] + key = key[:num_prefill_tokens] + value = value[:num_prefill_tokens] + + assert query.shape[0] == num_prefill_tokens + assert decode_query.shape[0] == num_decode_tokens + + prefill_output: Optional[torch.Tensor] = None + decode_output: Optional[torch.Tensor] = None + + if prefill_meta := attn_metadata.prefill_metadata: + # Prompt run. + if (kv_cache.numel() == 0 or prefill_meta.block_tables is None + or prefill_meta.block_tables.numel() == 0): + # normal attention + # When block_tables are not filled, it means q and k are the + # prompt, and they have the same length. + prefill_output = flash_attn_varlen_func( + q=query, + k=key, + v=value, + cu_seqlens_q=prefill_meta.seq_start_loc, + cu_seqlens_k=prefill_meta.seq_start_loc, + max_seqlen_q=prefill_meta.max_prefill_seq_len, + max_seqlen_k=prefill_meta.max_prefill_seq_len, + softmax_scale=softmax_scale, + causal=True, + window_size=window_size, + alibi_slopes=alibi_slopes, + softcap=logits_soft_cap, + ) + else: + # prefix-enabled attention + assert prefill_meta.seq_lens is not None + max_seq_len = max(prefill_meta.seq_lens) + prefill_output = flash_attn_varlen_func( # noqa + q=query, + k=key_cache, + v=value_cache, + cu_seqlens_q=prefill_meta.query_start_loc, + max_seqlen_q=prefill_meta.max_query_len, + cu_seqlens_k=prefill_meta.seq_start_loc, + max_seqlen_k=max_seq_len, + softmax_scale=softmax_scale, causal=True, - alibi_slopes=self.alibi_slopes, - softcap=self.logits_soft_cap, + alibi_slopes=alibi_slopes, + block_table=prefill_meta.block_tables, + softcap=logits_soft_cap, ) - if prefill_output is None: - assert decode_output is not None - return decode_output.view(num_decode_tokens, hidden_size) - if decode_output is None: - assert prefill_output is not None - return prefill_output.view(num_prefill_tokens, hidden_size) - - # Chunked prefill does not work with speculative decoding. - # Therefore, the query length for decode should be 1 in chunked prefill. - assert decode_meta is not None - assert decode_meta.decode_query_len == 1 - decode_output = decode_output.squeeze(1) - output = torch.cat([prefill_output, decode_output], dim=0) - return output.view(num_tokens, hidden_size) + if decode_meta := attn_metadata.decode_metadata: + # Decoding run. + _, num_head, head_dim = decode_query.shape + decode_query = decode_query.reshape(-1, decode_meta.decode_query_len, + num_head, head_dim) + decode_output = flash_attn_with_kvcache( + q=decode_query, + k_cache=key_cache, + v_cache=value_cache, + block_table=decode_meta.block_tables, + cache_seqlens=decode_meta.seq_lens_tensor, + softmax_scale=softmax_scale, + causal=True, + alibi_slopes=alibi_slopes, + softcap=logits_soft_cap, + ).squeeze(1) + + if prefill_output is None: + assert decode_output is not None + return decode_output.view(num_decode_tokens, hidden_size) + if decode_output is None: + assert prefill_output is not None + return prefill_output.view(num_prefill_tokens, hidden_size) + + # Chunked prefill does not work with speculative decoding. + # Therefore, the query length for decode should be 1 in chunked prefill. + assert decode_meta is not None + assert decode_meta.decode_query_len == 1 + decode_output = decode_output.squeeze(1) + output = torch.cat([prefill_output, decode_output], dim=0) + return output.view(num_tokens, hidden_size) + + +@unified_flash_attention.register_fake +def _( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + num_heads: int, + head_size: int, + num_kv_heads: int, + kv_cache: torch.Tensor, + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + softmax_scale: float, + window_size: Optional[List[int]] = None, + alibi_slopes: Optional[torch.Tensor] = None, + logits_soft_cap: Optional[float] = None, +) -> torch.Tensor: + return torch.empty_like(query) diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 96d37b99f2013..40e804934cbdd 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -7,7 +7,7 @@ from flashinfer.decode import CUDAGraphBatchDecodeWithPagedKVCacheWrapper from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper - import vllm.attention.backends.flash_attn # noqa + from vllm.vllm_flash_attn import flash_attn_varlen_func FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024 except ImportError: BatchDecodeWithPagedKVCacheWrapper = None @@ -799,7 +799,7 @@ def forward( # This happens when vllm runs the profiling to # determine the number of blocks. if kv_cache.numel() == 0: - output = torch.ops.vllm.flash_attn_varlen_func( + output = flash_attn_varlen_func( q=query, k=key, v=value, diff --git a/vllm/forward_context.py b/vllm/forward_context.py new file mode 100644 index 0000000000000..777747505e14a --- /dev/null +++ b/vllm/forward_context.py @@ -0,0 +1,22 @@ +from contextlib import contextmanager +from typing import Any + +_forward_context: Any = None + + +def get_forward_context() -> Any: + """Get the current forward context.""" + return _forward_context + + +@contextmanager +def set_forward_context(context: Any): + """A context manager that stores the current forward context, + can be attention metadata, etc.""" + global _forward_context + prev_context = _forward_context + _forward_context = context + try: + yield + finally: + _forward_context = prev_context diff --git a/vllm/spec_decode/draft_model_runner.py b/vllm/spec_decode/draft_model_runner.py index 71cba5dd25f6a..984747c53c6c0 100644 --- a/vllm/spec_decode/draft_model_runner.py +++ b/vllm/spec_decode/draft_model_runner.py @@ -2,6 +2,7 @@ import torch +from vllm.forward_context import set_forward_context from vllm.model_executor.layers.sampler import SamplerOutput try: @@ -291,16 +292,17 @@ def execute_model( if previous_hidden_states is not None else {} # Run model - hidden_states = model_executable( - input_ids=model_input.input_tokens, - positions=model_input.input_positions, - kv_caches=kv_caches, - attn_metadata=model_input.attn_metadata, - intermediate_tensors=intermediate_tensors, - **MultiModalInputs.as_kwargs(multi_modal_kwargs, - device=self.device), - **kwargs, - ) + with set_forward_context(model_input.attn_metadata): + hidden_states = model_executable( + input_ids=model_input.input_tokens, + positions=model_input.input_positions, + kv_caches=kv_caches, + attn_metadata=model_input.attn_metadata, + intermediate_tensors=intermediate_tensors, + **MultiModalInputs.as_kwargs(multi_modal_kwargs, + device=self.device), + **kwargs, + ) # Compute the logits. logits = self.model.compute_logits(hidden_states, diff --git a/vllm/worker/embedding_model_runner.py b/vllm/worker/embedding_model_runner.py index 1ccf10f1a60da..1fd37eac6b851 100644 --- a/vllm/worker/embedding_model_runner.py +++ b/vllm/worker/embedding_model_runner.py @@ -6,6 +6,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ObservabilityConfig, ParallelConfig, PromptAdapterConfig, SchedulerConfig) +from vllm.forward_context import set_forward_context from vllm.logger import init_logger from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.multimodal import MultiModalInputs @@ -119,7 +120,8 @@ def execute_model( device=self.device), } - hidden_states = model_executable(**execute_model_kwargs) + with set_forward_context(model_input.attn_metadata): + hidden_states = model_executable(**execute_model_kwargs) # Only perform pooling in the driver worker. if not self.is_driver_worker: diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index 90dfad62e0286..59b4b8c4ddf38 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -14,6 +14,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ObservabilityConfig, ParallelConfig, PromptAdapterConfig, SchedulerConfig) +from vllm.forward_context import set_forward_context from vllm.inputs import INPUT_REGISTRY, InputRegistry from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata @@ -198,17 +199,18 @@ def execute_model( } if self.has_seqlen_agnostic else {} multi_modal_kwargs = model_input.multi_modal_kwargs or {} - hidden_or_intermediate_states = model_executable( - input_ids=model_input.input_tokens, - positions=model_input.input_positions, - encoder_input_ids=model_input.encoder_input_tokens, - encoder_positions=model_input.encoder_input_positions, - kv_caches=kv_caches, - attn_metadata=model_input.attn_metadata, - intermediate_tensors=intermediate_tensors, - **MultiModalInputs.as_kwargs(multi_modal_kwargs, - device=self.device), - **seqlen_agnostic_kwargs) + with set_forward_context(model_input.attn_metadata): + hidden_or_intermediate_states = model_executable( + input_ids=model_input.input_tokens, + positions=model_input.input_positions, + encoder_input_ids=model_input.encoder_input_tokens, + encoder_positions=model_input.encoder_input_positions, + kv_caches=kv_caches, + attn_metadata=model_input.attn_metadata, + intermediate_tensors=intermediate_tensors, + **MultiModalInputs.as_kwargs(multi_modal_kwargs, + device=self.device), + **seqlen_agnostic_kwargs) logits = self.model.compute_logits(hidden_or_intermediate_states, model_input.sampling_metadata) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index f44e5113c218d..51f65cbfcf862 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -24,6 +24,7 @@ from vllm.core.scheduler import SchedulerOutputs from vllm.distributed import get_pp_group from vllm.distributed.parallel_state import graph_capture +from vllm.forward_context import set_forward_context from vllm.inputs import INPUT_REGISTRY, InputRegistry from vllm.logger import init_logger from vllm.lora.layers import LoRAMapping @@ -1499,7 +1500,8 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: self._update_inputs_to_capture_for_enc_dec_model( capture_inputs) - graph_runner.capture(**capture_inputs) + with set_forward_context(attn_metadata): + graph_runner.capture(**capture_inputs) self.graph_memory_pool = graph_runner.graph.pool() self.graph_runners[virtual_engine][batch_size] = ( graph_runner) @@ -1641,15 +1643,16 @@ def execute_model( model_forward_end = torch.cuda.Event(enable_timing=True) model_forward_start.record() - hidden_or_intermediate_states = model_executable( - input_ids=model_input.input_tokens, - positions=model_input.input_positions, - kv_caches=kv_caches, - attn_metadata=model_input.attn_metadata, - intermediate_tensors=intermediate_tensors, - **MultiModalInputs.as_kwargs(multi_modal_kwargs, - device=self.device), - **seqlen_agnostic_kwargs) + with set_forward_context(model_input.attn_metadata): + hidden_or_intermediate_states = model_executable( + input_ids=model_input.input_tokens, + positions=model_input.input_positions, + kv_caches=kv_caches, + attn_metadata=model_input.attn_metadata, + intermediate_tensors=intermediate_tensors, + **MultiModalInputs.as_kwargs(multi_modal_kwargs, + device=self.device), + **seqlen_agnostic_kwargs) if (self.observability_config is not None and self.observability_config.collect_model_forward_time):