Skip to content

Commit

Permalink
[CPU] Fuse the Conv2D, BN (and Activation) into FusedConv2D (#2393)
Browse files Browse the repository at this point in the history
Co-authored-by: Zhang, Jianyi <jianyi.zhang@intel.com>
  • Loading branch information
LIONEFAN and jianyizh authored Oct 13, 2023
1 parent 1ddc3df commit 51793ce
Show file tree
Hide file tree
Showing 11 changed files with 698 additions and 108 deletions.
229 changes: 229 additions & 0 deletions itex/core/graph/remapper/remapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,36 @@ struct ContractionWithBiasAndAddActivation {
int bias_port = kMissingIndex;
};

// Contraction node followed by a FusedBatchNorm.
struct ContractionWithBatchNorm {
ContractionWithBatchNorm() = default;
ContractionWithBatchNorm(int contraction, int fused_batch_norm,
float epsilon = 0.0)
: contraction(contraction),
fused_batch_norm(fused_batch_norm),
epsilon(epsilon) {}

int contraction = kMissingIndex;
int fused_batch_norm = kMissingIndex;
float epsilon = 0.0;
};

// Contraction node followed by a FusedBatchNorm and Activation.
struct ContractionWithBatchNormAndActivation {
ContractionWithBatchNormAndActivation() = default;
ContractionWithBatchNormAndActivation(int contraction, int fused_batch_norm,
int activation, float epsilon = 0.0)
: contraction(contraction),
fused_batch_norm(fused_batch_norm),
activation(activation),
epsilon(epsilon) {}

int contraction = kMissingIndex;
int fused_batch_norm = kMissingIndex;
int activation = kMissingIndex;
float epsilon = 0.0;
};

struct ContractionWithBiasAndActivationAdd {
ContractionWithBiasAndActivationAdd() = default;
ContractionWithBiasAndActivationAdd(int contraction, int bias_add,
Expand Down Expand Up @@ -1633,6 +1663,90 @@ bool FindContractionWithBiasAndAddActivation(
return true;
}

bool FindConv2DWithBatchNorm(const RemapperContext& ctx, int node_index,
ContractionWithBatchNorm* matched) {
const auto* node_view = ctx.graph_view.GetNode(node_index);
const auto* node_def = node_view->node();
// Root of the pattern must be a FusedBatchNorm.
if (!IsFusedBatchNorm(*node_def) && !IsITEXFusedBatchNorm(*node_def))
return false;

if (node_view->GetOp() != "FusedBatchNorm" &&
node_view->GetOp() != "_ITEXFusedBatchNorm" &&
!HasDataType(node_def, DT_FLOAT, "U"))
return false;
// oneDNN batchnorm converted binary operation doesn't support double
if (HasDataType(node_def, DT_DOUBLE, "T")) return false;

// Check that batch normalization is in inference mode.
const auto* training_attr = node_view->GetAttr(kIsTraining);
if (training_attr != nullptr && training_attr->b()) return false;

// Check that only 0th output is consumed by other nodes.
if (HasControlFaninOrFanout(*node_view) ||
!node_view->GetRegularFanout(1).empty() || // batch_mean
!node_view->GetRegularFanout(2).empty() || // batch_variance
!node_view->GetRegularFanout(3).empty() || // reserve_space_1
!node_view->GetRegularFanout(4).empty()) // reserve_space_2
return false;

// Input to the FusedBatchNorm must be a Conv2D.
if (node_view->NumRegularFanins() < 1) return false;
const auto& regular_fanin_0 = node_view->GetRegularFanin(0);
const auto* conv2d_node_view = regular_fanin_0.node_view();
const auto* conv2d_node_def = conv2d_node_view->node();
if (!(IsConv2D(*conv2d_node_def) || conv2d_node_def->op() == "_ITEXConv2D") ||
!HaveSameDataType(node_def, conv2d_node_def) ||
HasControlFaninOrFanout(*conv2d_node_view) ||
!HasAtMostOneFanoutAtPort0(*conv2d_node_view) ||
IsInPreserveSet(ctx, conv2d_node_def))
return false;

// We successfully found a Conv2D+FusedBatchNorm pattern.
matched->contraction = conv2d_node_view->node_index();
matched->fused_batch_norm = node_index;
if (!TryGetNodeAttr(*node_def, "epsilon", &matched->epsilon)) return false;

return true;
}

bool FindConv2DWithBatchNormAndActivation(
const RemapperContext& ctx, int node_index,
ContractionWithBatchNormAndActivation* matched) {
const auto* node_view = ctx.graph_view.GetNode(node_index);
if (HasControlFaninOrFanout(*node_view)) return false;

// Root of the pattern must be an activation node.
const auto* node_def = node_view->node();
if (!IsSupportedActivation(*node_def)) return false;

// And input to the activation node must match Conv2DWithBatchNorm pattern.
if (node_view->NumRegularFanins() < 1) return false;

const auto& regular_fanin_0 = node_view->GetRegularFanin(0);
const auto* batch_norm_node_view = regular_fanin_0.node_view();

ContractionWithBatchNorm base;
if (!FindConv2DWithBatchNorm(ctx, batch_norm_node_view->node_index(), &base))
return false;

const auto* fused_batch_norm_node_view =
ctx.graph_view.GetNode(base.fused_batch_norm);
const auto* fused_batch_norm_node_def = fused_batch_norm_node_view->node();
if (!HasAtMostOneFanoutAtPort0(*fused_batch_norm_node_view) ||
!HaveSameDataType(node_def, fused_batch_norm_node_def) ||
IsInPreserveSet(ctx, fused_batch_norm_node_def))
return false;

// We successfully found a Conv2D+FusedBatchNorm+Activation pattern.
matched->contraction = base.contraction;
matched->fused_batch_norm = base.fused_batch_norm;
matched->activation = node_index;
matched->epsilon = base.epsilon;

return true;
}

bool FindContractionWithBiasAndActivationInPort(
const RemapperContext& ctx, const utils::MutableNodeView& add_node_view,
const NodeDef& add_node_def, int port_id) {
Expand Down Expand Up @@ -3754,6 +3868,8 @@ Status AddFusedContractionNode(RemapperContext* ctx,

if (IsConv2D(contraction)) {
fused_op.set_op(kFusedConv2D);
auto* attr = fused_op.mutable_attr();
SetAttrValue(0, &(*attr)["num_bn_args"]);
} else if (IsDepthwiseConv2dNative(contraction)) {
fused_op.set_op(kFusedDepthwiseConv2dNative);
} else if (IsConv3D(contraction)) {
Expand Down Expand Up @@ -4406,6 +4522,8 @@ Status AddFusedContractionNode(

if (IsConv2D(contraction)) {
fused_op.set_op(kFusedConv2D);
auto* attr = fused_op.mutable_attr();
SetAttrValue(0, &(*attr)["num_bn_args"]);
} else if (IsDepthwiseConv2dNative(contraction)) {
fused_op.set_op(kFusedDepthwiseConv2dNative);
} else if (IsConv3D(contraction)) {
Expand Down Expand Up @@ -4559,6 +4677,93 @@ Status AddFusedContractionNode(
return Status::OK();
}

Status AddFusedConv2DNode(RemapperContext* ctx,
const ContractionWithBatchNorm& matched,
std::vector<bool>* invalidated_nodes,
std::vector<bool>* nodes_to_delete) {
const GraphDef* graph = ctx->graph_view.graph();
const NodeDef& contraction = graph->node(matched.contraction);
ITEX_DCHECK(IsConv2D(contraction)) << "Only Conv2D supported for now";
const NodeDef& fused_batch_norm = graph->node(matched.fused_batch_norm);
ITEX_VLOG(2) << "Fuse Conv2D with BatchNorm: batch_norm="
<< fused_batch_norm.name() << " conv2d=" << contraction.name();

NodeDef fused_conv2d;
fused_conv2d.set_name(fused_batch_norm.name());
fused_conv2d.set_op(kFusedConv2D);
fused_conv2d.set_device(contraction.device());
fused_conv2d.add_input(contraction.input(0)); // 0: input
fused_conv2d.add_input(contraction.input(1)); // 1: filter
fused_conv2d.add_input(fused_batch_norm.input(1)); // 2: scale
fused_conv2d.add_input(fused_batch_norm.input(2)); // 3: offset
fused_conv2d.add_input(fused_batch_norm.input(3)); // 4: mean
fused_conv2d.add_input(fused_batch_norm.input(4)); // 5: variance

CopyAllAttrs(contraction, &fused_conv2d);
SetFusedOpAttributes(&fused_conv2d, {"FusedBatchNorm"}, 0);
auto* attr = fused_conv2d.mutable_attr();
SetAttrValue(matched.epsilon, &(*attr)["epsilon"]);
SetAttrValue(4, &(*attr)["num_bn_args"]);

utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder();
Status status;
mutation->AddNode(std::move(fused_conv2d), &status);
TF_ABORT_IF_ERROR(status);
TF_ABORT_IF_ERROR(mutation->Apply());

(*invalidated_nodes)[matched.fused_batch_norm] = true;
(*nodes_to_delete)[matched.contraction] = true;

return Status::OK();
}

Status AddFusedConv2DNode(RemapperContext* ctx,
const ContractionWithBatchNormAndActivation& matched,
std::vector<bool>* invalidated_nodes,
std::vector<bool>* nodes_to_delete) {
const GraphDef* graph = ctx->graph_view.graph();
const NodeDef& contraction = graph->node(matched.contraction);

ITEX_DCHECK(IsConv2D(contraction)) << "Only Conv2D supported for now";

const NodeDef& activation = graph->node(matched.activation);
const NodeDef& fused_batch_norm = graph->node(matched.fused_batch_norm);
ITEX_VLOG(2) << "Fuse Conv2D with BatchNorm and " << activation.op()
<< ": activation=" << activation.name()
<< " batch_norm=" << fused_batch_norm.name()
<< " conv2d=" << contraction.name();

NodeDef fused_conv2d;
fused_conv2d.set_name(activation.name());
fused_conv2d.set_op(kFusedConv2D);
fused_conv2d.set_device(contraction.device());
fused_conv2d.add_input(contraction.input(0)); // 0: input
fused_conv2d.add_input(contraction.input(1)); // 1: filter
fused_conv2d.add_input(fused_batch_norm.input(1)); // 2: scale
fused_conv2d.add_input(fused_batch_norm.input(2)); // 3: offset
fused_conv2d.add_input(fused_batch_norm.input(3)); // 4: mean
fused_conv2d.add_input(fused_batch_norm.input(4)); // 5: variance

CopyAllAttrs(contraction, &fused_conv2d);
SetFusedOpAttributesWithActivation(&fused_conv2d, &activation,
{"FusedBatchNorm"}, 0);
auto* attr = fused_conv2d.mutable_attr();
SetAttrValue(matched.epsilon, &(*attr)["epsilon"]);
SetAttrValue(4, &(*attr)["num_bn_args"]);

utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder();
Status status;
mutation->AddNode(std::move(fused_conv2d), &status);
TF_ABORT_IF_ERROR(status);
TF_ABORT_IF_ERROR(mutation->Apply());

(*invalidated_nodes)[matched.activation] = true;
(*nodes_to_delete)[matched.contraction] = true;
(*nodes_to_delete)[matched.fused_batch_norm] = true;

return Status::OK();
}

// Contraction + Mul(scale).
// TODO(itex): Try to combine this function with Conv + BiasAdd
Status AddFusedContractionNode(RemapperContext* ctx,
Expand Down Expand Up @@ -6755,7 +6960,31 @@ Status RunRemapper(OptimizerContext* opt_ctx, const GrapplerItem& item,
&invalidated_nodes, &nodes_to_delete));
continue;
}
// NOTE: We can only fuse BatchNorm into Conv2D nodes. In theory we can do
// it for MatMul as well, but in practice this pattern does not appear in
// real Tensorflow graphs.

// Remap Conv2D+FusedBatchNorm+Activation into the _FusedConv2D;
ContractionWithBatchNormAndActivation
contract_with_batch_norm_and_activation;
if (!is_layout_opt &&
FindConv2DWithBatchNormAndActivation(
ctx, i, &contract_with_batch_norm_and_activation)) {
TF_RETURN_IF_ERROR(
AddFusedConv2DNode(&ctx, contract_with_batch_norm_and_activation,
&invalidated_nodes, &nodes_to_delete));
continue;
}

// Remap Conv2D+FusedBatchNorm into the _FusedConv2D;
ContractionWithBatchNorm contract_with_batch_norm;
if (!is_layout_opt &&
FindConv2DWithBatchNorm(ctx, i, &contract_with_batch_norm)) {
TF_RETURN_IF_ERROR(AddFusedConv2DNode(&ctx, contract_with_batch_norm,
&invalidated_nodes,
&nodes_to_delete));
continue;
}
// Remap FusedBatchNorm+<SideInput>+<Activation> into the
// _FusedBatchNormEx.
FusedBatchNormEx fused_batch_norm_ex;
Expand Down
7 changes: 7 additions & 0 deletions itex/core/graph/utils/layout_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,13 @@ bool RewriteFusedConv(const utils::MutableNodeView& node_view) {
}

bool RewriteOneDnnFusedConv(const utils::MutableNodeView& node_view) {
const NodeDef& node_def = *(node_view.node());
std::vector<string> fused_ops;
ITEX_CHECK_OK(GetNodeAttr(node_def, "fused_ops", &fused_ops));
for (auto& post_op : fused_ops) {
if (post_op == "FusedBatchNorm") return false;
}

return RewriteFusedConv(node_view) && RewriteOneDnnConv(node_view);
}

Expand Down
6 changes: 6 additions & 0 deletions itex/core/graph/utils/op_types.cc
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,12 @@ bool IsInstanceNorm(const NodeDef& node) {
return node.op() == "_ITEXInstanceNorm";
}

bool IsITEXFusedBatchNorm(const NodeDef& node) {
const auto& op = node.op();
return op == "_ITEXFusedBatchNorm" || op == "_ITEXFusedBatchNormV2" ||
op == "_ITEXFusedBatchNormV3";
}

bool IsLeakyRelu(const NodeDef& node) { return node.op() == "LeakyRelu"; }

bool IsLeakyReluGrad(const NodeDef& node) {
Expand Down
1 change: 1 addition & 0 deletions itex/core/graph/utils/op_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ bool IsImag(const NodeDef& node);
bool IsImmutableConst(const NodeDef& node);
bool IsInvGrad(const NodeDef& node);
bool IsInstanceNorm(const NodeDef& node);
bool IsITEXFusedBatchNorm(const NodeDef& node);
bool IsLeakyRelu(const NodeDef& node);
bool IsLeakyReluGrad(const NodeDef& node);
bool IsLess(const NodeDef& node);
Expand Down
Loading

0 comments on commit 51793ce

Please sign in to comment.