Skip to content

Commit

Permalink
QDQ tool modification part3 (#9904)
Browse files Browse the repository at this point in the history
* refine per channel quantization for qdq

* remove old option

* add comment

* add import itertools
  • Loading branch information
chilo-ms authored Dec 3, 2021
1 parent 4ff78aa commit 02aa16e
Show file tree
Hide file tree
Showing 6 changed files with 34 additions and 16 deletions.
2 changes: 0 additions & 2 deletions onnxruntime/python/tools/quantization/onnx_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,6 @@ def __init__(self, model, per_channel, reduce_range, mode, static, weight_qType,
is_weight_int8 = weight_qType == QuantType.QInt8
self.is_weight_symmetric = is_weight_int8 if 'WeightSymmetric' not in self.extra_options else self.extra_options['WeightSymmetric']
self.is_activation_symmetric = False if 'ActivationSymmetric' not in self.extra_options else self.extra_options['ActivationSymmetric']
self.op_types_support_per_channel_quantization = [] if 'OpTypesSupportPerChannelQuantization' not in extra_options \
else extra_options['OpTypesSupportPerChannelQuantization']

self.input_qType = onnx_proto.TensorProto.INT8 if input_qType == QuantType.QInt8 else onnx_proto.TensorProto.UINT8
self.weight_qType = onnx_proto.TensorProto.INT8 if weight_qType == QuantType.QInt8 else onnx_proto.TensorProto.UINT8
Expand Down
23 changes: 23 additions & 0 deletions onnxruntime/python/tools/quantization/operators/matmul.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import onnx
import itertools
from .base_operator import QuantOperatorBase
from .qdq_base_operator import QDQOperatorBase
from ..quant_utils import find_by_name, get_mul_node, QuantizedValue, QuantizedValueType
from onnx import onnx_pb as onnx_proto
'''
Expand Down Expand Up @@ -98,3 +100,24 @@ def quantize(self):
self.quantizer.quantized_value_map[node.output[0]] = q_output

self.quantizer.new_nodes += nodes

class QDQMatMul(QDQOperatorBase):
def __init__(self, onnx_quantizer, onnx_node):
super().__init__(onnx_quantizer, onnx_node)

def quantize(self):
node = self.node
assert (node.op_type == "MatMul")

if self.disable_qdq_for_node_output:
nodes_to_iterate = node.input
else:
nodes_to_iterate = itertools.chain(node.input, node.output)

for tensor_name in nodes_to_iterate:
# only support per-channel quantization on weight
if self.quantizer.is_per_channel() and find_by_name(tensor_name, self.quantizer.model.initializer()) :
channel_axis = self.quantizer.qdq_op_type_per_channel_support_to_axis.get(node.op_type, 1)
self.quantizer.quantize_tensor_per_channel(tensor_name, channel_axis)
else:
self.quantizer.quantize_tensor(tensor_name)
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,4 @@ def quantize(self):
nodes_to_iterate = itertools.chain(node.input, node.output)

for tensor_name in nodes_to_iterate:
if self.quantizer.is_per_channel():
if node.op_type in self.quantizer.op_types_support_per_channel_quantization :
self.quantizer.quantize_tensor_per_channel(tensor_name, self.quantizer.qdq_channel_axis)
else:
self.quantizer.quantize_tensor(tensor_name)
else:
self.quantizer.quantize_tensor(tensor_name)
self.quantizer.quantize_tensor(tensor_name)
10 changes: 4 additions & 6 deletions onnxruntime/python/tools/quantization/qdq_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,8 @@ def __init__(self, model, per_channel, reduce_range, mode, static, weight_qType,
self.op_types_to_exclude_output_quantization = [] if 'OpTypesToExcludeOutputQuantizatioin' not in extra_options \
else extra_options['OpTypesToExcludeOutputQuantizatioin']

# In some cases, for example QDQ BERT model for TensorRT,
# QDQ should always appear as a pair.
# For our quantization tool, we do quantization on Dequantizelinear's input
# to remove Quantizelinear as optimization for weight.
# We do quantization on Dequantizelinear's input to remove Quantizelinear for weight as an optimization.
# In some cases, for example QDQ BERT model for TensorRT, QDQ should always appear as a pair.
# Therefore, we need to disable this optimization and add qdq pair to weight.
self.add_qdq_pair_to_weight = False if 'AddQDQPairToWeight' not in extra_options \
else extra_options['AddQDQPairToWeight']
Expand All @@ -57,8 +55,8 @@ def __init__(self, model, per_channel, reduce_range, mode, static, weight_qType,
if self.dedicated_qdq_pair:
self.tensor_to_its_receiving_nodes = {}

# Channel axis when per_channel is True
self.qdq_channel_axis = 0 if 'QDQChannelAxis' not in extra_options else extra_options['QDQChannelAxis']
# Let user set channel axis for specific op type and it's effective only when per channel quantization is supported and per_channel is True.
self.qdq_op_type_per_channel_support_to_axis = {} if 'QDQOpTypePerChannelSupportToAxis' not in extra_options else extra_options['QDQOpTypePerChannelSupportToAxis']

def quantize_tensor(self, tensor_name):
weight = find_by_name(tensor_name, self.model.initializer())
Expand Down
4 changes: 4 additions & 0 deletions onnxruntime/python/tools/quantization/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,10 @@ def quantize_static(model_input,
the output of ops with this specific op types.
DedicatedQDQPair = True/False : Default is False. When inserting QDQ pair, multiple nodes can share a single QDQ pair as their inputs.
If True, it will create identical and dedicated QDQ pair for each node.
QDQOpTypePerChannelSupportToAxis = dictionary : Default is {}. Set channel axis for specific op type, for example: {'MatMul': 1},
and it's effective only when per channel quantization is supported and per_channel is True.
If specific op type supports per channel quantization but not explicitly specified with channel axis,
default channel axis will be used.
'''

mode = QuantizationMode.QLinearOps
Expand Down
3 changes: 2 additions & 1 deletion onnxruntime/python/tools/quantization/registry.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from .quant_utils import QuantizationMode
from .operators.base_operator import QuantOperatorBase
from .operators.qdq_base_operator import QDQOperatorBase
from .operators.matmul import MatMulInteger, QLinearMatMul
from .operators.matmul import MatMulInteger, QLinearMatMul, QDQMatMul
from .operators.attention import AttentionQuant
from .operators.embed_layernorm import EmbedLayerNormalizationQuant
from .operators.gather import GatherQuant
Expand Down Expand Up @@ -66,6 +66,7 @@
"MaxPool": QDQMaxPool,
"AveragePool" : QDQDirect8BitOp,
"Concat": QDQConcat,
"MatMul": QDQMatMul,
}


Expand Down

0 comments on commit 02aa16e

Please sign in to comment.