Skip to content

Commit

Permalink
Support Fake GroupWise Quant (PaddlePaddle#61900)
Browse files Browse the repository at this point in the history
  • Loading branch information
RachelXu7 authored Feb 21, 2024
1 parent 96c2aaf commit 2175de0
Show file tree
Hide file tree
Showing 4 changed files with 221 additions and 3 deletions.
39 changes: 37 additions & 2 deletions python/paddle/nn/quant/format.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,14 @@ def from_quanter(quanter):


class LinearQuanter(Layer):
def __init__(self, scales, zero_point=None, quant_axis=None, bit_length=8):
def __init__(
self,
scales,
zero_point=None,
quant_axis=None,
bit_length=8,
group_size=128,
):
super().__init__()
scales = paddle.to_tensor(scales, dtype="float32")
scale_attr = paddle.framework.ParamAttr(
Expand All @@ -65,9 +72,21 @@ def __init__(self, scales, zero_point=None, quant_axis=None, bit_length=8):
)
self._quant_axis = -1 if quant_axis is None else quant_axis
self._bit_length = bit_length
self._group_size = group_size

def forward(self, input):
if in_dynamic_mode():
if len(self._scales.shape) > 1:
bnt = (1 << (self._bit_length - 1)) - 1
new_s = paddle.repeat_interleave(
self._scales, self._group_size, 0
)
quant_weight = paddle.clip(
paddle.round(input.cast('float32') / new_s * bnt),
-bnt - 1,
bnt,
)
return quant_weight.cast(input.dtype)
return _C_ops.quantize_linear(
input.cast('float32'),
self._scales,
Expand Down Expand Up @@ -105,7 +124,14 @@ def from_quanter(quanter):


class LinearDequanter(Layer):
def __init__(self, scales, zero_point=None, quant_axis=None, bit_length=8):
def __init__(
self,
scales,
zero_point=None,
quant_axis=None,
bit_length=8,
group_size=128,
):
super().__init__()
scales = paddle.to_tensor(scales, dtype="float32")
scale_attr = paddle.framework.ParamAttr(
Expand All @@ -124,9 +150,18 @@ def __init__(self, scales, zero_point=None, quant_axis=None, bit_length=8):
)
self._quant_axis = -1 if quant_axis is None else quant_axis
self._bit_length = bit_length
self._group_size = group_size

def forward(self, input):
if in_dynamic_mode():
if len(self._scales.shape) > 1:
bnt = (1 << (self._bit_length - 1)) - 1
new_s = paddle.repeat_interleave(
self._scales, self._group_size, 0
)
quant_dequant_weight = input.cast('float32') / bnt * new_s
return quant_dequant_weight.cast(input.dtype)

return _C_ops.dequantize_linear(
input.cast('float32'),
self._scales,
Expand Down
3 changes: 2 additions & 1 deletion python/paddle/quantization/observers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,6 @@
# limitations under the License.

from .abs_max import AbsmaxObserver
from .groupwise import GroupWiseWeightObserver

__all__ = ["AbsmaxObserver"]
__all__ = ["AbsmaxObserver", "GroupWiseWeightObserver"]
113 changes: 113 additions & 0 deletions python/paddle/quantization/observers/groupwise.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np

import paddle

from ..base_observer import BaseObserver
from ..factory import ObserverFactory


class GroupWiseWeightObserver(ObserverFactory):
r"""
It collects channel-wise maximum absolute values of target weights.
Args:
bit_length(int, optional): Number of bits to represent an quantized integer in binary.
dtype(str, optional): The data type of input tensor.
name (str, optional): This parameter is used by developers to print debugging information. \
For details, please refer to :ref:`api_guide_Name`. Default is None.
Examples:
.. code-block:: python
from paddle.quantization import QuantConfig
from paddle.quantization.quanters import AbsMaxChannelWiseWeightObserver
quanter = AbsMaxChannelWiseWeightObserver()
q_config = QuantConfig(activation=None, weight=quanter)
"""

def __init__(self, quant_bits=8, group_size=128):
super().__init__(quant_bits=quant_bits)

def _get_class(self):
return GroupWiseWeightObserverLayer


class GroupWiseWeightObserverLayer(BaseObserver):
def __init__(self, layer, quant_bits=8, group_size=128):
super().__init__()
self.quant_bits = quant_bits
self.group_size = group_size
self._layer = layer
self._max = None
self._scale = None
self._zero_point = None

def forward(self, inputs):
self._max = self._cal_abs_max(inputs)
return inputs

def _cal_abs_max(self, inputs):
"""Use group_size to group the input, then use the
absmax method to calculate the scale
"""
input_shape = inputs.shape
assert (
self.group_size == 64 or self.group_size == 128
), "group_size only support 64 or 128"
assert (
inputs.shape[0] % self.group_size == 0
), "group_size must be a factor of input channels"
assert len(inputs.shape) == 2, "Currently only support 2D tensor"
input_processed = inputs.transpose([1, 0]).reshape(
[input_shape[1], input_shape[0] // self.group_size, self.group_size]
)

abs_max_values = paddle.max(paddle.abs(input_processed), axis=2).cast(
"float32"
)
abs_max_values = paddle.where(
abs_max_values == np.float32(0), np.float32(1e-8), abs_max_values
)
abs_max_values = abs_max_values.transpose([1, 0])
return abs_max_values

def min_value(self) -> float:
return 0.0

def max_value(self) -> float:
return self._max

def bit_length(self):
return self._quant_bits

def quant_axis(self):
return -1

def cal_thresholds(self):
"""Compute thresholds for MAX function."""
if self._scale is None:
self._scale = self._max
self._zero_point = paddle.zeros_like(self._scale)

def scales(self):
"""Return output scales."""
if self._scale is None:
self.cal_thresholds()
return self._scale

def zero_points(self):
"""Return output zero points."""
if self._zero_point is None:
self.cal_thresholds()
return self._zero_point
69 changes: 69 additions & 0 deletions test/quantization/test_groupwise.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# copyright (c) 2023 paddlepaddle authors. all rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import tempfile
import unittest

import paddle
from paddle.nn import Linear, Sequential
from paddle.quantization import PTQ, QuantConfig
from paddle.quantization.observers import GroupWiseWeightObserver


class LinearDygraph(paddle.nn.Layer):
def __init__(self):
super().__init__()
self.fc = Sequential(
Linear(128, 128), Linear(128, 128), Linear(128, 128)
)

def forward(self, inputs):
out = self.fc(inputs)
return out


class TestPTQGroupWise(unittest.TestCase):
def setUp(self):
self.temp_dir = tempfile.TemporaryDirectory()
self.path = os.path.join(self.temp_dir.name, 'ptq')

def tearDown(self):
self.temp_dir.cleanup()

def _get_model_for_ptq(self):
observer = GroupWiseWeightObserver(quant_bits=4, group_size=128)
model = LinearDygraph()
model.eval()
q_config = QuantConfig(activation=None, weight=observer)
ptq = PTQ(q_config)
quant_model = ptq.quantize(model)
return quant_model, ptq

def _count_layers(self, model, layer_type):
count = 0
for _layer in model.sublayers(True):
if isinstance(_layer, layer_type):
count += 1
return count

def test_quantize(self):
ptq_model, _ = self._get_model_for_ptq()
inputs = paddle.rand([128, 128], dtype="float32")
out = ptq_model(inputs)
self.assertIsNotNone(out)


if __name__ == '__main__':
unittest.main()

0 comments on commit 2175de0

Please sign in to comment.