diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/utils.py b/vllm/model_executor/layers/quantization/compressed_tensors/utils.py index 8a86066df910e..93ee724ae45c0 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/utils.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/utils.py @@ -5,6 +5,9 @@ from pydantic import BaseModel, Field from torch.nn import Module +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + FUSED_LAYER_NAME_MAPPING) + class CompressionFormat(Enum): dense = "dense" @@ -87,13 +90,6 @@ def is_activation_quantization_format(format: str) -> bool: return format in _ACTIVATION_QUANTIZATION_FORMATS -# fused_name: List[shard_name] -_FUSED_LAYER_NAME_MAPPING = { - "qkv_proj": ["q_proj", "k_proj", "v_proj"], - "gate_up_proj": ["gate_proj", "up_proj"] -} - - def should_ignore_layer(layer_name: Optional[str], ignore: Iterable[str]) -> bool: if layer_name is None: @@ -107,8 +103,8 @@ def should_ignore_layer(layer_name: Optional[str], # in the safetensors checkpoint. So, we convert the name # from the fused version to unfused + check to make sure that # each shard of the fused layer has the same scheme. - if proj_name in _FUSED_LAYER_NAME_MAPPING: - shard_proj_names = _FUSED_LAYER_NAME_MAPPING[proj_name] + if proj_name in FUSED_LAYER_NAME_MAPPING: + shard_proj_names = FUSED_LAYER_NAME_MAPPING[proj_name] # Convert fused_name --> [shard_names] shard_names = [ diff --git a/vllm/model_executor/layers/quantization/fbgemm_fp8.py b/vllm/model_executor/layers/quantization/fbgemm_fp8.py index 6b329231ec3af..5e8d1f1947421 100644 --- a/vllm/model_executor/layers/quantization/fbgemm_fp8.py +++ b/vllm/model_executor/layers/quantization/fbgemm_fp8.py @@ -11,6 +11,8 @@ QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + is_layer_skipped) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( apply_fp8_linear, create_per_channel_scale_param) from vllm.model_executor.utils import set_weight_attrs @@ -18,14 +20,6 @@ logger = init_logger(__name__) -# Note: this is a hack. We should update each model to register the -# stacked params and get it from there instead in a future PR. -# fused_name: List[shard_name] -_FUSED_LAYER_NAME_MAPPING = { - "qkv_proj": ["q_proj", "k_proj", "v_proj"], - "gate_up_proj": ["gate_proj", "up_proj"] -} - class FBGEMMFp8Config(QuantizationConfig): """Config class for FBGEMM Fp8.""" @@ -62,37 +56,10 @@ def from_config(cls, config: Dict[str, Any]) -> "FBGEMMFp8Config": input_scale_ub = cls.get_from_keys(config, ["activation_scale_ub"]) return cls(ignore_list=ignore_list, input_scale_ub=input_scale_ub) - def _is_layer_skipped(self, prefix: str) -> bool: - # prefix: model.layers.0.self_attn.q_proj - # proj_name: q_proj - proj_name = prefix.split(".")[-1] - if proj_name in _FUSED_LAYER_NAME_MAPPING: - shard_prefixes = [ - prefix.replace(proj_name, shard_proj_name) - for shard_proj_name in _FUSED_LAYER_NAME_MAPPING[proj_name] - ] - - is_skipped = None - for shard_prefix in shard_prefixes: - is_shard_skipped = shard_prefix in self.ignore_list - - if is_skipped is None: - is_skipped = is_shard_skipped - elif is_shard_skipped != is_skipped: - raise ValueError( - f"Detected some but not all shards of {prefix} " - "are quantized. All shards of fused layers " - "to have the same precision.") - else: - is_skipped = prefix in self.ignore_list - - assert is_skipped is not None - return is_skipped - def get_quant_method(self, layer: torch.nn.Module, prefix: str) -> Optional["QuantizeMethodBase"]: if isinstance(layer, LinearBase): - if self._is_layer_skipped(prefix): + if is_layer_skipped(prefix, self.ignore_list): return UnquantizedLinearMethod() return FBGEMMFp8LinearMethod(self) return None diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index b2a1b0a9534e8..3a4f2a49a3497 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -8,12 +8,15 @@ from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase, fused_moe) -from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase +from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, + UnquantizedLinearMethod) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + is_layer_skipped) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( all_close_1d, apply_fp8_linear, create_per_tensor_scale_param, cutlass_fp8_supported, per_tensor_dequantize, requantize_with_max_scale) @@ -33,6 +36,7 @@ def __init__( self, is_checkpoint_fp8_serialized: bool = False, activation_scheme: str = "dynamic", + ignored_layers: Optional[List[str]] = None, ) -> None: self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized if is_checkpoint_fp8_serialized: @@ -42,6 +46,7 @@ def __init__( raise ValueError( f"Unsupported activation scheme {activation_scheme}") self.activation_scheme = activation_scheme + self.ignored_layers = ignored_layers or [] @classmethod def get_name(cls) -> str: @@ -64,14 +69,18 @@ def from_config(cls, config: Dict[str, Any]) -> "Fp8Config": quant_method = cls.get_from_keys(config, ["quant_method"]) is_checkpoint_fp8_serialized = ("fp8" in quant_method) activation_scheme = cls.get_from_keys(config, ["activation_scheme"]) + ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None) return cls(is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized, - activation_scheme=activation_scheme) + activation_scheme=activation_scheme, + ignored_layers=ignored_layers) def get_quant_method(self, layer: torch.nn.Module, prefix: str) -> Optional["QuantizeMethodBase"]: from vllm.attention.layer import Attention # Avoid circular import if isinstance(layer, LinearBase): + if is_layer_skipped(prefix, self.ignored_layers): + return UnquantizedLinearMethod() return Fp8LinearMethod(self) elif isinstance(layer, FusedMoE): return Fp8MoEMethod(self) diff --git a/vllm/model_executor/layers/quantization/utils/quant_utils.py b/vllm/model_executor/layers/quantization/utils/quant_utils.py index 7abe919f859ca..2ba6a9a810ec0 100644 --- a/vllm/model_executor/layers/quantization/utils/quant_utils.py +++ b/vllm/model_executor/layers/quantization/utils/quant_utils.py @@ -1,10 +1,48 @@ """This file is used for /tests and /benchmarks""" +from typing import List + import numpy import torch SUPPORTED_NUM_BITS = [4, 8] SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] +# Note: this is a hack. We should update each model to register the +# stacked params and get it from there instead in a future PR. +# fused_name: List[shard_name] +FUSED_LAYER_NAME_MAPPING = { + "qkv_proj": ["q_proj", "k_proj", "v_proj"], + "gate_up_proj": ["gate_proj", "up_proj"] +} + + +def is_layer_skipped(prefix: str, ignored_layers: List[str]) -> bool: + # prefix: model.layers.0.self_attn.q_proj + # proj_name: q_proj + proj_name = prefix.split(".")[-1] + if proj_name in FUSED_LAYER_NAME_MAPPING: + shard_prefixes = [ + prefix.replace(proj_name, shard_proj_name) + for shard_proj_name in FUSED_LAYER_NAME_MAPPING[proj_name] + ] + + is_skipped = None + for shard_prefix in shard_prefixes: + is_shard_skipped = shard_prefix in ignored_layers + + if is_skipped is None: + is_skipped = is_shard_skipped + elif is_shard_skipped != is_skipped: + raise ValueError( + f"Detected some but not all shards of {prefix} " + "are quantized. All shards of fused layers " + "to have the same precision.") + else: + is_skipped = prefix in ignored_layers + + assert is_skipped is not None + return is_skipped + def get_pack_factor(num_bits): assert num_bits in SUPPORTED_NUM_BITS, f"Unsupported num_bits = {num_bits}"