diff --git a/vllm/envs.py b/vllm/envs.py index 5b8a65bd6545..595058bcbb02 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -174,6 +174,10 @@ def get_default_config_root(): lambda: (os.environ.get("VLLM_USE_TRITON_FLASH_ATTN", "True").lower() in ("true", "1")), + # Internal flag to enable Dynamo graph capture + "VLLM_TEST_DYNAMO_GRAPH_CAPTURE": + lambda: int(os.environ.get("VLLM_TEST_DYNAMO_GRAPH_CAPTURE", "0")), + # local rank of the process in the distributed setting, used to determine # the GPU device id "LOCAL_RANK": diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 777344289958..f9c26e0c318b 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -23,6 +23,7 @@ BatchPrefillWithPagedKVCacheWrapper = None FLASHINFER_WORKSPACE_BUFFER_SIZE = 0 +import vllm.envs as envs from vllm.attention import AttentionMetadata, get_attn_backend from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, MultiModalConfig, ParallelConfig, @@ -786,6 +787,11 @@ def load_model(self) -> None: "provided. Defaulting to scaling factors of 1.0. " "This may lead to less accurate results!") + if envs.VLLM_TEST_DYNAMO_GRAPH_CAPTURE: + self.model = torch.compile(self.model, + fullgraph=True, + backend="eager") + def save_sharded_state( self, path: str,