Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TPU] Add Load-time W8A16 quantization for TPU Backend #7005

Merged
merged 7 commits into from
Aug 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ def _verify_quantization(self) -> None:
"fp8", "marlin", "gptq_marlin_24", "gptq_marlin", "awq_marlin",
"fbgemm_fp8", "compressed_tensors", "compressed-tensors"
]
tpu_supported_quantization = ["tpu_int8"]
if self.quantization is not None:
self.quantization = self.quantization.lower()

Expand Down Expand Up @@ -282,6 +283,11 @@ def _verify_quantization(self) -> None:
raise ValueError(
f"{self.quantization} quantization is currently not "
f"supported in ROCm.")
if is_tpu(
) and self.quantization not in tpu_supported_quantization:
raise ValueError(
f"{self.quantization} quantization is currently not "
f"supported in TPU Backend.")
if self.quantization not in optimized_quantization_methods:
logger.warning(
"%s quantization is not fully "
Expand Down
2 changes: 2 additions & 0 deletions vllm/model_executor/layers/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,13 @@
from vllm.model_executor.layers.quantization.marlin import MarlinConfig
from vllm.model_executor.layers.quantization.qqq import QQQConfig
from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig
from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig

QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
"aqlm": AQLMConfig,
"awq": AWQConfig,
"deepspeedfp": DeepSpeedFPConfig,
"tpu_int8": Int8TpuConfig,
"fp8": Fp8Config,
"fbgemm_fp8": FBGEMMFp8Config,
# The order of gptq methods is important for config.py iteration over
Expand Down
118 changes: 118 additions & 0 deletions vllm/model_executor/layers/quantization/tpu_int8.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
from typing import Any, Dict, List, Optional, Tuple

import torch
from torch.nn import Module
from torch.nn.parameter import Parameter

from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.utils import set_weight_attrs

ACTIVATION_SCHEMES = ["none"]


class Int8TpuConfig(QuantizationConfig):
"""Int8 Quantization Config class for TPU Backend."""

def __init__(
self,
activation_scheme: str = "none",
) -> None:
if activation_scheme not in ACTIVATION_SCHEMES:
raise ValueError(
f"Unsupported activation scheme {activation_scheme}")
self.activation_scheme = activation_scheme
WoosukKwon marked this conversation as resolved.
Show resolved Hide resolved

def get_name(self) -> str:
return "tpu_int8"

def get_supported_act_dtypes(self) -> List[torch.dtype]:
return [torch.float16, torch.bfloat16]
lsy323 marked this conversation as resolved.
Show resolved Hide resolved

@classmethod
def get_min_capability(cls) -> int:
raise NotImplementedError(
"This function should not be called with TPU Backend")

@staticmethod
def get_config_filenames() -> List[str]:
return []

@classmethod
def from_config(cls, config: Dict[str, Any]) -> "Int8TpuConfig":
activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
return cls(activation_scheme=activation_scheme)

def get_quant_method(self, layer: Module,
prefix: str) -> Optional["TPUInt8LinearMethod"]:
if isinstance(layer, LinearBase):
return TPUInt8LinearMethod(self)
return None

def get_scaled_act_names(self) -> List[str]:
return []


class TPUInt8LinearMethod(LinearMethodBase):
"""Int8 Linear method for TPU Quant. """

def __init__(self, quant_config: Int8TpuConfig):
self.quant_config = quant_config

def create_weights(self, layer: Module, input_size_per_partition: int,
output_partition_sizes: List[int], input_size: int,
output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs):
weight = Parameter(torch.empty(sum(output_partition_sizes),
input_size_per_partition,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("weight", weight)
set_weight_attrs(weight, {
**extra_weight_attrs,
"input_dim": 1,
"output_dim": 0,
})

def _quantize_weight(
self, weight: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
weight_dtype = weight.dtype
weight = weight.cpu().to(torch.float32)
n_bit = 8
eps = 1e-5
max_int = 2**(n_bit - 1) - 1
min_int = -(2**(n_bit - 1))
max_val = weight.abs().amax(dim=-1, keepdim=True)
max_val = max_val.clamp(min=eps)
qscale = max_val / max_int
qweight = torch.clamp(torch.round(weight * (1.0 / qscale)), min_int,
max_int).to(torch.int8)
qscale = qscale.squeeze().to(weight_dtype)
return qweight, qscale

def process_weights_after_loading(self, layer: Module) -> None:
device = layer.weight.device
qweight, qscale = self._quantize_weight(layer.weight)
qweight = qweight.to(device)
qscale = qscale.to(device)
layer.weight = Parameter(qweight, requires_grad=False)
layer.scale = Parameter(qscale, requires_grad=False)

def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
try:
import torch_xla.experimental.xla_quantized_matmul # noqa: F401
except ImportError as err:
raise ImportError(
"Please install torch_xla by following the instructions at "
"https://docs.vllm.ai/en/latest/getting_started/tpu-installation.html " # noqa: E501
"to run vLLM on TPU.") from err
weight = layer.weight
scale = layer.scale
out = torch.ops.xla.quantized_matmul(x, weight, scale)
if bias is not None:
out = out + bias
return out
17 changes: 9 additions & 8 deletions vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,14 +94,15 @@ def _get_quantization_config(
"""Get the quantization config."""
if model_config.quantization is not None:
quant_config = get_quant_config(model_config, load_config)
capability = current_platform.get_device_capability()
capability = capability[0] * 10 + capability[1]
if capability < quant_config.get_min_capability():
raise ValueError(
f"The quantization method {model_config.quantization} is not "
"supported for the current GPU. "
f"Minimum capability: {quant_config.get_min_capability()}. "
f"Current capability: {capability}.")
if not is_tpu():
lsy323 marked this conversation as resolved.
Show resolved Hide resolved
capability = current_platform.get_device_capability()
capability = capability[0] * 10 + capability[1]
if capability < quant_config.get_min_capability():
raise ValueError(
f"The quantization method {model_config.quantization} "
"is not supported for the current GPU. "
f"Minimum capability: {quant_config.get_min_capability()}. "
f"Current capability: {capability}.")
supported_dtypes = quant_config.get_supported_act_dtypes()
if model_config.dtype not in supported_dtypes:
raise ValueError(
Expand Down
Loading