Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flexible Pack DType #1158

Merged
merged 15 commits into from
Jan 26, 2025
2 changes: 2 additions & 0 deletions gptqmodel/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,7 @@ def quantize(
device=DEVICE(self.quantize_config.device),
pack=True,
format=self.quantize_config.format,
pack_dtype=self.quantize_config.pack_dtype,
)

# Use the provided tokenizer if one is passed to quantize()
Expand Down Expand Up @@ -842,6 +843,7 @@ def tmp(_, inp: Tuple[torch.Tensor, ...], out: torch.Tensor):
lm_head_name=self.lm_head,
dynamic=self.quantize_config.dynamic,
parallel_packing=self.quantize_config.parallel_packing,
pack_dtype=self.quantize_config.pack_dtype,
)

self.model.config.use_cache = forward_pass_use_cache
Expand Down
82 changes: 42 additions & 40 deletions gptqmodel/models/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,13 +302,13 @@ def from_quantized(
if config.model_type not in SUPPORTED_MODELS:
raise TypeError(f"{config.model_type} isn't supported yet.")

quantize_config = QuantizeConfig.from_pretrained(model_local_path, **cached_file_kwargs, **kwargs)
qcfg = QuantizeConfig.from_pretrained(model_local_path, **cached_file_kwargs, **kwargs)

quantize_config.calculate_bits_per_weight()
qcfg.calculate_bits_per_weight()

if backend == BACKEND.VLLM or backend == BACKEND.SGLANG:
if quantize_config.format != FORMAT.GPTQ:
raise ValueError(f"{backend} backend only supports FORMAT.GPTQ: actual = {quantize_config.format}")
if qcfg.format != FORMAT.GPTQ:
raise ValueError(f"{backend} backend only supports FORMAT.GPTQ: actual = {qcfg.format}")
if backend == BACKEND.VLLM:
from ..utils.vllm import load_model_by_vllm, vllm_generate

Expand All @@ -335,12 +335,12 @@ def from_quantized(
return cls(
model,
quantized=True,
quantize_config=quantize_config,
quantize_config=qcfg,
qlinear_kernel=None,
model_local_path=model_local_path,
)

if quantize_config.format == FORMAT.MARLIN:
if qcfg.format == FORMAT.MARLIN:
# format marlin requires marlin kernel
if backend != BACKEND.MARLIN and backend != BACKEND.AUTO:
raise TypeError(f"FORMAT.MARLIN requires BACKEND.AUTO or BACKEND.MARLIN: actual = `{backend}`.")
Expand All @@ -350,13 +350,13 @@ def from_quantized(

# check for marlin compat for cuda device onnly
if backend != BACKEND.MARLIN and device == DEVICE.CUDA:
unsupported = _validate_marlin_compatibility(quantize_config)
unsupported = _validate_marlin_compatibility(qcfg)
if unsupported is None and marlin_compatible:
logger.info(
"You passed a model that is compatible with the Marlin kernel. Use `BACKEND.MARLIN` for optimal inference with batching on Nvidia GPU: `model = GPTQModel.load(..., backend=BACKEND.MARLIN)`."
)

if quantize_config.format == FORMAT.BITBLAS:
if qcfg.format == FORMAT.BITBLAS:
# format bitblas requires bitblas kernel
if backend != BACKEND.BITBLAS and backend != BACKEND.AUTO:
raise TypeError(f"FORMAT.BITBLAS requires BACKEND.AUTO or BACKEND.BITBLAS: actual = `{backend}`.")
Expand All @@ -368,7 +368,7 @@ def from_quantized(
raise ValueError(BITBLAS_INSTALL_HINT)

possible_model_basenames = [
f"gptq_model-{quantize_config.bits}bit-{quantize_config.group_size}g",
f"gptq_model-{qcfg.bits}bit-{qcfg.group_size}g",
"model",
]

Expand All @@ -390,7 +390,7 @@ def from_quantized(
"Loading of .bin files are not allowed due to safety. Please convert your model to safetensor or pytorch format."
)

quantize_config.runtime_format = quantize_config.format
qcfg.runtime_format = qcfg.format

model_save_name = resolved_archive_file # In case a model is sharded, this would be `model.safetensors.index.json` which may later break.
if verify_hash:
Expand Down Expand Up @@ -443,7 +443,7 @@ def skip(*args, **kwargs):

for name in list(layers.keys()):
# allow loading of quantized lm_head
if quantize_config.lm_head and name == cls.lm_head:
if qcfg.lm_head and name == cls.lm_head:
continue

if any(name.startswith(ignore_layer) for ignore_layer in ignore_layers) or all(
Expand All @@ -457,22 +457,23 @@ def skip(*args, **kwargs):
preload_qlinear_kernel = make_quant(
model,
layers,
quantize_config.bits,
quantize_config.group_size,
backend=backend.AUTO if (backend == BACKEND.MARLIN and quantize_config.format == FORMAT.MARLIN) or backend == BACKEND.BITBLAS else backend,
format=quantize_config.format,
qcfg.bits,
qcfg.group_size,
backend=backend.AUTO if (backend == BACKEND.MARLIN and qcfg.format == FORMAT.MARLIN) or backend == BACKEND.BITBLAS else backend,
format=qcfg.format,
lm_head_name=cls.lm_head,
desc_act=quantize_config.desc_act,
sym=quantize_config.sym,
dynamic=quantize_config.dynamic,
desc_act=qcfg.desc_act,
sym=qcfg.sym,
dynamic=qcfg.dynamic,
device=device,
pack_dtype=qcfg.pack_dtype,
)
if preload_qlinear_kernel == IPEXQuantLinear:
quantize_config.runtime_format = FORMAT.IPEX
qcfg.runtime_format = FORMAT.IPEX

load_checkpoint_in_model = False
# compat: runtime convert checkpoint gptq(v1) to gptq_v2 format
if quantize_config.format == FORMAT.GPTQ and backend != BACKEND.IPEX:
if qcfg.format == FORMAT.GPTQ and backend != BACKEND.IPEX:
load_checkpoint_in_model_then_tie_weights(
model,
dtype=torch_dtype,
Expand All @@ -483,7 +484,7 @@ def skip(*args, **kwargs):
offload_buffers=True,
)
# validate sym=False v1 loading needs to be protected for models produced with new v2 format codebase
if not quantize_config.sym and not quantize_config.is_quantized_by_v2():
if not qcfg.sym and not qcfg.is_quantized_by_v2():
raise ValueError(
f"Loading of a sym=False model with format={FORMAT.GPTQ} is only supported if produced by gptqmodel version >= {MIN_VERSION_WITH_V2}"
)
Expand All @@ -492,15 +493,15 @@ def skip(*args, **kwargs):
logger.info(f"Converting `{FORMAT_FIELD_JSON}` from `{FORMAT.GPTQ}` to `{FORMAT.GPTQ_V2}`.")
model = convert_gptq_v1_to_v2_format(
model,
quantize_config=quantize_config,
cfg=qcfg,
qlinear_kernel=preload_qlinear_kernel,
)
logger.info(f"Conversion complete: {time.time()-t}s")
load_checkpoint_in_model = True
quantize_config.runtime_format = FORMAT.GPTQ_V2
qcfg.runtime_format = FORMAT.GPTQ_V2

if backend == BACKEND.MARLIN and (
preload_qlinear_kernel == ExllamaV2QuantLinear or quantize_config.format == FORMAT.MARLIN):
preload_qlinear_kernel == ExllamaV2QuantLinear or qcfg.format == FORMAT.MARLIN):
if is_sharded:
raise ValueError(
"The loading of sharded checkpoints with Marlin is currently not supported."
Expand All @@ -514,19 +515,19 @@ def skip(*args, **kwargs):
if torch_dtype != torch.float16:
raise ValueError("Marlin kernel requires torch_dtype=torch.float16.")

_validate_marlin_compatibility(quantize_config, throw_error=True)
_validate_marlin_compatibility(qcfg, throw_error=True)

# Prepare model for marlin load.
# If is marlin serialized load then load directly. Otherwise, convert to marlin.
model = prepare_model_for_marlin_load(
model=model,
quantize_config=quantize_config,
quantize_config=qcfg,
quant_linear_class=preload_qlinear_kernel,
torch_dtype=torch_dtype,
current_model_save_name=model_save_name,
device_map=device_map,
desc_act=quantize_config.desc_act,
sym=quantize_config.sym,
desc_act=qcfg.desc_act,
sym=qcfg.sym,
load_checkpoint_in_model=load_checkpoint_in_model,
)

Expand All @@ -537,13 +538,13 @@ def skip(*args, **kwargs):
# If is bitblas serialized load then load directly. Otherwise, convert to bitblas.
model = prepare_model_for_bitblas_load(
model=model,
quantize_config=quantize_config,
quantize_config=qcfg,
quant_linear_class=preload_qlinear_kernel,
torch_dtype=torch_dtype,
model_save_name=model_save_name,
device_map=device_map,
desc_act=quantize_config.desc_act,
sym=quantize_config.sym,
desc_act=qcfg.desc_act,
sym=qcfg.sym,
load_checkpoint_in_model=load_checkpoint_in_model,
)

Expand All @@ -564,14 +565,15 @@ def skip(*args, **kwargs):
model = simple_dispatch_model(model, device_map)

qlinear_kernel = select_quant_linear(
bits=quantize_config.bits,
dynamic=quantize_config.dynamic,
group_size=quantize_config.group_size,
desc_act=quantize_config.desc_act,
sym=quantize_config.sym,
bits=qcfg.bits,
dynamic=qcfg.dynamic,
group_size=qcfg.group_size,
desc_act=qcfg.desc_act,
sym=qcfg.sym,
backend=backend,
format=quantize_config.format,
format=qcfg.format,
device=device,
pack_dtype=qcfg.pack_dtype,
)

# == step4: set seqlen == #
Expand All @@ -587,7 +589,7 @@ def skip(*args, **kwargs):
model.seqlen = 4096

# Any post-initialization that require device information, for example buffers initialization on device.
model = gptqmodel_post_init(model, use_act_order=quantize_config.desc_act, quantize_config=quantize_config)
model = gptqmodel_post_init(model, use_act_order=qcfg.desc_act, quantize_config=qcfg)

model.eval()

Expand All @@ -607,7 +609,7 @@ def skip(*args, **kwargs):
)

with tempfile.TemporaryDirectory() as temp_dir:
mlx_weights, mlx_config = convert_gptq_to_mlx_weights(model_id_or_path, model, quantize_config.to_dict())
mlx_weights, mlx_config = convert_gptq_to_mlx_weights(model_id_or_path, model, qcfg.to_dict())

save_weights(temp_dir, mlx_weights, donate_weights=True)
save_config(mlx_config, config_path=temp_dir + "/config.json")
Expand All @@ -621,7 +623,7 @@ def skip(*args, **kwargs):
return cls(
model,
quantized=True,
quantize_config=quantize_config,
quantize_config=qcfg,
tokenizer=tokenizer,
qlinear_kernel=qlinear_kernel,
load_quantized_model=True,
Expand Down
42 changes: 35 additions & 7 deletions gptqmodel/nn_modules/qlinear/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@
# limitations under the License.

import sys
from typing import List, Optional, Tuple, Union
from typing import List, Optional, Tuple

import numpy as np
import torch as t # conflict with torch.py
import torch.nn as nn

from ...models._const import DEVICE, PLATFORM
Expand All @@ -32,23 +34,45 @@ class BaseQuantLinear(nn.Module):
SUPPORTS_IN_FEATURES_DIVISIBLE_BY: List[int] = None
SUPPORTS_OUT_FEATURES_DIVISIBLE_BY: List[int] = None

SUPPORTS_PACK_DTYPES: List[t.dtype] = None
SUPPORTS_DEVICES: List[DEVICE] = None
SUPPORTS_PLATFORM: List[PLATFORM] = None

def __init__(self, bits: int, group_size: int, desc_act: bool, sym: bool, infeatures: int, outfeatures: int, *args,
def __init__(self, bits: int, group_size: int, desc_act: bool, sym: bool, infeatures: int, outfeatures: int, pack_dtype: t.dtype, *args,
**kwargs):
super().__init__()
_, err = self._validate(bits=bits, group_size=group_size, desc_act=desc_act, sym=sym, infeatures=infeatures,outfeatures=outfeatures)

self.bits = bits

self.pack_dtype = pack_dtype

if self.pack_dtype == t.int8:
self.pack_dtype_bits = 8
self.pack_np_dtype = np.int8
elif self.pack_dtype == t.int16:
self.pack_dtype_bits = 16
self.pack_np_dtype = np.int16
elif self.pack_dtype == t.int32:
self.pack_dtype_bits = 32
self.pack_np_dtype = np.int32
elif self.pack_dtype == t.int64:
self.pack_dtype_bits = 64
self.pack_np_dtype = np.int64
else:
raise ValueError("Unsupported weight_dtype. Only int16 and int32 are supported.")

self.tensors_per_storage_dtype = self.pack_dtype_bits // self.bits
_, err = self._validate(bits=bits, group_size=group_size, desc_act=desc_act, sym=sym, infeatures=infeatures,outfeatures=outfeatures, pack_dtype=pack_dtype)
if err:
raise err

@classmethod
# custom quant linear class can override this and add custom checks
def validate(cls, bits: int, group_size: int, desc_act: bool, sym: bool, infeatures:int=None,
outfeatures:int=None, dynamic:Optional[dict]=None, device:Optional[DEVICE]=None, trainable:Optional[bool]=None) -> Tuple[
outfeatures:int=None, pack_dtype:t.dtype=None, dynamic:Optional[dict]=None, device:Optional[DEVICE]=None, trainable:Optional[bool]=None) -> Tuple[
bool, Optional[Exception]]:
validate, err = cls._validate(bits=bits, group_size=group_size, desc_act=desc_act, sym=sym,
infeatures=infeatures, outfeatures=outfeatures, dynamic=dynamic,
infeatures=infeatures, outfeatures=outfeatures, pack_dtype=pack_dtype, dynamic=dynamic,
device=device, trainable=trainable)
return validate, err

Expand Down Expand Up @@ -86,10 +110,14 @@ def verify_supports_params(cls):
raise ValueError(f"{cls.__name__}.{name} cannot be None or an empty list.")

@classmethod
def _validate(cls, bits: int, group_size: int, desc_act: bool, sym: bool, dynamic:Optional[dict]=None, infeatures:int=None,
outfeatures:int=None, device:Optional[DEVICE]=None, trainable:Optional[bool]=None) -> Tuple[bool, Optional[Exception]]:
def _validate(cls, bits: int=4, group_size: int=128, desc_act: bool=False, sym: bool=False, pack_dtype:t.dtype=None, dynamic:Optional[dict]=None, infeatures:int=None,
outfeatures:int=None, device:Optional[DEVICE]=None, trainable:Optional[bool]=None) -> Tuple[bool, Optional[Exception]]:
cls.verify_supports_params()

if pack_dtype not in cls.SUPPORTS_PACK_DTYPES:
err = f"{cls} does not support `pack_dtype`: {pack_dtype}"
return False, NotImplementedError(err)

if PLATFORM.ALL not in cls.SUPPORTS_PLATFORM and sys.platform not in cls.SUPPORTS_PLATFORM:
err = f"{cls} does not support platform: {sys.platform}"
return False, NotImplementedError(err)
Expand Down
5 changes: 3 additions & 2 deletions gptqmodel/nn_modules/qlinear/bitblas.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ class BitBLASQuantLinear(BaseQuantLinear):

SUPPORTS_DEVICES = [DEVICE.CUDA]
SUPPORTS_PLATFORM = [PLATFORM.LINUX, PLATFORM.WIN32]
SUPPORTS_PACK_DTYPES = [torch.int32]

OPT_FEATURES = [1, 16, 32, 64, 128, 256, 512]
zeros_mode = "quantized" # "original" or "rescale" or "quantized"
Expand All @@ -125,6 +126,7 @@ def __init__(
sym: bool,
infeatures: int,
outfeatures: int,
pack_dtype: torch.dtype,
bias: bool,
enable_tuning: bool = True,
fast_decoding: bool = True,
Expand All @@ -133,13 +135,12 @@ def __init__(
layout: str = "nt",
**kwargs,
):
super().__init__(bits=bits, group_size=group_size, sym=sym, desc_act=desc_act, infeatures=infeatures, outfeatures=outfeatures, **kwargs)
super().__init__(bits=bits, group_size=group_size, sym=sym, desc_act=desc_act, infeatures=infeatures, outfeatures=outfeatures, pack_dtype=pack_dtype, **kwargs)

import_bitblas()

self._validate_parameters(group_size, infeatures, outfeatures)

self.bits = bits
self.infeatures = infeatures
self.outfeatures = outfeatures
self.group_size = self._set_group_size(group_size, infeatures)
Expand Down
7 changes: 3 additions & 4 deletions gptqmodel/nn_modules/qlinear/exllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,14 @@ class ExllamaQuantLinear(BaseQuantLinear):

SUPPORTS_DEVICES = [DEVICE.CUDA, DEVICE.ROCM]
SUPPORTS_PLATFORM = [PLATFORM.LINUX]
SUPPORTS_PACK_DTYPES = [torch.int32]

# for transformers/optimum tests compat
QUANT_TYPE = "exllama"

"""Linear layer implementation with per-group 4-bit quantization of the weights"""

def __init__(self, bits: int, group_size: int, desc_act: bool, sym: bool, infeatures: int, outfeatures: int, bias: bool, **kwargs,):
def __init__(self, bits: int, group_size: int, desc_act: bool, sym: bool, infeatures: int, outfeatures: int, pack_dtype: torch.dtype, bias: bool, **kwargs,):
if exllama_import_exception is not None:
raise ValueError(
f"Trying to use the exllama backend, but could not import the C++/CUDA dependencies with the following error: {exllama_import_exception}"
Expand All @@ -88,9 +89,7 @@ def __init__(self, bits: int, group_size: int, desc_act: bool, sym: bool, infeat
self.outfeatures = outfeatures + (-outfeatures % 32)
self.infeatures = infeatures + (-infeatures % self.group_size)

super().__init__(bits=bits, group_size=group_size, sym=sym, desc_act=desc_act, infeatures=self.infeatures, outfeatures=self.outfeatures, **kwargs)

self.bits = bits
super().__init__(bits=bits, group_size=group_size, sym=sym, desc_act=desc_act, infeatures=self.infeatures, outfeatures=self.outfeatures, pack_dtype=pack_dtype, **kwargs)

# backup original values
self.original_outfeatures = outfeatures
Expand Down
Loading
Loading