Skip to content

Commit

Permalink
move to util
Browse files Browse the repository at this point in the history
  • Loading branch information
mgoin committed Jul 22, 2024
1 parent 506ed48 commit 98e40d2
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 59 deletions.
31 changes: 2 additions & 29 deletions vllm/model_executor/layers/quantization/fbgemm_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
FUSED_LAYER_NAME_MAPPING)
is_layer_skipped)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
apply_fp8_linear, create_per_channel_scale_param)
from vllm.model_executor.utils import set_weight_attrs
Expand Down Expand Up @@ -56,37 +56,10 @@ def from_config(cls, config: Dict[str, Any]) -> "FBGEMMFp8Config":
input_scale_ub = cls.get_from_keys(config, ["activation_scale_ub"])
return cls(ignore_list=ignore_list, input_scale_ub=input_scale_ub)

def _is_layer_skipped(self, prefix: str) -> bool:
# prefix: model.layers.0.self_attn.q_proj
# proj_name: q_proj
proj_name = prefix.split(".")[-1]
if proj_name in FUSED_LAYER_NAME_MAPPING:
shard_prefixes = [
prefix.replace(proj_name, shard_proj_name)
for shard_proj_name in FUSED_LAYER_NAME_MAPPING[proj_name]
]

is_skipped = None
for shard_prefix in shard_prefixes:
is_shard_skipped = shard_prefix in self.ignore_list

if is_skipped is None:
is_skipped = is_shard_skipped
elif is_shard_skipped != is_skipped:
raise ValueError(
f"Detected some but not all shards of {prefix} "
"are quantized. All shards of fused layers "
"to have the same precision.")
else:
is_skipped = prefix in self.ignore_list

assert is_skipped is not None
return is_skipped

def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["QuantizeMethodBase"]:
if isinstance(layer, LinearBase):
if self._is_layer_skipped(prefix):
if is_layer_skipped(prefix, self.ignore_list):
return UnquantizedLinearMethod()
return FBGEMMFp8LinearMethod(self)
return None
Expand Down
33 changes: 3 additions & 30 deletions vllm/model_executor/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
FUSED_LAYER_NAME_MAPPING)
is_layer_skipped)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
all_close_1d, apply_fp8_linear, create_per_tensor_scale_param,
cutlass_fp8_supported, per_tensor_dequantize, requantize_with_max_scale)
Expand Down Expand Up @@ -68,44 +68,17 @@ def from_config(cls, config: Dict[str, Any]) -> "Fp8Config":
quant_method = cls.get_from_keys(config, ["quant_method"])
is_checkpoint_fp8_serialized = ("fp8" in quant_method)
activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
ignored_layers = cls.get_from_keys(config, ["ignored_layers"])
ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None)
return cls(is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
activation_scheme=activation_scheme,
ignored_layers=ignored_layers)

def _is_layer_skipped(self, prefix: str) -> bool:
# prefix: model.layers.0.self_attn.q_proj
# proj_name: q_proj
proj_name = prefix.split(".")[-1]
if proj_name in FUSED_LAYER_NAME_MAPPING:
shard_prefixes = [
prefix.replace(proj_name, shard_proj_name)
for shard_proj_name in FUSED_LAYER_NAME_MAPPING[proj_name]
]

is_skipped = None
for shard_prefix in shard_prefixes:
is_shard_skipped = shard_prefix in self.ignored_layers

if is_skipped is None:
is_skipped = is_shard_skipped
elif is_shard_skipped != is_skipped:
raise ValueError(
f"Detected some but not all shards of {prefix} "
"are quantized. All shards of fused layers "
"to have the same precision.")
else:
is_skipped = prefix in self.ignored_layers

assert is_skipped is not None
return is_skipped

def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["QuantizeMethodBase"]:
from vllm.attention.layer import Attention # Avoid circular import

if isinstance(layer, LinearBase):
if self._is_layer_skipped(prefix):
if is_layer_skipped(prefix, self.ignored_layers):
return UnquantizedLinearMethod()
return Fp8LinearMethod(self)
elif isinstance(layer, FusedMoE):
Expand Down
30 changes: 30 additions & 0 deletions vllm/model_executor/layers/quantization/utils/quant_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
"""This file is used for /tests and /benchmarks"""
from typing import List

import numpy
import torch

Expand All @@ -14,6 +16,34 @@
}


def is_layer_skipped(prefix: str, ignored_layers: List[str]) -> bool:
# prefix: model.layers.0.self_attn.q_proj
# proj_name: q_proj
proj_name = prefix.split(".")[-1]
if proj_name in FUSED_LAYER_NAME_MAPPING:
shard_prefixes = [
prefix.replace(proj_name, shard_proj_name)
for shard_proj_name in FUSED_LAYER_NAME_MAPPING[proj_name]
]

is_skipped = None
for shard_prefix in shard_prefixes:
is_shard_skipped = shard_prefix in ignored_layers

if is_skipped is None:
is_skipped = is_shard_skipped
elif is_shard_skipped != is_skipped:
raise ValueError(
f"Detected some but not all shards of {prefix} "
"are quantized. All shards of fused layers "
"to have the same precision.")
else:
is_skipped = prefix in ignored_layers

assert is_skipped is not None
return is_skipped


def get_pack_factor(num_bits):
assert num_bits in SUPPORTED_NUM_BITS, f"Unsupported num_bits = {num_bits}"
return 32 // num_bits
Expand Down

0 comments on commit 98e40d2

Please sign in to comment.