Skip to content

Commit

Permalink
Skip the int input operator when inserting a quant node & fix some bug (
Browse files Browse the repository at this point in the history
  • Loading branch information
yghstill authored Feb 1, 2023
1 parent 3a73d34 commit 0361903
Showing 1 changed file with 53 additions and 33 deletions.
86 changes: 53 additions & 33 deletions python/paddle/static/quantization/quantization_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -2890,6 +2890,19 @@ def apply(self, graph):
)
if in_node.persistable():
continue

if in_node.dtype() not in [
paddle.float64,
paddle.float32,
paddle.float16,
]:
_logger.warning(
"Since the {} contains an input of type INT, the quantization of this layer is skipped.".format(
op_node.name()
)
)
break

if arg_name in dequantized_vars_map:
dequant_var_node = dequantized_vars_map[arg_name]
else:
Expand Down Expand Up @@ -3137,7 +3150,7 @@ def __init__(
self._save_int_weight = save_int_weight
assert self._scope is not None, "scope must not be None."
assert self._place is not None, "place must not be None."
self._quantized_ops = set()
self._quantized_ops = {}

def apply(self, graph):
assert isinstance(
Expand Down Expand Up @@ -3176,7 +3189,6 @@ def apply(self, graph):
quant_axis = _op.op().attr("quant_axis")
bits_length = _op.op().attr("bit_length")
if x_node.name() not in self._quantized_ops:
self._quantized_ops.add(x_node.name())
quantized_param_v = utils.quant_tensor(
param_v.copy(),
scale_v,
Expand Down Expand Up @@ -3211,10 +3223,13 @@ def apply(self, graph):
self._scope,
self._place,
)
self._quantized_ops[x_node.name()] = quant_weight_node

for next_op_node in out_node.outputs:
graph.update_input_link(
out_node, quant_weight_node, next_op_node
out_node,
self._quantized_ops[x_node.name()],
next_op_node,
)
graph.safe_remove_nodes(_op)
self._remove_unused_var_nodes(graph)
Expand Down Expand Up @@ -3298,9 +3313,9 @@ def apply(self, graph):
op_node.outputs, var_name
)
if out_node.dtype() not in [
core.VarDesc.VarType.FP64,
core.VarDesc.VarType.FP32,
core.VarDesc.VarType.FP16,
paddle.float64,
paddle.float32,
paddle.float16,
]:
continue
if var_name in dequantized_vars_map:
Expand All @@ -3319,7 +3334,10 @@ def apply(self, graph):
else:
var_names = utils._get_op_input_var_names(op_node)
for var_name in var_names:
if var_name in dequant_node_map:
if (
var_name in dequant_node_map
and dequant_node_map[var_name]
):
in_node = graph._find_node_by_name(
op_node.inputs, var_name
)
Expand All @@ -3345,39 +3363,41 @@ def _insert_quant_dequant_op(self, graph, var_node):
shape=var_node.shape(),
var_dtype=var_node.dtype(),
)
if not self._calibration_range_dict:
try:
scale_var_node = graph._find_node_by_name(
graph.all_persistable_nodes(), self._scale_name(var_name)

try:
scale_var_node = graph._find_node_by_name(
graph.all_persistable_nodes(), self._scale_name(var_name)
)
except:
if (
self._calibration_range_dict
and var_name in self._calibration_range_dict
):
scale_value = self._calibration_range_dict[var_name]
scale_var_node = graph.create_persistable_node(
name=self._scale_name(var_name),
var_type=var_node.type(),
shape=[1],
var_dtype=var_node.dtype(),
)
except:
data_type = (
'float64'
if var_node.dtype() == core.VarDesc.VarType.FP64
else 'float32'
)
_init_var_node(
scale_var_node,
np.array(scale_value, dtype=data_type),
self._scope,
self._place,
)
else:
_logger.warning(
"Cannot find the target node {} in scope, so skip adding quant node.".format(
var_name
)
)
return None
elif var_name in self._calibration_range_dict:
scale_value = self._calibration_range_dict[var_name]
scale_var_node = graph.create_persistable_node(
name=self._scale_name(var_name),
var_type=var_node.type(),
shape=[1],
var_dtype=var_node.dtype(),
)
data_type = (
'float64'
if var_node.dtype() == core.VarDesc.VarType.FP64
else 'float32'
)
_init_var_node(
scale_var_node,
np.array(scale_value, dtype=data_type),
self._scope,
self._place,
)
else:
return None
try:
zero_point_node = graph._find_node_by_name(
graph.all_persistable_nodes(),
Expand Down

0 comments on commit 0361903

Please sign in to comment.