Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[torch.compile] limit inductor threads and lazy import quant #10482

Merged
merged 14 commits into from
Nov 21, 2024
2 changes: 2 additions & 0 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,9 @@ steps:
- tests/multimodal
- tests/test_utils
- tests/worker
- tests/test_lazy_torch_compile.py
commands:
- python3 test_lazy_torch_compile.py
- pytest -v -s mq_llm_engine # MQLLMEngine
- pytest -v -s async_engine # AsyncLLMEngine
- NUM_SCHEDULER_STEPS=4 pytest -v -s async_engine/test_async_llm_engine.py
Expand Down
4 changes: 2 additions & 2 deletions tests/quantization/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.model_executor.layers.quantization import get_quantization_config
from vllm.platforms import current_platform


Expand All @@ -10,6 +10,6 @@ def is_quant_method_supported(quant_method: str) -> bool:
capability = current_platform.get_device_capability()
assert capability is not None

min_capability = QUANTIZATION_METHODS[quant_method].get_min_capability()
min_capability = get_quantization_config(quant_method).get_min_capability()

return capability.to_int() >= min_capability
68 changes: 68 additions & 0 deletions tests/test_lazy_torch_compile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Description: Test the lazy import module
# The utility function cannot be placed in `vllm.utils`
# this needs to be a standalone script

import contextlib
import dataclasses
import sys
import traceback
from typing import Callable, Generator


@dataclasses.dataclass
class BlameResult:
found: bool = False
trace_stack: str = ""


@contextlib.contextmanager
def blame(func: Callable) -> Generator[BlameResult, None, None]:
"""
Trace the function calls to find the first function that satisfies the
condition. The trace stack will be stored in the result.

Usage:

```python
with blame(lambda: some_condition()) as result:
# do something

if result.found:
print(result.trace_stack)
"""
result = BlameResult()

def _trace_calls(frame, event, arg=None):
nonlocal result
if event in ['call', 'return']:
# for every function call or return
try:
# Temporarily disable the trace function
sys.settrace(None)
# check condition here
if not result.found and func():
result.found = True
result.trace_stack = "".join(traceback.format_stack())
# Re-enable the trace function
sys.settrace(_trace_calls)
except NameError:
# modules are deleted during shutdown
pass
return _trace_calls

sys.settrace(_trace_calls)

yield result

sys.settrace(None)


module_name = "torch._inductor.async_compile"

with blame(lambda: module_name in sys.modules) as result:
import vllm # noqa

assert not result.found, (f"Module {module_name} is already imported, the"
f" first import location is:\n{result.trace_stack}")

print(f"Module {module_name} is not imported yet")
3 changes: 0 additions & 3 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,6 @@
except ImportError as e:
logger.warning("Failed to import from vllm._C with %r", e)

if current_platform.is_rocm():
import vllm._rocm_C # noqa: F401

supports_moe_ops = False
with contextlib.suppress(ImportError):
import vllm._moe_C # noqa: F401
Expand Down
8 changes: 5 additions & 3 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@

import vllm.envs as envs
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.model_executor.layers.quantization import (QUANTIZATION_METHODS,
get_quantization_config)
from vllm.model_executor.models import ModelRegistry
from vllm.platforms import current_platform
from vllm.tracing import is_otel_available, otel_import_error_traceback
Expand Down Expand Up @@ -370,7 +371,7 @@ def _parse_quant_hf_config(self):
return quant_cfg

def _verify_quantization(self) -> None:
supported_quantization = [*QUANTIZATION_METHODS]
supported_quantization = QUANTIZATION_METHODS
rocm_supported_quantization = [
"awq", "gptq", "fp8", "compressed_tensors", "compressed-tensors",
"fbgemm_fp8"
Expand All @@ -392,7 +393,8 @@ def _verify_quantization(self) -> None:
quant_method = quant_cfg.get("quant_method", "").lower()

# Detect which checkpoint is it
for _, method in QUANTIZATION_METHODS.items():
for name in QUANTIZATION_METHODS:
method = get_quantization_config(name)
DarkLight1337 marked this conversation as resolved.
Show resolved Hide resolved
quantization_override = method.override_quantization_method(
quant_cfg, self.quantization)
if quantization_override:
Expand Down
124 changes: 73 additions & 51 deletions vllm/model_executor/layers/quantization/__init__.py
Original file line number Diff line number Diff line change
@@ -1,65 +1,87 @@
from typing import Dict, Type
from typing import Dict, List, Type

from vllm.model_executor.layers.quantization.aqlm import AQLMConfig
from vllm.model_executor.layers.quantization.awq import AWQConfig
from vllm.model_executor.layers.quantization.awq_marlin import AWQMarlinConfig
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.bitsandbytes import (
BitsAndBytesConfig)
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
CompressedTensorsConfig)
from vllm.model_executor.layers.quantization.deepspeedfp import (
DeepSpeedFPConfig)
from vllm.model_executor.layers.quantization.experts_int8 import (
ExpertsInt8Config)
from vllm.model_executor.layers.quantization.fbgemm_fp8 import FBGEMMFp8Config
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
from vllm.model_executor.layers.quantization.gguf import GGUFConfig
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
from vllm.model_executor.layers.quantization.gptq_marlin import (
GPTQMarlinConfig)
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
GPTQMarlin24Config)
from vllm.model_executor.layers.quantization.hqq_marlin import HQQMarlinConfig
from vllm.model_executor.layers.quantization.ipex_quant import IPEXConfig
from vllm.model_executor.layers.quantization.marlin import MarlinConfig
from vllm.model_executor.layers.quantization.modelopt import ModelOptFp8Config
from vllm.model_executor.layers.quantization.neuron_quant import (
NeuronQuantConfig)
from vllm.model_executor.layers.quantization.qqq import QQQConfig
from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig

QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
"aqlm": AQLMConfig,
"awq": AWQConfig,
"deepspeedfp": DeepSpeedFPConfig,
"tpu_int8": Int8TpuConfig,
"fp8": Fp8Config,
"fbgemm_fp8": FBGEMMFp8Config,
"modelopt": ModelOptFp8Config,
QUANTIZATION_METHODS: List[str] = [
"aqlm",
"awq",
"deepspeedfp",
"tpu_int8",
"fp8",
"fbgemm_fp8",
"modelopt",
# The order of gptq methods is important for config.py iteration over
# override_quantization_method(..)
"marlin": MarlinConfig,
"gguf": GGUFConfig,
"gptq_marlin_24": GPTQMarlin24Config,
"gptq_marlin": GPTQMarlinConfig,
"awq_marlin": AWQMarlinConfig,
"gptq": GPTQConfig,
"compressed-tensors": CompressedTensorsConfig,
"bitsandbytes": BitsAndBytesConfig,
"qqq": QQQConfig,
"hqq": HQQMarlinConfig,
"experts_int8": ExpertsInt8Config,
"neuron_quant": NeuronQuantConfig,
"ipex": IPEXConfig,
}
"marlin",
"gguf",
"gptq_marlin_24",
"gptq_marlin",
"awq_marlin",
"gptq",
"compressed-tensors",
"bitsandbytes",
"qqq",
"hqq",
"experts_int8",
"neuron_quant",
"ipex",
]


def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
if quantization not in QUANTIZATION_METHODS:
raise ValueError(f"Invalid quantization method: {quantization}")
return QUANTIZATION_METHODS[quantization]

# lazy import to avoid triggering `torch.compile` too early
from .aqlm import AQLMConfig
from .awq import AWQConfig
from .awq_marlin import AWQMarlinConfig
from .bitsandbytes import BitsAndBytesConfig
from .compressed_tensors.compressed_tensors import ( # noqa: E501
CompressedTensorsConfig)
from .deepspeedfp import DeepSpeedFPConfig
from .experts_int8 import ExpertsInt8Config
from .fbgemm_fp8 import FBGEMMFp8Config
from .fp8 import Fp8Config
from .gguf import GGUFConfig
from .gptq import GPTQConfig
from .gptq_marlin import GPTQMarlinConfig
from .gptq_marlin_24 import GPTQMarlin24Config
from .hqq_marlin import HQQMarlinConfig
from .ipex_quant import IPEXConfig
from .marlin import MarlinConfig
from .modelopt import ModelOptFp8Config
from .neuron_quant import NeuronQuantConfig
from .qqq import QQQConfig
from .tpu_int8 import Int8TpuConfig

method_to_config: Dict[str, Type[QuantizationConfig]] = {
"aqlm": AQLMConfig,
"awq": AWQConfig,
"deepspeedfp": DeepSpeedFPConfig,
"tpu_int8": Int8TpuConfig,
"fp8": Fp8Config,
"fbgemm_fp8": FBGEMMFp8Config,
"modelopt": ModelOptFp8Config,
# The order of gptq methods is important for config.py iteration over
# override_quantization_method(..)
"marlin": MarlinConfig,
"gguf": GGUFConfig,
"gptq_marlin_24": GPTQMarlin24Config,
"gptq_marlin": GPTQMarlinConfig,
"awq_marlin": AWQMarlinConfig,
"gptq": GPTQConfig,
"compressed-tensors": CompressedTensorsConfig,
"bitsandbytes": BitsAndBytesConfig,
"qqq": QQQConfig,
"hqq": HQQMarlinConfig,
"experts_int8": ExpertsInt8Config,
"neuron_quant": NeuronQuantConfig,
"ipex": IPEXConfig,
}

return method_to_config[quantization]


__all__ = [
Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/models/internvl.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
from vllm.config import VllmConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs)
from vllm.model_executor.layers.quantization import (AWQConfig,
QuantizationConfig)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.awq import AWQConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.models.intern_vit import (InternVisionModel,
InternVisionPatchModel)
Expand Down
7 changes: 4 additions & 3 deletions vllm/model_executor/models/qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,10 @@
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.layers.quantization import (GPTQConfig,
GPTQMarlinConfig,
QuantizationConfig)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
from vllm.model_executor.layers.quantization.gptq_marlin import (
GPTQMarlinConfig)
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
Expand Down
2 changes: 2 additions & 0 deletions vllm/platforms/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import torch
from typing_extensions import ParamSpec

# import custom ops, trigger op registration
import vllm._C # noqa
from vllm.logger import init_logger

from .interface import DeviceCapability, Platform, PlatformEnum
Expand Down
11 changes: 11 additions & 0 deletions vllm/platforms/rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,17 @@

logger = init_logger(__name__)

try:
import vllm._C # noqa: F401
except ImportError as e:
logger.warning("Failed to import from vllm._C with %r", e)

# import custom ops, trigger op registration
try:
import vllm._rocm_C # noqa: F401
except ImportError as e:
logger.warning("Failed to import from vllm._rocm_C with %r", e)

if os.environ.get("VLLM_WORKER_MULTIPROC_METHOD", None) in ["fork", None]:
logger.warning("`fork` method is not supported by ROCm. "
"VLLM_WORKER_MULTIPROC_METHOD is overridden to"
Expand Down
9 changes: 9 additions & 0 deletions vllm/plugins/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import os
from contextlib import contextmanager
from typing import TYPE_CHECKING, Optional

Expand All @@ -18,6 +19,14 @@ def load_general_plugins():
processes. They should be designed in a way that they can be loaded
multiple times without causing issues.
"""

# all processes created by vllm will load plugins,
# and here we can inject some common environment variables
# for all processes.

# see https://github.com/vllm-project/vllm/issues/10480
os.environ['TORCHINDUCTOR_COMPILE_THREADS'] = '1'

global plugins_loaded
youkaichao marked this conversation as resolved.
Show resolved Hide resolved
if plugins_loaded:
return
Expand Down