Skip to content

Commit

Permalink
[Kernel] Add Exllama as a backend for compressed-tensors (vllm-projec…
Browse files Browse the repository at this point in the history
  • Loading branch information
LucasWilkinson authored Oct 17, 2024
1 parent dbfa8d3 commit e312e52
Show file tree
Hide file tree
Showing 7 changed files with 173 additions and 16 deletions.
9 changes: 9 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
VLLM_SKIP_P2P_CHECK: bool = False
VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1: bool = False
VLLM_TORCH_COMPILE_LEVEL: int = 0
VLLM_DISABLED_KERNELS: List[str] = []


def get_default_cache_root():
Expand Down Expand Up @@ -430,6 +431,14 @@ def get_default_config_root():
"VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1":
lambda: os.environ.get("VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1", "0"
) == "1",

# List of quantization kernels that should be disabled, used for testing
# and performance comparisons. Currently only affects MPLinearKernel
# selection
# (kernels: MacheteLinearKernel, MarlinLinearKernel, ExllamaLinearKernel)
"VLLM_DISABLED_KERNELS":
lambda: [] if "VLLM_DISABLED_KERNELS" not in os.environ else os.environ[
"VLLM_DISABLED_KERNELS"].split(","),
}

# end-env-vars-definition
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ def __init__(self,
self.config = c
self.w_q_name = w_q_param_name
self.w_s_name = w_s_param_name
if c.zero_points:
assert w_zp_param_name is not None
if c.has_g_idx:
assert w_gidx_param_name is not None
self.w_zp_name = w_zp_param_name
self.w_gidx_name = w_gidx_param_name

Expand Down
8 changes: 5 additions & 3 deletions vllm/model_executor/layers/quantization/kernels/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import os
from typing import List, Optional, Type

import vllm.envs as envs
from vllm.model_executor.layers.quantization.kernels.exllama import (
ExllamaLinearKernel)
from vllm.model_executor.layers.quantization.kernels.machete import (
MacheteLinearKernel)
from vllm.model_executor.layers.quantization.kernels.marlin import (
Expand All @@ -13,6 +15,7 @@
_POSSIBLE_KERNELS: List[Type[MPLinearKernel]] = [
MacheteLinearKernel,
MarlinLinearKernel,
ExllamaLinearKernel,
]


Expand Down Expand Up @@ -45,8 +48,7 @@ def choose_mp_linear_kernel(

failure_reasons = []
for kernel in _POSSIBLE_KERNELS:
if kernel.__name__ in os.environ.get("VLLM_DISABLED_KERNELS", "")\
.split(","):
if kernel.__name__ in envs.VLLM_DISABLED_KERNELS:
failure_reasons.append(
f' {kernel.__name__} disabled by environment variable')
continue
Expand Down
140 changes: 140 additions & 0 deletions vllm/model_executor/layers/quantization/kernels/exllama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
from typing import Optional, Tuple

import torch

from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.quant_utils import (
pack_quantized_values_into_int32)
from vllm.model_executor.parameter import (BasevLLMParameter,
permute_param_layout_)
from vllm.scalar_type import scalar_types

from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig


class ExllamaLinearKernel(MPLinearKernel):
SUPPORTED_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128]
# In theory supports `scalar_types.uint2b2, scalar_types.uint3b4` too but
# currently untested so not added to the list

@classmethod
def get_min_capability(cls) -> int:
return 60

@classmethod
def can_implement(cls,
c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]:
if c.has_g_idx and\
c.partition_weight_shape[0] != c.full_weight_shape[0]:
return False, "Act reordering currently not supported by Exllama, "\
"when the input features are partitioned across "\
"devices"

if c.partition_weight_shape[1] % (32 // c.weight_type.size_bits) != 0:
return False, "Output features must be a multiple of the pack " \
"factor (32 / num_bits) so that we can correctly " \
"pack the zero points"

if c.act_type != torch.float16:
return False, "Exllama only supports float16 activations"

if c.weight_type not in cls.SUPPORTED_QUANT_TYPES:
return False, f"Quant type ({c.weight_type}) not supported by "\
"Exllama, supported types are: "\
f"{cls.SUPPORTED_QUANT_TYPES}"

if c.full_weight_shape[0] % c.group_size != 0:
return False, f"Group size ({c.group_size}) does not evenly divide"\
" the number of input features "\
f"({c.full_weight_shape[0]})"

return True, None

def process_weights_after_loading(self, layer: torch.nn.Module):
c = self.config

# For Exllama, we need to set a zero-point tensor if there is not one
if not c.zero_points:
self.w_zp_name = "qzeros"
device = getattr(layer, self.w_q_name).device
groups = c.partition_weight_shape[0] // c.group_size
out_features = c.partition_weight_shape[1]

if c.weight_type.has_bias():
# if the type has a bias we have to create a zeros tensor that
# contains the bias values repeated for each group (-1 due to
# a bug in the original GPTQ checkpoint format leading to
# exllama kernel adding 1 to the zero points during inference)
# Documentation of the bug can be found here:
# https://garden.danieldk.eu/GPTQ-Checkpoint-Format
zeros = torch.full((groups, out_features),
c.weight_type.bias - 1,
dtype=torch.int32,
device=device)
else:
raise NotImplementedError(
"A 0 zero-point is not supported by Exllama due to "
"a bug in the original GPTQ checkpoint format leading to "
"exllama kernel adding 1 to the zero points during "
"inference")
zeros = pack_quantized_values_into_int32(zeros,
c.weight_type,
packed_dim=1)
setattr(layer, self.w_zp_name,
torch.nn.Parameter(zeros, requires_grad=False))

if c.has_g_idx:

def transform_w_g_idx(x):
# Exllama wants the permutation array instead of the group
# indices
return torch.argsort(x).to(torch.int)

self._transform_param(layer, self.w_gidx_name, transform_w_g_idx)
else:
self.w_gidx_name = "g_idx"
empty_g_idx = torch.nn.Parameter(torch.empty((0, ),
dtype=torch.int,
device=device),
requires_grad=False)
setattr(layer, self.w_gidx_name, empty_g_idx)

def transform_w_q(x):
assert isinstance(x, BasevLLMParameter)
assert self.w_gidx_name is not None
g_idx = getattr(layer, self.w_gidx_name)

permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
x_cont = x.data.contiguous()
ops.gptq_shuffle(x_cont, g_idx, c.weight_type.size_bits)
return x_cont

def transform_w_s(x):
assert isinstance(x, BasevLLMParameter)
permute_param_layout_(x, input_dim=0, output_dim=1)
x.data = x.data.contiguous()
return x.to(dtype=c.act_type)

# Repack weights and scales for Machete
self._transform_param(layer, self.w_q_name, transform_w_q)
self._transform_param(layer, self.w_s_name, transform_w_s)

def apply_weights(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
c = self.config

x_2d = x.reshape(-1, x.shape[-1])
out_shape = x.shape[:-1] + (c.partition_weight_shape[1], )

w_q, w_s, w_zp, w_g_idx = self._get_weight_params(layer)

assert w_zp is not None, "Zero points are required by Exllama"
assert w_g_idx is not None, "Group index is required by Exllama"
output = ops.gptq_gemm(x_2d, w_q, w_zp, w_s, w_g_idx, True,
c.weight_type.size_bits)

if bias is not None:
output.add_(bias)
return output.reshape(out_shape)
14 changes: 7 additions & 7 deletions vllm/model_executor/layers/quantization/kernels/machete.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
MACHETE_SUPPORTED_GROUP_SIZES, check_machete_supports_shape,
query_machete_supported_quant_types)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
pack_weights_into_int32, unpack_weights_into_int32)
pack_quantized_values_into_int32, unpack_quantized_values_into_int32)
from vllm.model_executor.parameter import (BasevLLMParameter,
permute_param_layout_)

Expand Down Expand Up @@ -71,13 +71,13 @@ def transform_w_q(x):
assert isinstance(x, BasevLLMParameter)
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
if c.has_g_idx:
x_unpacked = unpack_weights_into_int32(x.data,
c.weight_type,
packed_dim=0)
x_unpacked = unpack_quantized_values_into_int32(x.data,
c.weight_type,
packed_dim=0)
x_perm = x_unpacked[perm, :]
x.data = pack_weights_into_int32(x_perm,
c.weight_type,
packed_dim=0)
x.data = pack_quantized_values_into_int32(x_perm,
c.weight_type,
packed_dim=0)
x.data = ops.machete_prepack_B(x.data.t().contiguous().t(),
self.config.weight_type)
return x
Expand Down
12 changes: 6 additions & 6 deletions vllm/model_executor/layers/quantization/utils/quant_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
}


def pack_weights_into_int32(w_q: torch.Tensor,
wtype: ScalarType,
packed_dim: int = 0):
def pack_quantized_values_into_int32(w_q: torch.Tensor,
wtype: ScalarType,
packed_dim: int = 0):
# move dim to pack to the end
perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim)
inv_perm = tuple(perm.index(i) for i in range(len(perm)))
Expand All @@ -42,9 +42,9 @@ def pack_weights_into_int32(w_q: torch.Tensor,
return res.permute(inv_perm)


def unpack_weights_into_int32(w_q: torch.Tensor,
wtype: ScalarType,
packed_dim: int = 0):
def unpack_quantized_values_into_int32(w_q: torch.Tensor,
wtype: ScalarType,
packed_dim: int = 0):
# move dim to pack to the end
perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim)
inv_perm = tuple(perm.index(i) for i in range(len(perm)))
Expand Down
2 changes: 2 additions & 0 deletions vllm/scalar_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ class scalar_types:
float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE.value)

# "gptq" types
uint2b2 = ScalarType.uint(2, 2)
uint3b4 = ScalarType.uint(3, 4)
uint4b8 = ScalarType.uint(4, 8)
uint8b128 = ScalarType.uint(8, 128)

Expand Down

0 comments on commit e312e52

Please sign in to comment.