Skip to content

Commit

Permalink
[torch.compile] limit inductor threads and lazy import quant (vllm-pr…
Browse files Browse the repository at this point in the history
…oject#10482)

Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: Maxime Fournioux <55544262+mfournioux@users.noreply.github.com>
  • Loading branch information
youkaichao authored and mfournioux committed Nov 28, 2024
1 parent 2cfbd36 commit d866125
Show file tree
Hide file tree
Showing 11 changed files with 178 additions and 64 deletions.
2 changes: 2 additions & 0 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,9 @@ steps:
- tests/multimodal
- tests/test_utils
- tests/worker
- tests/test_lazy_torch_compile.py
commands:
- python3 test_lazy_torch_compile.py
- pytest -v -s mq_llm_engine # MQLLMEngine
- pytest -v -s async_engine # AsyncLLMEngine
- NUM_SCHEDULER_STEPS=4 pytest -v -s async_engine/test_async_llm_engine.py
Expand Down
4 changes: 2 additions & 2 deletions tests/quantization/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.model_executor.layers.quantization import get_quantization_config
from vllm.platforms import current_platform


Expand All @@ -10,6 +10,6 @@ def is_quant_method_supported(quant_method: str) -> bool:
capability = current_platform.get_device_capability()
assert capability is not None

min_capability = QUANTIZATION_METHODS[quant_method].get_min_capability()
min_capability = get_quantization_config(quant_method).get_min_capability()

return capability.to_int() >= min_capability
68 changes: 68 additions & 0 deletions tests/test_lazy_torch_compile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Description: Test the lazy import module
# The utility function cannot be placed in `vllm.utils`
# this needs to be a standalone script

import contextlib
import dataclasses
import sys
import traceback
from typing import Callable, Generator


@dataclasses.dataclass
class BlameResult:
found: bool = False
trace_stack: str = ""


@contextlib.contextmanager
def blame(func: Callable) -> Generator[BlameResult, None, None]:
"""
Trace the function calls to find the first function that satisfies the
condition. The trace stack will be stored in the result.
Usage:
```python
with blame(lambda: some_condition()) as result:
# do something
if result.found:
print(result.trace_stack)
"""
result = BlameResult()

def _trace_calls(frame, event, arg=None):
nonlocal result
if event in ['call', 'return']:
# for every function call or return
try:
# Temporarily disable the trace function
sys.settrace(None)
# check condition here
if not result.found and func():
result.found = True
result.trace_stack = "".join(traceback.format_stack())
# Re-enable the trace function
sys.settrace(_trace_calls)
except NameError:
# modules are deleted during shutdown
pass
return _trace_calls

sys.settrace(_trace_calls)

yield result

sys.settrace(None)


module_name = "torch._inductor.async_compile"

with blame(lambda: module_name in sys.modules) as result:
import vllm # noqa

assert not result.found, (f"Module {module_name} is already imported, the"
f" first import location is:\n{result.trace_stack}")

print(f"Module {module_name} is not imported yet")
3 changes: 0 additions & 3 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,6 @@
except ImportError as e:
logger.warning("Failed to import from vllm._C with %r", e)

if current_platform.is_rocm():
import vllm._rocm_C # noqa: F401

supports_moe_ops = False
with contextlib.suppress(ImportError):
import vllm._moe_C # noqa: F401
Expand Down
8 changes: 5 additions & 3 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@

import vllm.envs as envs
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.model_executor.layers.quantization import (QUANTIZATION_METHODS,
get_quantization_config)
from vllm.model_executor.models import ModelRegistry
from vllm.platforms import current_platform
from vllm.tracing import is_otel_available, otel_import_error_traceback
Expand Down Expand Up @@ -370,7 +371,7 @@ def _parse_quant_hf_config(self):
return quant_cfg

def _verify_quantization(self) -> None:
supported_quantization = [*QUANTIZATION_METHODS]
supported_quantization = QUANTIZATION_METHODS
rocm_supported_quantization = [
"awq", "gptq", "fp8", "compressed_tensors", "compressed-tensors",
"fbgemm_fp8"
Expand All @@ -392,7 +393,8 @@ def _verify_quantization(self) -> None:
quant_method = quant_cfg.get("quant_method", "").lower()

# Detect which checkpoint is it
for _, method in QUANTIZATION_METHODS.items():
for name in QUANTIZATION_METHODS:
method = get_quantization_config(name)
quantization_override = method.override_quantization_method(
quant_cfg, self.quantization)
if quantization_override:
Expand Down
124 changes: 73 additions & 51 deletions vllm/model_executor/layers/quantization/__init__.py
Original file line number Diff line number Diff line change
@@ -1,65 +1,87 @@
from typing import Dict, Type
from typing import Dict, List, Type

from vllm.model_executor.layers.quantization.aqlm import AQLMConfig
from vllm.model_executor.layers.quantization.awq import AWQConfig
from vllm.model_executor.layers.quantization.awq_marlin import AWQMarlinConfig
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.bitsandbytes import (
BitsAndBytesConfig)
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
CompressedTensorsConfig)
from vllm.model_executor.layers.quantization.deepspeedfp import (
DeepSpeedFPConfig)
from vllm.model_executor.layers.quantization.experts_int8 import (
ExpertsInt8Config)
from vllm.model_executor.layers.quantization.fbgemm_fp8 import FBGEMMFp8Config
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
from vllm.model_executor.layers.quantization.gguf import GGUFConfig
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
from vllm.model_executor.layers.quantization.gptq_marlin import (
GPTQMarlinConfig)
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
GPTQMarlin24Config)
from vllm.model_executor.layers.quantization.hqq_marlin import HQQMarlinConfig
from vllm.model_executor.layers.quantization.ipex_quant import IPEXConfig
from vllm.model_executor.layers.quantization.marlin import MarlinConfig
from vllm.model_executor.layers.quantization.modelopt import ModelOptFp8Config
from vllm.model_executor.layers.quantization.neuron_quant import (
NeuronQuantConfig)
from vllm.model_executor.layers.quantization.qqq import QQQConfig
from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig

QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
"aqlm": AQLMConfig,
"awq": AWQConfig,
"deepspeedfp": DeepSpeedFPConfig,
"tpu_int8": Int8TpuConfig,
"fp8": Fp8Config,
"fbgemm_fp8": FBGEMMFp8Config,
"modelopt": ModelOptFp8Config,
QUANTIZATION_METHODS: List[str] = [
"aqlm",
"awq",
"deepspeedfp",
"tpu_int8",
"fp8",
"fbgemm_fp8",
"modelopt",
# The order of gptq methods is important for config.py iteration over
# override_quantization_method(..)
"marlin": MarlinConfig,
"gguf": GGUFConfig,
"gptq_marlin_24": GPTQMarlin24Config,
"gptq_marlin": GPTQMarlinConfig,
"awq_marlin": AWQMarlinConfig,
"gptq": GPTQConfig,
"compressed-tensors": CompressedTensorsConfig,
"bitsandbytes": BitsAndBytesConfig,
"qqq": QQQConfig,
"hqq": HQQMarlinConfig,
"experts_int8": ExpertsInt8Config,
"neuron_quant": NeuronQuantConfig,
"ipex": IPEXConfig,
}
"marlin",
"gguf",
"gptq_marlin_24",
"gptq_marlin",
"awq_marlin",
"gptq",
"compressed-tensors",
"bitsandbytes",
"qqq",
"hqq",
"experts_int8",
"neuron_quant",
"ipex",
]


def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
if quantization not in QUANTIZATION_METHODS:
raise ValueError(f"Invalid quantization method: {quantization}")
return QUANTIZATION_METHODS[quantization]

# lazy import to avoid triggering `torch.compile` too early
from .aqlm import AQLMConfig
from .awq import AWQConfig
from .awq_marlin import AWQMarlinConfig
from .bitsandbytes import BitsAndBytesConfig
from .compressed_tensors.compressed_tensors import ( # noqa: E501
CompressedTensorsConfig)
from .deepspeedfp import DeepSpeedFPConfig
from .experts_int8 import ExpertsInt8Config
from .fbgemm_fp8 import FBGEMMFp8Config
from .fp8 import Fp8Config
from .gguf import GGUFConfig
from .gptq import GPTQConfig
from .gptq_marlin import GPTQMarlinConfig
from .gptq_marlin_24 import GPTQMarlin24Config
from .hqq_marlin import HQQMarlinConfig
from .ipex_quant import IPEXConfig
from .marlin import MarlinConfig
from .modelopt import ModelOptFp8Config
from .neuron_quant import NeuronQuantConfig
from .qqq import QQQConfig
from .tpu_int8 import Int8TpuConfig

method_to_config: Dict[str, Type[QuantizationConfig]] = {
"aqlm": AQLMConfig,
"awq": AWQConfig,
"deepspeedfp": DeepSpeedFPConfig,
"tpu_int8": Int8TpuConfig,
"fp8": Fp8Config,
"fbgemm_fp8": FBGEMMFp8Config,
"modelopt": ModelOptFp8Config,
# The order of gptq methods is important for config.py iteration over
# override_quantization_method(..)
"marlin": MarlinConfig,
"gguf": GGUFConfig,
"gptq_marlin_24": GPTQMarlin24Config,
"gptq_marlin": GPTQMarlinConfig,
"awq_marlin": AWQMarlinConfig,
"gptq": GPTQConfig,
"compressed-tensors": CompressedTensorsConfig,
"bitsandbytes": BitsAndBytesConfig,
"qqq": QQQConfig,
"hqq": HQQMarlinConfig,
"experts_int8": ExpertsInt8Config,
"neuron_quant": NeuronQuantConfig,
"ipex": IPEXConfig,
}

return method_to_config[quantization]


__all__ = [
Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/models/internvl.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
from vllm.config import VllmConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs)
from vllm.model_executor.layers.quantization import (AWQConfig,
QuantizationConfig)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.awq import AWQConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.models.intern_vit import (InternVisionModel,
InternVisionPatchModel)
Expand Down
7 changes: 4 additions & 3 deletions vllm/model_executor/models/qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,10 @@
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.layers.quantization import (GPTQConfig,
GPTQMarlinConfig,
QuantizationConfig)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
from vllm.model_executor.layers.quantization.gptq_marlin import (
GPTQMarlinConfig)
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
Expand Down
2 changes: 2 additions & 0 deletions vllm/platforms/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import torch
from typing_extensions import ParamSpec

# import custom ops, trigger op registration
import vllm._C # noqa
from vllm.logger import init_logger

from .interface import DeviceCapability, Platform, PlatformEnum
Expand Down
11 changes: 11 additions & 0 deletions vllm/platforms/rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,17 @@

logger = init_logger(__name__)

try:
import vllm._C # noqa: F401
except ImportError as e:
logger.warning("Failed to import from vllm._C with %r", e)

# import custom ops, trigger op registration
try:
import vllm._rocm_C # noqa: F401
except ImportError as e:
logger.warning("Failed to import from vllm._rocm_C with %r", e)

if os.environ.get("VLLM_WORKER_MULTIPROC_METHOD", None) in ["fork", None]:
logger.warning("`fork` method is not supported by ROCm. "
"VLLM_WORKER_MULTIPROC_METHOD is overridden to"
Expand Down
9 changes: 9 additions & 0 deletions vllm/plugins/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import os
from contextlib import contextmanager
from typing import TYPE_CHECKING, Optional

Expand All @@ -18,6 +19,14 @@ def load_general_plugins():
processes. They should be designed in a way that they can be loaded
multiple times without causing issues.
"""

# all processes created by vllm will load plugins,
# and here we can inject some common environment variables
# for all processes.

# see https://github.com/vllm-project/vllm/issues/10480
os.environ['TORCHINDUCTOR_COMPILE_THREADS'] = '1'

global plugins_loaded
if plugins_loaded:
return
Expand Down

0 comments on commit d866125

Please sign in to comment.