Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel] AQ AZP 4/4: Integrate asymmetric quantization to linear method #7271

Merged
merged 19 commits into from
Sep 27, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,10 @@ def _is_static_tensor_w8a8(self, weight_quant: BaseModel,
or weight_quant.strategy == QuantizationStrategy.CHANNEL.value)
is_tensor = (weight_strategy and input_quant.strategy
== QuantizationStrategy.TENSOR.value)
is_symmetric = weight_quant.symmetric and input_quant.symmetric
is_symmetric_weight = weight_quant.symmetric
is_static = not weight_quant.dynamic and not input_quant.dynamic

return is_8_bits and is_tensor and is_symmetric and is_static
return is_8_bits and is_tensor and is_symmetric_weight and is_static
ProExpertProg marked this conversation as resolved.
Show resolved Hide resolved

def _is_dynamic_token_w8a8(self, weight_quant: BaseModel,
input_quant: BaseModel) -> bool:
Expand All @@ -151,10 +151,10 @@ def _is_dynamic_token_w8a8(self, weight_quant: BaseModel,
or weight_quant.strategy == QuantizationStrategy.CHANNEL.value)
is_token = (weight_strategy and input_quant.strategy
== QuantizationStrategy.TOKEN.value)
is_symmetric = weight_quant.symmetric and input_quant.symmetric
is_symmetric_weight = weight_quant.symmetric
is_dynamic = not weight_quant.dynamic and input_quant.dynamic

return is_8_bits and is_token and is_symmetric and is_dynamic
return is_8_bits and is_token and is_symmetric_weight and is_dynamic
ProExpertProg marked this conversation as resolved.
Show resolved Hide resolved

def _is_fp8_w8a8(self, weight_quant: BaseModel,
input_quant: BaseModel) -> bool:
Expand Down Expand Up @@ -265,12 +265,14 @@ def _get_scheme_from_parts(
if self._is_static_tensor_w8a8(weight_quant, input_quant):
return CompressedTensorsW8A8Int8(
strategy=weight_quant.strategy,
is_static_input_scheme=True)
is_static_input_scheme=True,
input_symmetric=input_quant.symmetric)

if self._is_dynamic_token_w8a8(weight_quant, input_quant):
return CompressedTensorsW8A8Int8(
strategy=weight_quant.strategy,
is_static_input_scheme=False)
is_static_input_scheme=False,
input_symmetric=input_quant.symmetric)

raise NotImplementedError(
"No compressed-tensors compatible scheme was found.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch
from torch.nn import Parameter

from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
Expand All @@ -14,12 +15,16 @@
ModelWeightParameter,
PerTensorScaleParameter)

logger = init_logger(__name__)


class CompressedTensorsW8A8Int8(CompressedTensorsScheme):

def __init__(self, strategy: str, is_static_input_scheme: bool):
def __init__(self, strategy: str, is_static_input_scheme: bool,
input_symmetric: bool):
self.strategy = strategy
self.is_static_input_scheme = is_static_input_scheme
self.input_symmetric = input_symmetric

@classmethod
def get_min_capability(cls) -> int:
Expand All @@ -46,10 +51,43 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
requires_grad=False)
# INPUT SCALE
if self.is_static_input_scheme:
layer.input_scale = Parameter(layer.input_scale.max(),
requires_grad=False)
if self.input_symmetric:
layer.input_scale = Parameter(layer.input_scale.max(),
requires_grad=False)
else:
# Static asymmetric quantization has not been tested yet.
# Kernel and ops support exists and is tested, it's just the
# following integration code that is untested.
logger.warning(
"Static asymmetric quantization currently untested")
ProExpertProg marked this conversation as resolved.
Show resolved Hide resolved

# reconstruct the ranges
int8_traits = torch.iinfo(torch.int8)
range_max = (layer.input_scale *
(int8_traits.max - layer.input_zero_point)).max()
range_min = (layer.input_scale *
(int8_traits.min - layer.input_zero_point)).min()

scale = (range_max - range_min) / (int8_traits.max -
int8_traits.min)
layer.input_scale = Parameter(scale, requires_grad=False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should add an accuracy test to make sure this works as expected

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where should we add the accuracy test?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is now complete with the compressed tensors quantization tests - thanks!


azp = int8_traits.min - range_min / scale
layer.input_zero_point = Parameter(azp, requires_grad=False)

else:
layer.input_scale = None
layer.input_zero_point = None

if not self.input_symmetric:
ProExpertProg marked this conversation as resolved.
Show resolved Hide resolved
# azp_adj is the AZP adjustment term, used to account for weights.
# For more details, see csrc/quantization/cutlass_w8a8/Epilogues.md
ProExpertProg marked this conversation as resolved.
Show resolved Hide resolved
# https://github.com/vllm-project/vllm/blob/8d59dbb00044a588cab96bcdc028006ed922eb06/csrc/quantization/cutlass_w8a8/Epilogues.md
layer.azp_adj = layer.weight.sum(dim=0,
mgoin marked this conversation as resolved.
Show resolved Hide resolved
keepdim=True,
dtype=torch.int32)
else:
layer.azp_adj = None

def create_weights(self, layer: torch.nn.Module,
output_partition_sizes: List[int],
Expand Down Expand Up @@ -90,11 +128,19 @@ def create_weights(self, layer: torch.nn.Module,
weight_loader=weight_loader)
layer.register_parameter("input_scale", input_scale)

if not self.input_symmetric:
ProExpertProg marked this conversation as resolved.
Show resolved Hide resolved
# Static asymmetric quantization has not been tested yet
logger.warning(
"Static asymmetric quantization currently untested")
ProExpertProg marked this conversation as resolved.
Show resolved Hide resolved
input_zero_point = Parameter(torch.zeros(1, dtype=torch.int32))
ProExpertProg marked this conversation as resolved.
Show resolved Hide resolved
layer.register_parameter("input_zero_point", input_zero_point)

def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:

return apply_int8_linear(input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
input_scale=layer.input_scale,
input_zero_point=layer.input_zero_point,
azp_adj=layer.azp_adj,
bias=bias)
19 changes: 17 additions & 2 deletions vllm/model_executor/layers/quantization/utils/w8a8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,13 +191,28 @@ def apply_int8_linear(
weight: torch.Tensor,
weight_scale: torch.Tensor,
input_scale: Optional[torch.Tensor] = None,
input_zero_point: Optional[torch.Tensor] = None,
azp_adj: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
):
# ops.scaled_int8_quant supports both dynamic and static quant.
# * dynamic, layer.input_scale is None and x_scale computed from x.
# * static, layer.input_scale is scalar and x_scale is input_scale.
x_q, x_scale, _ = ops.scaled_int8_quant(input, input_scale)

symmetric = azp_adj is None
x_q, x_scale, x_zp = ops.scaled_int8_quant(input,
input_scale,
input_zero_point,
symmetric=symmetric)

if x_zp is not None:
return ops.cutlass_scaled_mm_azp(x_q,
ProExpertProg marked this conversation as resolved.
Show resolved Hide resolved
weight,
scale_a=x_scale,
scale_b=weight_scale,
out_dtype=input.dtype,
azp_adj=azp_adj,
azp=x_zp,
bias=bias)
return ops.cutlass_scaled_mm(x_q,
weight,
scale_a=x_scale,
Expand Down
Loading