diff --git a/paddle/fluid/operators/mkldnn/mul_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/mul_mkldnn_op.cc index f667c9809df04..86395b0465d03 100644 --- a/paddle/fluid/operators/mkldnn/mul_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/mul_mkldnn_op.cc @@ -489,83 +489,6 @@ class MulMKLDNNKernel : public framework::OpKernel { } }; -template -class MulGradMKLDNNKernel : public MulMKLDNNKernel { - public: - void Compute(const ExecutionContext &ctx) const override { RunKernel(ctx); } - - private: - void RunKernel(const ExecutionContext &ctx) const { - const auto &dev_ctx = ctx.template device_context(); - const auto &onednn_engine = dev_ctx.GetEngine(); - - const auto *x = ctx.Input("X"); - const auto *y = ctx.Input("Y"); - const auto *dout = - ctx.Input(framework::GradVarName("Out")); - - auto *dx = ctx.Output(framework::GradVarName("X")); - auto *dy = ctx.Output(framework::GradVarName("Y")); - - int x_num_col_dims = ctx.Attr("x_num_col_dims"); - int y_num_col_dims = ctx.Attr("y_num_col_dims"); - - const Tensor x_matrix = x->dims().size() > 2 - ? framework::ReshapeToMatrix(*x, x_num_col_dims) - : static_cast(*x); - const Tensor y_matrix = y->dims().size() > 2 - ? framework::ReshapeToMatrix(*y, y_num_col_dims) - : static_cast(*y); - - Tensor dout_matrix = *dout; - dout_matrix.Resize({phi::flatten_to_2d(x->dims(), x_num_col_dims)[0], - phi::flatten_to_2d(y->dims(), y_num_col_dims)[1]}); - - // adding mb dim because MatMulV2 handler needs it - std::vector x_dims(3, 1); - std::vector y_dims(3, 1); - std::vector dout_dims(3, 1); - - x_dims[1] = x_matrix.dims()[0]; - x_dims[2] = x_matrix.dims()[1]; - - y_dims[1] = y_matrix.dims()[0]; - y_dims[2] = y_matrix.dims()[1]; - - dout_dims[1] = dout_matrix.dims()[0]; - dout_dims[2] = dout_matrix.dims()[1]; - - if (dx != nullptr) { - dx->set_lod(x->lod()); - this->ExecuteMatMul(ctx, - dev_ctx, - onednn_engine, - ctx.GetPlace(), - &dout_matrix, - dout_dims, - false, - &y_matrix, - y_dims, - true, - static_cast(dx)); - } - if (dy != nullptr) { - dy->set_lod(y->lod()); - this->ExecuteMatMul(ctx, - dev_ctx, - onednn_engine, - ctx.GetPlace(), - &x_matrix, - x_dims, - true, - &dout_matrix, - dout_dims, - false, - static_cast(dy)); - } - } -}; - } // namespace operators } // namespace paddle @@ -578,9 +501,3 @@ REGISTER_OP_KERNEL(mul, ops::MulMKLDNNINT8Kernel, ops::MulMKLDNNKernel, ops::MulMKLDNNKernel); - -REGISTER_OP_KERNEL(mul_grad, - MKLDNN, - ::paddle::platform::CPUPlace, - ops::MulGradMKLDNNKernel, - ops::MulGradMKLDNNKernel); diff --git a/paddle/phi/backends/onednn/onednn_reuse.h b/paddle/phi/backends/onednn/onednn_reuse.h index bd3d3f30f7a44..bc88fef443df2 100644 --- a/paddle/phi/backends/onednn/onednn_reuse.h +++ b/paddle/phi/backends/onednn/onednn_reuse.h @@ -1912,6 +1912,47 @@ class MatmulOneDNNHandler } }; +template +static void ExecuteMul(const OneDNNContext& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const std::vector& x_dims, + const std::vector& y_dims, + bool trans_x, + bool trans_y, + DenseTensor* out) { + static const std::vector vec_placeholder; + MatmulOneDNNHandler handler(dev_ctx, + x_dims, + y_dims, + trans_x, + trans_y, + vec_placeholder, + vec_placeholder, + false); + + const auto src_memory_p = handler.AcquireSrcMemory(&x); + const auto weights_memory_p = handler.AcquireWeightsMemory(&y); + const auto dst_memory_p = handler.AcquireDstMemory(dev_ctx, out); + + auto matmul_p = handler.AcquireForwardPrimitive(); + + std::unordered_map matmul_args = { + {DNNL_ARG_SRC, *src_memory_p}, + {DNNL_ARG_WEIGHTS, *weights_memory_p}, + {DNNL_ARG_DST, *dst_memory_p}}; + + auto& astream = OneDNNContext::tls().get_stream(); + matmul_p->execute(astream, matmul_args); + astream.wait(); + + // This kernel is flattening dims so then we need to unflattened version + // that should be set in out reshape require plain layout, but + // MatmulV2MKLDNNHanlder enforces one so it should work + out->set_mem_desc( + dst_memory_p->get_desc().reshape(vectorize(out->dims()))); +} + template void ExecuteMatmul(const OneDNNContext& dev_ctx, const DenseTensor& x, diff --git a/paddle/phi/kernels/onednn/matmul_grad_kernel.cc b/paddle/phi/kernels/onednn/matmul_grad_kernel.cc index 47807f156b18f..ceb752f6d41be 100644 --- a/paddle/phi/kernels/onednn/matmul_grad_kernel.cc +++ b/paddle/phi/kernels/onednn/matmul_grad_kernel.cc @@ -153,6 +153,49 @@ void MatmulGradKernel(const Context &dev_ctx, dy->Resize(y.dims()); } +template +void MatmulWithFlattenGradKernel(const Context &dev_ctx, + const DenseTensor &x, + const DenseTensor &y, + const DenseTensor &out_grad, + int x_num_col_dims, + int y_num_col_dims, + DenseTensor *x_grad, + DenseTensor *y_grad) { + const DenseTensor reshaped_y = + paddle::framework::ReshapeToMatrix(y, y_num_col_dims); + const DenseTensor reshaped_x = + paddle::framework::ReshapeToMatrix(x, x_num_col_dims); + const DenseTensor x_matrix = x.dims().size() > 2 ? reshaped_x : x; + const DenseTensor y_matrix = y.dims().size() > 2 ? reshaped_y : y; + + DenseTensor dout_matrix = out_grad; + dout_matrix.Resize({flatten_to_2d(x.dims(), x_num_col_dims)[0], + flatten_to_2d(y.dims(), y_num_col_dims)[1]}); + + // adding mb dim because MatMulV2 handler needs it + std::vector x_dims(3, 1); + std::vector y_dims(3, 1); + std::vector dout_dims(3, 1); + x_dims[1] = x_matrix.dims()[0]; + x_dims[2] = x_matrix.dims()[1]; + y_dims[1] = y_matrix.dims()[0]; + y_dims[2] = y_matrix.dims()[1]; + dout_dims[1] = dout_matrix.dims()[0]; + dout_dims[2] = dout_matrix.dims()[1]; + + if (x_grad != nullptr) { + x_grad->set_lod(x.lod()); + funcs::ExecuteMul( + dev_ctx, dout_matrix, y_matrix, dout_dims, y_dims, false, true, x_grad); + } + if (y_grad != nullptr) { + y_grad->set_lod(y.lod()); + funcs::ExecuteMul( + dev_ctx, x_matrix, dout_matrix, x_dims, dout_dims, true, false, y_grad); + } +} + } // namespace phi PD_REGISTER_KERNEL(matmul_grad, @@ -161,3 +204,10 @@ PD_REGISTER_KERNEL(matmul_grad, phi::MatmulGradKernel, float, phi::dtype::bfloat16) {} + +PD_REGISTER_KERNEL(matmul_with_flatten_grad, + OneDNN, + ONEDNN, + phi::MatmulWithFlattenGradKernel, + float, + phi::dtype::bfloat16) {}