Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support hybrid parallel in quantization_pass #52219

Merged
merged 2 commits into from
Apr 6, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 83 additions & 43 deletions python/paddle/static/quantization/quantization_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,7 @@ def _quant_preprocess(op_node):
def _transform_forward(graph, op):
op.op()._set_attr("quantization_type", "qat_with_weight")
op.op()._set_attr("with_quant_attr", True)
op_role = op.op().attr("op_role")
inputs = op.inputs
for var_node in inputs:
if var_node.name() not in op.input_arg_names():
Expand Down Expand Up @@ -368,21 +369,36 @@ def _transform_forward(graph, op):
quant_var_node,
scale_var_node,
) = self._insert_channel_quant_op(
graph, var_node, name, quant_bits, quant_axis
graph,
var_node,
name,
quant_bits,
quant_axis,
op_role,
)
dequant_var_node = self._insert_channel_dequant_op(
graph,
quant_var_node,
[scale_var_node],
[quant_bits],
quant_axis,
op_role,
)
else:
quant_var_node, scale_var_node = self._insert_quant_op(
graph, var_node, name, quant_bits, quant_type
graph,
var_node,
name,
quant_bits,
quant_type,
op_role,
)
dequant_var_node = self._insert_dequant_op(
graph, quant_var_node, scale_var_node, quant_bits
graph,
quant_var_node,
scale_var_node,
quant_bits,
op_role,
)
dequantized_vars[name] = dequant_var_node
graph.update_input_link(var_node, dequant_var_node, op)
Expand Down Expand Up @@ -476,24 +492,28 @@ def _create_global_step(self, graph):
graph.link_to(increment_op, global_step_out)
self._global_step = global_step_out

def _insert_quant_op(self, graph, var_node, name, quant_bits, quant_type):
def _insert_quant_op(
self, graph, var_node, name, quant_bits, quant_type, op_role
):
"""
Insert fake_quantize_op in the graph.
"""
if quant_type == 'abs_max':
return self._insert_quant_abs_max_op(
graph, var_node, name, quant_bits
graph, var_node, name, quant_bits, op_role
)
elif quant_type == 'range_abs_max':
return self._insert_quant_range_abs_max_op(
graph, var_node, name, quant_bits
graph, var_node, name, quant_bits, op_role
)
elif quant_type == 'moving_average_abs_max':
return self._insert_quant_moving_average_abs_max_op(
graph, var_node, name, quant_bits
graph, var_node, name, quant_bits, op_role
)

def _insert_quant_abs_max_op(self, graph, var_node, name, quant_bits):
def _insert_quant_abs_max_op(
self, graph, var_node, name, quant_bits, op_role
):
"""
Insert fake_quantize_abs_max op in the graph.
"""
Expand Down Expand Up @@ -528,10 +548,7 @@ def _insert_quant_abs_max_op(self, graph, var_node, name, quant_bits):

quant_op_node = graph.create_op_node(
op_type='fake_quantize_abs_max',
attrs={
'bit_length': quant_bits,
'op_role': core.op_proto_and_checker_maker.OpRole.Forward,
},
attrs={'bit_length': quant_bits, 'op_role': op_role},
inputs={'X': var_node},
outputs={'Out': quant_var_node, 'OutScale': scale_var_node},
)
Expand All @@ -540,7 +557,9 @@ def _insert_quant_abs_max_op(self, graph, var_node, name, quant_bits):
graph.link_to(quant_op_node, scale_var_node)
return quant_var_node, scale_var_node

def _insert_quant_range_abs_max_op(self, graph, var_node, name, quant_bits):
def _insert_quant_range_abs_max_op(
self, graph, var_node, name, quant_bits, op_role
):
"""
Insert fake_quantize_range_abs_max on the graph.
"""
Expand Down Expand Up @@ -605,7 +624,7 @@ def _insert_quant_range_abs_max_op(self, graph, var_node, name, quant_bits):
'window_size': self._window_size,
'bit_length': quant_bits,
'is_test': self._is_test,
'op_role': core.op_proto_and_checker_maker.OpRole.Forward,
'op_role': op_role,
}
quant_op_node = graph.create_op_node(
op_type='fake_quantize_range_abs_max',
Expand All @@ -626,7 +645,7 @@ def _insert_quant_range_abs_max_op(self, graph, var_node, name, quant_bits):
return quant_var_node, scale_out_node

def _insert_quant_moving_average_abs_max_op(
self, graph, var_node, name, quant_bits
self, graph, var_node, name, quant_bits, op_role
):
"""Insert fake_quantize_moving_average_abs_max"""
quant_var_node = graph.create_var_node(
Expand Down Expand Up @@ -706,7 +725,7 @@ def _insert_quant_moving_average_abs_max_op(
'bit_length': quant_bits,
'moving_rate': self._moving_rate,
'is_test': self._is_test,
'op_role': core.op_proto_and_checker_maker.OpRole.Forward,
'op_role': op_role,
}

quant_op_node = graph.create_op_node(
Expand All @@ -730,7 +749,7 @@ def _insert_quant_moving_average_abs_max_op(
return quant_var_node, scale_out_node

def _insert_channel_quant_op(
self, graph, var_node, name, quant_bits, quant_axis
self, graph, var_node, name, quant_bits, quant_axis, op_role
):
"""
Insert fake_channel_wise_quantize_abs_max op in the graph.
Expand Down Expand Up @@ -771,7 +790,7 @@ def _insert_channel_quant_op(
'bit_length': quant_bits,
'quant_axis': quant_axis,
'is_test': self._is_test,
'op_role': core.op_proto_and_checker_maker.OpRole.Forward,
'op_role': op_role,
},
inputs={'X': var_node},
outputs={'Out': quant_var_node, 'OutScale': scale_var_node},
Expand All @@ -781,7 +800,9 @@ def _insert_channel_quant_op(
graph.link_to(quant_op_node, scale_var_node)
return quant_var_node, scale_var_node

def _insert_dequant_op(self, graph, var_node, scale_var_node, quant_bits):
def _insert_dequant_op(
self, graph, var_node, scale_var_node, quant_bits, op_role
):
"""
Insert fake_dequantize_op in the graph.
"""
Expand All @@ -796,10 +817,7 @@ def _insert_dequant_op(self, graph, var_node, scale_var_node, quant_bits):
max_range = (1 << (quant_bits - 1)) - 1
dequant_op_node = graph.create_op_node(
op_type='fake_dequantize_max_abs',
attrs={
'max_range': float(max_range),
'op_role': core.op_proto_and_checker_maker.OpRole.Forward,
},
attrs={'max_range': float(max_range), 'op_role': op_role},
inputs={'X': var_node, 'Scale': scale_var_node},
outputs={'Out': dequant_var_node},
)
Expand All @@ -809,7 +827,7 @@ def _insert_dequant_op(self, graph, var_node, scale_var_node, quant_bits):
return dequant_var_node

def _insert_channel_dequant_op(
self, graph, var_node, scale_var_nodes, quant_bits, quant_axis
self, graph, var_node, scale_var_nodes, quant_bits, quant_axis, op_role
):
"""
Insert fake_channel_wise_dequantize_max_abs in the graph.
Expand All @@ -827,7 +845,7 @@ def _insert_channel_dequant_op(
attrs={
'quant_bits': quant_bits,
'quant_axis': quant_axis,
'op_role': core.op_proto_and_checker_maker.OpRole.Forward,
'op_role': op_role,
},
inputs={'X': var_node, 'Scales': scale_var_nodes},
outputs={'Out': dequant_var_node},
Expand Down Expand Up @@ -1628,11 +1646,15 @@ def apply(self, graph):
in_node = graph._find_node_by_name(
op.outputs, output_var_name
)
if in_node.dtype() not in [
core.VarDesc.VarType.FP64,
core.VarDesc.VarType.FP32,
core.VarDesc.VarType.FP16,
]:
if (
in_node.dtype()
not in [
core.VarDesc.VarType.FP64,
core.VarDesc.VarType.FP32,
core.VarDesc.VarType.FP16,
]
or '@GRAD' in in_node.name()
):
continue

if in_node.dtype() == core.VarDesc.VarType.FP64:
Expand Down Expand Up @@ -1710,7 +1732,7 @@ def apply(self, graph):
attrs = {
'moving_rate': self._moving_rate,
'is_test': self._is_test,
'op_role': core.op_proto_and_checker_maker.OpRole.Forward,
'op_role': op.op().attr("op_role"),
}
scale_op_node = graph.create_op_node(
op_type='moving_average_abs_max_scale',
Expand Down Expand Up @@ -1953,7 +1975,10 @@ def apply(self, graph):
quant_var_node,
_,
) = self._inser_quant_dequant_moving_average_abs_max_op(
graph, in_node, self._quant_bits
graph,
in_node,
self._quant_bits,
op_node.op().attr("op_role"),
)
dequantized_vars_map[arg_name] = quant_var_node
graph.update_input_link(
Expand All @@ -1978,7 +2003,7 @@ def apply(self, graph):
return graph

def _inser_quant_dequant_moving_average_abs_max_op(
self, graph, var_node, quant_bits
self, graph, var_node, quant_bits, op_role
):
"""Insert fake_quantize_dequantize_moving_average_abs_max op."""
quant_var_node = graph.create_var_node(
Expand Down Expand Up @@ -2068,7 +2093,7 @@ def _inser_quant_dequant_moving_average_abs_max_op(
'bit_length': quant_bits,
'moving_rate': self._moving_rate,
'is_test': self._is_test,
'op_role': core.op_proto_and_checker_maker.OpRole.Forward,
'op_role': op_role,
}

quant_op_node = graph.create_op_node(
Expand Down Expand Up @@ -2131,7 +2156,12 @@ def __init__(
self._scale_dict = scale_dict

def insert_quant_op(
self, graph, var_node, var_name=None, scale_var_node=None
self,
graph,
var_node,
var_name=None,
scale_var_node=None,
op_role=core.op_proto_and_checker_maker.OpRole.Forward,
):
assert var_node.is_var(), f'{var_node.name()} is not a var'
var_name = var_node.name() if not var_name else var_name
Expand Down Expand Up @@ -2200,7 +2230,7 @@ def insert_quant_op(
inputs["ZeroPoint"] = zero_point_node

attrs = {"quant_axis": self.quant_axis, "bit_length": self.quant_bits}
attrs["op_role"] = core.op_proto_and_checker_maker.OpRole.Forward
attrs["op_role"] = op_role
outputs = {"Y": quant_var_node}
if not self._is_test:
scale_out_node = graph.create_var_node_from_desc(
Expand Down Expand Up @@ -2271,7 +2301,7 @@ def insert_quant_op(
graph.link_to(quant_op_node, scale_out_node)
return quant_var_node, scale_var_node

def insert_dequant_op(self, graph, var_node, scale_var_node):
def insert_dequant_op(self, graph, var_node, scale_var_node, op_role):
assert var_node.is_var(), f'{var_node.name()} is not a var'

dequant_var_node = graph.create_var_node(
Expand Down Expand Up @@ -2301,7 +2331,7 @@ def insert_dequant_op(self, graph, var_node, scale_var_node):
inputs["ZeroPoint"] = zero_point_node

attrs = {"quant_axis": self.quant_axis, "bit_length": self.quant_bits}
attrs["op_role"] = core.op_proto_and_checker_maker.OpRole.Forward
attrs["op_role"] = op_role

quant_op_node = graph.create_op_node(
op_type="dequantize_linear",
Expand Down Expand Up @@ -2513,6 +2543,7 @@ def _quant_preprocess(self, op_node):

def _transform_forward(self, graph, op):
op.op()._set_attr("quantization_type", "qat_with_weight")
op_role = op.op().attr("op_role")
weight_scale_node = None
inputs = op.inputs
for var_node in inputs:
Expand Down Expand Up @@ -2592,10 +2623,10 @@ def _transform_forward(self, graph, op):
quant_var_node,
scale_var_node,
) = insert_quant_pass.insert_quant_op(
graph, var_node, var_name=name
graph, var_node, var_name=name, op_role=op_role
)
dequant_var_node = insert_quant_pass.insert_dequant_op(
graph, quant_var_node, scale_var_node
graph, quant_var_node, scale_var_node, op_role
)

self.dequantized_vars[name] = dequant_var_node
Expand Down Expand Up @@ -2676,9 +2707,13 @@ def _quant_conv1d(self, graph, op):
var_node,
var_name=var_node.name(),
scale_var_node=scale_var_node,
op_role=op.op().attr("op_role"),
)
dequant_var_node = insert_quant_pass.insert_dequant_op(
graph, quant_var_node, scale_var_node
graph,
quant_var_node,
scale_var_node,
op.op().attr("op_role"),
)
graph.update_input_link(var_node, dequant_var_node, op)

Expand Down Expand Up @@ -2913,11 +2948,16 @@ def apply(self, graph):
quant_var_node,
scale_var_node,
) = insert_quant_pass.insert_quant_op(
graph, in_node
graph,
in_node,
op_role=op_node.op().attr("op_role"),
)
dequant_var_node = (
insert_quant_pass.insert_dequant_op(
graph, quant_var_node, scale_var_node
graph,
quant_var_node,
scale_var_node,
op_node.op().attr("op_role"),
)
)
dequantized_vars_map[arg_name] = dequant_var_node
Expand Down