Skip to content

Commit

Permalink
Merge pull request #13 from neuralmagic/awq_moe
Browse files Browse the repository at this point in the history
add awq moe
  • Loading branch information
ElizaWszola authored Oct 1, 2024
2 parents e0e5a74 + 79126f9 commit 3ff0ba1
Show file tree
Hide file tree
Showing 5 changed files with 228 additions and 10 deletions.
1 change: 1 addition & 0 deletions tests/weight_loading/models-large.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-channel-quantize
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W8A16-quantized, main
compressed-tensors, mgoin/DeepSeek-Coder-V2-Lite-Instruct-FP8, main
gptq_marlin, TheBloke/Mixtral-8x7B-v0.1-GPTQ, main
awq_marlin, casperhansen/deepseek-coder-v2-instruct-awq, main
14 changes: 14 additions & 0 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,6 +559,20 @@ def gptq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor,
return output


def awq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor,
size_k: int, size_n: int,
num_bits: int) -> torch.Tensor:
num_experts = b_q_weight.shape[0]
assert size_k % 16 == 0
output = torch.empty((num_experts, size_k // 16, size_n * (num_bits // 2)),
device=b_q_weight.device,
dtype=b_q_weight.dtype)
for e in range(num_experts):
output[e] = torch.ops._C.awq_marlin_repack(b_q_weight[e], size_k,
size_n, num_bits)
return output


def gptq_marlin_gemm(a: torch.Tensor,
b_q_weight: torch.Tensor,
b_scales: torch.Tensor,
Expand Down
204 changes: 195 additions & 9 deletions vllm/model_executor/layers/quantization/awq_marlin.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
from typing import Any, Dict, List, Optional
from typing import Any, Callable, Dict, List, Optional

import torch
from torch.nn import Parameter

from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
set_weight_attrs)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
apply_awq_marlin_linear, awq_to_marlin_zero_points, check_marlin_supported,
marlin_make_empty_g_idx, marlin_make_workspace, marlin_permute_scales,
marlin_make_empty_g_idx, marlin_make_workspace, marlin_moe_permute_scales,
marlin_permute_scales, moe_awq_to_marlin_zero_points,
verify_marlin_supported, verify_marlin_supports_shape)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
Expand All @@ -35,12 +40,13 @@ def __init__(self, weight_bits: int, group_size: int, has_zp: bool,
self.group_size = group_size
self.has_zp = has_zp
self.lm_head_quantized = lm_head_quantized
self.weight_bits = weight_bits

if weight_bits not in self.TYPE_MAP:
raise ValueError(f"Unsupported num_bits = {weight_bits}. "
if self.weight_bits not in self.TYPE_MAP:
raise ValueError(f"Unsupported num_bits = {self.weight_bits}. "
f"Supported num_bits = {self.TYPE_MAP.keys()}")

self.quant_type = self.TYPE_MAP[weight_bits]
self.quant_type = self.TYPE_MAP[self.weight_bits]

verify_marlin_supported(self.quant_type,
group_size=self.group_size,
Expand Down Expand Up @@ -98,10 +104,12 @@ def override_quantization_method(cls, hf_quant_cfg,
return None

def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["AWQMarlinLinearMethod"]:
prefix: str) -> Optional["QuantizeMethodBase"]:
if (isinstance(layer, LinearBase) or
(isinstance(layer, ParallelLMHead) and self.lm_head_quantized)):
return AWQMarlinLinearMethod(self)
elif isinstance(layer, FusedMoE):
return AWQMoEMethod(self)
return None

def get_scaled_act_names(self) -> List[str]:
Expand Down Expand Up @@ -271,4 +279,182 @@ def apply(
quant_type=self.quant_config.quant_type,
output_size_per_partition=layer.output_size_per_partition,
input_size_per_partition=layer.input_size_per_partition,
bias=bias)
bias=bias)


class AWQMoEMethod(FusedMoEMethodBase):

def __init__(self, quant_config: AWQMarlinConfig):
self.quant_config = quant_config

def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size: int,
params_dtype: torch.dtype, **extra_weight_attrs):
extra_weight_attrs.update({
"is_transposed":
True,
"quant_method":
FusedMoeWeightScaleSupported.GROUP.value,
})

w13_qweight = Parameter(torch.empty(num_experts,
hidden_size,
2 * intermediate_size //
self.quant_config.pack_factor,
dtype=torch.int32),
requires_grad=False)
layer.register_parameter("w13_qweight", w13_qweight)
set_weight_attrs(w13_qweight, extra_weight_attrs)

w2_qweight = Parameter(torch.empty(num_experts,
intermediate_size,
hidden_size //
self.quant_config.pack_factor,
dtype=torch.int32),
requires_grad=False)
layer.register_parameter("w2_qweight", w2_qweight)
set_weight_attrs(w2_qweight, extra_weight_attrs)

num_groups_w13 = hidden_size // self.quant_config.group_size
num_groups_w2 = intermediate_size // self.quant_config.group_size

# WEIGHT_SCALES
# Allocate 2 scales for w1 and w3 respectively.
w13_scales = Parameter(torch.empty(num_experts,
num_groups_w13,
intermediate_size * 2,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w13_scales", w13_scales)
set_weight_attrs(w13_scales, extra_weight_attrs)

w2_scales = Parameter(torch.empty(num_experts,
num_groups_w2,
hidden_size,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w2_scales", w2_scales)
set_weight_attrs(w2_scales, extra_weight_attrs)

# WEIGHT_ZERO_POINT
# Allocate 2 zero points for w1 and w3 respectively.
w13_qzeros = Parameter(torch.empty(num_experts,
num_groups_w13,
2 * intermediate_size //
self.quant_config.pack_factor,
dtype=torch.int32),
requires_grad=False)
layer.register_parameter("w13_qzeros", w13_qzeros)
set_weight_attrs(w13_qzeros, extra_weight_attrs)

w2_qzeros = Parameter(torch.empty(num_experts,
num_groups_w2,
hidden_size //
self.quant_config.pack_factor,
dtype=torch.int32),
requires_grad=False)
layer.register_parameter("w2_qzeros", w2_qzeros)
set_weight_attrs(w2_qzeros, extra_weight_attrs)

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
num_experts = layer.w13_qweight.shape[0]
device = layer.w13_qweight.device

layer.w13_g_idx_sort_indices = torch.nn.Parameter(
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
requires_grad=False,
)
layer.w2_g_idx_sort_indices = torch.nn.Parameter(
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
requires_grad=False,
)

marlin_w13_qweight = ops.awq_marlin_moe_repack(
layer.w13_qweight,
layer.w13_g_idx_sort_indices,
size_k=layer.w13_qweight.shape[1],
size_n=layer.w13_qweight.shape[2] * self.quant_config.pack_factor,
num_bits=self.quant_config.weight_bits,
)
replace_parameter(layer, "w13_qweight", marlin_w13_qweight)

marlin_w2_qweight = ops.awq_marlin_moe_repack(
layer.w2_qweight,
layer.w2_g_idx_sort_indices,
size_k=layer.w2_qweight.shape[1],
size_n=layer.w2_qweight.shape[2] * self.quant_config.pack_factor,
num_bits=self.quant_config.weight_bits,
)
replace_parameter(layer, "w2_qweight", marlin_w2_qweight)

# Why does this take the intermediate size for size_k?
marlin_w13_scales = marlin_moe_permute_scales(
s=layer.w13_scales,
size_k=layer.intermediate_size_per_partition,
size_n=layer.w13_scales.shape[2],
group_size=self.quant_config.group_size,
)

replace_parameter(layer, "w13_scales", marlin_w13_scales)

marlin_w2_scales = marlin_moe_permute_scales(
s=layer.w2_scales,
size_k=layer.intermediate_size_per_partition,
size_n=layer.w2_scales.shape[2],
group_size=self.quant_config.group_size,
)
replace_parameter(layer, "w2_scales", marlin_w2_scales)

marlin_w13_zp = moe_awq_to_marlin_zero_points(
layer.w13_qzeros,
size_k=layer.w13_qzeros.shape[1],
size_n=layer.w13_qzeros.shape[2] * self.quant_config.pack_factor,
num_bits=self.quant_config.weight_bits)
replace_parameter(layer, "w13_qzeros", marlin_w13_zp)

marlin_w2_zp = moe_awq_to_marlin_zero_points(
layer.w2_qzeros,
size_k=layer.w2_qzeros.shape[1],
size_n=layer.w2_qzeros.shape[2] * self.quant_config.pack_factor,
num_bits=self.quant_config.weight_bits)
replace_parameter(layer, "w2_qzeros", marlin_w2_zp)

def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool = True,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
) -> torch.Tensor:

from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
fused_marlin_moe)

topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function)

return fused_marlin_moe(
x,
layer.w13_qweight,
layer.w2_qweight,
layer.w13_scales,
layer.w2_scales,
router_logits,
topk_weights,
topk_ids,
w1_zeros=layer.w13_qzeros,
w2_zeros=layer.w2_qzeros,
num_bits=self.quant_config.weight_bits,
)
15 changes: 15 additions & 0 deletions vllm/model_executor/layers/quantization/utils/marlin_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ def marlin_moe_permute_scales(
device=s.device,
dtype=s.dtype,
)

for e in range(num_experts):
output[e] = marlin_permute_scales(s[e], size_k, size_n, group_size)
return output
Expand Down Expand Up @@ -258,6 +259,20 @@ def awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int,
return marlin_zp


def moe_awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int,
size_n: int, num_bits: int):
num_experts = q_zp_packed.shape[0]
output = torch.empty(
(num_experts, q_zp_packed.shape[1], q_zp_packed.shape[2]),
device=q_zp_packed.device,
dtype=q_zp_packed.dtype,
)
for e in range(num_experts):
output[e] = awq_to_marlin_zero_points(q_zp_packed[e], size_k, size_n,
num_bits)
return output


def apply_gptq_marlin_linear(
input: torch.Tensor,
weight: torch.Tensor,
Expand Down
4 changes: 3 additions & 1 deletion vllm/model_executor/model_loader/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ def get_model_architecture(
architectures = getattr(model_config.hf_config, "architectures", [])
# Special handling for quantized Mixtral.
# FIXME(woosuk): This is a temporary hack.
mixtral_supported = ["fp8", "compressed-tensors", "gptq_marlin"]
mixtral_supported = [
"fp8", "compressed-tensors", "gptq_marlin", "awq_marlin"
]

if (model_config.quantization is not None
and model_config.quantization not in mixtral_supported
Expand Down

0 comments on commit 3ff0ba1

Please sign in to comment.