Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[prototype] Refine the API of quantization for dygraph #47530

Closed
wants to merge 9 commits into from
1 change: 1 addition & 0 deletions python/paddle/nn/quant/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,6 @@
from .functional_layers import concat # noqa: F401
from .functional_layers import flatten # noqa: F401
from .quant_layers import QuantStub # noqa: F401
from . import qat

__all__ = []
18 changes: 18 additions & 0 deletions python/paddle/nn/quant/qat/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .conv import QuantConv2D
from .linear import QuantLinear

__all__ = ["QuantConv2D", "QuantLinear"]
81 changes: 81 additions & 0 deletions python/paddle/nn/quant/qat/conv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Layers used for QAT.
"""
from paddle.nn import functional as F
from paddle.nn import Layer

__all__ = ["QuantConv2D"]


class QuantConv2D(Layer):
"""
The computational logic of QuantizedConv2D is the same with Conv2D.
The only difference is that its inputs are all fake quantized.
"""

def __init__(self, layer: Layer, q_config):
super(QuantConv2D, self).__init__()

# For Conv2D
self._groups = getattr(layer, '_groups')
self._stride = getattr(layer, '_stride')
self._padding = getattr(layer, '_padding')
self._padding_mode = getattr(layer, '_padding_mode')
if self._padding_mode != 'zeros':
self._reversed_padding_repeated_twice = getattr(
layer, '_reversed_padding_repeated_twice'
)
self._dilation = getattr(layer, '_dilation')
self._data_format = getattr(layer, '_data_format')
self.weight = getattr(layer, 'weight')
self.bias = getattr(layer, 'bias')

self.weight_quanter = None
self.activation_quanter = None
if q_config.weight != None:
self.weight_quanter = q_config.weight.instance(layer)
if q_config.activation != None:
self.activation_quanter = q_config.activation.instance(layer)

def forward(self, input):
quant_input = input
quant_weight = self.weight
if self.activation_quanter is not None:
quant_input = self.activation_quanter(input)
if self.weight_quanter is not None:
quant_weight = self.weight_quanter(self.weight)
return self._conv_forward(quant_input, quant_weight)

def _conv_forward(self, inputs, weights):
if self._padding_mode != 'zeros':
inputs = F.pad(
inputs,
self._reversed_padding_repeated_twice,
mode=self._padding_mode,
data_format=self._data_format,
)
self._padding = 0

return F.conv2d(
inputs,
weights,
bias=self.bias,
padding=self._padding,
stride=self._stride,
dilation=self._dilation,
groups=self._groups,
data_format=self._data_format,
)
53 changes: 53 additions & 0 deletions python/paddle/nn/quant/qat/linear.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from paddle.nn import functional as F
from paddle.nn import Layer

__all__ = ["QuantLinear"]


class QuantLinear(Layer):
"""
The computational logic of QuantizedLinear is the same with Linear.
The only difference is that its inputs are all fake quantized.
"""

def __init__(self, layer: Layer, q_config):
super(QuantLinear, self).__init__()
# For Linear
self.weight = getattr(layer, 'weight')
self.bias = getattr(layer, 'bias')
self.name = getattr(layer, 'name')
# For FakeQuant

self.weight_quanter = None
self.activation_quanter = None
if q_config.weight != None:
self.weight_quanter = q_config.weight.instance(layer)
if q_config.activation != None:
self.activation_quanter = q_config.activation.instance(layer)

def forward(self, input):
quant_input = input
quant_weight = self.weight
if self.activation_quanter is not None:
quant_input = self.activation_quanter(input)
if self.weight_quanter is not None:
quant_weight = self.weight_quanter(self.weight)
return self._linear_forward(quant_input, quant_weight)

def _linear_forward(self, input, weight):
out = F.linear(x=input, weight=weight, bias=self.bias, name=self.name)
return out
3 changes: 0 additions & 3 deletions python/paddle/nn/quant/quant_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from paddle import _legacy_C_ops
from paddle import in_dynamic_mode
from paddle.nn import Layer
from paddle.nn.quant.lsq import FakeQuantActLSQPlus, FakeQuantWeightLSQPlus

__all__ = [
'FakeQuantAbsMax',
Expand Down Expand Up @@ -1116,8 +1115,6 @@ def _get_fake_quant_type(quant_type, **kwargs):
'abs_max': FakeQuantAbsMax,
'moving_average_abs_max': FakeQuantMovingAverageAbsMax,
'channel_wise_abs_max': FakeQuantChannelWiseAbsMax,
'lsq_weight': FakeQuantWeightLSQPlus,
'lsq_act': FakeQuantActLSQPlus,
}

return fake_quant_map[quant_type](**call_args)
73 changes: 41 additions & 32 deletions python/paddle/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,35 +12,44 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from ...fluid.contrib.slim.quantization.imperative.ptq_config import (
PTQConfig,
default_ptq_config,
)
from ...fluid.contrib.slim.quantization.imperative.ptq_quantizer import (
BaseQuantizer,
)
from ...fluid.contrib.slim.quantization.imperative.ptq_quantizer import (
AbsmaxQuantizer,
)
from ...fluid.contrib.slim.quantization.imperative.ptq_quantizer import (
PerChannelAbsmaxQuantizer,
)
from ...fluid.contrib.slim.quantization.imperative.ptq_quantizer import (
KLQuantizer,
)
from ...fluid.contrib.slim.quantization.imperative.ptq_quantizer import (
HistQuantizer,
)
from ...fluid.contrib.slim.quantization.imperative.ptq_quantizer import (
SUPPORT_ACT_QUANTIZERS,
)
from ...fluid.contrib.slim.quantization.imperative.ptq_quantizer import (
SUPPORT_WT_QUANTIZERS,
)
from ...fluid.contrib.slim.quantization.imperative.ptq_registry import (
PTQRegistry,
)
from ...fluid.contrib.slim.quantization.imperative.ptq import ImperativePTQ
from ...fluid.contrib.slim.quantization.imperative.qat import (
ImperativeQuantAware,
)
# from ...fluid.contrib.slim.quantization.imperative.ptq_config import PTQConfig, default_ptq_config
# from ...fluid.contrib.slim.quantization.imperative.ptq_quantizer import BaseQuantizer
# from ...fluid.contrib.slim.quantization.imperative.ptq_quantizer import AbsmaxQuantizer
# from ...fluid.contrib.slim.quantization.imperative.ptq_quantizer import PerChannelAbsmaxQuantizer
# from ...fluid.contrib.slim.quantization.imperative.ptq_quantizer import KLQuantizer
# from ...fluid.contrib.slim.quantization.imperative.ptq_quantizer import HistQuantizer
# from ...fluid.contrib.slim.quantization.imperative.ptq_quantizer import SUPPORT_ACT_QUANTIZERS
# from ...fluid.contrib.slim.quantization.imperative.ptq_quantizer import SUPPORT_WT_QUANTIZERS
# from ...fluid.contrib.slim.quantization.imperative.ptq_registry import PTQRegistry
# from ...fluid.contrib.slim.quantization.imperative.ptq import ImperativePTQ
# from ...fluid.contrib.slim.quantization.imperative.qat import ImperativeQuantAware
from .qat import QAT
from . import qat
from .ptq import PTQ
from . import ptq
from .config import QuantConfig, TRTQuantConfig
from . import config
from .quanters import (
ActLSQPlusQuanter,
ActLSQPlusQuanter,
FakeQuanterWithAbsMaxObserver,
)
from . import quanters
from .stubs import Stub
from . import stubs
from .factory import ObserverFactory, QuanterFactory
from . import factory
from .quanter import BaseQuanter
from . import quanter
from .observer import BaseObserver
from . import observer

__all__ = []
__all__ += qat.__all__
__all__ += ptq.__all__
__all__ += config.__all__
__all__ += quanters.__all__
__all__ += stubs.__all__
__all__ += factory.__all__
__all__ += quanter.__all__
__all__ += observer.__all__
Loading