Skip to content

Commit

Permalink
[Bugfix] Fix gptq failure on T4s (vllm-project#7264)
Browse files Browse the repository at this point in the history
  • Loading branch information
LucasWilkinson authored and fialhocoelho committed Aug 22, 2024
1 parent a3079aa commit 5af0dfb
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 15 deletions.
3 changes: 1 addition & 2 deletions vllm/model_executor/layers/quantization/awq_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,7 @@ def is_awq_marlin_compatible(cls, quant_config: Dict[str, Any]):

return check_marlin_supported(quant_type=cls.TYPE_MAP[num_bits],
group_size=group_size,
has_zp=has_zp,
min_capability=cls.get_min_capability())
has_zp=has_zp)


class AWQMarlinLinearMethod(LinearMethodBase):
Expand Down
3 changes: 1 addition & 2 deletions vllm/model_executor/layers/quantization/gptq_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,7 @@ def is_gptq_marlin_compatible(cls, quant_config: Dict[str, Any]):
return False

return check_marlin_supported(quant_type=cls.TYPE_MAP[(num_bits, sym)],
group_size=group_size,
min_capability=cls.get_min_capability())
group_size=group_size)


class GPTQMarlinLinearMethod(LinearMethodBase):
Expand Down
23 changes: 12 additions & 11 deletions vllm/model_executor/layers/quantization/utils/marlin_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,13 @@
# without runtime zero-point. We support common cases, i.e. AWQ and GPTQ.
# TODO: we may want to move this into the C++ so its closer to the actual impl
def query_marlin_supported_quant_types(has_zp: bool,
min_capability: Optional[int] = None):
if min_capability is None:
device_capability: Optional[int] = None
):
if device_capability is None:
major, minor = current_platform.get_device_capability()
min_capability = major * 10 + minor
device_capability = major * 10 + minor

if min_capability < 80:
if device_capability < 80:
return []

if has_zp:
Expand All @@ -48,20 +49,20 @@ def _check_marlin_supported(
quant_type: ScalarType,
group_size: Optional[int],
has_zp: bool,
min_capability: Optional[int] = None) -> Tuple[bool, Optional[str]]:
device_capability: Optional[int] = None) -> Tuple[bool, Optional[str]]:

if min_capability is None:
if device_capability is None:
major, minor = current_platform.get_device_capability()
min_capability = major * 10 + minor
device_capability = major * 10 + minor

supported_types = query_marlin_supported_quant_types(
has_zp, min_capability)
has_zp, device_capability)

if quant_type not in supported_types:
return (False, f"Marlin does not support weight_bits = {quant_type}. "
f"Only types = {supported_types} "
f"are supported (for group_size = {group_size}, "
f"min_capability = {min_capability}, zp = {has_zp}).")
f"device_capability = {device_capability}, zp = {has_zp}).")
if (group_size is None or group_size not in MARLIN_SUPPORTED_GROUP_SIZES):
return (False, f"Marlin does not support group_size = {group_size}. "
f"Only group_sizes = {MARLIN_SUPPORTED_GROUP_SIZES} "
Expand All @@ -73,9 +74,9 @@ def _check_marlin_supported(
def check_marlin_supported(quant_type: ScalarType,
group_size: int,
has_zp: bool = False,
min_capability: Optional[int] = None) -> bool:
device_capability: Optional[int] = None) -> bool:
cond, _ = _check_marlin_supported(quant_type, group_size, has_zp,
min_capability)
device_capability)
return cond


Expand Down

0 comments on commit 5af0dfb

Please sign in to comment.