Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[torch.compile] limit inductor threads and lazy import quant #10482

Merged
merged 14 commits into from
Nov 21, 2024
4 changes: 2 additions & 2 deletions tests/quantization/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.model_executor.layers.quantization import get_quantization_config
from vllm.platforms import current_platform


Expand All @@ -10,6 +10,6 @@ def is_quant_method_supported(quant_method: str) -> bool:
capability = current_platform.get_device_capability()
assert capability is not None

min_capability = QUANTIZATION_METHODS[quant_method].get_min_capability()
min_capability = get_quantization_config(quant_method).get_min_capability()

return capability.to_int() >= min_capability
8 changes: 5 additions & 3 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@

import vllm.envs as envs
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.model_executor.layers.quantization import (QUANTIZATION_METHODS,
get_quantization_config)
from vllm.model_executor.models import ModelRegistry
from vllm.platforms import current_platform
from vllm.tracing import is_otel_available, otel_import_error_traceback
Expand Down Expand Up @@ -370,7 +371,7 @@ def _parse_quant_hf_config(self):
return quant_cfg

def _verify_quantization(self) -> None:
supported_quantization = [*QUANTIZATION_METHODS]
supported_quantization = QUANTIZATION_METHODS
rocm_supported_quantization = [
"awq", "gptq", "fp8", "compressed_tensors", "compressed-tensors",
"fbgemm_fp8"
Expand All @@ -392,7 +393,8 @@ def _verify_quantization(self) -> None:
quant_method = quant_cfg.get("quant_method", "").lower()

# Detect which checkpoint is it
for _, method in QUANTIZATION_METHODS.items():
for name in QUANTIZATION_METHODS:
method = get_quantization_config(name)
DarkLight1337 marked this conversation as resolved.
Show resolved Hide resolved
quantization_override = method.override_quantization_method(
quant_cfg, self.quantization)
if quantization_override:
Expand Down
124 changes: 73 additions & 51 deletions vllm/model_executor/layers/quantization/__init__.py
Original file line number Diff line number Diff line change
@@ -1,65 +1,87 @@
from typing import Dict, Type
from typing import Dict, List, Type

from vllm.model_executor.layers.quantization.aqlm import AQLMConfig
from vllm.model_executor.layers.quantization.awq import AWQConfig
from vllm.model_executor.layers.quantization.awq_marlin import AWQMarlinConfig
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.bitsandbytes import (
BitsAndBytesConfig)
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
CompressedTensorsConfig)
from vllm.model_executor.layers.quantization.deepspeedfp import (
DeepSpeedFPConfig)
from vllm.model_executor.layers.quantization.experts_int8 import (
ExpertsInt8Config)
from vllm.model_executor.layers.quantization.fbgemm_fp8 import FBGEMMFp8Config
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
from vllm.model_executor.layers.quantization.gguf import GGUFConfig
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
from vllm.model_executor.layers.quantization.gptq_marlin import (
GPTQMarlinConfig)
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
GPTQMarlin24Config)
from vllm.model_executor.layers.quantization.hqq_marlin import HQQMarlinConfig
from vllm.model_executor.layers.quantization.ipex_quant import IPEXConfig
from vllm.model_executor.layers.quantization.marlin import MarlinConfig
from vllm.model_executor.layers.quantization.modelopt import ModelOptFp8Config
from vllm.model_executor.layers.quantization.neuron_quant import (
NeuronQuantConfig)
from vllm.model_executor.layers.quantization.qqq import QQQConfig
from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig

QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
"aqlm": AQLMConfig,
"awq": AWQConfig,
"deepspeedfp": DeepSpeedFPConfig,
"tpu_int8": Int8TpuConfig,
"fp8": Fp8Config,
"fbgemm_fp8": FBGEMMFp8Config,
"modelopt": ModelOptFp8Config,
QUANTIZATION_METHODS: List[str] = [
"aqlm",
"awq",
"deepspeedfp",
"tpu_int8",
"fp8",
"fbgemm_fp8",
"modelopt",
# The order of gptq methods is important for config.py iteration over
# override_quantization_method(..)
"marlin": MarlinConfig,
"gguf": GGUFConfig,
"gptq_marlin_24": GPTQMarlin24Config,
"gptq_marlin": GPTQMarlinConfig,
"awq_marlin": AWQMarlinConfig,
"gptq": GPTQConfig,
"compressed-tensors": CompressedTensorsConfig,
"bitsandbytes": BitsAndBytesConfig,
"qqq": QQQConfig,
"hqq": HQQMarlinConfig,
"experts_int8": ExpertsInt8Config,
"neuron_quant": NeuronQuantConfig,
"ipex": IPEXConfig,
}
"marlin",
"gguf",
"gptq_marlin_24",
"gptq_marlin",
"awq_marlin",
"gptq",
"compressed-tensors",
"bitsandbytes",
"qqq",
"hqq",
"experts_int8",
"neuron_quant",
"ipex",
]


def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
if quantization not in QUANTIZATION_METHODS:
raise ValueError(f"Invalid quantization method: {quantization}")
return QUANTIZATION_METHODS[quantization]

# lazy import to avoid triggering `torch.compile` too early
from .aqlm import AQLMConfig
from .awq import AWQConfig
from .awq_marlin import AWQMarlinConfig
from .bitsandbytes import BitsAndBytesConfig
from .compressed_tensors.compressed_tensors import ( # noqa: E501
CompressedTensorsConfig)
from .deepspeedfp import DeepSpeedFPConfig
from .experts_int8 import ExpertsInt8Config
from .fbgemm_fp8 import FBGEMMFp8Config
from .fp8 import Fp8Config
from .gguf import GGUFConfig
from .gptq import GPTQConfig
from .gptq_marlin import GPTQMarlinConfig
from .gptq_marlin_24 import GPTQMarlin24Config
from .hqq_marlin import HQQMarlinConfig
from .ipex_quant import IPEXConfig
from .marlin import MarlinConfig
from .modelopt import ModelOptFp8Config
from .neuron_quant import NeuronQuantConfig
from .qqq import QQQConfig
from .tpu_int8 import Int8TpuConfig

method_to_config: Dict[str, Type[QuantizationConfig]] = {
"aqlm": AQLMConfig,
"awq": AWQConfig,
"deepspeedfp": DeepSpeedFPConfig,
"tpu_int8": Int8TpuConfig,
"fp8": Fp8Config,
"fbgemm_fp8": FBGEMMFp8Config,
"modelopt": ModelOptFp8Config,
# The order of gptq methods is important for config.py iteration over
# override_quantization_method(..)
"marlin": MarlinConfig,
"gguf": GGUFConfig,
"gptq_marlin_24": GPTQMarlin24Config,
"gptq_marlin": GPTQMarlinConfig,
"awq_marlin": AWQMarlinConfig,
"gptq": GPTQConfig,
"compressed-tensors": CompressedTensorsConfig,
"bitsandbytes": BitsAndBytesConfig,
"qqq": QQQConfig,
"hqq": HQQMarlinConfig,
"experts_int8": ExpertsInt8Config,
"neuron_quant": NeuronQuantConfig,
"ipex": IPEXConfig,
}

return method_to_config[quantization]


__all__ = [
Expand Down
9 changes: 9 additions & 0 deletions vllm/plugins/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import os
from contextlib import contextmanager
from typing import TYPE_CHECKING, Optional

Expand All @@ -18,6 +19,14 @@ def load_general_plugins():
processes. They should be designed in a way that they can be loaded
multiple times without causing issues.
"""

# all processes created by vllm will load plugins,
# and here we can inject some common environment variables
# for all processes.

# see https://github.com/vllm-project/vllm/issues/10480
os.environ['TORCHINDUCTOR_COMPILE_THREADS'] = '1'

global plugins_loaded
youkaichao marked this conversation as resolved.
Show resolved Hide resolved
if plugins_loaded:
return
Expand Down