diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 501743c887596..c436d2b48d20f 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -50,7 +50,9 @@ steps: - tests/multimodal - tests/test_utils - tests/worker + - tests/test_lazy_torch_compile.py commands: + - python3 test_lazy_torch_compile.py - pytest -v -s mq_llm_engine # MQLLMEngine - pytest -v -s async_engine # AsyncLLMEngine - NUM_SCHEDULER_STEPS=4 pytest -v -s async_engine/test_async_llm_engine.py diff --git a/tests/quantization/utils.py b/tests/quantization/utils.py index 061a077592e80..8ebd8dd2be0d5 100644 --- a/tests/quantization/utils.py +++ b/tests/quantization/utils.py @@ -1,4 +1,4 @@ -from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS +from vllm.model_executor.layers.quantization import get_quantization_config from vllm.platforms import current_platform @@ -10,6 +10,6 @@ def is_quant_method_supported(quant_method: str) -> bool: capability = current_platform.get_device_capability() assert capability is not None - min_capability = QUANTIZATION_METHODS[quant_method].get_min_capability() + min_capability = get_quantization_config(quant_method).get_min_capability() return capability.to_int() >= min_capability diff --git a/tests/test_lazy_torch_compile.py b/tests/test_lazy_torch_compile.py new file mode 100644 index 0000000000000..b8ac4dd93732b --- /dev/null +++ b/tests/test_lazy_torch_compile.py @@ -0,0 +1,68 @@ +# Description: Test the lazy import module +# The utility function cannot be placed in `vllm.utils` +# this needs to be a standalone script + +import contextlib +import dataclasses +import sys +import traceback +from typing import Callable, Generator + + +@dataclasses.dataclass +class BlameResult: + found: bool = False + trace_stack: str = "" + + +@contextlib.contextmanager +def blame(func: Callable) -> Generator[BlameResult, None, None]: + """ + Trace the function calls to find the first function that satisfies the + condition. The trace stack will be stored in the result. + + Usage: + + ```python + with blame(lambda: some_condition()) as result: + # do something + + if result.found: + print(result.trace_stack) + """ + result = BlameResult() + + def _trace_calls(frame, event, arg=None): + nonlocal result + if event in ['call', 'return']: + # for every function call or return + try: + # Temporarily disable the trace function + sys.settrace(None) + # check condition here + if not result.found and func(): + result.found = True + result.trace_stack = "".join(traceback.format_stack()) + # Re-enable the trace function + sys.settrace(_trace_calls) + except NameError: + # modules are deleted during shutdown + pass + return _trace_calls + + sys.settrace(_trace_calls) + + yield result + + sys.settrace(None) + + +module_name = "torch._inductor.async_compile" + +with blame(lambda: module_name in sys.modules) as result: + import vllm # noqa + +assert not result.found, (f"Module {module_name} is already imported, the" + f" first import location is:\n{result.trace_stack}") + +print(f"Module {module_name} is not imported yet") diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 782dc6aed1b8c..41892e4dddf7e 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -19,9 +19,6 @@ except ImportError as e: logger.warning("Failed to import from vllm._C with %r", e) -if current_platform.is_rocm(): - import vllm._rocm_C # noqa: F401 - supports_moe_ops = False with contextlib.suppress(ImportError): import vllm._moe_C # noqa: F401 diff --git a/vllm/config.py b/vllm/config.py index 3d0c616868225..7522486782cc9 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -14,7 +14,8 @@ import vllm.envs as envs from vllm.logger import init_logger -from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS +from vllm.model_executor.layers.quantization import (QUANTIZATION_METHODS, + get_quantization_config) from vllm.model_executor.models import ModelRegistry from vllm.platforms import current_platform from vllm.tracing import is_otel_available, otel_import_error_traceback @@ -370,7 +371,7 @@ def _parse_quant_hf_config(self): return quant_cfg def _verify_quantization(self) -> None: - supported_quantization = [*QUANTIZATION_METHODS] + supported_quantization = QUANTIZATION_METHODS rocm_supported_quantization = [ "awq", "gptq", "fp8", "compressed_tensors", "compressed-tensors", "fbgemm_fp8" @@ -392,7 +393,8 @@ def _verify_quantization(self) -> None: quant_method = quant_cfg.get("quant_method", "").lower() # Detect which checkpoint is it - for _, method in QUANTIZATION_METHODS.items(): + for name in QUANTIZATION_METHODS: + method = get_quantization_config(name) quantization_override = method.override_quantization_method( quant_cfg, self.quantization) if quantization_override: diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index ff342c4f9479e..dd10c434f0752 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -1,65 +1,87 @@ -from typing import Dict, Type +from typing import Dict, List, Type -from vllm.model_executor.layers.quantization.aqlm import AQLMConfig -from vllm.model_executor.layers.quantization.awq import AWQConfig -from vllm.model_executor.layers.quantization.awq_marlin import AWQMarlinConfig from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) -from vllm.model_executor.layers.quantization.bitsandbytes import ( - BitsAndBytesConfig) -from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501 - CompressedTensorsConfig) -from vllm.model_executor.layers.quantization.deepspeedfp import ( - DeepSpeedFPConfig) -from vllm.model_executor.layers.quantization.experts_int8 import ( - ExpertsInt8Config) -from vllm.model_executor.layers.quantization.fbgemm_fp8 import FBGEMMFp8Config -from vllm.model_executor.layers.quantization.fp8 import Fp8Config -from vllm.model_executor.layers.quantization.gguf import GGUFConfig -from vllm.model_executor.layers.quantization.gptq import GPTQConfig -from vllm.model_executor.layers.quantization.gptq_marlin import ( - GPTQMarlinConfig) -from vllm.model_executor.layers.quantization.gptq_marlin_24 import ( - GPTQMarlin24Config) -from vllm.model_executor.layers.quantization.hqq_marlin import HQQMarlinConfig -from vllm.model_executor.layers.quantization.ipex_quant import IPEXConfig -from vllm.model_executor.layers.quantization.marlin import MarlinConfig -from vllm.model_executor.layers.quantization.modelopt import ModelOptFp8Config -from vllm.model_executor.layers.quantization.neuron_quant import ( - NeuronQuantConfig) -from vllm.model_executor.layers.quantization.qqq import QQQConfig -from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig -QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = { - "aqlm": AQLMConfig, - "awq": AWQConfig, - "deepspeedfp": DeepSpeedFPConfig, - "tpu_int8": Int8TpuConfig, - "fp8": Fp8Config, - "fbgemm_fp8": FBGEMMFp8Config, - "modelopt": ModelOptFp8Config, +QUANTIZATION_METHODS: List[str] = [ + "aqlm", + "awq", + "deepspeedfp", + "tpu_int8", + "fp8", + "fbgemm_fp8", + "modelopt", # The order of gptq methods is important for config.py iteration over # override_quantization_method(..) - "marlin": MarlinConfig, - "gguf": GGUFConfig, - "gptq_marlin_24": GPTQMarlin24Config, - "gptq_marlin": GPTQMarlinConfig, - "awq_marlin": AWQMarlinConfig, - "gptq": GPTQConfig, - "compressed-tensors": CompressedTensorsConfig, - "bitsandbytes": BitsAndBytesConfig, - "qqq": QQQConfig, - "hqq": HQQMarlinConfig, - "experts_int8": ExpertsInt8Config, - "neuron_quant": NeuronQuantConfig, - "ipex": IPEXConfig, -} + "marlin", + "gguf", + "gptq_marlin_24", + "gptq_marlin", + "awq_marlin", + "gptq", + "compressed-tensors", + "bitsandbytes", + "qqq", + "hqq", + "experts_int8", + "neuron_quant", + "ipex", +] def get_quantization_config(quantization: str) -> Type[QuantizationConfig]: if quantization not in QUANTIZATION_METHODS: raise ValueError(f"Invalid quantization method: {quantization}") - return QUANTIZATION_METHODS[quantization] + + # lazy import to avoid triggering `torch.compile` too early + from .aqlm import AQLMConfig + from .awq import AWQConfig + from .awq_marlin import AWQMarlinConfig + from .bitsandbytes import BitsAndBytesConfig + from .compressed_tensors.compressed_tensors import ( # noqa: E501 + CompressedTensorsConfig) + from .deepspeedfp import DeepSpeedFPConfig + from .experts_int8 import ExpertsInt8Config + from .fbgemm_fp8 import FBGEMMFp8Config + from .fp8 import Fp8Config + from .gguf import GGUFConfig + from .gptq import GPTQConfig + from .gptq_marlin import GPTQMarlinConfig + from .gptq_marlin_24 import GPTQMarlin24Config + from .hqq_marlin import HQQMarlinConfig + from .ipex_quant import IPEXConfig + from .marlin import MarlinConfig + from .modelopt import ModelOptFp8Config + from .neuron_quant import NeuronQuantConfig + from .qqq import QQQConfig + from .tpu_int8 import Int8TpuConfig + + method_to_config: Dict[str, Type[QuantizationConfig]] = { + "aqlm": AQLMConfig, + "awq": AWQConfig, + "deepspeedfp": DeepSpeedFPConfig, + "tpu_int8": Int8TpuConfig, + "fp8": Fp8Config, + "fbgemm_fp8": FBGEMMFp8Config, + "modelopt": ModelOptFp8Config, + # The order of gptq methods is important for config.py iteration over + # override_quantization_method(..) + "marlin": MarlinConfig, + "gguf": GGUFConfig, + "gptq_marlin_24": GPTQMarlin24Config, + "gptq_marlin": GPTQMarlinConfig, + "awq_marlin": AWQMarlinConfig, + "gptq": GPTQConfig, + "compressed-tensors": CompressedTensorsConfig, + "bitsandbytes": BitsAndBytesConfig, + "qqq": QQQConfig, + "hqq": HQQMarlinConfig, + "experts_int8": ExpertsInt8Config, + "neuron_quant": NeuronQuantConfig, + "ipex": IPEXConfig, + } + + return method_to_config[quantization] __all__ = [ diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index 7ea2f9be2191d..5d38b4b1ef14b 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -19,8 +19,8 @@ from vllm.config import VllmConfig from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, InputContext, token_inputs) -from vllm.model_executor.layers.quantization import (AWQConfig, - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.quantization.awq import AWQConfig from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.models.intern_vit import (InternVisionModel, InternVisionPatchModel) diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 0ac81387b1bd8..531608a877f2f 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -51,9 +51,10 @@ RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.pooler import Pooler, PoolingType -from vllm.model_executor.layers.quantization import (GPTQConfig, - GPTQMarlinConfig, - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.quantization.gptq import GPTQConfig +from vllm.model_executor.layers.quantization.gptq_marlin import ( + GPTQMarlinConfig) from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader.weight_utils import default_weight_loader diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 9c5212ace1346..d2911ef650743 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -10,6 +10,8 @@ import torch from typing_extensions import ParamSpec +# import custom ops, trigger op registration +import vllm._C # noqa from vllm.logger import init_logger from .interface import DeviceCapability, Platform, PlatformEnum diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 022256996f97b..bb3a49c8b73bc 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -9,6 +9,17 @@ logger = init_logger(__name__) +try: + import vllm._C # noqa: F401 +except ImportError as e: + logger.warning("Failed to import from vllm._C with %r", e) + +# import custom ops, trigger op registration +try: + import vllm._rocm_C # noqa: F401 +except ImportError as e: + logger.warning("Failed to import from vllm._rocm_C with %r", e) + if os.environ.get("VLLM_WORKER_MULTIPROC_METHOD", None) in ["fork", None]: logger.warning("`fork` method is not supported by ROCm. " "VLLM_WORKER_MULTIPROC_METHOD is overridden to" diff --git a/vllm/plugins/__init__.py b/vllm/plugins/__init__.py index dc183dbfc9b96..d5056b18fe968 100644 --- a/vllm/plugins/__init__.py +++ b/vllm/plugins/__init__.py @@ -1,4 +1,5 @@ import logging +import os from contextlib import contextmanager from typing import TYPE_CHECKING, Optional @@ -18,6 +19,14 @@ def load_general_plugins(): processes. They should be designed in a way that they can be loaded multiple times without causing issues. """ + + # all processes created by vllm will load plugins, + # and here we can inject some common environment variables + # for all processes. + + # see https://github.com/vllm-project/vllm/issues/10480 + os.environ['TORCHINDUCTOR_COMPILE_THREADS'] = '1' + global plugins_loaded if plugins_loaded: return