Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
LucasWilkinson committed Aug 15, 2024
1 parent 467848d commit 57a8011
Show file tree
Hide file tree
Showing 6 changed files with 97 additions and 6 deletions.
3 changes: 1 addition & 2 deletions csrc/quantization/machete/machete_mm_launcher.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,7 @@ torch::Tensor run_impl(PyTorchArguments args) {
torch::empty({M, N}, torch::TensorOptions()
.dtype(equivalent_scalar_type_v<EleD>)
.device(device));

auto const &A = args.A, &B = args.B;
*auto const &A = args.A, &B = args.B;
auto const &C = args.C, &scales = args.scales, &zeros = args.zeros;

auto layout_A = make_cute_layout<StrideA>(A, "A");
Expand Down
7 changes: 6 additions & 1 deletion examples/offline_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,12 @@
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)

# Create an LLM.
llm = LLM(model="nm-testing/tinyllama-oneshot-w4a16-group128-v2")
# GPTQ = "kaitchup/Llama-2-7b-gptq-3bit"
# marlin = "TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ"
# machete = "TheBloke/Llama-2-7B-GPTQ"
# machete/marlin CT = "nm-testing/tinyllama-oneshot-w4a16-group128-v2"
# "nm-testing/tinyllama-oneshot-w4a16-channel-v2"
llm = LLM(model="nm-testing/tinyllama-oneshot-w4a16-channel-v2")
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
Expand Down
3 changes: 3 additions & 0 deletions vllm/model_executor/layers/quantization/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,9 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
ops.gptq_shuffle(layer.qweight, layer.g_idx,
self.quant_config.weight_bits)

print(layer.qzeros)
print(hex(layer.qzeros[0][0].to(torch.uint32).item()))

def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
Expand Down
3 changes: 0 additions & 3 deletions vllm/model_executor/layers/quantization/gptq_marlin.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
from typing import Any, Dict, List, Optional

import torch
from torch.nn import Parameter

from vllm.logger import init_logger
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.kernels import (
choose_mp_linear_kernel, MPLinearLayerConfig)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.kernels import (
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.machete_utils import (
MACHETE_SUPPORTED_GROUP_SIZES, check_machete_supports_shape,
query_machete_supported_quant_types)
from vllm.model_executor.parameter import (ModelWeightParameter,
PackedvLLMParameter)

from .MPLinearKernel import *


class GPTQLinearKernel(MPLinearKernel):

@classmethod
def get_min_capability(cls) -> int:
return 60

@classmethod
def can_implement(cls,
c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]:

if c.act_type != torch.half:
return False, f"Act type {c.act_type} currently not supported by GPTQLinearKernel"

if c.zero_points:
return False, "Zero points currently not supported by GPTQLinearKernel"

if c.weight_type not in query_machete_supported_quant_types(
c.zero_points):
return False, f"Quant type ({c.weight_type}) not supported by "\
"Machete, supported types are: "\
f"{query_machete_supported_quant_types(c.zero_points)}"

if c.group_size not in MACHETE_SUPPORTED_GROUP_SIZES:
return False, f"Group size ({c.group_size}) not supported by "\
"Machete, supported group sizes are: "\
f"{MACHETE_SUPPORTED_GROUP_SIZES}"

return check_machete_supports_shape(c.partition_weight_shape[0],
c.partition_weight_shape[1])

# note assumes that
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
# `weight_scale` is: {input_dim = 0, output_dim = 1}
def process_weights_after_loading(self, layer: torch.nn.Module):

def transform_w_q(x):
# TODO (lucas): assert isinstance(x, PackedvLLMParameter) once
# everything is migrated to using weight_loader_v2
if isinstance(x, PackedvLLMParameter):
x = x.permute_layout(input_dim=0, output_dim=1, packed_dim=0)
return ops.machete_prepack_B(x.t().contiguous().t(),
self.config.weight_type)

def transform_w_s(x):
# TODO (lucas): assert isinstance(x, PackedvLLMParameter) once
# everything is migrated to using weight_loader_v2
if isinstance(x, ModelWeightParameter):
x = x.permute_layout(input_dim=0, output_dim=1)
return x.contiguous()

# Repack weights and scales for Machete
self._transform_param(layer, self.w_q_name, transform_w_q)
self._transform_param(layer, self.w_s_name, transform_w_s)

def apply_weights(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
c = self.config
w_q, w_s, _, _ = self._get_weight_params(layer)

x_2d = x.reshape(-1, x.shape[-1])
out_shape = x.shape[:-1] + (c.partition_weight_shape[1], )

output = ops.machete_gemm(a=x_2d,
b_q=w_q,
b_type=c.weight_type,
b_zeros=None,
b_scales=w_s,
b_group_size=c.group_size)

if bias is not None:
output.add_(bias) # In-place add

return output.reshape(out_shape)
2 changes: 2 additions & 0 deletions vllm/scalar_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ class scalar_types:
float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE.value)

# "gptq" types
uint2b2 = ScalarType.uint(2, 2)
uint3b4 = ScalarType.uint(3, 4)
uint4b8 = ScalarType.uint(4, 8)
uint8b128 = ScalarType.uint(8, 128)

Expand Down

0 comments on commit 57a8011

Please sign in to comment.