From 51793ce7efcc69999d2a145b4a46e186a432c315 Mon Sep 17 00:00:00 2001 From: "Li, Yifan" <52822833+LIONEFAN@users.noreply.github.com> Date: Fri, 13 Oct 2023 16:40:48 +0800 Subject: [PATCH] [CPU] Fuse the Conv2D, BN (and Activation) into FusedConv2D (#2393) Co-authored-by: Zhang, Jianyi --- itex/core/graph/remapper/remapper.cc | 229 ++++++++++++ itex/core/graph/utils/layout_utils.cc | 7 + itex/core/graph/utils/op_types.cc | 6 + itex/core/graph/utils/op_types.h | 1 + itex/core/kernels/common/conv_ops.h | 80 ++++- .../core/kernels/common/fused_batch_norm_op.h | 2 +- itex/core/ops/nn_ops.cc | 2 + itex/core/utils/onednn/onednn_post_op_util.cc | 86 ++++- itex/core/utils/onednn/onednn_post_op_util.h | 49 +++ itex/core/utils/onednn/onednn_util.h | 4 +- .../python/grappler/remapper_test.py | 340 ++++++++++++------ 11 files changed, 698 insertions(+), 108 deletions(-) diff --git a/itex/core/graph/remapper/remapper.cc b/itex/core/graph/remapper/remapper.cc index 8dce0d9aa..52a8c5e6b 100644 --- a/itex/core/graph/remapper/remapper.cc +++ b/itex/core/graph/remapper/remapper.cc @@ -300,6 +300,36 @@ struct ContractionWithBiasAndAddActivation { int bias_port = kMissingIndex; }; +// Contraction node followed by a FusedBatchNorm. +struct ContractionWithBatchNorm { + ContractionWithBatchNorm() = default; + ContractionWithBatchNorm(int contraction, int fused_batch_norm, + float epsilon = 0.0) + : contraction(contraction), + fused_batch_norm(fused_batch_norm), + epsilon(epsilon) {} + + int contraction = kMissingIndex; + int fused_batch_norm = kMissingIndex; + float epsilon = 0.0; +}; + +// Contraction node followed by a FusedBatchNorm and Activation. +struct ContractionWithBatchNormAndActivation { + ContractionWithBatchNormAndActivation() = default; + ContractionWithBatchNormAndActivation(int contraction, int fused_batch_norm, + int activation, float epsilon = 0.0) + : contraction(contraction), + fused_batch_norm(fused_batch_norm), + activation(activation), + epsilon(epsilon) {} + + int contraction = kMissingIndex; + int fused_batch_norm = kMissingIndex; + int activation = kMissingIndex; + float epsilon = 0.0; +}; + struct ContractionWithBiasAndActivationAdd { ContractionWithBiasAndActivationAdd() = default; ContractionWithBiasAndActivationAdd(int contraction, int bias_add, @@ -1633,6 +1663,90 @@ bool FindContractionWithBiasAndAddActivation( return true; } +bool FindConv2DWithBatchNorm(const RemapperContext& ctx, int node_index, + ContractionWithBatchNorm* matched) { + const auto* node_view = ctx.graph_view.GetNode(node_index); + const auto* node_def = node_view->node(); + // Root of the pattern must be a FusedBatchNorm. + if (!IsFusedBatchNorm(*node_def) && !IsITEXFusedBatchNorm(*node_def)) + return false; + + if (node_view->GetOp() != "FusedBatchNorm" && + node_view->GetOp() != "_ITEXFusedBatchNorm" && + !HasDataType(node_def, DT_FLOAT, "U")) + return false; + // oneDNN batchnorm converted binary operation doesn't support double + if (HasDataType(node_def, DT_DOUBLE, "T")) return false; + + // Check that batch normalization is in inference mode. + const auto* training_attr = node_view->GetAttr(kIsTraining); + if (training_attr != nullptr && training_attr->b()) return false; + + // Check that only 0th output is consumed by other nodes. + if (HasControlFaninOrFanout(*node_view) || + !node_view->GetRegularFanout(1).empty() || // batch_mean + !node_view->GetRegularFanout(2).empty() || // batch_variance + !node_view->GetRegularFanout(3).empty() || // reserve_space_1 + !node_view->GetRegularFanout(4).empty()) // reserve_space_2 + return false; + + // Input to the FusedBatchNorm must be a Conv2D. + if (node_view->NumRegularFanins() < 1) return false; + const auto& regular_fanin_0 = node_view->GetRegularFanin(0); + const auto* conv2d_node_view = regular_fanin_0.node_view(); + const auto* conv2d_node_def = conv2d_node_view->node(); + if (!(IsConv2D(*conv2d_node_def) || conv2d_node_def->op() == "_ITEXConv2D") || + !HaveSameDataType(node_def, conv2d_node_def) || + HasControlFaninOrFanout(*conv2d_node_view) || + !HasAtMostOneFanoutAtPort0(*conv2d_node_view) || + IsInPreserveSet(ctx, conv2d_node_def)) + return false; + + // We successfully found a Conv2D+FusedBatchNorm pattern. + matched->contraction = conv2d_node_view->node_index(); + matched->fused_batch_norm = node_index; + if (!TryGetNodeAttr(*node_def, "epsilon", &matched->epsilon)) return false; + + return true; +} + +bool FindConv2DWithBatchNormAndActivation( + const RemapperContext& ctx, int node_index, + ContractionWithBatchNormAndActivation* matched) { + const auto* node_view = ctx.graph_view.GetNode(node_index); + if (HasControlFaninOrFanout(*node_view)) return false; + + // Root of the pattern must be an activation node. + const auto* node_def = node_view->node(); + if (!IsSupportedActivation(*node_def)) return false; + + // And input to the activation node must match Conv2DWithBatchNorm pattern. + if (node_view->NumRegularFanins() < 1) return false; + + const auto& regular_fanin_0 = node_view->GetRegularFanin(0); + const auto* batch_norm_node_view = regular_fanin_0.node_view(); + + ContractionWithBatchNorm base; + if (!FindConv2DWithBatchNorm(ctx, batch_norm_node_view->node_index(), &base)) + return false; + + const auto* fused_batch_norm_node_view = + ctx.graph_view.GetNode(base.fused_batch_norm); + const auto* fused_batch_norm_node_def = fused_batch_norm_node_view->node(); + if (!HasAtMostOneFanoutAtPort0(*fused_batch_norm_node_view) || + !HaveSameDataType(node_def, fused_batch_norm_node_def) || + IsInPreserveSet(ctx, fused_batch_norm_node_def)) + return false; + + // We successfully found a Conv2D+FusedBatchNorm+Activation pattern. + matched->contraction = base.contraction; + matched->fused_batch_norm = base.fused_batch_norm; + matched->activation = node_index; + matched->epsilon = base.epsilon; + + return true; +} + bool FindContractionWithBiasAndActivationInPort( const RemapperContext& ctx, const utils::MutableNodeView& add_node_view, const NodeDef& add_node_def, int port_id) { @@ -3754,6 +3868,8 @@ Status AddFusedContractionNode(RemapperContext* ctx, if (IsConv2D(contraction)) { fused_op.set_op(kFusedConv2D); + auto* attr = fused_op.mutable_attr(); + SetAttrValue(0, &(*attr)["num_bn_args"]); } else if (IsDepthwiseConv2dNative(contraction)) { fused_op.set_op(kFusedDepthwiseConv2dNative); } else if (IsConv3D(contraction)) { @@ -4406,6 +4522,8 @@ Status AddFusedContractionNode( if (IsConv2D(contraction)) { fused_op.set_op(kFusedConv2D); + auto* attr = fused_op.mutable_attr(); + SetAttrValue(0, &(*attr)["num_bn_args"]); } else if (IsDepthwiseConv2dNative(contraction)) { fused_op.set_op(kFusedDepthwiseConv2dNative); } else if (IsConv3D(contraction)) { @@ -4559,6 +4677,93 @@ Status AddFusedContractionNode( return Status::OK(); } +Status AddFusedConv2DNode(RemapperContext* ctx, + const ContractionWithBatchNorm& matched, + std::vector* invalidated_nodes, + std::vector* nodes_to_delete) { + const GraphDef* graph = ctx->graph_view.graph(); + const NodeDef& contraction = graph->node(matched.contraction); + ITEX_DCHECK(IsConv2D(contraction)) << "Only Conv2D supported for now"; + const NodeDef& fused_batch_norm = graph->node(matched.fused_batch_norm); + ITEX_VLOG(2) << "Fuse Conv2D with BatchNorm: batch_norm=" + << fused_batch_norm.name() << " conv2d=" << contraction.name(); + + NodeDef fused_conv2d; + fused_conv2d.set_name(fused_batch_norm.name()); + fused_conv2d.set_op(kFusedConv2D); + fused_conv2d.set_device(contraction.device()); + fused_conv2d.add_input(contraction.input(0)); // 0: input + fused_conv2d.add_input(contraction.input(1)); // 1: filter + fused_conv2d.add_input(fused_batch_norm.input(1)); // 2: scale + fused_conv2d.add_input(fused_batch_norm.input(2)); // 3: offset + fused_conv2d.add_input(fused_batch_norm.input(3)); // 4: mean + fused_conv2d.add_input(fused_batch_norm.input(4)); // 5: variance + + CopyAllAttrs(contraction, &fused_conv2d); + SetFusedOpAttributes(&fused_conv2d, {"FusedBatchNorm"}, 0); + auto* attr = fused_conv2d.mutable_attr(); + SetAttrValue(matched.epsilon, &(*attr)["epsilon"]); + SetAttrValue(4, &(*attr)["num_bn_args"]); + + utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder(); + Status status; + mutation->AddNode(std::move(fused_conv2d), &status); + TF_ABORT_IF_ERROR(status); + TF_ABORT_IF_ERROR(mutation->Apply()); + + (*invalidated_nodes)[matched.fused_batch_norm] = true; + (*nodes_to_delete)[matched.contraction] = true; + + return Status::OK(); +} + +Status AddFusedConv2DNode(RemapperContext* ctx, + const ContractionWithBatchNormAndActivation& matched, + std::vector* invalidated_nodes, + std::vector* nodes_to_delete) { + const GraphDef* graph = ctx->graph_view.graph(); + const NodeDef& contraction = graph->node(matched.contraction); + + ITEX_DCHECK(IsConv2D(contraction)) << "Only Conv2D supported for now"; + + const NodeDef& activation = graph->node(matched.activation); + const NodeDef& fused_batch_norm = graph->node(matched.fused_batch_norm); + ITEX_VLOG(2) << "Fuse Conv2D with BatchNorm and " << activation.op() + << ": activation=" << activation.name() + << " batch_norm=" << fused_batch_norm.name() + << " conv2d=" << contraction.name(); + + NodeDef fused_conv2d; + fused_conv2d.set_name(activation.name()); + fused_conv2d.set_op(kFusedConv2D); + fused_conv2d.set_device(contraction.device()); + fused_conv2d.add_input(contraction.input(0)); // 0: input + fused_conv2d.add_input(contraction.input(1)); // 1: filter + fused_conv2d.add_input(fused_batch_norm.input(1)); // 2: scale + fused_conv2d.add_input(fused_batch_norm.input(2)); // 3: offset + fused_conv2d.add_input(fused_batch_norm.input(3)); // 4: mean + fused_conv2d.add_input(fused_batch_norm.input(4)); // 5: variance + + CopyAllAttrs(contraction, &fused_conv2d); + SetFusedOpAttributesWithActivation(&fused_conv2d, &activation, + {"FusedBatchNorm"}, 0); + auto* attr = fused_conv2d.mutable_attr(); + SetAttrValue(matched.epsilon, &(*attr)["epsilon"]); + SetAttrValue(4, &(*attr)["num_bn_args"]); + + utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder(); + Status status; + mutation->AddNode(std::move(fused_conv2d), &status); + TF_ABORT_IF_ERROR(status); + TF_ABORT_IF_ERROR(mutation->Apply()); + + (*invalidated_nodes)[matched.activation] = true; + (*nodes_to_delete)[matched.contraction] = true; + (*nodes_to_delete)[matched.fused_batch_norm] = true; + + return Status::OK(); +} + // Contraction + Mul(scale). // TODO(itex): Try to combine this function with Conv + BiasAdd Status AddFusedContractionNode(RemapperContext* ctx, @@ -6755,7 +6960,31 @@ Status RunRemapper(OptimizerContext* opt_ctx, const GrapplerItem& item, &invalidated_nodes, &nodes_to_delete)); continue; } + // NOTE: We can only fuse BatchNorm into Conv2D nodes. In theory we can do + // it for MatMul as well, but in practice this pattern does not appear in + // real Tensorflow graphs. + + // Remap Conv2D+FusedBatchNorm+Activation into the _FusedConv2D; + ContractionWithBatchNormAndActivation + contract_with_batch_norm_and_activation; + if (!is_layout_opt && + FindConv2DWithBatchNormAndActivation( + ctx, i, &contract_with_batch_norm_and_activation)) { + TF_RETURN_IF_ERROR( + AddFusedConv2DNode(&ctx, contract_with_batch_norm_and_activation, + &invalidated_nodes, &nodes_to_delete)); + continue; + } + // Remap Conv2D+FusedBatchNorm into the _FusedConv2D; + ContractionWithBatchNorm contract_with_batch_norm; + if (!is_layout_opt && + FindConv2DWithBatchNorm(ctx, i, &contract_with_batch_norm)) { + TF_RETURN_IF_ERROR(AddFusedConv2DNode(&ctx, contract_with_batch_norm, + &invalidated_nodes, + &nodes_to_delete)); + continue; + } // Remap FusedBatchNorm++ into the // _FusedBatchNormEx. FusedBatchNormEx fused_batch_norm_ex; diff --git a/itex/core/graph/utils/layout_utils.cc b/itex/core/graph/utils/layout_utils.cc index 83e495364..e62c1f7ea 100644 --- a/itex/core/graph/utils/layout_utils.cc +++ b/itex/core/graph/utils/layout_utils.cc @@ -215,6 +215,13 @@ bool RewriteFusedConv(const utils::MutableNodeView& node_view) { } bool RewriteOneDnnFusedConv(const utils::MutableNodeView& node_view) { + const NodeDef& node_def = *(node_view.node()); + std::vector fused_ops; + ITEX_CHECK_OK(GetNodeAttr(node_def, "fused_ops", &fused_ops)); + for (auto& post_op : fused_ops) { + if (post_op == "FusedBatchNorm") return false; + } + return RewriteFusedConv(node_view) && RewriteOneDnnConv(node_view); } diff --git a/itex/core/graph/utils/op_types.cc b/itex/core/graph/utils/op_types.cc index 72017952d..3b07ae43e 100644 --- a/itex/core/graph/utils/op_types.cc +++ b/itex/core/graph/utils/op_types.cc @@ -418,6 +418,12 @@ bool IsInstanceNorm(const NodeDef& node) { return node.op() == "_ITEXInstanceNorm"; } +bool IsITEXFusedBatchNorm(const NodeDef& node) { + const auto& op = node.op(); + return op == "_ITEXFusedBatchNorm" || op == "_ITEXFusedBatchNormV2" || + op == "_ITEXFusedBatchNormV3"; +} + bool IsLeakyRelu(const NodeDef& node) { return node.op() == "LeakyRelu"; } bool IsLeakyReluGrad(const NodeDef& node) { diff --git a/itex/core/graph/utils/op_types.h b/itex/core/graph/utils/op_types.h index 64b40869a..c9e45397e 100644 --- a/itex/core/graph/utils/op_types.h +++ b/itex/core/graph/utils/op_types.h @@ -121,6 +121,7 @@ bool IsImag(const NodeDef& node); bool IsImmutableConst(const NodeDef& node); bool IsInvGrad(const NodeDef& node); bool IsInstanceNorm(const NodeDef& node); +bool IsITEXFusedBatchNorm(const NodeDef& node); bool IsLeakyRelu(const NodeDef& node); bool IsLeakyReluGrad(const NodeDef& node); bool IsLess(const NodeDef& node); diff --git a/itex/core/kernels/common/conv_ops.h b/itex/core/kernels/common/conv_ops.h index 97ada9a85..73691aa2e 100644 --- a/itex/core/kernels/common/conv_ops.h +++ b/itex/core/kernels/common/conv_ops.h @@ -43,6 +43,16 @@ using ConvFwdPd = dnnl::convolution_forward::primitive_desc; #define DNNL_SIZE_DTYPE int64_t +namespace functor { +template +struct ComputeBNScale { + void operator()(const Device& d, typename TTypes::Vec buffer, + typename TTypes::ConstVec var, T variance_epsilon) { + buffer.device(d) = (var + var.constant(variance_epsilon)).rsqrt().eval(); + } +}; +} // namespace functor + class OneDnnConvUtil { protected: OpKernelContext* context_; // We don't own this. @@ -740,6 +750,19 @@ class ConvOpBase : public OpKernel { bias_mem_.set_data_handle(bias_data); } + if (post_op_util_.HasBN()) { + const Tensor& bn_scale_tensor = context->input(kInputIndex_BN_Scale_); + const Tensor& bn_mean_tensor = context->input(kInputIndex_BN_Mean_); + const Tensor& bn_offset_tensor = context->input(kInputIndex_BN_Offset_); + const Tensor& bn_var_tensor = context->input(kInputIndex_BN_Variance_); + functor::ComputeBNScale()( + context->eigen_device(), + cached_bn_rsqrt_tensor_.tensor(), + bn_var_tensor.vec(), bn_epsilon_); + post_op_util_.SetBNMemory(bn_scale_tensor, bn_mean_tensor, + bn_offset_tensor, cached_bn_rsqrt_tensor_); + } + // Reallocate scratchpad memory. OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum::v(), @@ -882,7 +905,6 @@ class ConvOpBase : public OpKernel { this->ExtendInt8PostOps(context); // Set post op attribution. dnnl::primitive_attr post_ops_attr; - post_op_util_.SetPostOpAttr(&post_ops_attr); post_ops_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user); if (std::is_same::value) { post_ops_attr.set_fpmath_mode(fp32_math_mode_); @@ -893,6 +915,40 @@ class ConvOpBase : public OpKernel { post_ops_attr.set_scales_mask(DNNL_ARG_WEIGHTS, 3); } + if (post_op_util_.HasBN()) { + // Since batchnorm do the following for each input x: + // scale * (x - mean) / sqrt(\sigma + \epsilon) + offset + // BatchNorm can be decomposed into four post ops: + // sub(mean), mul(rsqrt(\sigma + \epsilon)), mul(scale), add(offset) + const Tensor& bn_mean_tensor = context->input(kInputIndex_BN_Mean_); + // Inputs to FusedBatchNorm have same 1D shape + TensorShape fuse_bn_shape = bn_mean_tensor.shape(); + OP_REQUIRES(context, fuse_bn_shape.dims() == 1, + errors::InvalidArgument("FusedBatchNorm must be 1D, not: ", + fuse_bn_shape.DebugString())); + + // Note: oneDNN expects {1, C, 1, 1} for binary post-op even for NHWC + memory::dims fuse_bn_dims = {1, fuse_bn_shape.dim_size(0), 1, 1}; + post_op_util_.BuildBNContext(&data_layout, &fuse_bn_dims, + &onednn_engine_); + + const Tensor& bn_scale_tensor = context->input(kInputIndex_BN_Scale_); + const Tensor& bn_offset_tensor = context->input(kInputIndex_BN_Offset_); + const Tensor& bn_var_tensor = context->input(kInputIndex_BN_Variance_); + OP_REQUIRES_OK(context, context->allocate_temp( + DataTypeToEnum::v(), fuse_bn_shape, + &cached_bn_rsqrt_tensor_)); + functor::ComputeBNScale()( + context->eigen_device(), + cached_bn_rsqrt_tensor_.tensor(), + bn_var_tensor.vec(), bn_epsilon_); + post_op_util_.SetBNMemory(bn_scale_tensor, bn_mean_tensor, + bn_offset_tensor, cached_bn_rsqrt_tensor_); + post_op_util_.SetBNPostOpAttr(&post_ops_attr); + } else { + post_op_util_.SetPostOpAttr(&post_ops_attr); + } + fwd_pd_ = ConvFwdPd(onednn_engine_, prop_kind::forward, dnnl::algorithm::convolution_direct, src_md_opt, @@ -1041,6 +1097,9 @@ class ConvOpBase : public OpKernel { fwd_primitives_args_.insert({DNNL_ARG_WEIGHTS, filter_mem_}); fwd_primitives_args_.insert({DNNL_ARG_DST, dst_mem_opt_}); fwd_primitives_args_.insert({DNNL_ARG_SCRATCHPAD, scratchpad_mem_}); + if (post_op_util_.HasBN()) { + post_op_util_.AddBNPrimArgs(&fwd_primitives_args_); + } if (this->post_op_util_.HasOutputScales()) { float* output_scale_ptr = output_scale_cache_.GetCachedPtr( context, this->post_op_util_.GetOutputScale().data(), @@ -1091,6 +1150,10 @@ class ConvOpBase : public OpKernel { std::vector explicit_paddings_; bool is_conv2d_; const int kSrcIndex_ = 0, kFilterIndex_ = 1, kBiasIndex_ = 2, kAddIndex_ = 3; + // Input indices for FusedBatchNorm + const int kInputIndex_BN_Scale_ = 2, kInputIndex_BN_Offset_ = 3; + const int kInputIndex_BN_Mean_ = 4, kInputIndex_BN_Variance_ = 5; + const int kDstIndex_ = 0; PostOpUtil post_op_util_; @@ -1139,6 +1202,8 @@ class ConvOpBase : public OpKernel { // This one for dnnl primitive weight when weight need reorder. Tensor tmp_weight_; std::shared_ptr scratchpad_tensor_; + Tensor cached_bn_rsqrt_tensor_; + float bn_epsilon_; int64_t scratchpad_size_ = 0; bool enable_cache_ = false; @@ -1245,7 +1310,18 @@ class FusedConvOp : public ConvOpBasepost_op_util_.AddOps(fused_ops), errors::InvalidArgument("Found unsupported fusion in Fused Conv2D.")); - + if (this->post_op_util_.HasBN()) { + float epsilon; + int num_bn_args; + OP_REQUIRES_OK(context, context->GetAttr("epsilon", &epsilon)); + OP_REQUIRES_OK(context, context->GetAttr("num_bn_args", &num_bn_args)); + OP_REQUIRES( + context, num_bn_args == 4, + errors::InvalidArgument( + "Fused Conv2D with batchnorm must have 4 extra argument")); + this->post_op_util_.set_epsilon(epsilon); + this->bn_epsilon_ = epsilon; + } // Set alpha if get `LeakyRelu` after adding ops. if (this->post_op_util_.HasLeakyRelu()) { float alpha; diff --git a/itex/core/kernels/common/fused_batch_norm_op.h b/itex/core/kernels/common/fused_batch_norm_op.h index 6e1a1fb8b..631cdb684 100644 --- a/itex/core/kernels/common/fused_batch_norm_op.h +++ b/itex/core/kernels/common/fused_batch_norm_op.h @@ -523,7 +523,7 @@ class QuantizedFusedBatchNormOp context); if (out_dt_ == DT_QINT8) { // TODO(itex): here code may has some bugs. but we just follow Intel-TF - // implementation. It assumes the min/max of input & output of Batchnorm + // implementation. It assumes the min/max of input & output of BatchNorm // are the same. In reality, the assumption is not always true, but // currently, we don't receive model accuracy issue report. context->set_output(1, context->input(5)); diff --git a/itex/core/ops/nn_ops.cc b/itex/core/ops/nn_ops.cc index c980d160e..9c45b766f 100644 --- a/itex/core/ops/nn_ops.cc +++ b/itex/core/ops/nn_ops.cc @@ -322,9 +322,11 @@ void Register_ITEXFusedConv2DOp() { TF_OpDefinitionBuilderAddInput(op_builder, "input: T"); TF_OpDefinitionBuilderAddInput(op_builder, "filter: T"); TF_OpDefinitionBuilderAddInput(op_builder, "args: num_args * T"); + TF_OpDefinitionBuilderAddInput(op_builder, "bn_args: num_bn_args * float"); TF_OpDefinitionBuilderAddOutput(op_builder, "output: T"); TF_OpDefinitionBuilderAddAttr(op_builder, "T: {bfloat16, half, float}"); TF_OpDefinitionBuilderAddAttr(op_builder, "num_args: int >= 0"); + TF_OpDefinitionBuilderAddAttr(op_builder, "num_bn_args: int >= 0"); TF_OpDefinitionBuilderAddAttr(op_builder, "strides: list(int)"); TF_OpDefinitionBuilderAddAttr(op_builder, "is_filter_const: bool = false"); TF_OpDefinitionBuilderAddAttr(op_builder, diff --git a/itex/core/utils/onednn/onednn_post_op_util.cc b/itex/core/utils/onednn/onednn_post_op_util.cc index f08bdbe51..dbf8dff8d 100644 --- a/itex/core/utils/onednn/onednn_post_op_util.cc +++ b/itex/core/utils/onednn/onednn_post_op_util.cc @@ -17,8 +17,6 @@ limitations under the License. namespace itex { -using algorithm = dnnl::algorithm; -using kind = dnnl::primitive::kind; using memory = dnnl::memory; const std::vector& PostOpUtil::GetAllPostOpInfo() { @@ -106,6 +104,10 @@ bool PostOpUtil::AddOps(const std::vector& fused_ops) { // Simply record status of `BiasAdd` instead of putting it to table. if (name == "BiasAdd") { this->has_bias_ = true; + } else if (name == "FusedBatchNorm") { + // BatchNorm will also be fused in primitive construction. + // Just record it. + this->has_bn_ = true; } else if (name == "Quantized" || name == "Requantize" || name == "Dequantize") { // Handle Quantized kernel. @@ -125,6 +127,59 @@ bool PostOpUtil::AddOps(const std::vector& fused_ops) { return true; } +void PostOpUtil::BuildBNContext(memory::format_tag* data_layout, + memory::dims* fuse_bn_dims, + dnnl::engine* engine) { + auto reset_md = [&fuse_bn_dims, + &data_layout](std::shared_ptr& ptr) { + ptr.reset( + new memory::desc({*fuse_bn_dims}, OneDnnType(), *data_layout)); + }; + reset_md(bn_context_.bn_scale_md); + reset_md(bn_context_.bn_mean_md); + reset_md(bn_context_.bn_rsqrt_md); + reset_md(bn_context_.bn_offset_md); + + auto reset_mem = [&engine](std::shared_ptr& mem_ptr, + std::shared_ptr& md_ptr) { + mem_ptr.reset(new memory(*md_ptr, *engine, nullptr)); + }; + reset_mem(bn_context_.bn_scale_mem, bn_context_.bn_scale_md); + reset_mem(bn_context_.bn_mean_mem, bn_context_.bn_mean_md); + reset_mem(bn_context_.bn_offset_mem, bn_context_.bn_offset_md); + reset_mem(bn_context_.bn_rsqrt_mem, bn_context_.bn_rsqrt_md); +} + +void PostOpUtil::SetBNMemory(const Tensor& bn_scale_tensor, + const Tensor& bn_mean_tensor, + const Tensor& bn_offset_tensor, + const Tensor& bn_rsqrt_tensor) { + bn_context_.bn_scale_mem->set_data_handle( + GetTensorBuffer(&bn_scale_tensor)); + bn_context_.bn_mean_mem->set_data_handle( + GetTensorBuffer(&bn_mean_tensor)); + bn_context_.bn_rsqrt_mem->set_data_handle( + GetTensorBuffer(&bn_rsqrt_tensor)); + bn_context_.bn_offset_mem->set_data_handle( + GetTensorBuffer(&bn_offset_tensor)); +} + +void PostOpUtil::AddBNPrimArgs( + std::unordered_map* fwd_primitives_args_) { + fwd_primitives_args_->insert( + {DNNL_ARG_ATTR_MULTIPLE_POST_OP(0) | DNNL_ARG_SRC_1, + *bn_context_.bn_mean_mem}); + fwd_primitives_args_->insert( + {DNNL_ARG_ATTR_MULTIPLE_POST_OP(1) | DNNL_ARG_SRC_1, + *bn_context_.bn_rsqrt_mem}); + fwd_primitives_args_->insert( + {DNNL_ARG_ATTR_MULTIPLE_POST_OP(2) | DNNL_ARG_SRC_1, + *bn_context_.bn_scale_mem}); + fwd_primitives_args_->insert( + {DNNL_ARG_ATTR_MULTIPLE_POST_OP(3) | DNNL_ARG_SRC_1, + *bn_context_.bn_offset_mem}); +} + void PostOpUtil::SetLeakyReluAlpha(float alpha) { ITEX_CHECK(this->has_leaky_relu_) << "PostOpUtil: can't find LeakyRelu when set alpha"; @@ -239,6 +294,33 @@ void PostOpUtil::SetPostOpAttr(dnnl::primitive_attr* attr, } } +// Since batchnorm do the following for each input x: +// scale * (x - mean) / sqrt(\sigma + \epsilon) + offset +// BatchNorm can be decomposed into the following post ops: +// 1. A sub op `x - mean` +// 2. A mul op `1 / sqrt(\sigma + \epsilon) * x` +// 3. A mul op `scale * x`, +// 4. An add op `x + offset` +// where x means the output of previous step. +void PostOpUtil::SetBNPostOpAttr(dnnl::primitive_attr* attr, + const std::vector& md_list) { + ITEX_DCHECK(attr); + dnnl::post_ops post_ops = dnnl::post_ops(); + post_ops.append_binary(algorithm::binary_sub, *bn_context_.bn_mean_md); + post_ops.append_binary(algorithm::binary_mul, *bn_context_.bn_rsqrt_md); + post_ops.append_binary(algorithm::binary_mul, *bn_context_.bn_scale_md); + post_ops.append_binary(algorithm::binary_add, *bn_context_.bn_offset_md); + if (postop_scale_list_.size() != 0) { + SetPostOp(&post_ops, md_list); + } + attr->set_post_ops(post_ops); + if (has_output_scales_) { + if (output_scale_param_.scales.size()) { + attr->set_scales_mask(DNNL_ARG_WEIGHTS, output_scale_param_.mask); + } + } +} + const PostOpInfo* PostOpUtil::GetPostOpInfoByName( const absl::string_view op_name) { const std::vector& info_vec = PostOpUtil::GetAllPostOpInfo(); diff --git a/itex/core/utils/onednn/onednn_post_op_util.h b/itex/core/utils/onednn/onednn_post_op_util.h index a0caf6138..53a178c16 100644 --- a/itex/core/utils/onednn/onednn_post_op_util.h +++ b/itex/core/utils/onednn/onednn_post_op_util.h @@ -16,7 +16,9 @@ limitations under the License. #ifndef ITEX_CORE_UTILS_ONEDNN_ONEDNN_POST_OP_UTIL_H_ #define ITEX_CORE_UTILS_ONEDNN_ONEDNN_POST_OP_UTIL_H_ +#include #include +#include #include #include @@ -25,6 +27,10 @@ limitations under the License. namespace itex { +using kind = dnnl::primitive::kind; +using algorithm = dnnl::algorithm; +using memory = dnnl::memory; + // Helper data struct to record necessary info for post op fusion. struct PostOpInfo { string name; @@ -47,6 +53,22 @@ struct OutputScaleParam { std::vector scales; }; +// FusedBNContext storing fused batchnorm messages for contructing oneDNN +// post-ops operations +struct FusedBNContext { + // FusedBatchNorm related memory + std::shared_ptr bn_scale_mem; + std::shared_ptr bn_mean_mem; + std::shared_ptr bn_rsqrt_mem; + std::shared_ptr bn_offset_mem; + + // FusedBatchNorm related memory desc + std::shared_ptr bn_scale_md; + std::shared_ptr bn_mean_md; + std::shared_ptr bn_rsqrt_md; + std::shared_ptr bn_offset_md; +}; + class PostOpUtil { public: PostOpUtil() = default; @@ -60,6 +82,10 @@ class PostOpUtil { // Set alpha for `LeakyRelu`. // Will report error if no `LeakyRelu` in post ops. void SetLeakyReluAlpha(float alpha); + + // Set epsilon for `FusedBatchNorm` + void set_epsilon(float epsilon) { bn_epsilon_ = epsilon; } + // Set alpha/beta for `Linear`. // Will report error if no `Linear` in post ops. void SetLinearAlphaBeta(float alpha, float beta); @@ -81,6 +107,21 @@ class PostOpUtil { void SetPostOpAttr(dnnl::primitive_attr* attr, const std::vector& md_list = {}); + // Set batchnorm and post op attribution for `attr`. + void SetBNPostOpAttr(dnnl::primitive_attr* attr, + const std::vector& md_list = {}); + + // Build batchnorm context needed for post-ops primitive construction + void BuildBNContext(memory::format_tag* data_layout, + memory::dims* fuse_bn_dims, dnnl::engine* engine); + + // Set batchnorm post-ops onednn memory pointing to ready buffers + void SetBNMemory(const Tensor& bn_scale_tensor, const Tensor& bn_mean_tensor, + const Tensor& bn_offset_tensor, + const Tensor& bn_rsqrt_tensor); + + void AddBNPrimArgs(std::unordered_map* fwd_primitives_args_); + // Check the given elewise op is supported by oneDNN or not. static bool IsSupportedActivation(const absl::string_view op_name); @@ -88,6 +129,7 @@ class PostOpUtil { inline bool HasActivation() { return has_activation_; } inline bool HasAdd() { return has_add_; } inline bool HasBias() { return has_bias_; } + inline bool HasBN() { return has_bn_; } inline bool HasBinary() { return binary_num_ != 0; } inline bool HasLeakyRelu() { return has_leaky_relu_; } inline bool HasLinear() { return has_linear_; } @@ -127,6 +169,8 @@ class PostOpUtil { // Note `BiasAdd` is a special case, it doesn't have post op info because // it will be fused in primitive directly. bool has_bias_ = false; + // Whether have batchnorm fusion + bool has_bn_ = false; // Use this flag to check whether need to set alpha for `LeakyRelu`. bool has_leaky_relu_ = false; // Use this flag to check whether has linear post op @@ -136,12 +180,17 @@ class PostOpUtil { bool has_output_scales_ = false; bool has_requantize_ = false; + // Stores fused batchnorm constructing message + FusedBNContext bn_context_; + // Helper var for multilpe Binary post op fusion. int binary_num_ = 0; // Helper vars for post op execution. float leaky_relu_alpha_ = NAN; float linear_alpha_ = NAN; float linear_beta_ = NAN; + + float bn_epsilon_ = 0.0001; }; } // namespace itex diff --git a/itex/core/utils/onednn/onednn_util.h b/itex/core/utils/onednn/onednn_util.h index ec8433b34..28dffcd2a 100644 --- a/itex/core/utils/onednn/onednn_util.h +++ b/itex/core/utils/onednn/onednn_util.h @@ -298,7 +298,7 @@ inline dnnl::memory::format_tag OneDnnTensorFormatToTag( } /// Map TensorFlow data format into oneDNN data format. This is used in TF -/// kernels which have `data_format` attributes, such as Conv/Batchnorm/... +/// kernels which have `data_format` attributes, such as Conv/BatchNorm/... /// `TensorFormat` is original TF tensor attr, it's always NCHW or NHWC no /// matter the rank is 4D or 5D. /// @@ -368,7 +368,7 @@ inline dnnl::memory::dims TFShapeToOneDnnDims(const TensorShape& shape) { /// /// Commonly used in below scenarios: /// 1) Create oneDNN primitive from TF tensor in kernel which has `data_format` -/// attr, such as Conv/Batchnorm/Pooling; +/// attr, such as Conv/BatchNorm/Pooling; /// 2) Reorder TF/oneDNN tensors to same oneDNN format in kernel which has /// multiply inputs, such as AddN/Concat; /// diff --git a/test/tensorflow/python/grappler/remapper_test.py b/test/tensorflow/python/grappler/remapper_test.py index 95555c5b1..2cd53fdd8 100644 --- a/test/tensorflow/python/grappler/remapper_test.py +++ b/test/tensorflow/python/grappler/remapper_test.py @@ -39,10 +39,23 @@ from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import nn_impl from tensorflow.python.ops import random_ops from tensorflow.python.ops import variables from tensorflow.python.util import _pywrap_utils - +from intel_extension_for_tensorflow.python.ops.load_ops_library import load_ops_library + +Activation_op_dict ={ + "Elu":nn_ops.elu, + "LeakyRelu": nn_ops.leaky_relu, + "Relu":nn_ops.relu, + "Relu6": nn.relu6, + "Sigmoid": nn.sigmoid, + "Tanh": nn.tanh, + "GeluApproximate":load_ops_library.gelu, + "GeluExact":load_ops_library.gelu, +} def _input(shape): """Generates an input of a given shape.""" @@ -59,9 +72,15 @@ def _bias(shape): """Generates a bias of a given shape.""" return constant_op.constant(0.1, shape=shape) -def _conv2d(x, w): +def _conv2d(x, w, + data_format="NHWC", + padding="SAME", + strides=[1, 1, 1, 1], + dilations=[1, 1, 1, 1]): """Returns a 2d convolution layer with full stride.""" - return nn.conv2d(x, w, strides=[1, 1, 1, 1], padding='SAME') + if isinstance(padding, (list, tuple)): + padding = [(0, 0)] + padding + [(0, 0)] + return nn.conv2d(x, w, strides, padding, True, data_format, dilations) def _depthwise_conv2d(x, w): """Returns a 2d depthwise convolution layer with full stride.""" @@ -71,6 +90,16 @@ def _conv3d(x, w): """Returns a 3d convolution layer with full stride.""" return nn.conv3d(x, w, strides=[1, 1, 1, 1, 1], padding='SAME') +def _batch_norm(x, mean, + var, offset, scale, + epsilon=0.0001, + data_format = "NHWC"): + """Returns a batchnorm layer.""" + return nn_impl.fused_batch_norm(x, scale, + offset, mean=mean, + variance=var, epsilon=epsilon, + data_format=data_format, + is_training=False) def _get_config(remapping_on=False): """Returns a CongfigProto with remapper optimizer on/off.""" @@ -94,6 +123,49 @@ def _maybe_skip(self, mode): if mode == 'mkl' and not test_util.IsMklEnabled(): self.skipTest('MKL is not enabled.') + def _verify_value(self, output, + found_pattern, + expect_fused_ops=[], + atol=1e-5, rtol=1e-5): + run_options = config_pb2.RunOptions(output_partition_graphs=True) + metadata = config_pb2.RunMetadata() + # Compute reference value. + config = _get_config(remapping_on=False) + with session.Session(config=config) as sess: + sess.run(variables.global_variables_initializer()) + output_val_ref = sess.run( + output, options=run_options, run_metadata=metadata) + # Compute output with fusion. + config = _get_config(remapping_on=True) + with session.Session(config=config) as sess: + sess.run(variables.global_variables_initializer()) + output_val = sess.run(output, options=run_options, run_metadata=metadata) + graph = metadata.partition_graphs[0] + + # Graph should contain fused op. + found_fused_op = False + for node in graph.node: + if found_pattern in node.op: + found_fused_op = True + found_fused_node = node + break + self.assertTrue( + found_fused_op, "can not find " + found_pattern) + if expect_fused_ops != 0: + # Only when specifying expect_fused_ops need to verify fused ops. + fused_ops = found_fused_node.attr['fused_ops'].list.s + self.assertTrue(len(fused_ops) == len(expect_fused_ops), + "not match number of fused ops") + matched_patterns = True + for i, val in enumerate(fused_ops): + if val.decode('utf-8') != expect_fused_ops[i]: + matched_patterns = False + break + self.assertTrue(matched_patterns, "not match fused ops " + + str(fused_ops) + str(expect_fused_ops)) + # Computed output value should be close to reference value. + self.assertAllClose(output_val_ref, output_val, atol, rtol) + @test_util.run_deprecated_v1 @test_util.disable_xla('This test does not pass with XLA') def test_matmul_biasadd_gelu_fusion(self): @@ -283,8 +355,6 @@ def test_depth_conv2d_biassemantic_fusion(self): @test_util.disable_xla('This test does not pass with XLA') def test_depth_conv2d_bias_and_add_activation_fusion(self): is_bf16_supported = _pywrap_utils.IsBF16SupportedByOneDNNOnThisCPU() - run_options = config_pb2.RunOptions(output_partition_graphs=True) - metadata = config_pb2.RunMetadata() for precision in ('float32', 'bfloat16', 'float16'): if precision == 'bfloat16': @@ -319,102 +389,9 @@ def test_depth_conv2d_bias_and_add_activation_fusion(self): add_z = math_ops.add_n([z, add]) out = array_ops.identity(tf.math.sigmoid(add_z)) - # Compute reference value. - config = _get_config(remapping_on=False) - with session.Session(config=config) as sess: - sess.run(variables.global_variables_initializer()) - output_val_ref = sess.run( - out, options=run_options, run_metadata=metadata) - # Compute output with fusion. - config = _get_config(remapping_on=True) - with session.Session(config=config) as sess: - sess.run(variables.global_variables_initializer()) - output_val = sess.run(out, options=run_options, run_metadata=metadata) - graph = metadata.partition_graphs[0] - - # Graph should contain fused op. - found_fused_op = False - for node in graph.node: - if 'FusedDepthwiseConv2dNative' in node.op: - found_fused_op = True - found_fused_node = node - break - self.assertTrue( - found_fused_op, "this pattern has fusion issue!!") - fused_ops = found_fused_node.attr['fused_ops'].list.s - self.assertEqual( - len(fused_ops), 3, "the number of fused ops is not equal to 3") - existing_pattern = ( - fused_ops[0] == b'BiasAdd' - and fused_ops[1] == b'Add' - and fused_ops[2] == b'Sigmoid') - self.assertTrue(existing_pattern, "invalid fused ops") - - # Computed output value should be close to reference value. - tol = 1e-5 if precision == 'float32' else 1e-2 - self.assertAllClose(output_val_ref, output_val, atol=tol, rtol=tol) - - @test_util.run_deprecated_v1 - @test_util.disable_xla('This test does not pass with XLA') - def test_conv3d_biassemantic_fusion(self): - is_bf16_supported = _pywrap_utils.IsBF16SupportedByOneDNNOnThisCPU() - run_options = config_pb2.RunOptions(output_partition_graphs=True) - metadata = config_pb2.RunMetadata() - - for precision in ('float32', 'bfloat16', 'float16'): - if precision == 'bfloat16': - if not is_bf16_supported: - self.skipTest('Device do not support bfloat16') - - if precision == 'float16': - if not tf.config.list_physical_devices("XPU"): - self.skipTest('CPU do not support float16') - - for bias_shape in ((6,), (1, 1, 6), (1, 1, 1, 6)): - - ops.reset_default_graph() - x = _input((5, 8, 8, 8, 1)) - f = _weight([3, 3, 3, 1, 6]) - bias = constant_op.constant( - [0.13, 0.12, -0.1, 0.23, 0.19, 0.6], shape=bias_shape) - - if precision == 'bfloat16': - x = math_ops.cast(x, dtypes.bfloat16) - f = math_ops.cast(f, dtypes.bfloat16) - bias = math_ops.cast(bias, dtypes.bfloat16) - - if precision == 'float16': - x = math_ops.cast(x, dtypes.float16) - f = math_ops.cast(f, dtypes.float16) - bias = math_ops.cast(bias, dtypes.float16) - - y = _conv3d(x, f) - z = math_ops.add(bias, y) - out = array_ops.identity(z) - - # Compute reference value. - config = _get_config(remapping_on=False) - with session.Session(config=config) as sess: - sess.run(variables.global_variables_initializer()) - output_val_ref = sess.run( - out, options=run_options, run_metadata=metadata) - # Compute output with fusion. - config = _get_config(remapping_on=True) - with session.Session(config=config) as sess: - sess.run(variables.global_variables_initializer()) - output_val = sess.run(out, options=run_options, run_metadata=metadata) - graph = metadata.partition_graphs[0] - - # Graph should contain fused op. - found_fused_op = False - for node in graph.node: - if 'FusedConv3D' in node.op: - found_fused_op = 1 - self.assertTrue(found_fused_op) - - # Computed output value should be close to reference value. - tol = 1e-5 if precision == 'float32' else 1e-2 - self.assertAllClose(output_val_ref, output_val, atol=tol, rtol=tol) + tol = 1e-5 if precision == 'float32' else 1e-2 + self._verify_value(out, 'FusedDepthwiseConv2dNative', ['BiasAdd', 'Add', 'Sigmoid'], tol, tol) + @test_util.run_deprecated_v1 @test_util.disable_xla('This test does not pass with XLA') @@ -634,5 +611,166 @@ def testSpaceToBatchNDConv2dBatchToSpaceND(self): # Computed output value should be close to reference value. self.assertAllCloseAccordingToType(output_val_ref, output_val) + def _build_fused_conv2d_batchnorm_activation(self, + input_sizes, + weight_sizes, + padding = "SAME", + strides = [1, 1, 1, 1], + dilations = [1, 1, 1, 1], + data_format='NHWC', + activation=None): + os.environ['ITEX_LAYOUT_OPT'] = '0' + is_bf16_supported = _pywrap_utils.IsBF16SupportedByOneDNNOnThisCPU() + + if test.is_gpu_available(): + self.skipTest("Skip on GPU due to the pattern not supported") + + for precision in ('float32', 'bfloat16', 'float16'): + if precision == 'bfloat16': + if not is_bf16_supported: + self.skipTest('Device do not support bfloat16') + + if precision == 'float16': + if not tf.config.list_physical_devices("XPU"): + self.skipTest('CPU do not support float16') + + ops.reset_default_graph() + input = _input(input_sizes) + filter = _weight(weight_sizes) + + if precision == 'bfloat16': + input = math_ops.cast(input, dtypes.bfloat16) + filter = math_ops.cast(filter, dtypes.bfloat16) + + if precision == 'float16': + input = math_ops.cast(input, dtypes.float16) + filter = math_ops.cast(filter, dtypes.float16) + + conv_out = _conv2d(input, filter, + data_format=data_format, + padding=padding, + strides=strides, + dilations=dilations) + + bn_sizes = weight_sizes[3] + bn_scale = [0.2] * bn_sizes + bn_offset = [0.3] * bn_sizes + bn_mean = [0.1] * bn_sizes + bn_var = [4.0] * bn_sizes + + out, _, _ = _batch_norm(conv_out, mean = bn_mean, + var = bn_var, offset=bn_offset, + scale=bn_scale, data_format=data_format) + + if activation == 'GeluExact': + out = Activation_op_dict[activation](out, approximate=False) + elif activation is not None: + out = Activation_op_dict[activation](out) + out = array_ops.identity(out) + + tol = 1e-5 if precision == 'float32' else 1e-2 + if activation: + expect_fused_ops = ['FusedBatchNorm', activation] + else: + expect_fused_ops = ['FusedBatchNorm'] + self._verify_value(out, 'FusedConv2D', expect_fused_ops, tol, tol) + + @test_util.run_deprecated_v1 + @test_util.disable_xla('This test does not pass with XLA') + def test_conv2d_batchnorm_fusion(self): + self._build_fused_conv2d_batchnorm_activation( + input_sizes=[1, 2, 3, 3], + weight_sizes=[1, 2, 3, 3]) + + @test_util.run_deprecated_v1 + @test_util.disable_xla('This test does not pass with XLA') + def test_conv2d_batchnorm_nchw_fusion(self): + self._build_fused_conv2d_batchnorm_activation( + input_sizes=[1, 3, 3, 2], + weight_sizes=[1, 2, 3, 3], + data_format='NCHW') + + @test_util.run_deprecated_v1 + @test_util.disable_xla('This test does not pass with XLA') + def test_conv2d_batchnorm_elu_fusion(self): + self._build_fused_conv2d_batchnorm_activation( + input_sizes=[1, 2, 3, 3], + weight_sizes=[1, 2, 3, 3], + activation='Elu') + + @test_util.run_deprecated_v1 + @test_util.disable_xla('This test does not pass with XLA') + def test_conv2d_batchnorm_elu_nchw_fusion(self): + self._build_fused_conv2d_batchnorm_activation( + input_sizes=[1, 3, 2, 3], + weight_sizes=[1, 2, 3, 3], + data_format='NCHW', + activation='Elu') + + @test_util.run_deprecated_v1 + @test_util.disable_xla('This test does not pass with XLA') + def test_conv2d_batchnorm_leakyrelu_fusion(self): + self._build_fused_conv2d_batchnorm_activation( + input_sizes=[1, 3, 6, 1], + weight_sizes=[2, 2, 1, 1], + activation='LeakyRelu') + + @test_util.run_deprecated_v1 + @test_util.disable_xla('This test does not pass with XLA') + def test_conv2d_batchnorm_relu_fusion(self): + self._build_fused_conv2d_batchnorm_activation( + input_sizes=[1, 2, 2, 1], + weight_sizes=[2, 2, 1, 2], + activation='Relu') + + @test_util.run_deprecated_v1 + @test_util.disable_xla('This test does not pass with XLA') + def test_conv2d_batchnorm_relu_nchw_fusion(self): + self._build_fused_conv2d_batchnorm_activation( + input_sizes=[1, 1, 2, 2], + weight_sizes=[2, 2, 1, 2], + data_format='NCHW', + activation='Relu') + + @test_util.run_deprecated_v1 + @test_util.disable_xla('This test does not pass with XLA') + def test_conv2d_batchnorm_relu6_fusion(self): + self._build_fused_conv2d_batchnorm_activation( + input_sizes=[1, 2, 2, 1], + weight_sizes=[2, 2, 1, 2], + activation='Relu6') + + @test_util.run_deprecated_v1 + @test_util.disable_xla('This test does not pass with XLA') + def test_conv2d_batchnorm_sigmoid_fusion(self): + self._build_fused_conv2d_batchnorm_activation( + input_sizes=[1, 3, 2, 1], + weight_sizes=[1, 2, 1, 2], + activation='Sigmoid') + + @test_util.run_deprecated_v1 + @test_util.disable_xla('This test does not pass with XLA') + def test_conv2d_batchnorm_tanh_fusion(self): + self._build_fused_conv2d_batchnorm_activation( + input_sizes=[1, 3, 2, 1], + weight_sizes=[1, 2, 1, 2], + activation='Tanh') + + @test_util.run_deprecated_v1 + @test_util.disable_xla('This test does not pass with XLA') + def test_conv2d_batchnorm_geluapproximate_fusion(self): + self._build_fused_conv2d_batchnorm_activation( + input_sizes=[1, 2, 3, 3], + weight_sizes=[1, 2, 3, 3], + activation='GeluApproximate') + + @test_util.run_deprecated_v1 + @test_util.disable_xla('This test does not pass with XLA') + def test_conv2d_batchnorm_geluexact_fusion(self): + self._build_fused_conv2d_batchnorm_activation( + input_sizes=[1, 3, 6, 1], + weight_sizes=[2, 2, 1, 1], + activation='GeluExact') + if __name__ == '__main__': test.main()