diff --git a/intel_extension_for_pytorch/csrc/cpu/ideep/ideep/attributes.hpp b/intel_extension_for_pytorch/csrc/cpu/ideep/ideep/attributes.hpp index 9bb02458e..93b592ef4 100644 --- a/intel_extension_for_pytorch/csrc/cpu/ideep/ideep/attributes.hpp +++ b/intel_extension_for_pytorch/csrc/cpu/ideep/ideep/attributes.hpp @@ -65,6 +65,14 @@ struct attr_t : public dnnl::primitive_attr { return attr; } + static attr_t fuse_binary(algorithm alg, memory::desc src_desc) { + attr_t attr; + post_ops po; + po.append_binary(alg, src_desc); + attr.set_post_ops(po); + return attr; + } + static attr_t fuse_relu( float scale = 1.0, float alpha = 0.f, @@ -162,6 +170,7 @@ struct attr_t : public dnnl::primitive_attr { algorithm alg = algorithm::undef; float scale = 1.0, alpha = 1.0, beta = 0.0; + memory::desc binary_src_desc; auto akind = po.kind(index); switch (akind) { @@ -171,6 +180,9 @@ struct attr_t : public dnnl::primitive_attr { case kind::eltwise: po.get_params_eltwise(index, scale, alg, alpha, beta); break; + case kind::binary: + po.get_params_binary(index, alg, binary_src_desc); + break; default: error::wrap_c_api(dnnl_invalid_arguments, "could not get params"); break; @@ -243,6 +255,10 @@ struct attr_t : public dnnl::primitive_attr { utils::to_bytes(bytes, beta); bytes.append(1, '.'); utils::to_bytes(bytes, alg); + case kind::binary: + utils::to_bytes(bytes, akind); + bytes.append(1, '.'); + utils::to_bytes(bytes, alg); default: break; } diff --git a/intel_extension_for_pytorch/csrc/jit/cpu/kernels/Matmul.cpp b/intel_extension_for_pytorch/csrc/jit/cpu/kernels/Matmul.cpp index 2c8d4fe2d..848d93495 100644 --- a/intel_extension_for_pytorch/csrc/jit/cpu/kernels/Matmul.cpp +++ b/intel_extension_for_pytorch/csrc/jit/cpu/kernels/Matmul.cpp @@ -121,5 +121,28 @@ at::Tensor dil_matmul_div( } } +at::Tensor dil_bmm_add( + const at::Tensor& input, + const at::Tensor& batch1, + const at::Tensor& batch2, + const c10::Scalar& alpha) { +#if defined(IPEX_PROFILE_OP) + RECORD_FUNCTION("dil_bmm_add", std::vector({})); +#endif + auto batch1_dim = batch1.dim(); + auto batch2_dim = batch2.dim(); + if (batch1_dim == batch2_dim && batch1_dim >= 3) { + auto _input = input.is_contiguous() ? input : input.contiguous(); + ideep::tensor onednn_input = itensor_view_from_dense(_input); + + auto op_attr = ideep::attr_t::fuse_binary( + dnnl::algorithm::binary_add, onednn_input.get_desc()); + return bmm_impl( + batch1, batch2, at::Tensor(), op_attr, {onednn_input}, 1.0f); + } else { + return at::baddbmm(input, batch1, batch2); + } +} + } // namespace cpu } // namespace torch_ipex diff --git a/intel_extension_for_pytorch/csrc/jit/cpu/kernels/Matmul.h b/intel_extension_for_pytorch/csrc/jit/cpu/kernels/Matmul.h index bc74227f1..1f9a483b4 100644 --- a/intel_extension_for_pytorch/csrc/jit/cpu/kernels/Matmul.h +++ b/intel_extension_for_pytorch/csrc/jit/cpu/kernels/Matmul.h @@ -15,6 +15,7 @@ namespace jit { // So we fake some op namespaces to workaround that. namespace ipex { static auto matmul_div = Symbol::fromQualString("ipex::matmul_div"); +static auto bmm_add = Symbol::fromQualString("ipex::bmm_add"); } // namespace ipex @@ -36,5 +37,11 @@ at::Tensor dil_matmul_div( at::Tensor out_opt, const c10::Scalar& div_input); +at::Tensor dil_bmm_add( + const at::Tensor& input, + const at::Tensor& batch1, + const at::Tensor& batch2, + const c10::Scalar& alpha); + } // namespace cpu } // namespace torch_ipex diff --git a/intel_extension_for_pytorch/csrc/jit/cpu/passes/graph_rewrite.cpp b/intel_extension_for_pytorch/csrc/jit/cpu/passes/graph_rewrite.cpp index 71787de0c..bad1e6704 100644 --- a/intel_extension_for_pytorch/csrc/jit/cpu/passes/graph_rewrite.cpp +++ b/intel_extension_for_pytorch/csrc/jit/cpu/passes/graph_rewrite.cpp @@ -432,6 +432,42 @@ void replaceInteractionWithQInteraction(std::shared_ptr& graph) { } } +void fuseBmmAdd(std::shared_ptr& graph) { + std::array add_operators = {"add", "add_"}; + + auto bmm_add_rstring_v1 = R"( + graph(%input, %batch1, %batch2, %alpha): + %x = aten::bmm(%batch1, %batch2) + %res = aten::add(%x, %input, %alpha) + return (%res))"; + std::string bmm_add_fused = R"( + graph(%input, %batch1, %batch2, %alpha): + %res = ipex::bmm_add(%input, %batch1, %batch2, %alpha) + return (%res))"; + // fliter the unsupported case + auto fusion_filter = [](const Match& match, + const std::unordered_map& vmap) { + Node* node = match.anchor; + const auto& match_vmap = match.values_map; + + auto batch1 = node->input(1)->type()->cast(); + auto batch2 = node->input(2)->type()->cast(); + if (batch1->dim() != batch2->dim()) { + return false; + } + + if (batch1->dim().value() < 3) { + return false; + } + + return true; + }; + + SubgraphRewriter rewriter_add_v1; + rewriter_add_v1.RegisterRewritePattern(bmm_add_rstring_v1, bmm_add_fused); + rewriter_add_v1.runOnGraph(graph, fusion_filter); +} + void FuseConcatBnRelu(std::shared_ptr& graph) { std::string aten_concat_bn_relu = R"( graph(%input : Tensor[], %dim:int, %weight, %bias, %running_mean, %running_var, %training, %momentum, %eps, %cudnn_enabled): diff --git a/intel_extension_for_pytorch/csrc/jit/cpu/passes/graph_rewrite.h b/intel_extension_for_pytorch/csrc/jit/cpu/passes/graph_rewrite.h index 7c7761d2b..e345612ce 100644 --- a/intel_extension_for_pytorch/csrc/jit/cpu/passes/graph_rewrite.h +++ b/intel_extension_for_pytorch/csrc/jit/cpu/passes/graph_rewrite.h @@ -27,6 +27,8 @@ void FuseShuffle(std::shared_ptr& graph); void FuseMHAScoreCalc(std::shared_ptr& graph); void FuseLinearSwishCustomized(std::shared_ptr& graph); void replaceAtenMaxPool2dWithIpexMaxPool2d(std::shared_ptr& graph); +void fuseBmmAdd(std::shared_ptr& graph); + void replaceOpsWithAtenInplaceOps(std::shared_ptr& graph); void replaceAtenOpsWithIpexInplaceOps(std::shared_ptr& graph); void replaceAtenSoftmaxWithIpexSoftmax(std::shared_ptr& graph); diff --git a/intel_extension_for_pytorch/csrc/jit/cpu/passes/register_dnnl_jit_ops.cpp b/intel_extension_for_pytorch/csrc/jit/cpu/passes/register_dnnl_jit_ops.cpp index c858971b6..8683a6283 100644 --- a/intel_extension_for_pytorch/csrc/jit/cpu/passes/register_dnnl_jit_ops.cpp +++ b/intel_extension_for_pytorch/csrc/jit/cpu/passes/register_dnnl_jit_ops.cpp @@ -550,6 +550,23 @@ RegisterOperators op({ }, aliasAnalysisFromSchema()), + Operator( + "ipex::bmm_add(Tensor input, Tensor batch1, Tensor batch2, Scalar alpha) -> " + "Tensor", + [](const Node* node) -> Operation { + return [](Stack* stack) { + auto result = dil_bmm_add( + (std::move(peek(stack, 0, 4))).toTensor(), + (std::move(peek(stack, 1, 4))).toTensor(), + (std::move(peek(stack, 2, 4))).toTensor(), + (std::move(peek(stack, 3, 4))).toScalar()); + drop(stack, 4); + pack(stack, std::move(result)); + return 0; + }; + }, + aliasAnalysisFromSchema()), + Operator( "ipex::mha_scores_calc(Tensor q, Tensor k, Tensor rel_qk, Scalar " "alpha, " diff --git a/intel_extension_for_pytorch/csrc/jit/fusion_pass.cpp b/intel_extension_for_pytorch/csrc/jit/fusion_pass.cpp index f17885784..4bbebec27 100644 --- a/intel_extension_for_pytorch/csrc/jit/fusion_pass.cpp +++ b/intel_extension_for_pytorch/csrc/jit/fusion_pass.cpp @@ -355,6 +355,9 @@ void IPEXFusionPass(std::shared_ptr& graph) { // Multi-Head-Attention graph_rewrite::FuseMHAScoreCalc(graph); + // Fuse bmm + add for bmm_add + graph_rewrite::fuseBmmAdd(graph); + // Replace _convolution with conv2d or conv3d graph_rewrite_helper::replaceConvolutionWithAtenConv(graph); diff --git a/tests/cpu/test_jit.py b/tests/cpu/test_jit.py index 45dc9cbe8..fa75f4a9b 100644 --- a/tests/cpu/test_jit.py +++ b/tests/cpu/test_jit.py @@ -621,6 +621,15 @@ def forward(self, x): else: return mm_res.div(torch.ones(mm_res_shape,dtype=x.dtype)+1) +class BmmAdd(nn.Module): + def __init__(self): + super(BmmAdd, self).__init__() + + def forward(self, input, batch1, batch2): + bmm_res = torch.bmm(batch1, batch2) + res = torch.add(bmm_res, input) + return res + class MHAScoresCalculation(nn.Module): def __init__(self, dim_per_head, softmax_dim=-1): super(MHAScoresCalculation, self).__init__() @@ -2476,6 +2485,17 @@ def test_matmul_div(self): kind_not_in_graph=None, prec=5e-3) + def test_bmm_add(self): + M = torch.randn(10, 3, 5) + batch1 = torch.randn(10, 3, 4) + batch2 = torch.randn(10, 4, 5) + mod = BmmAdd() + traced_mod = torch.jit.trace(mod, (M, batch1, batch2)) + fused_mod = traced_mod.graph_for(M, batch1, batch2) + out = traced_mod(M, batch1, batch2) + expected = torch.baddbmm(M, batch1, batch2) + self.assertTrue(torch.allclose(out, expected)) + def test_ipex_softmax(self): self._test_output( AtenSoftmaxRepalce(),