Skip to content
This repository was archived by the owner on Oct 11, 2024. It is now read-only.

Commit 0e8e31e

Browse files
dsikkarobertgshaw2-redhat
authored andcommitted
[Misc] Add per channel support for static activation quantization; update w8a8 schemes to share base classes (vllm-project#5650)
1 parent 4f4cea6 commit 0e8e31e

File tree

5 files changed

+121
-136
lines changed

5 files changed

+121
-136
lines changed

tests/quantization/test_compressed_tensors.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,12 @@
1818
allow_module_level=True)
1919

2020

21-
def test_compressed_tensors_w8a8_static_setup(vllm_runner):
22-
model_path = "nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change"
21+
@pytest.mark.parametrize("model_args", [
22+
("nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", "tensor"),
23+
("nm-testing/tinyllama-oneshot-w8-channel-a8-tensor", "channel"),
24+
])
25+
def test_compressed_tensors_w8a8_static_setup(vllm_runner, model_args):
26+
model_path, strategy = model_args
2327
with vllm_runner(model_path, enforce_eager=True) as llm:
2428
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
2529
layer = model.model.layers[0]
@@ -38,12 +42,14 @@ def test_compressed_tensors_w8a8_static_setup(vllm_runner):
3842

3943
assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8StaticTensor)
4044

45+
assert qkv_proj.scheme.strategy == strategy
4146
assert qkv_proj.weight.dtype is torch.int8
4247
assert o_proj.weight.dtype is torch.int8
4348
assert gate_up_proj.weight.dtype is torch.int8
4449

45-
assert qkv_proj.weight_scale.shard_splitter is not None
46-
assert qkv_proj.weight_scale.logical_widths is not None
50+
if qkv_proj.scheme.strategy == "tensor":
51+
assert qkv_proj.weight_scale.shard_splitter is not None
52+
assert qkv_proj.weight_scale.logical_widths is not None
4753
assert qkv_proj.input_scale.dtype is torch.float32
4854

4955

vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,11 @@ def get_config_filenames(cls) -> List[str]:
8585
def _is_static_tensor_w8a8(self, weight_quant: BaseModel,
8686
input_quant: BaseModel) -> bool:
8787
is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8
88-
is_tensor = (weight_quant.strategy == input_quant.strategy ==
89-
QuantizationStrategy.TENSOR.value)
88+
weight_strategy = (
89+
weight_quant.strategy == QuantizationStrategy.TENSOR.value
90+
or weight_quant.strategy == QuantizationStrategy.CHANNEL.value)
91+
is_tensor = (weight_strategy and input_quant.strategy
92+
== QuantizationStrategy.TENSOR.value)
9093
is_symmetric = weight_quant.symmetric and input_quant.symmetric
9194
is_static = not weight_quant.dynamic and not input_quant.dynamic
9295

@@ -131,7 +134,8 @@ def _get_schema(self, weight_quant: BaseModel,
131134

132135
if self.quant_format == CompressionFormat.int_quantized.value:
133136
if self._is_static_tensor_w8a8(weight_quant, input_quant):
134-
return CompressedTensorsW8A8StaticTensor()
137+
return CompressedTensorsW8A8StaticTensor(
138+
strategy=weight_quant.strategy)
135139

136140
if self._is_dynamic_token_w8a8(weight_quant, input_quant):
137141
return CompressedTensorsW8A8DynamicToken(
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
from typing import Callable, List, Tuple, Union
2+
3+
import torch
4+
from torch.nn import Parameter
5+
6+
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
7+
CompressedTensorsScheme)
8+
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
9+
QuantizationStrategy)
10+
from vllm.model_executor.utils import set_weight_attrs
11+
12+
13+
class CompressedTensorsW8A8(CompressedTensorsScheme):
14+
15+
def __init__(self, strategy: str):
16+
self.strategy = strategy
17+
18+
def _shard_id_as_int(self, shard_id: Union[str, int]) -> int:
19+
if isinstance(shard_id, int):
20+
return shard_id
21+
22+
assert isinstance(shard_id, str)
23+
qkv_idxs = {"q": 0, "k": 1, "v": 2}
24+
assert shard_id in qkv_idxs
25+
return qkv_idxs[shard_id]
26+
27+
def scales_shard_splitter(
28+
self, param: torch.Tensor, loaded_weight: torch.Tensor,
29+
shard_id: Union[str, int],
30+
logical_widths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
31+
shard_id = self._shard_id_as_int(shard_id)
32+
offset = sum(logical_widths[:shard_id])
33+
size = logical_widths[shard_id]
34+
# update loaded weight with copies for broadcast.
35+
loaded_weight = loaded_weight.repeat(size)
36+
return param[offset:offset + size], loaded_weight
37+
38+
def create_weights(self, layer: torch.nn.Module,
39+
output_partition_sizes: List[int],
40+
input_size_per_partition: int,
41+
params_dtype: torch.dtype, weight_loader: Callable,
42+
**kwargs):
43+
44+
is_tensor_partitioned = len(output_partition_sizes) != 1
45+
weight_scale_dim = sum(output_partition_sizes) if (
46+
is_tensor_partitioned
47+
or self.strategy == QuantizationStrategy.CHANNEL) else 1
48+
49+
shape: Union[Tuple[int], Tuple[int, int]] = (weight_scale_dim, )
50+
if self.strategy == QuantizationStrategy.CHANNEL:
51+
shape = (weight_scale_dim, 1)
52+
53+
weight_scale = Parameter(torch.empty(*shape, dtype=torch.float32),
54+
requires_grad=False)
55+
56+
layer.register_parameter("weight_scale", weight_scale)
57+
set_weight_attrs(weight_scale, {"weight_loader": weight_loader})
58+
59+
weight = Parameter(torch.empty(sum(output_partition_sizes),
60+
input_size_per_partition,
61+
dtype=torch.int8),
62+
requires_grad=False)
63+
64+
layer.register_parameter("weight", weight)
65+
set_weight_attrs(
66+
weight, {
67+
"input_dim": 1,
68+
"output_dim": 0,
69+
"weight_loader": weight_loader,
70+
"logical_widths": output_partition_sizes
71+
})
72+
73+
# Don't need a shard_splitter for channel-wise quantization
74+
# Use the default loading method
75+
if self.strategy == QuantizationStrategy.CHANNEL:
76+
set_weight_attrs(weight_scale, {
77+
"output_dim": 0,
78+
})
79+
else:
80+
set_weight_attrs(
81+
weight_scale, {
82+
"logical_widths": output_partition_sizes,
83+
"shard_splitter": self.scales_shard_splitter,
84+
})

vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_dynamictoken.py

+10-79
Original file line numberDiff line numberDiff line change
@@ -1,97 +1,28 @@
1-
from typing import Callable, List, Tuple, Union
1+
from typing import Callable, List
22

33
import torch
4-
from torch.nn import Parameter
54

65
from vllm import _custom_ops as custom_ops
7-
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
8-
CompressedTensorsScheme)
9-
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
10-
QuantizationStrategy)
11-
from vllm.model_executor.utils import set_weight_attrs
6+
from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_w8a8 import ( # noqa: E501
7+
CompressedTensorsW8A8)
128

139
__all__ = ["CompressedTensorsW8A8DynamicToken"]
1410

1511

16-
class CompressedTensorsW8A8DynamicToken(CompressedTensorsScheme):
17-
18-
def __init__(self, strategy: str):
19-
self.strategy = strategy
20-
21-
def _shard_id_as_int(self, shard_id: Union[str, int]) -> int:
22-
if isinstance(shard_id, int):
23-
return shard_id
24-
25-
assert isinstance(shard_id, str)
26-
qkv_idxs = {"q": 0, "k": 1, "v": 2}
27-
assert shard_id in qkv_idxs
28-
return qkv_idxs[shard_id]
29-
30-
def scales_shard_splitter(
31-
self, param: torch.Tensor, loaded_weight: torch.Tensor,
32-
shard_id: Union[str, int],
33-
logical_widths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
34-
shard_id = self._shard_id_as_int(shard_id)
35-
offset = sum(logical_widths[:shard_id])
36-
size = logical_widths[shard_id]
37-
# update loaded weight with copies for broadcast.
38-
loaded_weight = loaded_weight.repeat(size)
39-
return param[offset:offset + size], loaded_weight
12+
class CompressedTensorsW8A8DynamicToken(CompressedTensorsW8A8):
4013

4114
def create_weights(self, layer: torch.nn.Module,
4215
output_partition_sizes: List[int],
4316
input_size_per_partition: int,
4417
params_dtype: torch.dtype, weight_loader: Callable,
4518
**kwargs):
4619

47-
# When the scales have a single value, it is required that they be
48-
# on the CPU for performance and CUDA Graphs compatibility. Please
49-
# refer to the comment in
50-
# CompressedTensorsW8A8StaticTensor::create_weights for further
51-
# information.
52-
is_tensor_partitioned = len(output_partition_sizes) != 1
53-
# when doing channel-wise quantization, number of scales
54-
# is equal to output_dim
55-
weight_scale_dim = sum(output_partition_sizes) if (
56-
is_tensor_partitioned
57-
or self.strategy == QuantizationStrategy.CHANNEL) else 1
58-
59-
shape: Union[Tuple[int], Tuple[int, int]] = (weight_scale_dim, )
60-
if self.strategy == QuantizationStrategy.CHANNEL:
61-
shape = (weight_scale_dim, 1)
62-
63-
weight_scale = Parameter(torch.empty(*shape, dtype=torch.float32),
64-
requires_grad=False)
65-
66-
weight = Parameter(torch.empty(sum(output_partition_sizes),
67-
input_size_per_partition,
68-
dtype=torch.int8),
69-
requires_grad=False)
70-
71-
layer.register_parameter("weight", weight)
72-
set_weight_attrs(
73-
weight, {
74-
"input_dim": 1,
75-
"output_dim": 0,
76-
"weight_loader": weight_loader,
77-
"logical_widths": output_partition_sizes
78-
})
79-
80-
layer.register_parameter("weight_scale", weight_scale)
81-
set_weight_attrs(weight_scale, {"weight_loader": weight_loader})
82-
83-
# Don't need a shard_splitter for channel-wise quantization
84-
# Use the default loading method
85-
if self.strategy == QuantizationStrategy.CHANNEL:
86-
set_weight_attrs(weight_scale, {
87-
"output_dim": 0,
88-
})
89-
else:
90-
set_weight_attrs(
91-
weight_scale, {
92-
"logical_widths": output_partition_sizes,
93-
"shard_splitter": self.scales_shard_splitter,
94-
})
20+
super().create_weights(
21+
layer=layer,
22+
output_partition_sizes=output_partition_sizes,
23+
input_size_per_partition=input_size_per_partition,
24+
params_dtype=params_dtype,
25+
weight_loader=weight_loader)
9526

9627
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor):
9728
weight = layer.weight

vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py

+10-50
Original file line numberDiff line numberDiff line change
@@ -1,79 +1,39 @@
1-
from typing import Callable, List, Tuple, Union
1+
from typing import Callable, List
22

33
import torch
44
from torch.nn import Parameter
55

66
from vllm import _custom_ops as custom_ops
7-
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
8-
CompressedTensorsScheme)
7+
from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_w8a8 import ( # noqa: E501
8+
CompressedTensorsW8A8)
99
from vllm.model_executor.utils import set_weight_attrs
1010

1111
__all__ = ["CompressedTensorsW8A8StaticTensor"]
1212

1313

14-
class CompressedTensorsW8A8StaticTensor(CompressedTensorsScheme):
15-
16-
def _shard_id_as_int(self, shard_id: Union[str, int]) -> int:
17-
if isinstance(shard_id, int):
18-
return shard_id
19-
20-
assert isinstance(shard_id, str)
21-
qkv_idxs = {"q": 0, "k": 1, "v": 2}
22-
assert shard_id in qkv_idxs
23-
return qkv_idxs[shard_id]
24-
25-
def scales_shard_splitter(
26-
self, param: torch.Tensor, loaded_weight: torch.Tensor,
27-
shard_id: Union[str, int],
28-
logical_widths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
29-
shard_id = self._shard_id_as_int(shard_id)
30-
offset = sum(logical_widths[:shard_id])
31-
size = logical_widths[shard_id]
32-
# update loaded weight with copies for broadcast.
33-
loaded_weight = loaded_weight.repeat(size)
34-
return param[offset:offset + size], loaded_weight
14+
class CompressedTensorsW8A8StaticTensor(CompressedTensorsW8A8):
3515

3616
def create_weights(self, layer: torch.nn.Module,
3717
output_partition_sizes: List[int],
3818
input_size_per_partition: int,
3919
params_dtype: torch.dtype, weight_loader: Callable,
4020
**kwargs):
4121

42-
is_tensor_partitioned = len(output_partition_sizes) != 1
43-
weight_scale_dim = sum(
44-
output_partition_sizes) if is_tensor_partitioned else 1
22+
super().create_weights(
23+
layer=layer,
24+
output_partition_sizes=output_partition_sizes,
25+
input_size_per_partition=input_size_per_partition,
26+
params_dtype=params_dtype,
27+
weight_loader=weight_loader)
4528

4629
input_scale = Parameter(torch.empty(1, dtype=torch.float32),
4730
requires_grad=False)
4831

49-
weight_scale = Parameter(torch.empty(weight_scale_dim,
50-
dtype=torch.float32),
51-
requires_grad=False)
52-
53-
weight = Parameter(torch.empty(sum(output_partition_sizes),
54-
input_size_per_partition,
55-
dtype=torch.int8),
56-
requires_grad=False)
57-
58-
layer.register_parameter("weight", weight)
59-
set_weight_attrs(weight, {
60-
"weight_loader": weight_loader,
61-
"input_dim": 1,
62-
"output_dim": 0,
63-
})
6432
layer.register_parameter("input_scale", input_scale)
6533
set_weight_attrs(input_scale, {
6634
"weight_loader": weight_loader,
6735
"ignore_warning": True,
6836
})
69-
layer.register_parameter("weight_scale", weight_scale)
70-
set_weight_attrs(
71-
weight_scale, {
72-
"weight_loader": weight_loader,
73-
"shard_splitter": self.scales_shard_splitter,
74-
"logical_widths": output_partition_sizes,
75-
"ignore_warning": True,
76-
})
7737

7838
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor):
7939
weight = layer.weight

0 commit comments

Comments
 (0)