Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel] AQ AZP 4/4: Integrate asymmetric quantization to linear method #7271

Merged
merged 19 commits into from
Sep 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Meta-Llama-3-8B-Instruct-W8-Channel-A8-Dynamic-Asym-Per-Token-Test -b "auto" -l 250 -f 5 -t 1
model_name: "nm-testing/Meta-Llama-3-8B-Instruct-W8-Channel-A8-Dynamic-Asym-Per-Token-Test"
tasks:
- name: "gsm8k"
metrics:
- name: "exact_match,strict-match"
value: 0.764
- name: "exact_match,flexible-extract"
value: 0.764
limit: 250
num_fewshot: 5
1 change: 1 addition & 0 deletions .buildkite/lm-eval-harness/configs/models-small.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
Meta-Llama-3-8B-Instruct.yaml
Meta-Llama-3-8B-Instruct-FP8-compressed-tensors.yaml
Meta-Llama-3-8B-Instruct-INT8-compressed-tensors.yaml
Meta-Llama-3-8B-Instruct-INT8-compressed-tensors-asym.yaml
Meta-Llama-3-8B-Instruct-nonuniform-compressed-tensors.yaml
Meta-Llama-3-8B-Instruct-Channelwise-compressed-tensors.yaml
Minitron-4B-Base-FP8.yaml
Expand Down
7 changes: 6 additions & 1 deletion .buildkite/lm-eval-harness/test_lm_eval_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,15 @@ def test_lm_eval_correctness():
results = launch_lm_eval(eval_config)

# Confirm scores match ground truth.
success = True
for task in eval_config["tasks"]:
for metric in task["metrics"]:
ground_truth = metric["value"]
measured_value = results["results"][task["name"]][metric["name"]]
print(f'{task["name"]} | {metric["name"]}: '
f'ground_truth={ground_truth} | measured={measured_value}')
assert numpy.isclose(ground_truth, measured_value, rtol=RTOL)
success = success and numpy.isclose(
ground_truth, measured_value, rtol=RTOL)

# Assert at the end, print all scores even on failure for debugging.
assert success
ProExpertProg marked this conversation as resolved.
Show resolved Hide resolved
36 changes: 27 additions & 9 deletions tests/quantization/test_compressed_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

Run `pytest tests/quantization/test_compressed_tensors.py`.
"""
from typing import Optional

import pytest
import torch
Expand All @@ -14,14 +15,16 @@
QuantizationType)


@pytest.mark.parametrize("model_args", [
("nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", "tensor",
QuantizationType.INT, 2560),
("nm-testing/tinyllama-oneshot-w8-channel-a8-tensor", "channel",
QuantizationType.INT, 2560),
])
@pytest.mark.parametrize(
"model_args",
[("nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", "tensor",
QuantizationType.INT, 2560, True),
("nm-testing/tinyllama-oneshot-w8-channel-a8-tensor", "channel",
QuantizationType.INT, 2560, True),
("nm-testing/asym-w8w8-int8-static-per-tensor-tiny-llama", "tensor",
QuantizationType.INT, 2560, False)])
def test_compressed_tensors_w8a8_static_setup(vllm_runner, model_args):
model_path, strategy, quant_type, shape_0 = model_args
model_path, strategy, quant_type, shape_0, is_symmetric = model_args
with vllm_runner(model_path, enforce_eager=True) as llm:
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
layer = model.model.layers[0]
Expand All @@ -31,6 +34,18 @@ def test_compressed_tensors_w8a8_static_setup(vllm_runner, model_args):
gate_up_proj = layer.mlp.gate_up_proj
down_proj = layer.mlp.down_proj

# assert zp for symmetric and asymmetric cases
def zp_valid(zp: Optional[torch.Tensor]):
if is_symmetric:
return zp is None

return zp is not None and zp.dtype is torch.int32

assert zp_valid(qkv_proj.input_zero_point)
assert zp_valid(o_proj.input_zero_point)
assert zp_valid(gate_up_proj.input_zero_point)
assert zp_valid(down_proj.input_zero_point)

assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
assert isinstance(o_proj.quant_method, CompressedTensorsLinearMethod)
assert isinstance(gate_up_proj.quant_method,
Expand Down Expand Up @@ -69,9 +84,12 @@ def test_compressed_tensors_no_enforce_eager(vllm_runner):

@pytest.mark.parametrize("model_args", [
("nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2", "tensor"),
("nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2-asym", "tensor"),
("nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2", "channel"),
("nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2-asym",
"channel"),
])
def test_compressed_tensors_w8a8_dynanmic_per_token(vllm_runner, model_args):
def test_compressed_tensors_w8a8_dynamic_per_token(vllm_runner, model_args):
model_path, strategy = model_args
with vllm_runner(model_path, dtype=torch.float16) as llm:
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
Expand Down Expand Up @@ -160,4 +178,4 @@ def test_compressed_tensors_kv_cache(vllm_runner):
model_path = "nm-testing/TinyLlama-1.1B-compressed-tensors-kv-cache-scheme"
with vllm_runner(model_path, kv_cache_dtype="fp8") as llm:
output = llm.generate_greedy("Hello world!", max_tokens=20)
assert output
assert output
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,11 @@ def _is_static_tensor_w8a8(self, weight_quant: BaseModel,
or weight_quant.strategy == QuantizationStrategy.CHANNEL.value)
is_tensor = (weight_strategy and input_quant.strategy
== QuantizationStrategy.TENSOR.value)
is_symmetric = weight_quant.symmetric and input_quant.symmetric
is_static = not weight_quant.dynamic and not input_quant.dynamic

return is_8_bits and is_tensor and is_symmetric and is_static
# Both symmetric and asymmetric input quantization supported.
# Only symmetric weight quantization supported.
return is_8_bits and is_tensor and weight_quant.symmetric and is_static

def _is_dynamic_token_w8a8(self, weight_quant: BaseModel,
input_quant: BaseModel) -> bool:
Expand All @@ -151,10 +152,11 @@ def _is_dynamic_token_w8a8(self, weight_quant: BaseModel,
or weight_quant.strategy == QuantizationStrategy.CHANNEL.value)
is_token = (weight_strategy and input_quant.strategy
== QuantizationStrategy.TOKEN.value)
is_symmetric = weight_quant.symmetric and input_quant.symmetric
is_dynamic = not weight_quant.dynamic and input_quant.dynamic

return is_8_bits and is_token and is_symmetric and is_dynamic
# Both symmetric and asymmetric input quantization supported.
# Only symmetric weight quantization supported.
return is_8_bits and is_token and weight_quant.symmetric and is_dynamic

def _is_fp8_w8a8(self, weight_quant: BaseModel,
input_quant: BaseModel) -> bool:
Expand Down Expand Up @@ -265,12 +267,14 @@ def _get_scheme_from_parts(
if self._is_static_tensor_w8a8(weight_quant, input_quant):
return CompressedTensorsW8A8Int8(
strategy=weight_quant.strategy,
is_static_input_scheme=True)
is_static_input_scheme=True,
input_symmetric=input_quant.symmetric)

if self._is_dynamic_token_w8a8(weight_quant, input_quant):
return CompressedTensorsW8A8Int8(
strategy=weight_quant.strategy,
is_static_input_scheme=False)
is_static_input_scheme=False,
input_symmetric=input_quant.symmetric)

raise NotImplementedError(
"No compressed-tensors compatible scheme was found.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch
from torch.nn import Parameter

from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
Expand All @@ -14,12 +15,16 @@
ModelWeightParameter,
PerTensorScaleParameter)

logger = init_logger(__name__)


class CompressedTensorsW8A8Int8(CompressedTensorsScheme):

def __init__(self, strategy: str, is_static_input_scheme: bool):
def __init__(self, strategy: str, is_static_input_scheme: bool,
input_symmetric: bool):
self.strategy = strategy
self.is_static_input_scheme = is_static_input_scheme
self.input_symmetric = input_symmetric

@classmethod
def get_min_capability(cls) -> int:
Expand All @@ -46,10 +51,43 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
requires_grad=False)
# INPUT SCALE
if self.is_static_input_scheme:
layer.input_scale = Parameter(layer.input_scale.max(),
requires_grad=False)
if self.input_symmetric:
layer.input_scale = Parameter(layer.input_scale.max(),
requires_grad=False)
layer.input_zero_point = None
else:
# reconstruct the ranges
int8_traits = torch.iinfo(torch.int8)
azps = layer.input_zero_point.to(dtype=torch.int32)
range_max = (layer.input_scale *
(int8_traits.max - azps)).max()
range_min = (layer.input_scale *
(int8_traits.min - azps)).min()

scale = (range_max - range_min) / (int8_traits.max -
int8_traits.min)
layer.input_scale = Parameter(scale, requires_grad=False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should add an accuracy test to make sure this works as expected

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where should we add the accuracy test?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is now complete with the compressed tensors quantization tests - thanks!


# AZP loaded as int8 but used as int32
azp = (int8_traits.min -
range_min / scale).to(dtype=torch.int32)
layer.input_zero_point = Parameter(azp, requires_grad=False)

else:
layer.input_scale = None
layer.input_zero_point = None

# azp_adj is the AZP adjustment term, used to account for weights.
# It does not depend on scales or azp, so it is the same for
# static and dynamic quantization.
# For more details, see csrc/quantization/cutlass_w8a8/Epilogues.md
# https://github.com/vllm-project/vllm/blob/8d59dbb00044a588cab96bcdc028006ed922eb06/csrc/quantization/cutlass_w8a8/Epilogues.md
if not self.input_symmetric:
ProExpertProg marked this conversation as resolved.
Show resolved Hide resolved
layer.azp_adj = layer.weight.sum(dim=0,
mgoin marked this conversation as resolved.
Show resolved Hide resolved
keepdim=True,
dtype=torch.int32)
else:
layer.azp_adj = None

def create_weights(self, layer: torch.nn.Module,
output_partition_sizes: List[int],
Expand Down Expand Up @@ -90,11 +128,22 @@ def create_weights(self, layer: torch.nn.Module,
weight_loader=weight_loader)
layer.register_parameter("input_scale", input_scale)

if not self.input_symmetric:
ProExpertProg marked this conversation as resolved.
Show resolved Hide resolved
# Note: compressed-tensors stores the zp using the same dtype
# as the weights
# AZP loaded as int8 but used as int32
input_zero_point = BasevLLMParameter(
data=torch.empty(1, dtype=torch.int8),
weight_loader=weight_loader)
layer.register_parameter("input_zero_point", input_zero_point)

def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:

return apply_int8_linear(input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
input_scale=layer.input_scale,
input_zero_point=layer.input_zero_point,
azp_adj=layer.azp_adj,
bias=bias)
19 changes: 17 additions & 2 deletions vllm/model_executor/layers/quantization/utils/w8a8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,13 +191,28 @@ def apply_int8_linear(
weight: torch.Tensor,
weight_scale: torch.Tensor,
input_scale: Optional[torch.Tensor] = None,
input_zero_point: Optional[torch.Tensor] = None,
azp_adj: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
):
# ops.scaled_int8_quant supports both dynamic and static quant.
# * dynamic, layer.input_scale is None and x_scale computed from x.
# * static, layer.input_scale is scalar and x_scale is input_scale.
x_q, x_scale, _ = ops.scaled_int8_quant(input, input_scale)

symmetric = azp_adj is None
x_q, x_scale, x_zp = ops.scaled_int8_quant(input,
input_scale,
input_zero_point,
symmetric=symmetric)

if x_zp is not None:
return ops.cutlass_scaled_mm_azp(x_q,
ProExpertProg marked this conversation as resolved.
Show resolved Hide resolved
weight,
scale_a=x_scale,
scale_b=weight_scale,
out_dtype=input.dtype,
azp_adj=azp_adj,
azp=x_zp,
bias=bias)
return ops.cutlass_scaled_mm(x_q,
weight,
scale_a=x_scale,
Expand Down
Loading