Skip to content
This repository has been archived by the owner on May 11, 2024. It is now read-only.

Commit

Permalink
[Urgent for mlperf] Fixed issues and cleaned up fuse_quantized_convol…
Browse files Browse the repository at this point in the history
…ution.cc (#75)
  • Loading branch information
mdfaijul authored and karthikvadla committed May 17, 2019
1 parent 90433ef commit 7361a87
Showing 1 changed file with 177 additions and 135 deletions.
312 changes: 177 additions & 135 deletions tensorflow_quantization/graph_transforms/fuse_quantized_convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,22 +97,16 @@ Status FuseQuantizedConvolutionAndRequantize(
AddNodeInput(const_requantize_range_min_node.name(), &fused_conv);
AddNodeInput(const_requantize_range_max_node.name(), &fused_conv);

// Add additional inputs to
// QuantizedConv2DWithBiasSumAndReluAndRequantize
// Ensure QuantizedConv2DWithBiasSumAndReluAndRequantize receives
// integer summand. Because requantization fusion is registered
// for integer summand only.
if (quantized_conv2D_op_name.compare(
"QuantizedConv2DWithBiasSumAndRelu") == 0) {
const NodeDef *in_requantize = node_map[node_map[
quantized_conv2D_node.input(n_input)]->input(0)];
const NodeDef *summand_node = node_map[quantized_conv2D_node.input(
n_input)];
bool quantized_summand = str_util::StrContains(
in_requantize->op(), "Quantized");
// If the summand is not quantized, we need to quantize it since the
// convolution kernel assumes that the summand is always quanitzed.
if (!quantized_summand &&
!is_perchannel &&
in_requantize->op() != "Requantize" &&
in_requantize->op() != "QuantizeV2") {
NodeDef* new_summand_node = nullptr;
NodeDef quantize_node;
if (summand_node->op() != "Dequantize") {
// Quantizing the summand.
// Add some common constants we need for reshaping inputs.
NodeDef reshape_dims;
Expand Down Expand Up @@ -156,10 +150,20 @@ Status FuseQuantizedConvolutionAndRequantize(
AddNodeInput(reshape_node.name(), &max_node);
AddNodeInput(reduction_dims.name(), &max_node);

NodeDef quantize_node;
// NodeDef quantize_node;
quantize_node.set_op("QuantizeV2");
quantize_node.set_name(summand_node->name() + "/quantize");
SetNodeAttr("T", DT_QUINT8, &quantize_node);
// Decide data type of quantize op
std::vector<string> relu_ops = {
"Relu",
"Relu6"
};
bool is_relu = std::find(relu_ops.begin(), relu_ops.end(),
summand_node->op()) != relu_ops.end();
if (is_relu)
SetNodeAttr("T", DT_QUINT8, &quantize_node);
else
SetNodeAttr("T", DT_QINT8, &quantize_node);
SetNodeAttr("mode", "SCALED", &quantize_node);

AddNodeInput(summand_node->name(), &reshape_node);
Expand All @@ -169,41 +173,71 @@ Status FuseQuantizedConvolutionAndRequantize(
AddNodeInput(min_node.name(), &quantize_node);
AddNodeInput(max_node.name(), &quantize_node);

AddNodeInput(quantize_node.name(), &fused_conv);
AddNodeInput(quantize_node.name() + ":1", &fused_conv);
AddNodeInput(quantize_node.name() + ":2", &fused_conv);

new_nodes->push_back(reshape_dims);
new_nodes->push_back(reduction_dims);
new_nodes->push_back(reshape_node);
new_nodes->push_back(min_node);
new_nodes->push_back(max_node);
new_nodes->push_back(quantize_node);
// Set the new summand node for fused_conv
new_summand_node = &quantize_node;
} else {
string summand(in_requantize->name());
string min_summand(in_requantize->name() + ":1");
string max_summand(in_requantize->name() + ":2");
AddNodeInput(summand, &fused_conv);
AddNodeInput(min_summand, &fused_conv);
AddNodeInput(max_summand, &fused_conv);
// If summand node is "Dequantize" then either "QuantizeV2" or
// "Requantize{PerChannel}" is feeding Dequantize op.
// Set new_summand_node as the input of summand node.
new_summand_node = const_cast<NodeDef*>(node_map[
summand_node->input(0)]);
}

// Signed version QuantizedConv2DWithBiasSumAndReluAndRequantize
// if Relu does not follow the convolution operation
std::vector<string> signed_ops = {
"QuantizedConv2DWithBias",
"QuantizedConv2D"
};
bool is_signed_summand =
string summand(new_summand_node->name());
string min_summand(new_summand_node->name() + ":1");
string max_summand(new_summand_node->name() + ":2");
AddNodeInput(summand, &fused_conv);
AddNodeInput(min_summand, &fused_conv);
AddNodeInput(max_summand, &fused_conv);

DataType summand_type = DT_QUINT8;
// New summand node should be QuantizeV2 or
// Requantize{PerChannel}
if (new_summand_node->op() == "QuantizeV2") {
TF_RETURN_IF_ERROR(GetNodeAttr(*new_summand_node,
"T", &summand_type));
} else if (new_summand_node->op() == "RequantizePerChannel") {
TF_RETURN_IF_ERROR(GetNodeAttr(*new_summand_node,
"out_type", &summand_type));
} else if (new_summand_node->op() == "Requantize") {
// Requantize op is Eigen kernel that does non-SCALED quantization
// and always maps into quint8. However, for MKLDNN fusion, which is
// SCALED quantization, the summand fused requantize op may have
// qint8 or quint8 as its output type. Therefore, it is needed to
// set the summand_type correctly.
std::vector<string> signed_ops = {
"QuantizedConv2DWithBias",
"QuantizedConv2D"
};
bool is_signed_summand =
std::find(signed_ops.begin(), signed_ops.end(),
node_map[in_requantize->input(0)]->op()) != signed_ops.end();
if (is_signed_summand) {
fused_conv.set_op(
"QuantizedConv2DWithBiasSignedSumAndReluAndRequantize");
SetNodeAttr("Tsummand", DT_QINT8, &fused_conv);
node_map[new_summand_node->input(0)]->op())
!= signed_ops.end();
summand_type = is_signed_summand ? DT_QINT8 : DT_QUINT8;
} else if (str_util::StartsWith(new_summand_node->op(),
"Quantized")) {
if (HasNodeAttr(*new_summand_node, "T")) {
TF_RETURN_IF_ERROR(GetNodeAttr(*new_summand_node,
"T", &summand_type));
} else if (HasNodeAttr(*new_summand_node, "out_type")) {
TF_RETURN_IF_ERROR(GetNodeAttr(*new_summand_node,
"out_type", &summand_type));
}
} else {
SetNodeAttr("Tsummand", DT_QUINT8, &fused_conv);
return Status(error::Code::FAILED_PRECONDITION,
"Fusion is not supported, a fix is required.");
}
SetNodeAttr("Tsummand", summand_type, &fused_conv);
// Decide whether signed version of
// QuantizedConv2DWithBiasSumAndReluAndRequantize or not
if (summand_type == DT_QINT8)
fused_conv.set_op(
"QuantizedConv2DWithBiasSignedSumAndReluAndRequantize");
}

// Add control input to the very end of the input list
Expand All @@ -216,32 +250,21 @@ Status FuseQuantizedConvolutionAndRequantize(
CopyNodeAttr(quantized_conv2D_node, "strides", "strides", &fused_conv);
CopyNodeAttr(quantized_conv2D_node, "padding", "padding", &fused_conv);

if (is_perchannel) {
std::vector<std::string> fused_quantized_bias_ops = {
"QuantizedConv2DWithBias",
"QuantizedConv2DWithBiasAndRelu",
"QuantizedDepthwiseConv2DWithBias",
"QuantizedDepthwiseConv2DWithBiasAndRelu",
"QuantizedConv2DWithBiasSumAndRelu",
"QuantizedConv2DWithBiasSignedSumAndRelu"
};

if (std::find(fused_quantized_bias_ops.begin(),
fused_quantized_bias_ops.end(),
quantized_conv2D_node.op()) != fused_quantized_bias_ops.end()) {
SetNodeAttr("Tbias", DT_FLOAT, &fused_conv);
}
std::vector<std::string> fused_quantized_bias_ops = {
"QuantizedConv2DWithBias",
"QuantizedConv2DWithBiasAndRelu",
"QuantizedDepthwiseConv2DWithBias",
"QuantizedDepthwiseConv2DWithBiasAndRelu",
"QuantizedConv2DWithBiasSumAndRelu",
};
if (std::find(fused_quantized_bias_ops.begin(),
fused_quantized_bias_ops.end(),
quantized_conv2D_node.op()) != fused_quantized_bias_ops.end()) {
SetNodeAttr("Tbias", DT_FLOAT, &fused_conv);
}

CopyNodeAttr(quantized_conv2D_node, "Tinput", "Tinput", &fused_conv);
CopyNodeAttr(quantized_conv2D_node, "Tfilter", "Tfilter", &fused_conv);
CopyNodeAttr(quantized_conv2D_node, "strides", "strides", &fused_conv);
CopyNodeAttr(quantized_conv2D_node, "padding", "padding", &fused_conv);

if (HasNodeAttr(quantized_conv2D_node, "padding_list"))
CopyNodeAttr(quantized_conv2D_node, "padding_list",
"padding_list", &fused_conv);

// Copy dilation attribute if exsit in the orginal node
if (HasNodeAttr(quantized_conv2D_node, "dilations"))
CopyNodeAttr(quantized_conv2D_node, "dilations",
Expand All @@ -259,93 +282,112 @@ Status FuseQuantizedConvolutionAndRequantize(
},
{}, &replaced_graph_def));

if (!is_perchannel) {
// Convert bias float -> int32 on replaced_graph_def
std::vector<std::string> fused_requantized_bias_ops = {
"QuantizedConv2DWithBiasAndRequantize",
"QuantizedConv2DWithBiasAndReluAndRequantize",
"QuantizedConv2DWithBiasSumAndReluAndRequantize",
"QuantizedConv2DWithBiasSignedSumAndReluAndRequantize"
};

node_map.clear();
MapNamesToNodes(replaced_graph_def, &node_map);
for (auto& node_pair : node_map) {
const NodeDef *node = node_pair.second;
if (str_util::StartsWith(node->op(), "Dequantize")) {
// dequant node should accept DT_QINT8 if the input node is
// "QuantizedConv2DAndRequantize" and
// "QuantizedConv2DWithBiasAndRequantize"
std::string input_node_op =
node_map[NodeNameFromInput(node->input(0))]->op();
if (str_util::StartsWith(input_node_op,
"QuantizedConv2DAndRequantize") ||
str_util::StartsWith(input_node_op,
"QuantizedConv2DWithBiasAndRequantize")) {
SetNodeAttr("T", DT_QINT8, const_cast<NodeDef*>(node));
SetNodeAttr("mode", "SCALED", const_cast<NodeDef*>(node));
}
// After Requantize op fusion, fix attributes for nodes in the graph,
// if threre is some discrepency. And also quantize the bias (float -> int32)
// List of requantize fused ops that have biases.
std::vector<std::string> fused_requantized_bias_ops = {
"QuantizedConv2DWithBiasAndRequantize",
"QuantizedConv2DWithBiasAndReluAndRequantize",
"QuantizedConv2DWithBiasSumAndReluAndRequantize",
"QuantizedConv2DWithBiasSignedSumAndReluAndRequantize",
"QuantizedDepthwiseConv2DWithBiasAndRequantize",
"QuantizedDepthwiseConv2DWithBiasAndReluAndRequantize"
};

node_map.clear();
MapNamesToNodes(replaced_graph_def, &node_map);
for (auto& node_pair : node_map) {
const NodeDef *node = node_pair.second;
// An workaround to fix attributes of "Dequantize" op with non-perchannel
// quantization. "Dequantize" node should accept DT_QINT8 if the input node
// is "QuantizedConv2DAndRequantize" or
// "QuantizedConv2DWithBiasAndRequantize".
if (str_util::StartsWith(node->op(), "Dequantize")) {
std::string input_node_op =
node_map[NodeNameFromInput(node->input(0))]->op();
if (str_util::StartsWith(input_node_op,
"QuantizedConv2DAndRequantize") ||
str_util::StartsWith(input_node_op,
"QuantizedConv2DWithBiasAndRequantize")) {
SetNodeAttr("T", DT_QINT8, const_cast<NodeDef*>(node));
SetNodeAttr("mode", "SCALED", const_cast<NodeDef*>(node));
}
continue;
}

// Quantize bias to int32 if input min-max values are constants.
// This is guaranteed if the preceeding op is a fused requantize op.
bool is_fused_requantized_conv_op =
std::find(fused_requantized_bias_ops.begin(),
fused_requantized_bias_ops.end(), node->op())
!= fused_requantized_bias_ops.end();
if (is_fused_requantized_conv_op) {
// If the op is feed by Quantize op then we keep bias as float
std::string input_op = node_map[NodeNameFromInput(
node->input(0))]->op();
if (str_util::StartsWith(input_op, "QuantizedConv2D") &&
str_util::EndsWith(input_op, "AndRequantize")) {
NodeDef *bias_node = const_cast<NodeDef*>(node_map[NodeNameFromInput(
node->input(2))]);
const NodeDef *min_input_node = node_map[NodeNameFromInput(
std::find(fused_requantized_bias_ops.begin(),
fused_requantized_bias_ops.end(), node->op())
!= fused_requantized_bias_ops.end();
if (is_fused_requantized_conv_op) {
std::string preceeding_op = node_map[NodeNameFromInput(
node->input(0))]->op();
if (str_util::StartsWith(preceeding_op, "Quantized") &&
str_util::StrContains(preceeding_op, "Conv2D") &&
str_util::EndsWith(preceeding_op, "AndRequantize")) {
NodeDef *bias_node = const_cast<NodeDef*>(node_map[NodeNameFromInput(
node->input(2))]);
const NodeDef *min_input_node = node_map[NodeNameFromInput(
node_map[node->input(0)]->input(7))];
const NodeDef *max_input_node = node_map[NodeNameFromInput(
const NodeDef *max_input_node = node_map[NodeNameFromInput(
node_map[node->input(0)]->input(8))];
const NodeDef *min_filter_node = node_map[NodeNameFromInput(
const NodeDef *min_filter_node = node_map[NodeNameFromInput(
node->input(5))];
const NodeDef *max_filter_node = node_map[NodeNameFromInput(
const NodeDef *max_filter_node = node_map[NodeNameFromInput(
node->input(6))];
const float min_input =
const float min_input =
GetNodeTensorAttr(*min_input_node, "value").flat<float>()(0);
const float max_input =
const float max_input =
GetNodeTensorAttr(*max_input_node, "value").flat<float>()(0);
const float min_filter =
GetNodeTensorAttr(*min_filter_node, "value").flat<float>()(0);
const float max_filter =
GetNodeTensorAttr(*max_filter_node, "value").flat<float>()(0);

TensorProto float_tensor_proto =
bias_node->attr().at("value").tensor();
Tensor float_tensor;
CHECK(float_tensor.FromProto(float_tensor_proto));
CHECK_EQ(float_tensor.dtype(), DT_FLOAT);
float *p_bias_float = float_tensor.flat<float>().data();

Tensor int32_tensor = Tensor(DT_QINT32, float_tensor.shape());
qint32 *p_bias_int32 = int32_tensor.flat<qint32>().data();

float bias_scale = 255.0 * 127.0 /
const Tensor& min_filter_tensor =
GetNodeTensorAttr(*min_filter_node, "value");
const Tensor& max_filter_tensor =
GetNodeTensorAttr(*max_filter_node, "value");
const float* min_filter = min_filter_tensor.flat<float>().data();
const float* max_filter = max_filter_tensor.flat<float>().data();
size_t num_scale_factors = min_filter_tensor.NumElements();

TensorProto float_tensor_proto =
bias_node->attr().at("value").tensor();
Tensor float_bias_tensor;
CHECK(float_bias_tensor.FromProto(float_tensor_proto));
CHECK_EQ(float_bias_tensor.dtype(), DT_FLOAT);
float *float_bias = float_bias_tensor.flat<float>().data();

Tensor int32_bias_tensor = Tensor(DT_QINT32, float_bias_tensor.shape());
qint32 *int32_bias = int32_bias_tensor.flat<qint32>().data();
std::vector<float> scales(num_scale_factors);
for (size_t i = 0; i < num_scale_factors; ++i) {
scales[i] = 255.0 * 127.0 /
(std::max(std::abs(max_input), std::abs(min_input)) *
std::max(std::abs(max_filter), std::abs(min_filter)));
int64 nelems = float_tensor.NumElements();
for (int64 n = 0; n < nelems; n++)
p_bias_int32[n] = (int32_t) (p_bias_float[n] * bias_scale);

bias_node->clear_attr();
AttrValue attr_type;
attr_type.set_type(int32_tensor.dtype());
bias_node->mutable_attr()->insert({"dtype", attr_type});
AttrValue attr_tensor;
TensorProto* t = attr_tensor.mutable_tensor();
int32_tensor.AsProtoTensorContent(t);
bias_node->mutable_attr()->insert({"value", attr_tensor});
SetNodeAttr("Tbias", DT_QINT32, const_cast<NodeDef*>(node));
std::max(std::abs(max_filter[i]), std::abs(min_filter[i])));
}
int64 bias_length = float_bias_tensor.NumElements();
if (num_scale_factors > 1) {
if (bias_length != num_scale_factors) {
return Status(error::Code::FAILED_PRECONDITION,
"Number of filter output channels is not"
"equal to bias size");
} else {
for (int64 i = 0; i < bias_length; i++)
int32_bias[i] = (int32_t) (float_bias[i] * scales[i]);
}
} else {
SetNodeAttr("Tbias", DT_FLOAT, const_cast<NodeDef*>(node));
for (int64 i = 0; i < bias_length; i++)
int32_bias[i] = (int32_t) (float_bias[i] * scales[0]);
}
bias_node->clear_attr();
AttrValue attr_type;
attr_type.set_type(int32_bias_tensor.dtype());
bias_node->mutable_attr()->insert({"dtype", attr_type});
AttrValue attr_tensor;
TensorProto* t = attr_tensor.mutable_tensor();
int32_bias_tensor.AsProtoTensorContent(t);
bias_node->mutable_attr()->insert({"value", attr_tensor});
SetNodeAttr("Tbias", DT_QINT32, const_cast<NodeDef*>(node));
} else {
SetNodeAttr("Tbias", DT_FLOAT, const_cast<NodeDef*>(node));
}
}
}
Expand Down

0 comments on commit 7361a87

Please sign in to comment.