diff --git a/cmake/pten_kernel.cmake b/cmake/pten_kernel.cmake index 947defcea4a61..f962c1332093a 100644 --- a/cmake/pten_kernel.cmake +++ b/cmake/pten_kernel.cmake @@ -79,6 +79,9 @@ function(kernel_library TARGET) endif() list(APPEND all_srcs ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.h) + if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/impl/${TARGET}_impl.h) + list(APPEND all_srcs ${CMAKE_CURRENT_SOURCE_DIR}/impl/${TARGET}_impl.h) + endif() list(APPEND all_srcs ${common_srcs}) list(APPEND all_srcs ${cpu_srcs}) list(APPEND all_srcs ${gpu_srcs}) diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 2d2e198ef40ec..f208fa7b8a4aa 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -1878,16 +1878,32 @@ void OperatorWithKernel::BuildPtenKernelContext( // Otherwise,we will create new storage. for (size_t offset = 0; offset < outs_vector.size(); ++offset) { if (current_vector_size > start_idx + offset) { - experimental::ReMakePtenDenseTensorFromVar( - outs_vector[offset], out_def, + auto* buffer_tensor = pt_kernel_context_->MutableOutputAt(start_idx + - offset)); + offset); + if (buffer_tensor) { + experimental::ReMakePtenDenseTensorFromVar(outs_vector[offset], + out_def, buffer_tensor); + } } else { pt_kernel_context_->EmplaceBackOutputWithoutSetRange( experimental::MakePtenTensorBaseFromVar(outs_vector[offset], out_def)); } } + + // Deal with the case that some outputs are NULL when run the kernel. + // For example : the outputs of matmul_grad are dx and dy, + // sometimes dx or dy may be NULL. + if (outs_vector.empty()) { + if (current_vector_size > start_idx) { + pt_kernel_context_->SetOutputWithoutSetRange(start_idx, {nullptr}); + } else { + pt_kernel_context_->EmplaceBackOutputWithoutSetRange({nullptr}); + } + end_idx = start_idx + 1; + } + pt_kernel_context_->AssignOutputRange(std::make_pair(start_idx, end_idx), i); } @@ -2000,7 +2016,9 @@ void OperatorWithKernel::WriteBackToOutputs(RuntimeContext* ctx) const { range_pair.first, range_pair.second); for (size_t j = 0; j < pten_outs.size(); ++j) { - experimental::MakeVariableFromPtenTensor(pten_outs[j], outs_vector[j]); + if (pten_outs[j]) { + experimental::MakeVariableFromPtenTensor(pten_outs[j], outs_vector[j]); + } } } } diff --git a/paddle/fluid/framework/pten_utils.cc b/paddle/fluid/framework/pten_utils.cc index b8aedcce3e3fa..b2fccbe43ab7a 100644 --- a/paddle/fluid/framework/pten_utils.cc +++ b/paddle/fluid/framework/pten_utils.cc @@ -99,7 +99,7 @@ KernelSignatureMap& KernelSignatureMap::Instance() { const auto& op_type = pair.first; const auto* op_proto = pair.second.proto_; if (pten::KernelFactory::Instance().HasCompatiblePtenKernel(op_type) && - op_proto != nullptr) { + op_proto) { KernelArgsNameMakerByOpProto maker(op_proto); VLOG(10) << "Register kernel signature for " << op_type; auto success = kernel_signature_map_->map_ diff --git a/paddle/fluid/imperative/prepared_operator.cc b/paddle/fluid/imperative/prepared_operator.cc index c5623a8f4f243..f27faaceceed7 100644 --- a/paddle/fluid/imperative/prepared_operator.cc +++ b/paddle/fluid/imperative/prepared_operator.cc @@ -338,19 +338,41 @@ static void BuildDygraphPtenKernelContext( for (size_t i = 0; i < output_names.size(); ++i) { auto& out_def = output_defs.at(i); - auto& outs_vector = outs.at(output_names[i]); size_t start_idx = (i == 0 ? 0 : kernel_ctx->OutputRangeAt(i - 1).second); - size_t end_idx = start_idx + outs_vector.size(); auto current_vector_size = kernel_ctx->OutputsSize(); + + auto iter = outs.find(output_names[i]); + if (iter == outs.end()) { + if (current_vector_size > start_idx) { + kernel_ctx->SetOutputWithoutSetRange(start_idx, {nullptr}); + } else { + kernel_ctx->EmplaceBackOutputWithoutSetRange({nullptr}); + } + kernel_ctx->AssignOutputRange(std::make_pair(start_idx, start_idx + 1), + i); + continue; + } + + auto& outs_vector = iter->second; + size_t end_idx = start_idx + outs_vector.size(); + // If the memory needed is less than the current memory allocated, we will // reuse the current memory by using ReMakePtenDenseTensorFromVar. // Otherwise,we will create new storage. for (size_t offset = 0; offset < outs_vector.size(); ++offset) { if (current_vector_size > start_idx + offset) { - experimental::ReMakePtenDenseTensorFromVar( - outs_vector[offset]->MutableVar(), out_def, - kernel_ctx->MutableOutputAt(start_idx + offset)); + auto* buffer_tensor = + kernel_ctx->MutableOutputAt(start_idx + offset); + if (buffer_tensor) { + experimental::ReMakePtenDenseTensorFromVar( + outs_vector[offset]->MutableVar(), out_def, buffer_tensor); + } else { + kernel_ctx->SetOutputWithoutSetRange( + start_idx + offset, + experimental::MakePtenTensorBaseFromVar( + outs_vector[offset]->MutableVar(), out_def)); + } } else { kernel_ctx->EmplaceBackOutputWithoutSetRange( experimental::MakePtenTensorBaseFromVar( @@ -465,15 +487,18 @@ static void WriteBackToOutputs( auto& output_names = std::get<2>(pt_kernel_signature.args); for (size_t i = 0; i < output_names.size(); ++i) { - auto& outs_vector = outs.at(output_names[i]); + auto iter = outs.find(output_names[i]); + if (iter != outs.end()) { + auto& outs_vector = iter->second; - auto& range_pair = kernel_ctx->OutputRangeAt(i); - auto pten_outs = kernel_ctx->MutableOutputBetween( - range_pair.first, range_pair.second); + auto& range_pair = kernel_ctx->OutputRangeAt(i); + auto pten_outs = kernel_ctx->MutableOutputBetween( + range_pair.first, range_pair.second); - for (size_t j = 0; j < pten_outs.size(); ++j) { - experimental::MakeVariableFromPtenTensor(pten_outs[j], - outs_vector[j]->MutableVar()); + for (size_t j = 0; j < pten_outs.size(); ++j) { + experimental::MakeVariableFromPtenTensor(pten_outs[j], + outs_vector[j]->MutableVar()); + } } } } @@ -530,6 +555,7 @@ static void PreparedOpRunImpl( template static void PreparedOpRunPtImpl( const framework::OperatorBase& op, + const framework::OpKernelType& kernel_type, const framework::KernelSignature& pt_kernel_signature, const pten::Kernel& pt_kernel, pten::KernelContext* pt_kernel_context, platform::DeviceContext* dev_ctx, const NameVarMap& ins, @@ -560,7 +586,9 @@ static void PreparedOpRunPtImpl( pt_kernel_context->ClearData(); // TODO(chenweihang): add debug flags later - // TODO(chenweihang): deal with complex cases later + if (framework::IsComplexType(kernel_type.data_type_)) { + HandleComplexGradToRealGrad(outs); + } } void PreparedOp::Run(const NameVarMap& ins, @@ -568,9 +596,9 @@ void PreparedOp::Run(const NameVarMap& ins, const framework::AttributeMap& attrs, const framework::AttributeMap& default_attrs) { if (run_pten_kernel_) { - PreparedOpRunPtImpl(op_, pt_kernel_signature_, pt_kernel_, - pt_kernel_context_, dev_ctx_, ins, outs, attrs, - default_attrs); + PreparedOpRunPtImpl(op_, kernel_type_, pt_kernel_signature_, + pt_kernel_, pt_kernel_context_, dev_ctx_, ins, + outs, attrs, default_attrs); } else { PreparedOpRunImpl(op_, ctx_, kernel_type_, func_, dev_ctx_, ins, outs, attrs, default_attrs); @@ -582,9 +610,9 @@ void PreparedOp::Run(const NameVarMap& ins, const framework::AttributeMap& attrs, const framework::AttributeMap& default_attrs) { if (run_pten_kernel_) { - PreparedOpRunPtImpl(op_, pt_kernel_signature_, pt_kernel_, - pt_kernel_context_, dev_ctx_, ins, - outs, attrs, default_attrs); + PreparedOpRunPtImpl( + op_, kernel_type_, pt_kernel_signature_, pt_kernel_, pt_kernel_context_, + dev_ctx_, ins, outs, attrs, default_attrs); } else { PreparedOpRunImpl(op_, ctx_, kernel_type_, func_, dev_ctx_, ins, outs, attrs, default_attrs); diff --git a/paddle/fluid/operators/conj_op.h b/paddle/fluid/operators/conj_op.h index 1012e9383f607..381f4cb66b3cd 100644 --- a/paddle/fluid/operators/conj_op.h +++ b/paddle/fluid/operators/conj_op.h @@ -39,7 +39,7 @@ class ConjKernel : public framework::OpKernel { auto pt_out = paddle::experimental::MakePtenDenseTensor(*out); // call new kernel - pten::Conj(dev_ctx, *pt_x.get(), pt_out.get()); + pten::ConjKernel(dev_ctx, *pt_x.get(), pt_out.get()); } }; diff --git a/paddle/fluid/operators/dot_op.cc b/paddle/fluid/operators/dot_op.cc index 31acd9718115c..e1463c8ccb58e 100644 --- a/paddle/fluid/operators/dot_op.cc +++ b/paddle/fluid/operators/dot_op.cc @@ -117,6 +117,13 @@ class DotGradOp : public framework::OperatorWithKernel { ctx, framework::GradVarName("Out")), ctx.GetPlace()); } + + framework::KernelSignature GetExpectedPtenKernelArgs( + const framework::ExecutionContext& ctx) const override { + return framework::KernelSignature( + "dot_grad", {"X", "Y", framework::GradVarName("Out")}, {}, + {framework::GradVarName("X"), framework::GradVarName("Y")}); + } }; template diff --git a/paddle/fluid/operators/dot_op.h b/paddle/fluid/operators/dot_op.h index f6877c57a5c18..02ba57ef8d495 100644 --- a/paddle/fluid/operators/dot_op.h +++ b/paddle/fluid/operators/dot_op.h @@ -22,217 +22,14 @@ // only can include the headers in paddle/pten/api dirs #include "paddle/pten/api/lib/utils/tensor_utils.h" #include "paddle/pten/include/core.h" -#include "paddle/pten/include/linalg.h" +#include "paddle/pten/kernels/dot_grad_kernel.h" +#include "paddle/pten/kernels/dot_kernel.h" namespace paddle { namespace operators { using Tensor = framework::Tensor; -template -struct P { - void operator()(T a, R b); -}; - -template -struct DotGradFunction { - void operator()(const Tensor* tensor_x, const Tensor* tensor_y, - const Tensor* tensor_dout, Tensor* tensor_dx, - Tensor* tensor_dy, - const paddle::framework::ExecutionContext& ctx); -}; - -template -struct DotGradFunction> { - void operator()(const Tensor* tensor_x, const Tensor* tensor_y, - const Tensor* tensor_dout, Tensor* tensor_dx, - Tensor* tensor_dy, - const paddle::framework::ExecutionContext& ctx) { -#if defined(__NVCC__) || defined(__HIPCC__) - if (1 == tensor_dout->dims().size()) { - auto dout = framework::EigenVector::Flatten(*tensor_dout); - - if (tensor_dx) { - auto y = framework::EigenVector::Flatten(*tensor_y); - auto& dev_raw = ctx.template device_context(); - auto& dev = *dev_raw.eigen_device(); - Eigen::DSizes size(tensor_dx->numel()); - - paddle::platform::ForRange for_range(dev_raw, - tensor_y->numel()); - math::ConjFunctor functor(tensor_y->data(), tensor_y->numel(), - tensor_dx->data()); - for_range(functor); - auto dx = framework::EigenVector::Flatten(*tensor_dx); - - dx.device(dev) = dx * dout.broadcast(size); - } - - if (tensor_dy) { - auto x = framework::EigenVector::Flatten(*tensor_x); - auto& dev_raw = ctx.template device_context(); - auto& dev = *dev_raw.eigen_device(); - Eigen::DSizes size(tensor_dy->numel()); - - paddle::platform::ForRange for_range(dev_raw, - tensor_y->numel()); - math::ConjFunctor functor(tensor_x->data(), tensor_x->numel(), - tensor_dy->data()); - for_range(functor); - auto dy = framework::EigenVector::Flatten(*tensor_dy); - - dy.device(dev) = dy * dout.broadcast(size); - } - } else { - auto dout = framework::EigenMatrix::From(*tensor_dout); - - if (tensor_dx) { - tensor_dx->mutable_data(ctx.GetPlace()); - auto y = framework::EigenMatrix::From(*tensor_y); - auto& dev_raw = ctx.template device_context(); - auto& dev = *dev_raw.eigen_device(); - Eigen::DSizes size(1, tensor_dx->dims()[1]); - - paddle::platform::ForRange for_range(dev_raw, - tensor_y->numel()); - math::ConjFunctor functor(tensor_y->data(), tensor_y->numel(), - tensor_dx->data()); - for_range(functor); - auto dx = framework::EigenMatrix::From(*tensor_dx); - - dx.device(dev) = dx * dout.broadcast(size); - } - - if (tensor_dy) { - tensor_dy->mutable_data(ctx.GetPlace()); - auto x = framework::EigenMatrix::From(*tensor_x); - auto& dev_raw = ctx.template device_context(); - auto& dev = *dev_raw.eigen_device(); - Eigen::DSizes size(1, tensor_dy->dims()[1]); - - paddle::platform::ForRange for_range(dev_raw, - tensor_x->numel()); - math::ConjFunctor functor(tensor_x->data(), tensor_x->numel(), - tensor_dy->data()); - for_range(functor); - - auto dy = framework::EigenMatrix::From(*tensor_dy); - - dy.device(dev) = dy * dout.broadcast(size); - } - } -#else - const auto* data_dout = tensor_dout->data(); - - if (tensor_dx) { - auto* data_dx = tensor_dx->mutable_data(ctx.GetPlace()); - const auto* data_y = tensor_y->data(); - const framework::DDim& dim = tensor_x->dims(); - size_t N = static_cast(framework::product(dim)); - - auto step = dim[dim.size() - 1]; - - int s = -1; - for (size_t i = 0; i < N; ++i) { - if (0 == i % step) ++s; - data_dx[i] = T(data_y[i].real, -data_y[i].imag) * data_dout[s]; - } - } - - if (tensor_dy) { - auto* data_dy = tensor_dy->mutable_data(ctx.GetPlace()); - const auto* data_x = tensor_x->data(); - const framework::DDim& dim = tensor_y->dims(); - size_t N = static_cast(framework::product(dim)); - - auto step = dim[dim.size() - 1]; - - int s = -1; - for (size_t i = 0; i < N; ++i) { - if (0 == i % step) ++s; - data_dy[i] = T(data_x[i].real, -data_x[i].imag) * data_dout[s]; - } - } -#endif - } -}; - -template -struct DotGradFunction> { - void operator()(const Tensor* tensor_x, const Tensor* tensor_y, - const Tensor* tensor_dout, Tensor* tensor_dx, - Tensor* tensor_dy, - const paddle::framework::ExecutionContext& ctx) { -#if defined(__NVCC__) || defined(__HIPCC__) - if (1 == tensor_dout->dims().size()) { - auto dout = framework::EigenVector::Flatten(*tensor_dout); - - if (tensor_dx) { - auto y = framework::EigenVector::Flatten(*tensor_y); - auto dx = framework::EigenVector::Flatten(*tensor_dx); - auto& dev = - *ctx.template device_context().eigen_device(); - Eigen::DSizes size(tensor_dx->numel()); - dx.device(dev) = y * dout.broadcast(size); - } - - if (tensor_dy) { - auto x = framework::EigenVector::Flatten(*tensor_x); - auto dy = framework::EigenVector::Flatten(*tensor_dy); - auto& dev = - *ctx.template device_context().eigen_device(); - Eigen::DSizes size(tensor_dy->numel()); - dy.device(dev) = x * dout.broadcast(size); - } - } else { - auto dout = framework::EigenMatrix::From(*tensor_dout); - - if (tensor_dx) { - tensor_dx->mutable_data(ctx.GetPlace()); - auto y = framework::EigenMatrix::From(*tensor_y); - auto dx = framework::EigenMatrix::From(*tensor_dx); - auto& dev = - *ctx.template device_context().eigen_device(); - Eigen::DSizes size(1, tensor_dx->dims()[1]); - dx.device(dev) = y * dout.broadcast(size); - } - - if (tensor_dy) { - tensor_dy->mutable_data(ctx.GetPlace()); - auto x = framework::EigenMatrix::From(*tensor_x); - auto dy = framework::EigenMatrix::From(*tensor_dy); - auto& dev = - *ctx.template device_context().eigen_device(); - Eigen::DSizes size(1, tensor_dy->dims()[1]); - dy.device(dev) = x * dout.broadcast(size); - } - } -#else - auto const *x = tensor_x->data(), *y = tensor_y->data(), - *dz = tensor_dout->data(); - auto&& d = tensor_x->dims(); - auto const N = tensor_x->numel(); - auto const B = d[d.size() - 1]; - - if (tensor_dx) { - auto* dx = tensor_dx->mutable_data(ctx.GetPlace()); - for (auto j = 0; j < N / B; ++j) { - auto const ss = dz[j]; - for (auto i = 0; i < B; ++i) *dx++ = *y++ * ss; - } - } - - if (tensor_dy) { - auto* dy = tensor_dy->mutable_data(ctx.GetPlace()); - for (auto j = 0; j < N / B; ++j) { - auto const ss = dz[j]; - for (auto i = 0; i < B; i++) *dy++ = *x++ * ss; - } - } -#endif - } -}; - // See Note [ Why still keep the original kernel implementation? ] template class DotKernel : public framework::OpKernel { @@ -249,7 +46,7 @@ class DotKernel : public framework::OpKernel { auto pt_out = paddle::experimental::MakePtenDenseTensor(*out); // call new kernel - pten::Dot(dev_ctx, *pt_x.get(), *pt_y.get(), pt_out.get()); + pten::DotKernel(dev_ctx, *pt_x.get(), *pt_y.get(), pt_out.get()); } }; @@ -266,8 +63,17 @@ class DotGradKernel : public framework::OpKernel { if (tensor_dx) tensor_dx->mutable_data(ctx.GetPlace()); if (tensor_dy) tensor_dy->mutable_data(ctx.GetPlace()); - DotGradFunction()(tensor_x, tensor_y, tensor_dout, - tensor_dx, tensor_dy, ctx); + auto pt_x = paddle::experimental::MakePtenDenseTensor(*tensor_x); + auto pt_y = paddle::experimental::MakePtenDenseTensor(*tensor_y); + auto pt_dout = paddle::experimental::MakePtenDenseTensor(*tensor_dout); + auto pt_dx = paddle::experimental::MakePtenDenseTensor(*tensor_dx); + auto pt_dy = paddle::experimental::MakePtenDenseTensor(*tensor_dy); + + auto& dev_ctx = ctx.device_context(); + + // call new kernel + pten::DotGradKernel(dev_ctx, *pt_x, *pt_y, *pt_dout, pt_dx.get(), + pt_dy.get()); } }; diff --git a/paddle/fluid/operators/math/blas.h b/paddle/fluid/operators/math/blas.h index f245bad01aa4c..2be7695e6a8c4 100644 --- a/paddle/fluid/operators/math/blas.h +++ b/paddle/fluid/operators/math/blas.h @@ -225,6 +225,10 @@ class Blas { const framework::Tensor& mat_b, const MatDescriptor& dim_b, T alpha, framework::Tensor* mat_out, T beta) const; + template + void MatMul(const T* mat_a, const MatDescriptor& dim_a, const T* mat_b, + const MatDescriptor& dim_b, T alpha, T* mat_out, T beta) const; + template void VINV(int n, const T* a, T* y) const; diff --git a/paddle/fluid/operators/math/blas_impl.h b/paddle/fluid/operators/math/blas_impl.h index 4bcf3baa64932..be9cf1e3448b6 100644 --- a/paddle/fluid/operators/math/blas_impl.h +++ b/paddle/fluid/operators/math/blas_impl.h @@ -1249,6 +1249,15 @@ void Blas::MatMul(const framework::Tensor &mat_a, const framework::Tensor &mat_b, const MatDescriptor &dim_b, T alpha, framework::Tensor *mat_out, T beta) const { + MatMul(mat_a.data(), dim_a, mat_b.data(), dim_b, alpha, + mat_out->data(), beta); +} + +template +template +void Blas::MatMul(const T *mat_a, const MatDescriptor &dim_a, + const T *mat_b, const MatDescriptor &dim_b, + T alpha, T *mat_out, T beta) const { PADDLE_ENFORCE_EQ( dim_a.width_, dim_b.height_, platform::errors::InvalidArgument( @@ -1261,8 +1270,7 @@ void Blas::MatMul(const framework::Tensor &mat_a, CBLAS_TRANSPOSE transB = !dim_b.trans_ ? CblasNoTrans : CblasTrans; if (dim_a.batch_size_ == 0 && dim_b.batch_size_ == 0) { this->template GEMM(transA, transB, dim_a.height_, dim_b.width_, - dim_a.width_, alpha, mat_a.data(), - mat_b.data(), beta, mat_out->data()); + dim_a.width_, alpha, mat_a, mat_b, beta, mat_out); } else { PADDLE_ENFORCE_EQ( dim_a.batch_size_ == dim_b.batch_size_ || dim_a.batch_size_ == 0 || @@ -1273,8 +1281,8 @@ void Blas::MatMul(const framework::Tensor &mat_a, "But got dim_a.batch_size = %d, dim_b.batch_size = %d.", dim_a.batch_size_, dim_b.batch_size_)); this->template BatchedGEMM( - transA, transB, dim_a.height_, dim_b.width_, dim_a.width_, alpha, - mat_a.data(), mat_b.data(), beta, mat_out->data(), + transA, transB, dim_a.height_, dim_b.width_, dim_a.width_, alpha, mat_a, + mat_b, beta, mat_out, dim_a.batch_size_ == 0 ? dim_b.batch_size_ : dim_a.batch_size_, dim_a.stride_, dim_b.stride_); } diff --git a/paddle/fluid/operators/matmul_v2_op.cc b/paddle/fluid/operators/matmul_v2_op.cc index 5add86f5b3c74..a5eca7b225558 100644 --- a/paddle/fluid/operators/matmul_v2_op.cc +++ b/paddle/fluid/operators/matmul_v2_op.cc @@ -389,6 +389,14 @@ class MatMulV2OpGrad : public framework::OperatorWithKernel { tensor.place(), tensor.layout()); } } + + framework::KernelSignature GetExpectedPtenKernelArgs( + const framework::ExecutionContext& ctx) const override { + return framework::KernelSignature( + "matmul_grad", {"X", "Y", framework::GradVarName("Out")}, + {"trans_x", "trans_y"}, + {framework::GradVarName("X"), framework::GradVarName("Y")}); + } }; template @@ -431,6 +439,13 @@ class MatMulV2OpDoubleGrad : public framework::OperatorWithKernel { context->ShareDim("DOut", "DDOut"); } } + + framework::KernelSignature GetExpectedPtenKernelArgs( + const framework::ExecutionContext& ctx) const override { + return framework::KernelSignature( + "matmul_double_grad", {"X", "Y", "DOut", "DDX", "DDY"}, + {"trans_x", "trans_y"}, {"DX", "DY", "DDOut"}); + } }; template @@ -500,6 +515,15 @@ class MatMulV2OpTripleGrad : public framework::OperatorWithKernel { context->ShareDim("Y", "D_DDY_out"); } } + + framework::KernelSignature GetExpectedPtenKernelArgs( + const framework::ExecutionContext& ctx) const override { + return framework::KernelSignature( + "matmul_triple_grad", + {"X", "Y", "DOut", "DDX", "DDY", "D_DX", "D_DY", "D_DDOut"}, + {"trans_x", "trans_y"}, + {"D_X_out", "D_Y_out", "D_DOut_out", "D_DDX_out", "D_DDY_out"}); + } }; template diff --git a/paddle/fluid/operators/matmul_v2_op.h b/paddle/fluid/operators/matmul_v2_op.h index b257f345eaf36..e93bd212868fd 100644 --- a/paddle/fluid/operators/matmul_v2_op.h +++ b/paddle/fluid/operators/matmul_v2_op.h @@ -28,6 +28,7 @@ limitations under the License. */ // only can include the headers in paddle/pten/api dirs #include "paddle/pten/api/lib/utils/tensor_utils.h" #include "paddle/pten/include/core.h" +#include "paddle/pten/kernels/matmul_grad_kernel.h" #include "paddle/pten/kernels/matmul_kernel.h" #if defined(__NVCC__) || defined(__HIPCC__) @@ -39,333 +40,6 @@ namespace operators { using framework::Tensor; -template -void ReduceSumForMatmulGrad(const Tensor* input, Tensor* output, - const std::vector& reduce_dims, - const paddle::framework::ExecutionContext& ctx) { -#if defined(__NVCC__) || defined(__HIPCC__) - auto stream = ctx.cuda_device_context().stream(); - TensorReduceFunctorImpl>( - *input, output, kps::IdentityFunctor(), reduce_dims, stream); -#else - ReduceKernelFunctor( - input, output, reduce_dims, true, false, ctx) - .template apply(); -#endif -} - -static void GetBroadcastFromDims(const int x_ndim, const std::int64_t* x_dims, - const int y_ndim, const std::int64_t* y_dims, - std::int64_t* x_bd_dims, - std::int64_t* y_bd_dims, - std::int64_t* out_bd_dims) { - const int ndim = (std::max)(x_ndim, y_ndim); - std::fill(x_bd_dims, x_bd_dims + ndim - x_ndim, 1); - std::fill(y_bd_dims, y_bd_dims + ndim - y_ndim, 1); - std::copy(x_dims, x_dims + x_ndim, x_bd_dims + ndim - x_ndim); - std::copy(y_dims, y_dims + y_ndim, y_bd_dims + ndim - y_ndim); - - for (int i = 0; i < ndim; ++i) { - PADDLE_ENFORCE_EQ( - x_bd_dims[i] == y_bd_dims[i] || x_bd_dims[i] <= 1 || y_bd_dims[i] <= 1, - true, - platform::errors::InvalidArgument( - "Input(X) and Input(Y) has error dim." - "X_broadcast's shape[%s] must be equal to Y_broadcast's shape[%s]," - "or X_broadcast's shape[%s] <= 1, or Y_broadcast's shape[%s] <= 1," - "But received X_broadcast's shape[%s] = [%s]" - "received Y_broadcast's shape[%s] = [%s]", - i, i, i, i, i, x_bd_dims[i], i, y_bd_dims[i])); - if (x_bd_dims[i] == 0 || y_bd_dims[i] == 0) { - out_bd_dims[i] = 0; - } else { - out_bd_dims[i] = (std::max)(x_bd_dims[i], y_bd_dims[i]); - } - } -} - -static int64_t GetIndexMessage(const int n, const int64_t* dims, - const int64_t* index) { - int64_t sum = 0; - for (int i = 0; i < n; ++i) { - if (dims[i] > 1) { - sum = sum * dims[i] + index[i]; - } - } - return sum; -} - -static void IndexIncreaseFromDims(const int ndim, const int64_t* dims, - int64_t* index) { - for (int i = ndim - 1; i >= 0; --i) { - ++index[i]; - if (index[i] >= dims[i]) { - index[i] -= dims[i]; - } else { - break; - } - } -} - -template -void MatMulFunction(const Tensor* X, const Tensor* Y, - const std::vector& x_dims, - const std::vector& y_dims, Tensor* Out, - bool trans_x, bool trans_y, - const paddle::framework::ExecutionContext& ctx, - bool flag = false) { - const int x_ndim = x_dims.size(); - const int y_ndim = y_dims.size(); - - // Get data ptr - const T* x_data = X->data(); - const T* y_data = Y->data(); - - if (x_ndim == 1 && y_ndim == 1) { - PADDLE_ENFORCE_EQ( - X->numel(), Y->numel(), - platform::errors::InvalidArgument( - "X's numbers must be equal to Y's numbers," - "when X/Y's dims =1. But received X has [%d] elements," - "received Y has [%d] elements", - X->numel(), Y->numel())); - VLOG(3) << "MatMul's case 1"; - Out->Resize({1}); - Out->mutable_data(ctx.GetPlace()); - auto out_eigen = framework::EigenScalar::From(*Out); - auto x_eigen = framework::EigenVector::Flatten(*X); - auto y_eigen = framework::EigenVector::Flatten(*Y); - - auto& dev = *ctx.template device_context().eigen_device(); - if (flag) { - out_eigen.device(dev) = (x_eigen * y_eigen).sum() + out_eigen; - } else { - out_eigen.device(dev) = (x_eigen * y_eigen).sum(); - } - return; - } - - auto& dev_ctx = ctx.template device_context(); - auto blas = math::GetBlas(dev_ctx); - - if (x_ndim == 1) { - const int N = X->numel(); - if (trans_y) { - PADDLE_ENFORCE_EQ(y_dims[y_ndim - 1], N, - platform::errors::InvalidArgument( - "Input(Y) has error dim." - "Y'dims[%d] must be equal to %d" - "But received Y'dims[%d] is %d", - y_ndim - 1, N, y_ndim - 1, y_dims[y_ndim - 1])); - } else { - PADDLE_ENFORCE_EQ(y_dims[y_ndim - 2], N, - platform::errors::InvalidArgument( - "Input(Y) has error dim." - "Y'dims[%d] must be equal to %d" - "But received Y'dims[%d] is %d", - y_ndim - 2, N, y_ndim - 2, y_dims[y_ndim - 2])); - } - std::vector out_dims(y_ndim - 1); - if (trans_y) { - std::copy_n(y_dims.cbegin(), y_ndim - 1, out_dims.begin()); - } else { - std::copy_n(y_dims.cbegin(), y_ndim - 2, out_dims.begin()); - out_dims.back() = y_dims.back(); - } - Out->Resize(framework::make_ddim(out_dims)); - Out->mutable_data(ctx.GetPlace()); - if (trans_y) { - const int M = Y->numel() / N; - VLOG(3) << "MatMul's case 2"; - blas.GEMV(false, M, N, static_cast(1), y_data, x_data, - static_cast(flag), Out->data()); - } else { - const int M = y_dims[y_ndim - 1]; - const int batch_size = Y->numel() / (M * N); - if (batch_size == 1) { - VLOG(3) << "MatMul's case 3"; - blas.GEMV(true, N, M, static_cast(1), y_data, x_data, - static_cast(flag), Out->data()); - } else { - VLOG(3) << "MatMul's case 4"; - blas.BatchedGEMM(CblasTrans, CblasNoTrans, M, 1, N, static_cast(1), - y_data, x_data, static_cast(flag), Out->data(), - batch_size, M * N, 0); - } - } - return; - } - - if (y_ndim == 1) { - const int N = Y->numel(); - if (trans_x) { - PADDLE_ENFORCE_EQ(x_dims[x_ndim - 2], N, - platform::errors::InvalidArgument( - "Input(X) has error dim." - "X'dims[%d] must be equal to %d" - "But received X'dims[%d] is %d", - x_ndim - 2, N, x_ndim - 2, x_dims[x_ndim - 2])); - } else { - PADDLE_ENFORCE_EQ(x_dims[x_ndim - 1], N, - platform::errors::InvalidArgument( - "Input(X) has error dim." - "X'dims[%d] must be equal to %d" - "But received X'dims[%d] is %d", - x_ndim - 1, N, x_ndim - 1, x_dims[x_ndim - 1])); - } - std::vector out_dims(x_ndim - 1); - if (trans_x) { - std::copy_n(x_dims.cbegin(), x_ndim - 2, out_dims.begin()); - out_dims.back() = x_dims.back(); - } else { - std::copy_n(x_dims.cbegin(), x_ndim - 1, out_dims.begin()); - } - Out->Resize(framework::make_ddim(out_dims)); - Out->mutable_data(ctx.GetPlace()); - - if (trans_x) { - const int M = x_dims[x_ndim - 1]; - const int batch_size = X->numel() / (M * N); - if (batch_size == 1) { - VLOG(3) << "MatMul's case 5"; - blas.GEMV(true, N, M, static_cast(1), x_data, y_data, - static_cast(flag), Out->data()); - } else { - VLOG(3) << "MatMul's case 6"; - blas.BatchedGEMM(CblasTrans, CblasNoTrans, M, 1, N, static_cast(1), - x_data, y_data, static_cast(flag), Out->data(), - batch_size, M * N, 0); - } - } else { - const int M = X->numel() / N; - VLOG(3) << "MatMul's case 7"; - blas.GEMV(false, M, N, static_cast(1), x_data, y_data, - static_cast(flag), Out->data()); - } - return; - } - - const int M = trans_x ? x_dims[x_ndim - 1] : x_dims[x_ndim - 2]; - const int K = trans_x ? x_dims[x_ndim - 2] : x_dims[x_ndim - 1]; - if (trans_y) { - PADDLE_ENFORCE_EQ(y_dims[y_ndim - 1], K, - platform::errors::InvalidArgument( - "Input(Y) has error dim." - "Y'dims[%d] must be equal to %d" - "But received Y'dims[%d] is %d", - y_ndim - 1, K, y_ndim - 1, y_dims[y_ndim - 1])); - } else { - PADDLE_ENFORCE_EQ(y_dims[y_ndim - 2], K, - platform::errors::InvalidArgument( - "Input(Y) has error dim." - "Y'dims[%d] must be equal to %d" - "But received Y'dims[%d] is %d", - y_ndim - 2, K, y_ndim - 2, y_dims[y_ndim - 2])); - } - const int N = trans_y ? y_dims[y_ndim - 2] : y_dims[y_ndim - 1]; - const int ndim = (std::max)(x_ndim, y_ndim); - std::vector x_broadcast_dims(ndim); - std::vector y_broadcast_dims(ndim); - std::vector out_broadcast_dims(ndim); - - GetBroadcastFromDims(x_ndim - 2, x_dims.data(), y_ndim - 2, y_dims.data(), - x_broadcast_dims.data(), y_broadcast_dims.data(), - out_broadcast_dims.data()); - - out_broadcast_dims[ndim - 2] = M; - out_broadcast_dims[ndim - 1] = N; - - Out->Resize(framework::make_ddim(out_broadcast_dims)); - Out->mutable_data(ctx.GetPlace()); - - const int batch_dim = ndim - 2; - // broadcast message - const bool is_broadcast_dims = !std::equal( - x_broadcast_dims.cbegin(), x_broadcast_dims.cbegin() + batch_dim, - y_broadcast_dims.cbegin()); - - const std::int64_t x_batch_size = std::accumulate( - x_broadcast_dims.cbegin(), x_broadcast_dims.cbegin() + batch_dim, 1LL, - std::multiplies()); - const std::int64_t y_batch_size = std::accumulate( - y_broadcast_dims.cbegin(), y_broadcast_dims.cbegin() + batch_dim, 1LL, - std::multiplies()); - const std::int64_t out_batch_size = std::accumulate( - out_broadcast_dims.cbegin(), out_broadcast_dims.cbegin() + batch_dim, 1LL, - std::multiplies()); - if (out_batch_size == 0) return; - if (x_batch_size == 1 && y_batch_size == 1) { - VLOG(3) << "MatMul's case 8"; - blas.GEMM(trans_x ? CblasTrans : CblasNoTrans, - trans_y ? CblasTrans : CblasNoTrans, M, N, K, static_cast(1), - x_data, y_data, static_cast(flag), Out->data()); - } else if (x_batch_size == 1) { - if (M == 1 && trans_y) { - VLOG(3) << "MatMul's case 9"; - blas.GEMV(false, y_batch_size * N, K, static_cast(1), y_data, x_data, - static_cast(flag), Out->data()); - } else { - VLOG(3) << "MatMul's case 10"; - blas.BatchedGEMM(trans_x ? CblasTrans : CblasNoTrans, - trans_y ? CblasTrans : CblasNoTrans, M, N, K, - static_cast(1), x_data, y_data, static_cast(flag), - Out->data(), out_batch_size, 0, K * N); - } - } else if (y_batch_size == 1) { - if (!trans_x) { - VLOG(3) << "MatMul's case 11"; - blas.GEMM(CblasNoTrans, trans_y ? CblasTrans : CblasNoTrans, - x_batch_size * M, N, K, static_cast(1), x_data, y_data, - static_cast(flag), Out->data()); - } else { - VLOG(3) << "MatMul's case 12"; - blas.BatchedGEMM(CblasTrans, trans_y ? CblasTrans : CblasNoTrans, M, N, K, - static_cast(1), x_data, y_data, static_cast(flag), - Out->data(), out_batch_size, M * K, 0); - } - } else if (!is_broadcast_dims) { - VLOG(3) << "MatMul's case 13"; - blas.BatchedGEMM(trans_x ? CblasTrans : CblasNoTrans, - trans_y ? CblasTrans : CblasNoTrans, M, N, K, - static_cast(1), x_data, y_data, static_cast(flag), - Out->data(), out_batch_size, M * K, K * N); - } else { - // in the case, can't use stridedgemm - std::vector x_ptr(out_batch_size); - std::vector y_ptr(out_batch_size); - std::vector out_ptr(out_batch_size); - std::vector index(batch_dim, 0); - for (std::int64_t i = 0; i < out_batch_size; ++i) { - // using the index to get offset - const std::int64_t x_index = - GetIndexMessage(batch_dim, x_broadcast_dims.data(), index.data()); - const std::int64_t y_index = - GetIndexMessage(batch_dim, y_broadcast_dims.data(), index.data()); - - x_ptr[i] = x_data + x_index * M * K; - y_ptr[i] = y_data + y_index * K * N; - out_ptr[i] = Out->data() + i * M * N; - IndexIncreaseFromDims(batch_dim, out_broadcast_dims.data(), index.data()); - } - VLOG(3) << "MatMul's case 14"; - blas.BatchedGEMM(trans_x ? CblasTrans : CblasNoTrans, - trans_y ? CblasTrans : CblasNoTrans, M, N, K, - static_cast(1), x_ptr.data(), y_ptr.data(), - static_cast(flag), out_ptr.data(), out_batch_size); - } -} - -template -void MatMulFunction(const Tensor* X, const Tensor* Y, Tensor* Out, bool trans_x, - bool trans_y, - const paddle::framework::ExecutionContext& ctx, - bool flag = false) { - const std::vector x_dims = vectorize(X->dims()); - const std::vector y_dims = vectorize(Y->dims()); - MatMulFunction(X, Y, x_dims, y_dims, Out, trans_x, trans_y, - ctx, flag); -} - template class MatMulV2Kernel : public framework::OpKernel { public: @@ -400,26 +74,6 @@ static framework::Tensor FoldInitDims(const framework::Tensor& input) { return output; } -// Reshape a rank-3 tensor from P x M x N to M x (P * N). -// (Warning: This requires transposing data and writes into new memory.) -// Identity op if the tensor is not of rank 3. -template -static framework::Tensor FoldHeadAndLastDims(const DeviceContext& context, - const framework::Tensor& input) { - auto in_dims = input.dims(); - if (in_dims.size() != 3) { - return input; - } - framework::Tensor output; - output.Resize({in_dims[1], in_dims[0], in_dims[2]}); - output.mutable_data(context.GetPlace()); - std::vector axis = {1, 0, 2}; - math::Transpose trans; - trans(context, input, &output, axis); - output.Resize({in_dims[1], in_dims[0] * in_dims[2]}); - return output; -} - /** * Get row matrix shape from a vector shape. If the rank of x_dim > 1, the * original x_dim is returned. @@ -482,1000 +136,45 @@ static void ReshapeXYOutIntoMatrixSequence(framework::Tensor* x, ReshapeTensorIntoMatrixSequence(y, mat_dim_y); } -template -struct ConjHelper { - explicit ConjHelper(const framework::ExecutionContext& ctx) : ctx_(ctx) {} - HOSTDEVICE void operator()(framework::Tensor& src, framework::Tensor& dst) { - dst.Resize(src.dims()); - dst.set_layout(src.layout()); - dst.ShareDataWith(src); - return; - } - - const framework::ExecutionContext& ctx_; -}; - -template -struct ConjHelper> { - explicit ConjHelper(const framework::ExecutionContext& ctx) : ctx_(ctx) {} - - HOSTDEVICE void operator()(framework::Tensor& src, framework::Tensor& dst) { - dst.Resize(src.dims()); - auto* src_data = src.data>(); - auto* dst_data = dst.mutable_data>( - ctx_.GetPlace(), - size_t(src.numel() * sizeof(paddle::platform::complex))); - - platform::ForRange for_range( - ctx_.template device_context(), src.numel()); - math::ConjFunctor> functor( - src_data, src.numel(), dst_data); - for_range(functor); - return; - } - const framework::ExecutionContext& ctx_; -}; - -template -struct ConjHelper> { - explicit ConjHelper(const framework::ExecutionContext& ctx) : ctx_(ctx) {} - - HOSTDEVICE void operator()(framework::Tensor& src, framework::Tensor& dst) { - dst.Resize(src.dims()); - auto* src_data = src.data>(); - auto* dst_data = dst.mutable_data>( - ctx_.GetPlace(), - size_t(src.numel() * sizeof(paddle::platform::complex))); - - platform::ForRange for_range( - ctx_.template device_context(), src.numel()); - math::ConjFunctor> functor( - src_data, src.numel(), dst_data); - for_range(functor); - return; - } - const framework::ExecutionContext& ctx_; -}; - -template -struct DotDoubleGradFunction { - void operator()(const Tensor* tensor_x, const Tensor* tensor_y, - Tensor* tensor_dx, Tensor* tensor_dy, - const Tensor* tensor_dout, const Tensor* tensor_ddx, - const Tensor* tensor_ddy, Tensor* tensor_ddout, - const paddle::framework::ExecutionContext& ctx); -}; - -template -struct DotDoubleGradFunction> { - void operator()(const Tensor* tensor_x, const Tensor* tensor_y, - Tensor* tensor_dx, Tensor* tensor_dy, - const Tensor* tensor_dout, const Tensor* tensor_ddx, - const Tensor* tensor_ddy, Tensor* tensor_ddout, - const paddle::framework::ExecutionContext& ctx) { -#if defined(__NVCC__) || defined(__HIPCC__) - if (1 == tensor_dout->dims().size()) { - framework::Tensor tensor_dout_help; - auto& dev_raw = ctx.template device_context(); - auto& dev = *dev_raw.eigen_device(); - if (tensor_dx || tensor_dy) { - tensor_dout_help.Resize(tensor_dout->dims()); - tensor_dout_help.mutable_data(ctx.GetPlace()); - paddle::platform::ForRange for_range( - dev_raw, tensor_dout->numel()); - math::ConjFunctor functor(tensor_dout->data(), - tensor_dout->numel(), - tensor_dout_help.data()); - for_range(functor); - } - if (tensor_dx) { - auto ddy = framework::EigenVector::Flatten(*tensor_ddy); - Eigen::DSizes size(tensor_ddy->numel()); - auto dx = framework::EigenVector::Flatten(*tensor_dx); - auto dout = framework::EigenVector::Flatten(tensor_dout_help); - dx.device(dev) = ddy * dout.broadcast(size); - } - - if (tensor_dy) { - auto ddx = framework::EigenVector::Flatten(*tensor_ddx); - Eigen::DSizes size(tensor_ddx->numel()); - auto dy = framework::EigenVector::Flatten(*tensor_dy); - auto dout = framework::EigenVector::Flatten(tensor_dout_help); - dy.device(dev) = ddx * dout.broadcast(size); - } - - if (tensor_ddout) { - framework::Tensor tensor_x_help, tensor_y_help; - tensor_x_help.Resize(tensor_x->dims()); - tensor_x_help.mutable_data(ctx.GetPlace()); - tensor_y_help.Resize(tensor_y->dims()); - tensor_y_help.mutable_data(ctx.GetPlace()); - - auto& dev_raw = ctx.template device_context(); - auto& dev = *dev_raw.eigen_device(); - paddle::platform::ForRange for_range(dev_raw, - tensor_x->numel()); - math::ConjFunctor functor_x(tensor_x->data(), tensor_x->numel(), - tensor_x_help.data()); - for_range(functor_x); - math::ConjFunctor functor_y(tensor_y->data(), tensor_y->numel(), - tensor_y_help.data()); - for_range(functor_y); - auto x = framework::EigenVector::Flatten(tensor_x_help); - auto y = framework::EigenVector::Flatten(tensor_y_help); - auto ddx = framework::EigenVector::Flatten(*tensor_ddx); - auto ddy = framework::EigenVector::Flatten(*tensor_ddy); - auto ddout = framework::EigenVector::Flatten(*tensor_ddout); - ddout.device(dev) = (x * ddy + y * ddx).sum(); - } - } -#else - const auto* data_dout = tensor_dout->data(); - - if (tensor_dx) { - auto* data_dx = tensor_dx->mutable_data(ctx.GetPlace()); - const auto* data_ddy = tensor_ddy->data(); - const framework::DDim& dim = tensor_dx->dims(); - size_t N = static_cast(framework::product(dim)); - - auto step = dim[dim.size() - 1]; - - int s = -1; - for (size_t i = 0; i < N; ++i) { - if (0 == i % step) ++s; - data_dx[i] = T(data_dout[s].real, -data_dout[s].imag) * data_ddy[i]; - } - } - - if (tensor_dy) { - auto* data_dy = tensor_dy->mutable_data(ctx.GetPlace()); - const auto* data_ddx = tensor_ddx->data(); - const framework::DDim& dim = tensor_dy->dims(); - size_t N = static_cast(framework::product(dim)); - - auto step = dim[dim.size() - 1]; - - int s = -1; - for (size_t i = 0; i < N; ++i) { - if (0 == i % step) ++s; - data_dy[i] = T(data_dout[s].real, -data_dout[s].imag) * data_ddx[i]; - } - } - - if (tensor_ddout) { - auto* data_ddout = tensor_ddout->mutable_data(ctx.GetPlace()); - auto* data_x = tensor_x->data(); - auto* data_y = tensor_y->data(); - auto* data_ddx = tensor_ddx->data(); - auto* data_ddy = tensor_ddy->data(); - - const framework::DDim& dim = tensor_dy->dims(); - size_t N = static_cast(framework::product(dim)); - auto step = dim[dim.size() - 1]; - int s = -1; - bool new_s = false; - - for (size_t i = 0; i < N; ++i) { - if (0 == i % step) { - ++s; - new_s = true; - } - if (new_s) { - data_ddout[s] = T(data_x[i].real, -data_x[i].imag) * data_ddy[i] + - T(data_y[i].real, -data_y[i].imag) * data_ddx[i]; - } else { - data_ddout[s] += T(data_x[i].real, -data_x[i].imag) * data_ddy[i] + - T(data_y[i].real, -data_y[i].imag) * data_ddx[i]; - } - new_s = false; - } - } -#endif - } -}; - -template -struct DotDoubleGradFunction> { - void operator()(const Tensor* tensor_x, const Tensor* tensor_y, - Tensor* tensor_dx, Tensor* tensor_dy, - const Tensor* tensor_dout, const Tensor* tensor_ddx, - const Tensor* tensor_ddy, Tensor* tensor_ddout, - const paddle::framework::ExecutionContext& ctx) { -#if defined(__NVCC__) || defined(__HIPCC__) - if (1 == tensor_dout->dims().size()) { - auto& dev_raw = ctx.template device_context(); - auto& dev = *dev_raw.eigen_device(); - auto dout = framework::EigenVector::Flatten(*tensor_dout); - if (tensor_dx) { - tensor_dx->mutable_data(ctx.GetPlace()); - auto ddy = framework::EigenVector::Flatten(*tensor_ddy); - Eigen::DSizes size(tensor_ddy->numel()); - auto dx = framework::EigenVector::Flatten(*tensor_dx); - dx.device(dev) = ddy * dout.broadcast(size); - } - - if (tensor_dy) { - tensor_dy->mutable_data(ctx.GetPlace()); - auto ddx = framework::EigenVector::Flatten(*tensor_ddx); - Eigen::DSizes size(tensor_ddx->numel()); - - auto dy = framework::EigenVector::Flatten(*tensor_dy); - dy.device(dev) = ddx * dout.broadcast(size); - } - - if (tensor_ddout) { - tensor_ddout->mutable_data(ctx.GetPlace()); - auto x = framework::EigenVector::Flatten(*tensor_x); - auto y = framework::EigenVector::Flatten(*tensor_y); - auto ddx = framework::EigenVector::Flatten(*tensor_ddx); - auto ddy = framework::EigenVector::Flatten(*tensor_ddy); - auto ddout = framework::EigenVector::Flatten(*tensor_ddout); - ddout.device(dev) = (x * ddy + y * ddx).sum(); - } - } -#else - const auto* data_dout = tensor_dout->data(); - - if (tensor_dx) { - auto* data_dx = tensor_dx->mutable_data(ctx.GetPlace()); - const auto* data_ddy = tensor_ddy->data(); - const framework::DDim& dim = tensor_dx->dims(); - size_t N = static_cast(framework::product(dim)); - - auto step = dim[dim.size() - 1]; - - int s = -1; - for (size_t i = 0; i < N; ++i) { - if (0 == i % step) ++s; - data_dx[i] = data_dout[s] * data_ddy[i]; - } - } - - if (tensor_dy) { - auto* data_dy = tensor_dy->mutable_data(ctx.GetPlace()); - const auto* data_ddx = tensor_ddx->data(); - const framework::DDim& dim = tensor_dy->dims(); - size_t N = static_cast(framework::product(dim)); - - auto step = dim[dim.size() - 1]; - - int s = -1; - for (size_t i = 0; i < N; ++i) { - if (0 == i % step) ++s; - data_dy[i] = data_dout[s] * data_ddx[i]; - } - } - - if (tensor_ddout) { - auto* data_ddout = tensor_ddout->mutable_data(ctx.GetPlace()); - auto* data_x = tensor_x->data(); - auto* data_y = tensor_y->data(); - auto* data_ddx = tensor_ddx->data(); - auto* data_ddy = tensor_ddy->data(); - - const framework::DDim& dim = tensor_dy->dims(); - size_t N = static_cast(framework::product(dim)); - auto step = dim[dim.size() - 1]; - int s = -1; - bool new_s = false; - - for (size_t i = 0; i < N; ++i) { - if (0 == i % step) { - ++s; - new_s = true; - } - if (new_s) { - data_ddout[s] = data_x[i] * data_ddy[i] + data_y[i] * data_ddx[i]; - } else { - data_ddout[s] += data_x[i] * data_ddy[i] + data_y[i] * data_ddx[i]; - } - new_s = false; - } - } -#endif - } -}; - -template -struct DotTripleGradFunction { - void operator()(const Tensor* in_tensor_x, const Tensor* in_tensor_y, - const Tensor* in_tensor_ddx, const Tensor* in_tensor_ddy, - const Tensor* in_tensor_d_dx, const Tensor* in_tensor_d_dy, - const Tensor* in_tensor_dout, const Tensor* in_tensor_d_ddout, - Tensor* out_tensor_d_x, Tensor* out_tensor_d_y, - Tensor* out_tensor_d_dout, Tensor* out_tensor_d_ddx, - Tensor* out_tensor_d_ddy, - const paddle::framework::ExecutionContext& ctx); -}; - -// TODO(wuweilong): enable this function when the unittests framewark for multi -// grad is ok (dtype: complex64 or complex128). -template -struct DotTripleGradFunction> { - void operator()(const Tensor* in_tensor_x, const Tensor* in_tensor_y, - const Tensor* in_tensor_ddx, const Tensor* in_tensor_ddy, - const Tensor* in_tensor_d_dx, const Tensor* in_tensor_d_dy, - const Tensor* in_tensor_dout, const Tensor* in_tensor_d_ddout, - Tensor* out_tensor_d_x, Tensor* out_tensor_d_y, - Tensor* out_tensor_d_dout, Tensor* out_tensor_d_ddx, - Tensor* out_tensor_d_ddy, - const paddle::framework::ExecutionContext& ctx) { -#if defined(__NVCC__) || defined(__HIPCC__) - if (1 == in_tensor_d_ddout->dims().size()) { - framework::Tensor in_tensor_d_ddout_help; - auto& dev_raw = ctx.template device_context(); - auto& dev = *dev_raw.eigen_device(); - if (out_tensor_d_x || out_tensor_d_y) { - in_tensor_d_ddout_help.Resize(in_tensor_d_ddout->dims()); - in_tensor_d_ddout_help.mutable_data(ctx.GetPlace()); - paddle::platform::ForRange for_range( - dev_raw, in_tensor_d_ddout->numel()); - math::ConjFunctor functor(in_tensor_d_ddout->data(), - in_tensor_d_ddout->numel(), - in_tensor_d_ddout_help.data()); - for_range(functor); - } - if (out_tensor_d_x) { - auto ddy = framework::EigenVector::Flatten(*in_tensor_ddy); - Eigen::DSizes size(in_tensor_ddy->numel()); - auto d_x = framework::EigenVector::Flatten(*out_tensor_d_x); - auto d_ddout = - framework::EigenVector::Flatten(in_tensor_d_ddout_help); - d_x.device(dev) = ddy * d_ddout.broadcast(size); - } - - if (out_tensor_d_y) { - auto ddx = framework::EigenVector::Flatten(*in_tensor_ddx); - Eigen::DSizes size(in_tensor_ddx->numel()); - auto d_y = framework::EigenVector::Flatten(*out_tensor_d_y); - auto d_ddout = - framework::EigenVector::Flatten(in_tensor_d_ddout_help); - d_y.device(dev) = ddx * d_ddout.broadcast(size); - } - - if (out_tensor_d_dout) { - framework::Tensor in_tensor_ddx_help, in_tensor_ddy_help; - in_tensor_ddx_help.Resize(in_tensor_ddx->dims()); - in_tensor_ddx_help.mutable_data(ctx.GetPlace()); - in_tensor_ddy_help.Resize(in_tensor_ddy->dims()); - in_tensor_ddy_help.mutable_data(ctx.GetPlace()); - - auto& dev_raw = ctx.template device_context(); - auto& dev = *dev_raw.eigen_device(); - paddle::platform::ForRange for_range( - dev_raw, in_tensor_ddx->numel()); - math::ConjFunctor functor_ddx(in_tensor_ddx->data(), - in_tensor_ddx->numel(), - in_tensor_ddx_help.data()); - for_range(functor_ddx); - math::ConjFunctor functor_ddy(in_tensor_ddy->data(), - in_tensor_ddy->numel(), - in_tensor_ddy_help.data()); - for_range(functor_ddy); - auto ddx = framework::EigenVector::Flatten(in_tensor_ddx_help); - auto ddy = framework::EigenVector::Flatten(in_tensor_ddy_help); - auto d_dx = framework::EigenVector::Flatten(*in_tensor_d_dx); - auto d_dy = framework::EigenVector::Flatten(*in_tensor_d_dy); - auto d_dout = framework::EigenVector::Flatten(*out_tensor_d_dout); - d_dout.device(dev) = (ddx * d_dy + ddy * d_dx).sum(); - } - if (out_tensor_d_ddx) { - framework::Tensor in_tensor_dout_help, in_tensor_y_help; - in_tensor_dout_help.Resize(in_tensor_dout->dims()); - in_tensor_dout_help.mutable_data(ctx.GetPlace()); - in_tensor_y_help.Resize(in_tensor_y->dims()); - in_tensor_y_help.mutable_data(ctx.GetPlace()); - - auto& dev_raw = ctx.template device_context(); - auto& dev = *dev_raw.eigen_device(); - paddle::platform::ForRange for_range( - dev_raw, in_tensor_dout->numel()); - math::ConjFunctor functor_dout(in_tensor_dout->data(), - in_tensor_dout->numel(), - in_tensor_dout_help.data()); - for_range(functor_dout); - math::ConjFunctor functor_y(in_tensor_y->data(), - in_tensor_y->numel(), - in_tensor_y_help.data()); - for_range(functor_y); - auto dout = framework::EigenVector::Flatten(in_tensor_dout_help); - auto y = framework::EigenVector::Flatten(in_tensor_y_help); - auto d_ddout = framework::EigenVector::Flatten(*in_tensor_d_ddout); - auto d_dy = framework::EigenVector::Flatten(*in_tensor_d_dy); - auto d_ddx = framework::EigenVector::Flatten(*out_tensor_d_ddx); - Eigen::DSizes size(in_tensor_y->numel()); - d_ddx.device(dev) = - (dout.broadcast(size) * d_dy + y * d_ddout.broadcast(size)); - } - if (out_tensor_d_ddy) { - framework::Tensor in_tensor_dout_help, in_tensor_x_help; - in_tensor_dout_help.Resize(in_tensor_dout->dims()); - in_tensor_dout_help.mutable_data(ctx.GetPlace()); - in_tensor_x_help.Resize(in_tensor_x->dims()); - in_tensor_x_help.mutable_data(ctx.GetPlace()); - - auto& dev_raw = ctx.template device_context(); - auto& dev = *dev_raw.eigen_device(); - paddle::platform::ForRange for_range( - dev_raw, in_tensor_dout->numel()); - math::ConjFunctor functor_dout(in_tensor_dout->data(), - in_tensor_dout->numel(), - in_tensor_dout_help.data()); - for_range(functor_dout); - math::ConjFunctor functor_x(in_tensor_x->data(), - in_tensor_x->numel(), - in_tensor_x_help.data()); - for_range(functor_x); - auto dout = framework::EigenVector::Flatten(in_tensor_dout_help); - auto x = framework::EigenVector::Flatten(in_tensor_x_help); - auto d_ddout = framework::EigenVector::Flatten(*in_tensor_d_ddout); - auto d_dx = framework::EigenVector::Flatten(*in_tensor_d_dx); - auto d_ddy = framework::EigenVector::Flatten(*out_tensor_d_ddy); - Eigen::DSizes size(in_tensor_x->numel()); - d_ddy.device(dev) = - (dout.broadcast(size) * d_dx + x * d_ddout.broadcast(size)); - } - } -#else - const auto* data_d_ddout = in_tensor_d_ddout->data(); - - if (out_tensor_d_x) { - auto* data_d_x = out_tensor_d_x->mutable_data(ctx.GetPlace()); - const auto* data_ddy = in_tensor_ddy->data(); - - const framework::DDim& dim = out_tensor_d_x->dims(); - size_t N = static_cast(framework::product(dim)); - auto step = dim[dim.size() - 1]; - int s = -1; - - for (size_t i = 0; i < N; ++i) { - if (0 == i % step) ++s; - data_d_x[i] = T(data_ddy[i].real, -data_ddy[i].imag) * data_d_ddout[s]; - } - } - - if (out_tensor_d_y) { - auto* data_d_y = out_tensor_d_y->mutable_data(ctx.GetPlace()); - const auto* data_ddx = in_tensor_ddx->data(); - - const framework::DDim& dim = out_tensor_d_y->dims(); - size_t N = static_cast(framework::product(dim)); - auto step = dim[dim.size() - 1]; - int s = -1; - - for (size_t i = 0; i < N; ++i) { - if (0 == i % step) ++s; - data_d_y[i] = T(data_ddx[i].real, -data_ddx[i].imag) * data_d_ddout[s]; - } - } - - if (out_tensor_d_dout) { - auto* data_d_dout = out_tensor_d_dout->mutable_data(ctx.GetPlace()); - auto* data_ddx = in_tensor_ddx->data(); - auto* data_ddy = in_tensor_ddy->data(); - auto* data_d_dx = in_tensor_d_dx->data(); - auto* data_d_dy = in_tensor_d_dy->data(); - - const framework::DDim& dim = out_tensor_d_dout->dims(); - size_t N = static_cast(framework::product(dim)); - auto step = dim[dim.size() - 1]; - int s = -1; - bool new_s = false; - - for (size_t i = 0; i < N; ++i) { - if (0 == i % step) { - ++s; - new_s = true; - } - if (new_s) { - data_d_dout[s] = - T(data_ddy[i].real, -data_ddy[i].imag) * data_d_dx[i] + - T(data_ddx[i].real, -data_ddx[i].imag) * data_d_dy[i]; - } else { - data_d_dout[s] += - T(data_ddy[i].real, -data_ddy[i].imag) * data_d_dx[i] + - T(data_ddx[i].real, -data_ddx[i].imag) * data_d_dy[i]; - } - new_s = false; - } - } - - if (out_tensor_d_ddx) { - auto* data_d_ddx = out_tensor_d_ddx->mutable_data(ctx.GetPlace()); - auto* data_dout = in_tensor_dout->data(); - auto* data_d_dy = in_tensor_d_dy->data(); - auto* data_y = in_tensor_y->data(); - auto* data_d_ddout = in_tensor_d_ddout->data(); - - const framework::DDim& dim = out_tensor_d_ddx->dims(); - size_t N = static_cast(framework::product(dim)); - auto step = dim[dim.size() - 1]; - int s = -1; - - for (size_t i = 0; i < N; ++i) { - if (0 == i % step) ++s; - data_d_ddx[i] = - T(data_dout[s].real, -data_dout[s].imag) * data_d_dy[i] + - T(data_y[i].real, -data_y[i].imag) * data_d_ddout[s]; - } - } - - if (out_tensor_d_ddy) { - auto* data_d_ddy = out_tensor_d_ddy->mutable_data(ctx.GetPlace()); - auto* data_dout = in_tensor_dout->data(); - auto* data_d_dx = in_tensor_d_dx->data(); - auto* data_x = in_tensor_x->data(); - auto* data_d_ddout = in_tensor_d_ddout->data(); - - const framework::DDim& dim = out_tensor_d_ddy->dims(); - size_t N = static_cast(framework::product(dim)); - auto step = dim[dim.size() - 1]; - int s = -1; - - for (size_t i = 0; i < N; ++i) { - if (0 == i % step) ++s; - data_d_ddy[i] = - T(data_dout[s].real, -data_dout[s].imag) * data_d_dx[i] + - T(data_x[i].real, -data_x[i].imag) * data_d_ddout[s]; - } - } -#endif - } -}; - -template -struct DotTripleGradFunction> { - void operator()(const Tensor* in_tensor_x, const Tensor* in_tensor_y, - const Tensor* in_tensor_ddx, const Tensor* in_tensor_ddy, - const Tensor* in_tensor_d_dx, const Tensor* in_tensor_d_dy, - const Tensor* in_tensor_dout, const Tensor* in_tensor_d_ddout, - Tensor* out_tensor_d_x, Tensor* out_tensor_d_y, - Tensor* out_tensor_d_dout, Tensor* out_tensor_d_ddx, - Tensor* out_tensor_d_ddy, - const paddle::framework::ExecutionContext& ctx) { -#if defined(__NVCC__) || defined(__HIPCC__) - if (1 == in_tensor_d_ddout->dims().size()) { - auto& dev_raw = ctx.template device_context(); - auto& dev = *dev_raw.eigen_device(); - auto d_ddout = framework::EigenVector::Flatten(*in_tensor_d_ddout); - if (out_tensor_d_x) { - out_tensor_d_x->mutable_data(ctx.GetPlace()); - auto ddy = framework::EigenVector::Flatten(*in_tensor_ddy); - Eigen::DSizes size(in_tensor_ddy->numel()); - auto d_x = framework::EigenVector::Flatten(*out_tensor_d_x); - d_x.device(dev) = ddy * d_ddout.broadcast(size); - } - - if (out_tensor_d_y) { - out_tensor_d_y->mutable_data(ctx.GetPlace()); - auto ddx = framework::EigenVector::Flatten(*in_tensor_ddx); - Eigen::DSizes size(in_tensor_ddx->numel()); - - auto d_y = framework::EigenVector::Flatten(*out_tensor_d_y); - d_y.device(dev) = ddx * d_ddout.broadcast(size); - } - - if (out_tensor_d_dout) { - out_tensor_d_dout->mutable_data(ctx.GetPlace()); - auto ddx = framework::EigenVector::Flatten(*in_tensor_ddx); - auto ddy = framework::EigenVector::Flatten(*in_tensor_ddy); - auto d_dx = framework::EigenVector::Flatten(*in_tensor_d_dx); - auto d_dy = framework::EigenVector::Flatten(*in_tensor_d_dy); - auto d_dout = framework::EigenVector::Flatten(*out_tensor_d_dout); - d_dout.device(dev) = (ddx * d_dy + ddy * d_dx).sum(); - } - - if (out_tensor_d_ddx) { - out_tensor_d_ddx->mutable_data(ctx.GetPlace()); - auto dout = framework::EigenVector::Flatten(*in_tensor_dout); - auto y = framework::EigenVector::Flatten(*in_tensor_y); - auto d_ddout = framework::EigenVector::Flatten(*in_tensor_d_ddout); - auto d_dy = framework::EigenVector::Flatten(*in_tensor_d_dy); - auto d_ddx = framework::EigenVector::Flatten(*out_tensor_d_ddx); - Eigen::DSizes size(in_tensor_y->numel()); - d_ddx.device(dev) = - (dout.broadcast(size) * d_dy + y * d_ddout.broadcast(size)); - } - - if (out_tensor_d_ddy) { - out_tensor_d_ddy->mutable_data(ctx.GetPlace()); - auto dout = framework::EigenVector::Flatten(*in_tensor_dout); - auto x = framework::EigenVector::Flatten(*in_tensor_x); - auto d_ddout = framework::EigenVector::Flatten(*in_tensor_d_ddout); - auto d_dx = framework::EigenVector::Flatten(*in_tensor_d_dx); - auto d_ddy = framework::EigenVector::Flatten(*out_tensor_d_ddy); - Eigen::DSizes size(in_tensor_x->numel()); - d_ddy.device(dev) = - (dout.broadcast(size) * d_dx + x * d_ddout.broadcast(size)); - } - } -#else - const auto* data_d_ddout = in_tensor_d_ddout->data(); - - if (out_tensor_d_x) { - auto* data_d_x = out_tensor_d_x->mutable_data(ctx.GetPlace()); - const auto* data_ddy = in_tensor_ddy->data(); - - const framework::DDim& dim = out_tensor_d_x->dims(); - size_t N = static_cast(framework::product(dim)); - auto step = dim[dim.size() - 1]; - int s = -1; - - for (size_t i = 0; i < N; ++i) { - if (0 == i % step) ++s; - data_d_x[i] = data_ddy[i] * data_d_ddout[s]; - } - } - - if (out_tensor_d_y) { - auto* data_d_y = out_tensor_d_y->mutable_data(ctx.GetPlace()); - const auto* data_ddx = in_tensor_ddx->data(); - - const framework::DDim& dim = out_tensor_d_y->dims(); - size_t N = static_cast(framework::product(dim)); - auto step = dim[dim.size() - 1]; - int s = -1; - - for (size_t i = 0; i < N; ++i) { - if (0 == i % step) ++s; - data_d_y[i] = data_ddx[i] * data_d_ddout[s]; - } - } - - if (out_tensor_d_dout) { - auto* data_d_dout = out_tensor_d_dout->mutable_data(ctx.GetPlace()); - auto* data_ddx = in_tensor_ddx->data(); - auto* data_ddy = in_tensor_ddy->data(); - auto* data_d_dx = in_tensor_d_dx->data(); - auto* data_d_dy = in_tensor_d_dy->data(); - - const framework::DDim& dim = in_tensor_ddx->dims(); - size_t N = static_cast(framework::product(dim)); - auto step = dim[dim.size() - 1]; - int s = -1; - bool new_s = false; - for (size_t i = 0; i < N; ++i) { - if (0 == i % step) { - ++s; - new_s = true; - } - if (new_s) { - data_d_dout[s] = - data_ddy[i] * data_d_dx[i] + data_ddx[i] * data_d_dy[i]; - } else { - data_d_dout[s] += - data_ddy[i] * data_d_dx[i] + data_ddx[i] * data_d_dy[i]; - } - new_s = false; - } - } - - if (out_tensor_d_ddx) { - auto* data_d_ddx = out_tensor_d_ddx->mutable_data(ctx.GetPlace()); - auto* data_dout = in_tensor_dout->data(); - auto* data_d_dy = in_tensor_d_dy->data(); - auto* data_y = in_tensor_y->data(); - auto* data_d_ddout = in_tensor_d_ddout->data(); - - const framework::DDim& dim = out_tensor_d_ddx->dims(); - size_t N = static_cast(framework::product(dim)); - auto step = dim[dim.size() - 1]; - int s = -1; - - for (size_t i = 0; i < N; ++i) { - if (0 == i % step) ++s; - data_d_ddx[i] = - data_dout[s] * data_d_dy[i] + data_y[i] * data_d_ddout[s]; - } - } - - if (out_tensor_d_ddy) { - auto* data_d_ddy = out_tensor_d_ddy->mutable_data(ctx.GetPlace()); - auto* data_dout = in_tensor_dout->data(); - auto* data_d_dx = in_tensor_d_dx->data(); - auto* data_x = in_tensor_x->data(); - auto* data_d_ddout = in_tensor_d_ddout->data(); - - const framework::DDim& dim = out_tensor_d_ddy->dims(); - size_t N = static_cast(framework::product(dim)); - auto step = dim[dim.size() - 1]; - int s = -1; - - for (size_t i = 0; i < N; ++i) { - if (0 == i % step) ++s; - data_d_ddy[i] = - data_dout[s] * data_d_dx[i] + data_x[i] * data_d_ddout[s]; - } - } -#endif - } -}; - template class MatMulV2GradKernel : public framework::OpKernel { public: - void MatMul(const framework::ExecutionContext& context, - const framework::Tensor& a, bool trans_a, - const framework::Tensor& b, bool trans_b, - framework::Tensor* out) const { - out->mutable_data(context.GetPlace()); - auto blas = math::GetBlas(context); - auto mat_dim_a = math::CreateMatrixDescriptor(a.dims(), 0, trans_a); - auto mat_dim_b = math::CreateMatrixDescriptor(b.dims(), 0, trans_b); - if (a.dims().size() == 3 && b.dims().size() <= 2) { - // the transpose_X must be false, if is true, the transpose cost much time - if (!trans_a) { - mat_dim_a.height_ *= mat_dim_a.batch_size_; - mat_dim_a.batch_size_ = 0; - } - } - blas.MatMul(a, mat_dim_a, b, mat_dim_b, static_cast(1), out, - static_cast(0)); - } - - void CalcInputGrad(const framework::ExecutionContext& context, - const framework::Tensor& a, bool trans_a, - bool is_fold_init_dims_a, const framework::Tensor& b, - bool trans_b, bool is_fold_init_dims_b, - framework::Tensor* out) const { - if (out == nullptr) return; - bool need_combine = (a.dims().size() == 3 || b.dims().size() == 3) && - out->dims().size() == 2; - if (!need_combine) { - MatMul(context, a, trans_a, b, trans_b, out); - } else { - auto& ctx = context.template device_context(); - MatMul(context, is_fold_init_dims_a - ? FoldInitDims(a) - : FoldHeadAndLastDims(ctx, a), - trans_a, is_fold_init_dims_b - ? FoldInitDims(b) - : FoldHeadAndLastDims(ctx, b), - trans_b, out); - } - } - void Compute(const framework::ExecutionContext& ctx) const override { bool transpose_x = ctx.Attr("trans_x"); bool transpose_y = ctx.Attr("trans_y"); - auto x = *ctx.Input("X"); - auto y = *ctx.Input("Y"); - auto dout = *ctx.Input(framework::GradVarName("Out")); - - framework::Tensor y_conj(y.type()); - framework::Tensor x_conj(y.type()); - - // get dims - std::vector x_dims = vectorize(x.dims()); - std::vector y_dims = vectorize(y.dims()); - std::vector dout_dims = vectorize(dout.dims()); - - int x_ndim = x_dims.size(); - int y_ndim = y_dims.size(); - int ndim = dout_dims.size(); + auto* x = ctx.Input("X"); + auto* y = ctx.Input("Y"); + auto* dout = ctx.Input(framework::GradVarName("Out")); auto* dx = ctx.Output(framework::GradVarName("X")); auto* dy = ctx.Output(framework::GradVarName("Y")); - // Case1 : x's or y's dim = 1 - if (x_ndim == 1 && y_ndim == 1) { - if (dx) dx->mutable_data(ctx.GetPlace()); - if (dy) dy->mutable_data(ctx.GetPlace()); - if (dout.numel() == 1) { - DotGradFunction()(&x, &y, &dout, dx, dy, ctx); - return; - } - } - - bool is_broadcast = true; - if (x_ndim <= 2 || y_ndim <= 2) { - is_broadcast = false; - } else if (x_ndim != y_ndim) { - is_broadcast = true; - } else { - is_broadcast = !std::equal(x_dims.cbegin(), x_dims.cbegin() + x_ndim - 2, - y_dims.cbegin()); - } - - // Case2: no broadcast or no batch size, it aims to speed and it is same as - // matmul in old version. - if (!is_broadcast) { - ReshapeXYOutIntoMatrixSequence(&x, &y, &dout, transpose_x, transpose_y); - framework::DDim dx_dims; - if (dx) { - dx_dims = dx->dims(); - if (dx_dims != x.dims()) { - dx->Resize(x.dims()); - } + if (dx) dx->mutable_data(ctx.GetPlace()); + if (dy) dy->mutable_data(ctx.GetPlace()); - // for complex - ConjHelper conj_helper(ctx); - conj_helper(y, y_conj); - } + auto pt_x = paddle::experimental::MakePtenDenseTensor(*x); + auto pt_y = paddle::experimental::MakePtenDenseTensor(*y); + auto pt_dout = paddle::experimental::MakePtenDenseTensor(*dout); + auto pt_dx = dx ? paddle::experimental::MakePtenDenseTensor(*dx) + : std::unique_ptr(nullptr); + auto pt_dy = dy ? paddle::experimental::MakePtenDenseTensor(*dy) + : std::unique_ptr(nullptr); - framework::DDim dy_dims; - if (dy) { - dy_dims = dy->dims(); - if (dy_dims != y.dims()) { - dy->Resize(y.dims()); - } - - // for complex - ConjHelper conj_helper(ctx); - conj_helper(x, x_conj); - } - if (transpose_x && transpose_y) { - CalcInputGrad(ctx, y_conj, true, true, dout, true, false, dx); - CalcInputGrad(ctx, dout, true, true, x_conj, true, false, dy); - } else if (transpose_x) { - CalcInputGrad(ctx, y_conj, false, false, dout, true, false, dx); - CalcInputGrad(ctx, x_conj, false, false, dout, false, true, dy); - } else if (transpose_y) { - CalcInputGrad(ctx, dout, false, false, y_conj, false, true, dx); - CalcInputGrad(ctx, dout, true, true, x_conj, false, true, dy); - } else { - CalcInputGrad(ctx, dout, false, false, y_conj, true, false, dx); - CalcInputGrad(ctx, x_conj, true, true, dout, false, true, dy); - } - - if (dx) { - if (dx_dims != x.dims()) { - dx->Resize(dx_dims); - } - } - if (dy) { - if (dy_dims != y.dims()) { - dy->Resize(dy_dims); - } - } - } else { - // Case3: broadcast. It need cost much time to reduce sum for the - // broadcast and wastes the memory. - // So we should avoid the case in reality. - VLOG(3) << "It need cost much time to reduce sum for the broadcast and " - "wastes the memory. So we should avoid the case in reality"; - Tensor dx_help, dy_help; - - ConjHelper conj_helper(ctx); - conj_helper(x, x_conj); - conj_helper(y, y_conj); - if (transpose_x) { - if (transpose_y) { - // X'Y': dA = Y'G', dB = G'X' - if (dx) - MatMulFunction(&y_conj, &dout, y_dims, dout_dims, - &dx_help, true, true, ctx); - if (dy) - MatMulFunction(&dout, &x_conj, dout_dims, x_dims, - &dy_help, true, true, ctx); - } else { - // X'Y: dX = YG', dY = XG - if (dx) - MatMulFunction(&y_conj, &dout, y_dims, dout_dims, - &dx_help, false, true, ctx); - if (dy) - MatMulFunction(&x_conj, &dout, x_dims, dout_dims, - &dy_help, false, false, ctx); - } - } else { - if (transpose_y) { - // XY': dX = GY, dY = G'X - if (dx) - MatMulFunction(&dout, &y_conj, dout_dims, y_dims, - &dx_help, false, false, ctx); - if (dy) - MatMulFunction(&dout, &x_conj, dout_dims, x_dims, - &dy_help, true, false, ctx); - } else { - // XY: dX = GY', dY = X'G - if (dx) - MatMulFunction(&dout, &y_conj, dout_dims, y_dims, - &dx_help, false, true, ctx); - if (dy) - MatMulFunction(&x_conj, &dout, x_dims, dout_dims, - &dy_help, true, false, ctx); - } - } - - // get help dims - const std::vector dx_help_dims = vectorize(dx_help.dims()); - const std::vector dy_help_dims = vectorize(dy_help.dims()); - - std::vector dx_broadcast_dims(ndim); - std::vector dy_broadcast_dims(ndim); - - std::fill(dx_broadcast_dims.data(), - dx_broadcast_dims.data() + ndim - x_ndim, 1); - std::fill(dy_broadcast_dims.data(), - dy_broadcast_dims.data() + ndim - y_ndim, 1); - std::copy(x_dims.data(), x_dims.data() + x_ndim, - dx_broadcast_dims.data() + ndim - x_ndim); - std::copy(y_dims.data(), y_dims.data() + y_ndim, - dy_broadcast_dims.data() + ndim - y_ndim); - - std::vector dx_reduce_dims; - std::vector dy_reduce_dims; - for (int idx = 0; idx <= ndim - 3; idx++) { - if (dx_help_dims[idx] != 1 && dx_broadcast_dims[idx] == 1) { - dx_reduce_dims.push_back(idx); - } - if (dy_help_dims[idx] != 1 && dy_broadcast_dims[idx] == 1) { - dy_reduce_dims.push_back(idx); - } - } - // reduce sum to get grad by ReduceSum - if (dx) { - if (dx_reduce_dims.empty()) { - *dx = std::move(dx_help); - } else { - ReduceSumForMatmulGrad(&dx_help, dx, dx_reduce_dims, - ctx); - } - dx->Resize(x.dims()); - } - if (dy) { - if (dy_reduce_dims.empty()) { - *dy = std::move(dy_help); - } else { - ReduceSumForMatmulGrad(&dy_help, dy, dy_reduce_dims, - ctx); - } - dy->Resize(y.dims()); - } + auto& dev_ctx = ctx.device_context(); - // Get the OutputGrad(out) - } + // call new kernel + pten::MatmulGradKernel(dev_ctx, *pt_x, *pt_y, *pt_dout, transpose_x, + transpose_y, pt_dx.get(), pt_dy.get()); } }; template class MatMulV2DoubleGradKernel : public framework::OpKernel { public: - void MatMul(const framework::ExecutionContext& context, - const framework::Tensor& a, bool trans_a, - const framework::Tensor& b, bool trans_b, framework::Tensor* out, - bool flag) const { - out->mutable_data(context.GetPlace()); - auto blas = math::GetBlas(context); - auto mat_dim_a = math::CreateMatrixDescriptor(a.dims(), 0, trans_a); - auto mat_dim_b = math::CreateMatrixDescriptor(b.dims(), 0, trans_b); - if (a.dims().size() == 3 && b.dims().size() <= 2) { - // the transpose_X must be false, if is true, the transpose cost much time - if (!trans_a) { - mat_dim_a.height_ *= mat_dim_a.batch_size_; - mat_dim_a.batch_size_ = 0; - } - } - blas.MatMul(a, mat_dim_a, b, mat_dim_b, static_cast(1), out, - static_cast(flag)); - } - - void CalcInputGrad(const framework::ExecutionContext& context, - const framework::Tensor& a, bool trans_a, - bool is_fold_init_dims_a, const framework::Tensor& b, - bool trans_b, bool is_fold_init_dims_b, - framework::Tensor* out, bool flag) const { - if (out == nullptr) return; - bool need_combine = (a.dims().size() == 3 || b.dims().size() == 3) && - out->dims().size() == 2; - if (!need_combine) { - MatMul(context, a, trans_a, b, trans_b, out, flag); - } else { - auto& ctx = context.template device_context(); - MatMul(context, is_fold_init_dims_a - ? FoldInitDims(a) - : FoldHeadAndLastDims(ctx, a), - trans_a, is_fold_init_dims_b - ? FoldInitDims(b) - : FoldHeadAndLastDims(ctx, b), - trans_b, out, flag); - } - } - void Compute(const framework::ExecutionContext& context) const override { - auto x = *context.Input("X"); - auto y = *context.Input("Y"); - auto dout = *context.Input("DOut"); + auto* x = context.Input("X"); + auto* y = context.Input("Y"); + auto* dout = context.Input("DOut"); auto* ddx = context.Input("DDX"); auto* ddy = context.Input("DDY"); @@ -1486,316 +185,38 @@ class MatMulV2DoubleGradKernel : public framework::OpKernel { bool transpose_x = context.Attr("trans_x"); bool transpose_y = context.Attr("trans_y"); - // Get dims from the input x, y, output_grad - std::vector x_dims = vectorize(x.dims()); - std::vector y_dims = vectorize(y.dims()); - std::vector dout_dims = vectorize(dout.dims()); - framework::Tensor x_conj(x.type()); - framework::Tensor y_conj(y.type()); - framework::Tensor dout_conj(dout.type()); - - int x_ndim = x_dims.size(); - int y_ndim = y_dims.size(); - int ndim = dout_dims.size(); - - // Case1 : x's or y's dim = 1 - if (x_ndim == 1 && y_ndim == 1) { - DotDoubleGradFunction()(&x, &y, dx, dy, &dout, ddx, ddy, - ddout, context); - return; - } - - bool is_broadcast = true; - if (x_ndim <= 2 || y_ndim <= 2) { - is_broadcast = false; - } else if (x_ndim != y_ndim) { - is_broadcast = true; - } else { - is_broadcast = !std::equal(x_dims.cbegin(), x_dims.cbegin() + x_ndim - 2, - y_dims.cbegin()); - } - - if (!is_broadcast) { - // Case2: no broadcast or no batch size - ReshapeXYOutIntoMatrixSequence(&x, &y, &dout, transpose_x, transpose_y); - framework::DDim dx_dims; - - ConjHelper conj_helper(context); - if (dx) { - dx_dims = dx->dims(); - if (dx_dims != x.dims()) { - dx->Resize(x.dims()); - } - } - - framework::DDim dy_dims; - if (dy) { - dy_dims = dy->dims(); - if (dy_dims != y.dims()) { - dy->Resize(y.dims()); - } - } - - framework::DDim ddout_dims; - if (ddout) { - ddout_dims = ddout->dims(); - if (ddout_dims != dout.dims()) { - ddout->Resize(dout.dims()); - } - } - - if (ddx || ddy) { - ConjHelper conj_helper(context); - conj_helper(dout, dout_conj); - } - if (ddout) { - ConjHelper conj_helper(context); - conj_helper(x, x_conj); - conj_helper(y, y_conj); - } - bool ddout_flag = false; - if (ddx) { - auto ddx_mat = *ddx; - if (ddx_mat.dims() != x.dims()) { - ddx_mat.Resize(x.dims()); - } - if (dy) { - if (transpose_x && transpose_y) { - // dy = dout' * ddx' - CalcInputGrad(context, dout_conj, true, true, ddx_mat, true, false, - dy, false); - } else if (transpose_x) { - // dy = ddx * dout - CalcInputGrad(context, ddx_mat, false, false, dout_conj, false, - true, dy, false); - } else if (transpose_y) { - // dy = dout' * ddx - CalcInputGrad(context, dout_conj, true, true, ddx_mat, false, true, - dy, false); - } else { - // dy = ddx' * dout - CalcInputGrad(context, ddx_mat, true, true, dout_conj, false, true, - dy, false); - } - } - - if (ddout) { - CalcInputGrad(context, ddx_mat, transpose_x, true, y_conj, - transpose_y, false, ddout, ddout_flag); - ddout_flag = true; - } - } - - if (ddy) { - auto ddy_mat = *ddy; - if (ddy_mat.dims() != y.dims()) { - ddy_mat.Resize(y.dims()); - } - if (dx) { - if (transpose_x && transpose_y) { - // dx = ddy' * dout' - CalcInputGrad(context, ddy_mat, true, true, dout_conj, true, false, - dx, false); - } else if (transpose_x) { - // dx = ddy * dout' - CalcInputGrad(context, ddy_mat, false, false, dout_conj, true, - false, dx, false); - } else if (transpose_y) { - // dx = dout * ddy - CalcInputGrad(context, dout_conj, false, false, ddy_mat, false, - true, dx, false); - } else { - // dx = dout * ddy' - CalcInputGrad(context, dout_conj, false, false, ddy_mat, true, - false, dx, false); - } - } + if (dx) dx->mutable_data(context.GetPlace()); + if (dy) dy->mutable_data(context.GetPlace()); + if (ddout) ddout->mutable_data(context.GetPlace()); - if (ddout) { - CalcInputGrad(context, x_conj, transpose_x, true, ddy_mat, - transpose_y, false, ddout, ddout_flag); - } - } + auto pt_x = paddle::experimental::MakePtenDenseTensor(*x); + auto pt_y = paddle::experimental::MakePtenDenseTensor(*y); + auto pt_dout = paddle::experimental::MakePtenDenseTensor(*dout); + auto pt_ddx = paddle::experimental::MakePtenDenseTensor(*ddx); + auto pt_ddy = paddle::experimental::MakePtenDenseTensor(*ddy); + auto pt_dx = paddle::experimental::MakePtenDenseTensor(*dx); + auto pt_dy = paddle::experimental::MakePtenDenseTensor(*dy); + auto pt_ddout = paddle::experimental::MakePtenDenseTensor(*ddout); - if (dx) { - if (dx_dims != x.dims()) { - dx->Resize(dx_dims); - } - } + auto& dev_ctx = context.device_context(); - if (dy) { - if (dy_dims != y.dims()) { - dy->Resize(dy_dims); - } - } - - if (ddout) { - if (ddout_dims != dout.dims()) { - ddout->Resize(ddout_dims); - } - } - } else { - // Case3: broadcast. It need cost much time to reduce sum for the - // broadcast and wastes the memory. - // So we should avoid the case in reality. - VLOG(3) << "It need cost much time to reduce sum for the broadcast and " - "wastes the memory. So we should avoid the case in reality"; - framework::Tensor ddy_conj(ddx->type()); - framework::Tensor ddx_conj(ddy->type()); - - Tensor dx_help, dy_help; - if (dx || dy) { - ConjHelper conj_helper(context); - conj_helper(dout, dout_conj); - } - if (ddout) { - ConjHelper conj_helper(context); - conj_helper(x, x_conj); - conj_helper(y, y_conj); - } - if (transpose_x) { - if (transpose_y) { - if (dx) - MatMulFunction(ddy, &dout_conj, y_dims, dout_dims, - &dx_help, true, true, context); - if (dy) - MatMulFunction(&dout_conj, ddx, dout_dims, x_dims, - &dy_help, true, true, context); - } else { - if (dx) - MatMulFunction(ddy, &dout_conj, y_dims, dout_dims, - &dx_help, false, true, context); - if (dy) - MatMulFunction(ddx, &dout_conj, x_dims, dout_dims, - &dy_help, false, false, context); - } - } else { - if (transpose_y) { - if (dx) - MatMulFunction(&dout_conj, ddy, dout_dims, y_dims, - &dx_help, false, false, context); - if (dy) - MatMulFunction(&dout_conj, ddx, dout_dims, x_dims, - &dy_help, true, false, context); - } else { - if (dx) - MatMulFunction(&dout_conj, ddy, dout_dims, y_dims, - &dx_help, false, true, context); - if (dy) - MatMulFunction(ddx, &dout_conj, x_dims, dout_dims, - &dy_help, true, false, context); - } - } - - // get help dims - const std::vector dx_help_dims = vectorize(dx_help.dims()); - const std::vector dy_help_dims = vectorize(dy_help.dims()); - - std::vector dx_broadcast_dims(ndim); - std::vector dy_broadcast_dims(ndim); - - std::fill(dx_broadcast_dims.data(), - dx_broadcast_dims.data() + ndim - x_ndim, 1); - std::fill(dy_broadcast_dims.data(), - dy_broadcast_dims.data() + ndim - y_ndim, 1); - std::copy(x_dims.data(), x_dims.data() + x_ndim, - dx_broadcast_dims.data() + ndim - x_ndim); - std::copy(y_dims.data(), y_dims.data() + y_ndim, - dy_broadcast_dims.data() + ndim - y_ndim); - - std::vector dx_reduce_dims; - std::vector dy_reduce_dims; - for (int idx = 0; idx <= ndim - 3; idx++) { - if (dx_help_dims[idx] != 1 && dx_broadcast_dims[idx] == 1) { - dx_reduce_dims.push_back(idx); - } - if (dy_help_dims[idx] != 1 && dy_broadcast_dims[idx] == 1) { - dy_reduce_dims.push_back(idx); - } - } - // Reduce sum to get grad by ReduceSum - if (dx) { - if (dx_reduce_dims.empty()) { - *dx = std::move(dx_help); - } else { - ReduceSumForMatmulGrad(&dx_help, dx, dx_reduce_dims, - context); - } - dx->Resize(x.dims()); - } - if (dy) { - if (dy_reduce_dims.empty()) { - *dy = std::move(dy_help); - } else { - ReduceSumForMatmulGrad(&dy_help, dy, dy_reduce_dims, - context); - } - dy->Resize(y.dims()); - } - - if (ddout) { - // Calculate the gradient of OutputGrad(Out) - MatMulFunction(ddx, &y_conj, x_dims, y_dims, ddout, - transpose_x, transpose_y, context); - MatMulFunction(&x_conj, ddy, x_dims, y_dims, ddout, - transpose_x, transpose_y, context, - true); - } - } + // call new kernel + pten::MatmulDoubleGradKernel(dev_ctx, *pt_x, *pt_y, *pt_dout, *pt_ddx, + *pt_ddy, transpose_x, transpose_y, + pt_dx.get(), pt_dy.get(), pt_ddout.get()); } }; template class MatMulV2TripleGradKernel : public framework::OpKernel { public: - void MatMul(const framework::ExecutionContext& context, - const framework::Tensor& a, bool trans_a, - const framework::Tensor& b, bool trans_b, framework::Tensor* out, - bool flag) const { - out->mutable_data(context.GetPlace()); - auto blas = math::GetBlas(context); - auto mat_dim_a = math::CreateMatrixDescriptor(a.dims(), 0, trans_a); - auto mat_dim_b = math::CreateMatrixDescriptor(b.dims(), 0, trans_b); - if (a.dims().size() == 3 && b.dims().size() <= 2) { - // the transpose_X must be false, if is true, the transpose cost much time - if (!trans_a) { - mat_dim_a.height_ *= mat_dim_a.batch_size_; - mat_dim_a.batch_size_ = 0; - } - } - blas.MatMul(a, mat_dim_a, b, mat_dim_b, static_cast(1), out, - static_cast(flag)); - } - - void CalcInputGrad(const framework::ExecutionContext& context, - const framework::Tensor& a, bool trans_a, - bool is_fold_init_dims_a, const framework::Tensor& b, - bool trans_b, bool is_fold_init_dims_b, - framework::Tensor* out, bool flag) const { - if (out == nullptr) return; - bool need_combine = (a.dims().size() == 3 || b.dims().size() == 3) && - out->dims().size() == 2; - if (!need_combine) { - MatMul(context, a, trans_a, b, trans_b, out, flag); - } else { - auto& ctx = context.template device_context(); - MatMul(context, is_fold_init_dims_a - ? FoldInitDims(a) - : FoldHeadAndLastDims(ctx, a), - trans_a, is_fold_init_dims_b - ? FoldInitDims(b) - : FoldHeadAndLastDims(ctx, b), - trans_b, out, flag); - } - } - void Compute(const framework::ExecutionContext& context) const override { // get input - auto x = *context.Input("X"); - auto y = *context.Input("Y"); - auto dout = *context.Input("DOut"); - auto ddx = *context.Input("DDX"); - auto ddy = *context.Input("DDY"); + auto* x = context.Input("X"); + auto* y = context.Input("Y"); + auto* dout = context.Input("DOut"); + auto* ddx = context.Input("DDX"); + auto* ddy = context.Input("DDY"); auto* d_dx = context.Input("D_DX"); auto* d_dy = context.Input("D_DY"); @@ -1812,539 +233,34 @@ class MatMulV2TripleGradKernel : public framework::OpKernel { bool transpose_x = context.Attr("trans_x"); bool transpose_y = context.Attr("trans_y"); - // Get dims from the input x, y, output_grad - std::vector x_dims = vectorize(x.dims()); - std::vector y_dims = vectorize(y.dims()); - std::vector dout_dims = vectorize(dout.dims()); - framework::Tensor x_conj(x.type()); - framework::Tensor y_conj(y.type()); - framework::Tensor dout_conj(dout.type()); - framework::Tensor ddx_conj(ddx.type()); - framework::Tensor ddy_conj(ddy.type()); - - int x_ndim = x_dims.size(); - int y_ndim = y_dims.size(); - int ndim = dout_dims.size(); - - // Case1 : x's and y's dim = 1 - if (x_ndim == 1 && y_ndim == 1) { - VLOG(3) << "======== MatMulV2TripleGradKernel, Compute ====== Case 1"; - - DotTripleGradFunction()( - &x, &y, &ddx, &ddy, d_dx, d_dy, &dout, d_ddout, out_d_x, out_d_y, - out_d_dout, out_d_ddx, out_d_ddy, context); - return; - } - - bool is_broadcast = true; - if (x_ndim <= 2 || y_ndim <= 2) { - is_broadcast = false; - } else if (x_ndim != y_ndim) { - is_broadcast = true; - } else { - is_broadcast = !std::equal(x_dims.cbegin(), x_dims.cbegin() + x_ndim - 2, - y_dims.cbegin()); - } - - if (!is_broadcast) { - // Case2: no broadcast or no batch size - VLOG(3) << "======== MatMulV2TripleGradKernel, Compute ====== Case 2"; - ReshapeXYOutIntoMatrixSequence(&x, &y, &dout, transpose_x, transpose_y); - - if (ddx.dims() != x.dims()) { - ddx.Resize(x.dims()); - } - - if (ddy.dims() != y.dims()) { - ddy.Resize(y.dims()); - } - - ConjHelper conj_helper(context); - - framework::DDim out_dx_dims; - if (out_d_x) { - out_dx_dims = out_d_x->dims(); - if (out_dx_dims != x.dims()) { - out_d_x->Resize(x.dims()); - } - } - - framework::DDim out_dy_dims; - if (out_d_y) { - out_dy_dims = out_d_y->dims(); - if (out_dy_dims != y.dims()) { - out_d_y->Resize(y.dims()); - } - } - - framework::DDim out_d_dout_dims; - if (out_d_dout) { - out_d_dout_dims = out_d_dout->dims(); - if (out_d_dout_dims != dout.dims()) { - out_d_dout->Resize(dout.dims()); - } - } - - framework::DDim out_d_ddx_dims; - if (out_d_ddx) { - out_d_ddx_dims = out_d_ddx->dims(); - if (out_d_ddx_dims != x.dims()) { - out_d_ddx->Resize(x.dims()); - } - } - - framework::DDim out_d_ddy_dims; - if (out_d_ddy) { - out_d_ddy_dims = out_d_ddy->dims(); - if (out_d_ddy_dims != y.dims()) { - out_d_ddy->Resize(y.dims()); - } - } - - if (out_d_dout) { - ConjHelper conj_helper(context); - conj_helper(ddx, ddx_conj); - conj_helper(ddy, ddy_conj); - } - - if (out_d_ddx || out_d_ddy) { - ConjHelper conj_helper(context); - conj_helper(x, x_conj); - conj_helper(y, y_conj); - conj_helper(dout, dout_conj); - } - - bool d_dout_flag = false; - bool d_ddx_flag = false; - bool d_ddy_flag = false; - - if (d_ddout) { - auto d_ddout_mat = *d_ddout; - if (d_ddout_mat.dims() != dout.dims()) { - d_ddout_mat.Resize(dout.dims()); - } - - if (out_d_y) { - if (transpose_x && transpose_y) { - // out_d_y = d_ddout' * ddx' - CalcInputGrad(context, d_ddout_mat, true, true, ddx_conj, true, - false, out_d_y, false); - } else if (transpose_x) { - // out_d_y = ddx * d_ddout - CalcInputGrad(context, ddx_conj, false, false, d_ddout_mat, false, - true, out_d_y, false); - } else if (transpose_y) { - // out_d_y = d_ddout' * ddx - CalcInputGrad(context, d_ddout_mat, true, true, ddx_conj, false, - true, out_d_y, false); - } else { - // out_d_y = ddx' * d_ddout - CalcInputGrad(context, ddx_conj, true, true, d_ddout_mat, false, - true, out_d_y, false); - } - } - - if (out_d_x) { - if (transpose_x && transpose_y) { - // out_d_x = ddy' * d_ddout' - CalcInputGrad(context, ddy_conj, true, true, d_ddout_mat, true, - false, out_d_x, false); - } else if (transpose_x) { - // out_d_x = ddy * d_ddout' - CalcInputGrad(context, ddy_conj, false, false, d_ddout_mat, true, - false, out_d_x, false); - } else if (transpose_y) { - // out_d_x = d_ddout * ddy - CalcInputGrad(context, d_ddout_mat, false, false, ddy_conj, false, - true, out_d_x, false); - } else { - // out_d_x = d_ddout * ddy' - CalcInputGrad(context, d_ddout_mat, false, false, ddy_conj, true, - false, out_d_x, false); - } - } - - // equations: - // d_ddx = DOut * D_DY + Y * D_DDOut - // Let: d_ddx1 = Y * D_DDOut - // Let: d_ddx2 = DOut * D_DY - - // d_ddy = DOut * D_DX + X * D_DDOut - // Let: d_ddy1 = X * D_DDOut - // Let: d_ddy2 = DOut * D_DX - - // d_dout = DDY * D_DX + DDX * D_DY - // Let: d_dout1 = DDX * D_DY - // Let: d_dout2 = DDY * D_DX - - // compute d_ddx1 - if (out_d_ddx) { - if (transpose_x && transpose_y) { - // out_d_ddx1 = y' * d_ddout' - CalcInputGrad(context, y_conj, true, true, d_ddout_mat, true, false, - out_d_ddx, d_ddx_flag); - } else if (transpose_x) { - // out_d_ddx1 = y * d_ddout' - CalcInputGrad(context, y_conj, false, false, d_ddout_mat, true, - false, out_d_ddx, d_ddx_flag); - } else if (transpose_y) { - // out_d_ddx1 = d_ddout * y - CalcInputGrad(context, d_ddout_mat, false, false, y_conj, false, - true, out_d_ddx, d_ddx_flag); - } else { - // out_d_ddx1 = d_ddout * y' - CalcInputGrad(context, d_ddout_mat, false, false, y_conj, true, - false, out_d_ddx, d_ddx_flag); - } - d_ddx_flag = true; - } - - // compute d_ddy1 - if (out_d_ddy) { - if (transpose_x && transpose_y) { - // out_d_ddy1 = d_ddout' * x' - CalcInputGrad(context, d_ddout_mat, true, true, x_conj, true, false, - out_d_ddy, false); - } else if (transpose_x) { - // out_d_ddy1 = x * d_ddout - CalcInputGrad(context, x_conj, false, false, d_ddout_mat, false, - true, out_d_ddy, false); - } else if (transpose_y) { - // out_d_ddy1 = d_ddout' * x - CalcInputGrad(context, d_ddout_mat, true, true, x_conj, false, true, - out_d_ddy, false); - } else { - // out_d_ddy1 = x' * d_ddout - CalcInputGrad(context, x_conj, true, true, d_ddout_mat, false, true, - out_d_ddy, false); - } - d_ddy_flag = true; - } - } - - if (d_dy) { - auto d_dy_mat = *d_dy; - if (d_dy_mat.dims() != y.dims()) { - d_dy_mat.Resize(y.dims()); - } - - // compute d_dout1 - if (out_d_dout) { - CalcInputGrad(context, ddx_conj, transpose_x, true, d_dy_mat, - transpose_y, false, out_d_dout, d_dout_flag); - d_dout_flag = true; - } - - // compute d_ddx2 - if (out_d_ddx) { - if (transpose_x && transpose_y) { - // out_d_ddx2 = D_DY' * DOut' - CalcInputGrad(context, d_dy_mat, true, true, dout_conj, true, false, - out_d_ddx, d_ddx_flag); - } else if (transpose_x) { - // out_d_ddx2 = D_DY * Dout' - CalcInputGrad(context, d_dy_mat, false, false, dout_conj, true, - false, out_d_ddx, d_ddx_flag); - } else if (transpose_y) { - // out_d_ddx2 = Dout * D_DY - CalcInputGrad(context, dout_conj, false, false, d_dy_mat, false, - true, out_d_ddx, d_ddx_flag); - } else { - // out_d_ddx2 = Dout * D_DY' - CalcInputGrad(context, dout_conj, false, false, d_dy_mat, true, - false, out_d_ddx, d_ddx_flag); - } - } - } - - if (d_dx) { - auto d_dx_mat = *d_dx; - if (d_dx_mat.dims() != x.dims()) { - d_dx_mat.Resize(x.dims()); - } - - // compute d_dout2 - if (out_d_dout) { - CalcInputGrad(context, d_dx_mat, transpose_x, true, ddy_conj, - transpose_y, false, out_d_dout, d_dout_flag); - } - - // compute d_ddy2 - if (out_d_ddy) { - if (transpose_x && transpose_y) { - // out_d_ddy2 = dout' * d_dx' - CalcInputGrad(context, dout_conj, true, true, d_dx_mat, true, false, - out_d_ddy, d_ddy_flag); - } else if (transpose_x) { - // out_d_ddy2 = d_dx * dout - CalcInputGrad(context, d_dx_mat, false, false, dout_conj, false, - true, out_d_ddy, d_ddy_flag); - } else if (transpose_y) { - // out_d_ddy2 = dout' * d_dx - CalcInputGrad(context, dout_conj, true, true, d_dx_mat, false, true, - out_d_ddy, d_ddy_flag); - } else { - // out_d_ddy2 = d_dx' * dout - CalcInputGrad(context, d_dx_mat, true, true, dout_conj, false, true, - out_d_ddy, d_ddy_flag); - } - } - } - - if (out_d_x) { - if (out_dx_dims != x.dims()) { - out_d_x->Resize(out_dx_dims); - } - } - - if (out_d_y) { - if (out_dy_dims != y.dims()) { - out_d_y->Resize(out_dy_dims); - } - } - - if (out_d_dout) { - if (out_d_dout_dims != dout.dims()) { - out_d_dout->Resize(out_d_dout_dims); - } - } - - if (out_d_ddx) { - if (out_d_ddx_dims != x.dims()) { - out_d_ddx->Resize(out_d_ddx_dims); - } - } - - if (out_d_ddy) { - if (out_d_ddy_dims != x.dims()) { - out_d_ddy->Resize(out_d_ddy_dims); - } - } - - } else { - // Case3: broadcast. It need cost much time to reduce sum for the - // broadcast and wastes the memory. - // So we should avoid the case in reality. - VLOG(3) << "======== MatMulV2TripleGradKernel, Compute ====== Case 3"; - VLOG(3) << "It need cost much time to reduce sum for the broadcast and " - "wastes the memory. So we should avoid the case in reality"; - - Tensor out_dx_help, out_dy_help; - Tensor out_d_ddx_help, out_d_ddy_help; - if (out_d_dout) { - ConjHelper conj_helper(context); - conj_helper(ddx, ddx_conj); - conj_helper(ddy, ddy_conj); - } - if (out_d_ddx || out_d_ddy) { - ConjHelper conj_helper(context); - conj_helper(x, x_conj); - conj_helper(y, y_conj); - conj_helper(dout, dout_conj); - } - - if (transpose_x) { - if (transpose_y) { - // dX = ddY' d_ddout’, dY = d_ddout’ ddX' - if (out_d_x) - MatMulFunction(&ddy_conj, d_ddout, y_dims, - dout_dims, &out_dx_help, true, - true, context); - if (out_d_y) - MatMulFunction(d_ddout, &ddx_conj, dout_dims, - x_dims, &out_dy_help, true, true, - context); - } else { - // dX = ddY d_ddout', dY = ddX d_ddout - if (out_d_x) - MatMulFunction(&ddy_conj, d_ddout, y_dims, - dout_dims, &out_dx_help, false, - true, context); - if (out_d_y) - MatMulFunction(&ddx_conj, d_ddout, x_dims, - dout_dims, &out_dy_help, false, - false, context); - } - } else { - if (transpose_y) { - // dX = d_ddout ddY, dY = d_ddout’ ddX - if (out_d_x) - MatMulFunction(d_ddout, &ddy_conj, dout_dims, - y_dims, &out_dx_help, false, false, - context); - if (out_d_y) - MatMulFunction(d_ddout, &ddx_conj, dout_dims, - x_dims, &out_dy_help, true, false, - context); - } else { - // dX = d_ddout ddY', dY = ddX' d_ddout - if (out_d_x) - MatMulFunction(d_ddout, &ddy_conj, dout_dims, - y_dims, &out_dx_help, false, true, - context); - if (out_d_y) - MatMulFunction(&ddx_conj, d_ddout, x_dims, - dout_dims, &out_dy_help, true, - false, context); - } - } - - // get help dims - const std::vector dx_help_dims = - vectorize(out_dx_help.dims()); - const std::vector dy_help_dims = - vectorize(out_dx_help.dims()); - - std::vector dx_broadcast_dims(ndim); - std::vector dy_broadcast_dims(ndim); - - std::fill(dx_broadcast_dims.data(), - dx_broadcast_dims.data() + ndim - x_ndim, 1); - std::fill(dy_broadcast_dims.data(), - dy_broadcast_dims.data() + ndim - y_ndim, 1); - std::copy(x_dims.data(), x_dims.data() + x_ndim, - dx_broadcast_dims.data() + ndim - x_ndim); - std::copy(y_dims.data(), y_dims.data() + y_ndim, - dy_broadcast_dims.data() + ndim - y_ndim); - - std::vector dx_reduce_dims; - std::vector dy_reduce_dims; - for (int idx = 0; idx <= ndim - 3; idx++) { - if (dx_help_dims[idx] != 1 && dx_broadcast_dims[idx] == 1) { - dx_reduce_dims.push_back(idx); - } - if (dy_help_dims[idx] != 1 && dy_broadcast_dims[idx] == 1) { - dy_reduce_dims.push_back(idx); - } - } - // Reduce sum to get grad by ReduceSum - if (out_d_x) { - if (dx_reduce_dims.empty()) { - *out_d_x = std::move(out_dx_help); - } else { - ReduceSumForMatmulGrad(&out_dx_help, out_d_x, - dx_reduce_dims, context); - } - out_d_x->Resize(x.dims()); - } - - if (out_d_y) { - if (dy_reduce_dims.empty()) { - *out_d_y = std::move(out_dy_help); - } else { - ReduceSumForMatmulGrad(&out_dy_help, out_d_y, - dy_reduce_dims, context); - } - out_d_y->Resize(y.dims()); - } - - // compute d_dout - if (out_d_dout) { - MatMulFunction(d_dx, &ddy_conj, x_dims, y_dims, - out_d_dout, transpose_x, transpose_y, - context); - MatMulFunction(&ddx_conj, d_dy, x_dims, y_dims, - out_d_dout, transpose_x, transpose_y, - context, true); - } - - // compute d_ddx - if (out_d_ddx) { - if (transpose_x && transpose_y) { - // out_d_ddx1 = y' * d_ddout' - MatMulFunction(&y_conj, d_ddout, y_dims, dout_dims, - &out_d_ddx_help, true, true, - context); - // out_d_ddx2 = D_DY' * DOut' - MatMulFunction(d_dy, &dout_conj, y_dims, dout_dims, - &out_d_ddx_help, true, true, context, - true); - } else if (transpose_x) { - // out_d_ddx1 = y * d_ddout' - MatMulFunction(&y_conj, d_ddout, y_dims, dout_dims, - &out_d_ddx_help, false, true, - context); - // out_d_ddx2 = D_DY * Dout' - MatMulFunction(d_dy, &dout_conj, y_dims, dout_dims, - &out_d_ddx_help, false, true, - context, true); - } else if (transpose_y) { - // out_d_ddx1 = d_ddout * y - MatMulFunction(d_ddout, &y_conj, dout_dims, y_dims, - &out_d_ddx_help, false, false, - context); - // out_d_ddx2 = Dout * D_DY - MatMulFunction(&dout_conj, d_dy, dout_dims, y_dims, - &out_d_ddx_help, false, false, - context, true); - } else { - // out_d_ddx1 = d_ddout * y' - MatMulFunction(d_ddout, &y_conj, dout_dims, y_dims, - &out_d_ddx_help, false, true, - context); - // out_d_ddx2 = Dout * D_DY' - MatMulFunction(&dout_conj, d_dy, dout_dims, y_dims, - &out_d_ddx_help, false, true, - context, true); - } - if (dx_reduce_dims.empty()) { - *out_d_ddx = std::move(out_d_ddx_help); - } else { - ReduceSumForMatmulGrad(&out_d_ddx_help, out_d_ddx, - dx_reduce_dims, context); - } - out_d_ddx->Resize(x.dims()); - } - - // compute d_ddy - if (out_d_ddy) { - if (transpose_x && transpose_y) { - // out_d_ddy1 = d_ddout' * x' - MatMulFunction(d_ddout, &x_conj, dout_dims, x_dims, - &out_d_ddy_help, true, true, - context); - // out_d_ddy2 = dout' * d_dx' - MatMulFunction(&dout_conj, d_dx, dout_dims, x_dims, - &out_d_ddy_help, true, true, context, - true); - } else if (transpose_x) { - // out_d_ddy1 = x * d_ddout - MatMulFunction(&x_conj, d_ddout, x_dims, dout_dims, - &out_d_ddy_help, false, false, - context); - // out_d_ddy2 = d_dx * dout - MatMulFunction(d_dx, &dout_conj, x_dims, dout_dims, - &out_d_ddy_help, false, false, - context, true); - } else if (transpose_y) { - // out_d_ddy1 = d_ddout' * x - MatMulFunction(d_ddout, &x_conj, dout_dims, x_dims, - &out_d_ddy_help, true, false, - context); - // out_d_ddy2 = dout' * d_dx - MatMulFunction(&dout_conj, d_dx, dout_dims, x_dims, - &out_d_ddy_help, true, false, - context, true); - } else { - // out_d_ddy1 = x' * d_ddout - MatMulFunction(&x_conj, d_ddout, x_dims, dout_dims, - &out_d_ddy_help, true, false, - context); - // out_d_ddy2 = d_dx' * dout - MatMulFunction(d_dx, &dout_conj, x_dims, dout_dims, - &out_d_ddy_help, true, false, - context, true); - } - - if (dy_reduce_dims.empty()) { - *out_d_ddy = std::move(out_d_ddy_help); - } else { - ReduceSumForMatmulGrad(&out_d_ddy_help, out_d_ddy, - dy_reduce_dims, context); - } - out_d_ddy->Resize(y.dims()); - } - } + if (out_d_x) out_d_x->mutable_data(context.GetPlace()); + if (out_d_y) out_d_y->mutable_data(context.GetPlace()); + if (out_d_dout) out_d_dout->mutable_data(context.GetPlace()); + if (out_d_ddx) out_d_ddx->mutable_data(context.GetPlace()); + if (out_d_ddy) out_d_ddy->mutable_data(context.GetPlace()); + + auto pt_x = paddle::experimental::MakePtenDenseTensor(*x); + auto pt_y = paddle::experimental::MakePtenDenseTensor(*y); + auto pt_dout = paddle::experimental::MakePtenDenseTensor(*dout); + auto pt_ddx = paddle::experimental::MakePtenDenseTensor(*ddx); + auto pt_ddy = paddle::experimental::MakePtenDenseTensor(*ddy); + auto pt_d_dx = paddle::experimental::MakePtenDenseTensor(*d_dx); + auto pt_d_dy = paddle::experimental::MakePtenDenseTensor(*d_dy); + auto pt_d_ddout = paddle::experimental::MakePtenDenseTensor(*d_ddout); + + auto pt_out_d_x = paddle::experimental::MakePtenDenseTensor(*out_d_x); + auto pt_out_d_y = paddle::experimental::MakePtenDenseTensor(*out_d_y); + auto pt_out_d_dout = paddle::experimental::MakePtenDenseTensor(*out_d_dout); + auto pt_out_d_ddx = paddle::experimental::MakePtenDenseTensor(*out_d_ddx); + auto pt_out_d_ddy = paddle::experimental::MakePtenDenseTensor(*out_d_ddy); + + auto& dev_ctx = context.device_context(); + // call new kernel + pten::MatmulTripleGradKernel(dev_ctx, *pt_x, *pt_y, *pt_dout, *pt_ddx, + *pt_ddy, *pt_d_dx, *pt_d_dy, *pt_d_ddout, + transpose_x, transpose_y, pt_out_d_x.get(), + pt_out_d_y.get(), pt_out_d_dout.get(), + pt_out_d_ddx.get(), pt_out_d_ddy.get()); } }; diff --git a/paddle/pten/core/dense_tensor.cc b/paddle/pten/core/dense_tensor.cc index d8d83c575c4cf..58f56eeeef9a1 100644 --- a/paddle/pten/core/dense_tensor.cc +++ b/paddle/pten/core/dense_tensor.cc @@ -50,6 +50,12 @@ DenseTensor& DenseTensor::operator=(const DenseTensor& other) { return *this; } +DenseTensor& DenseTensor::operator=(DenseTensor&& other) { + meta_ = std::move(other.meta_); + storage_.swap(other.storage_); + return *this; +} + int64_t DenseTensor::numel() const { if (meta_.is_scalar) { return 1; diff --git a/paddle/pten/core/dense_tensor.h b/paddle/pten/core/dense_tensor.h index eb149220f942d..d1d583f2e2dbe 100644 --- a/paddle/pten/core/dense_tensor.h +++ b/paddle/pten/core/dense_tensor.h @@ -97,6 +97,8 @@ class DenseTensor : public TensorBase, /// \brief DenseTensor shallow copy assignment. DenseTensor& operator=(const DenseTensor& other); + DenseTensor& operator=(DenseTensor&& other); + /// \brief Destroy the tensor object and release exclusive resources. virtual ~DenseTensor() = default; diff --git a/paddle/pten/core/kernel_alias_name.h b/paddle/pten/core/kernel_alias_name.h index 56f7eea7ea802..46fa6dd376ee3 100644 --- a/paddle/pten/core/kernel_alias_name.h +++ b/paddle/pten/core/kernel_alias_name.h @@ -29,6 +29,9 @@ const std::unordered_map kernel_alias_name_map = { {"flatten_contiguous_range", "flatten"}, {"flatten_contiguous_range_grad", "flatten_grad"}, {"matmul_v2", "matmul"}, + {"matmul_v2_grad", "matmul_grad"}, + {"matmul_v2_grad_grad", "matmul_double_grad"}, + {"matmul_v2_triple_grad", "matmul_triple_grad"}, {"reduce_mean", "mean"}, {"reduce_sum", "sum"}, {"reshape2", "reshape"}, @@ -36,6 +39,8 @@ const std::unordered_map kernel_alias_name_map = { {"flatten", "deprecated"}, {"flatten_grad", "deprecated"}, {"matmul", "deprecated"}, + {"matmul_grad", "deprecated"}, + {"matmul_grad_grad", "deprecated"}, {"mean", "deprecated"}, {"reshape", "deprecated"}, {"sum", "deprecated"}}; diff --git a/paddle/pten/core/kernel_context.cc b/paddle/pten/core/kernel_context.cc index b2c84807951a5..74bd6d17f066a 100644 --- a/paddle/pten/core/kernel_context.cc +++ b/paddle/pten/core/kernel_context.cc @@ -50,6 +50,11 @@ void KernelContext::EmplaceBackOutputWithoutSetRange( outputs_.emplace_back(std::move(output)); } +void KernelContext::SetOutputWithoutSetRange( + int index, std::shared_ptr output) { + outputs_.at(index) = std::move(output); +} + void KernelContext::EmplaceBackOutputs( paddle::SmallVector> outputs) { int index = outputs_.size(); @@ -119,8 +124,10 @@ void KernelContext::ClearData() { } } for (auto& out : outputs_) { - CompatibleDenseTensorUtils::ClearStorage( - static_cast(out.get())); + if (out) { + CompatibleDenseTensorUtils::ClearStorage( + static_cast(out.get())); + } } attrs_.clear(); } diff --git a/paddle/pten/core/kernel_context.h b/paddle/pten/core/kernel_context.h index 6c695987096cb..b6cc15c084ac0 100644 --- a/paddle/pten/core/kernel_context.h +++ b/paddle/pten/core/kernel_context.h @@ -62,6 +62,8 @@ class KernelContext { void EmplaceBackOutputWithoutSetRange(std::shared_ptr output); + void SetOutputWithoutSetRange(int index, std::shared_ptr output); + void EmplaceBackOutputs( paddle::SmallVector> outputs); @@ -80,6 +82,14 @@ class KernelContext { return static_cast(*(inputs_.at(idx))); } + template + paddle::optional OptionalInputAt(size_t idx) const { + const auto& input = inputs_.at(idx); + return input ? paddle::optional{static_cast< + const TensorType&>(*input)} + : paddle::optional{paddle::none}; + } + std::shared_ptr& MutableInputPtrAt(size_t idx) { return inputs_.at(idx); } diff --git a/paddle/pten/core/kernel_registry.h b/paddle/pten/core/kernel_registry.h index bd4687c6e7f4e..f08ef4acfd9ce 100644 --- a/paddle/pten/core/kernel_registry.h +++ b/paddle/pten/core/kernel_registry.h @@ -65,6 +65,10 @@ struct KernelArgsParseFunctor { } else if (arg_type == std::type_index(typeid(const DenseTensor&))) { args_def->AppendInput( default_key.backend(), default_tensor_layout, default_key.dtype()); + } else if (arg_type == std::type_index(typeid( + paddle::optional))) { + args_def->AppendInput( + default_key.backend(), default_tensor_layout, default_key.dtype()); } else if (arg_type == std::type_index(typeid(const std::vector&))) { args_def->AppendInput( diff --git a/paddle/pten/core/kernel_utils.h b/paddle/pten/core/kernel_utils.h index 5087d912ed525..60201151c62a2 100644 --- a/paddle/pten/core/kernel_utils.h +++ b/paddle/pten/core/kernel_utils.h @@ -77,6 +77,27 @@ namespace pten { } \ } +#define PT_SPECIALIZE_KernelCallHelper_FOR_OPTIONAL_INPUT(tensor_type) \ + template \ + struct KernelCallHelper, Tail...> { \ + template \ + static void Compute(KernelContext* ctx, PreviousArgs&... pargs) { \ + static_assert(attr_idx == 0, \ + "Kernel's Input should appear before Attributes."); \ + static_assert(out_idx == 0, \ + "Kernel's Input should appear before Outputs."); \ + const std::pair range = ctx->InputRangeAt(in_idx); \ + auto arg = ctx->OptionalInputAt(range.first); \ + KernelCallHelper:: \ + template Compute( \ + ctx, pargs..., arg); \ + } \ + } + #define PT_SPECIALIZE_KernelCallHelper_FOR_MULTI_INPUT(tensor_type) \ template \ struct KernelCallHelper&, Tail...> { \ @@ -190,6 +211,7 @@ struct KernelImpl { /* Input Helpers */ PT_SPECIALIZE_KernelCallHelper_FOR_INPUT(DenseTensor); + PT_SPECIALIZE_KernelCallHelper_FOR_OPTIONAL_INPUT(DenseTensor); PT_SPECIALIZE_KernelCallHelper_FOR_MULTI_INPUT(DenseTensor); // TODO(chenweihang): adapt SelectedRows // PT_SPECIALIZE_KernelCallHelper_FOR_INPUT(SelectedRowsTensor); diff --git a/paddle/pten/include/linalg.h b/paddle/pten/include/linalg.h index 22f287468e673..71bc518aa89f8 100644 --- a/paddle/pten/include/linalg.h +++ b/paddle/pten/include/linalg.h @@ -30,7 +30,7 @@ DenseTensor Dot(const ContextT& dev_ctx, pten::make_intrusive( dev_ctx.GetPlace()), std::move(out_meta)); - Dot(dev_ctx, x, y, &dense_out); + DotKernel(dev_ctx, x, y, &dense_out); return dense_out; } diff --git a/paddle/pten/include/math.h b/paddle/pten/include/math.h index faa4c8db8dac3..5070d0d4e0e5a 100644 --- a/paddle/pten/include/math.h +++ b/paddle/pten/include/math.h @@ -48,15 +48,4 @@ DenseTensor Scale(const ContextT& dev_ctx, return dense_out; } -template -DenseTensor Conj(const ContextT& dev_ctx, const DenseTensor& x) { - auto out_meta = UnchangedInferMeta(x.meta()); - pten::DenseTensor dense_out( - pten::make_intrusive( - dev_ctx.GetPlace()), - std::move(out_meta)); - Conj(dev_ctx, x, &dense_out); - return dense_out; -} - } // namespace pten diff --git a/paddle/pten/kernels/complex_kernel.h b/paddle/pten/kernels/complex_kernel.h index dfe8fff43e6ef..e9f717152a458 100644 --- a/paddle/pten/kernels/complex_kernel.h +++ b/paddle/pten/kernels/complex_kernel.h @@ -16,9 +16,20 @@ limitations under the License. */ #include "paddle/pten/core/dense_tensor.h" +#include "paddle/pten/infermeta/unary.h" +#include "paddle/pten/kernels/empty_kernel.h" + namespace pten { template -void Conj(const Context& dev_ctx, const DenseTensor& x, DenseTensor* out); +void ConjKernel(const Context& dev_ctx, const DenseTensor& x, DenseTensor* out); + +template +DenseTensor Conj(const Context& dev_ctx, const DenseTensor& x) { + auto out_meta = UnchangedInferMeta(x.meta()); + auto dense_out = Empty(dev_ctx, std::move(out_meta)); + ConjKernel(dev_ctx, x, &dense_out); + return dense_out; +} } // namespace pten diff --git a/paddle/pten/kernels/cpu/complex_kernel.cc b/paddle/pten/kernels/cpu/complex_kernel.cc index 9bf27ef22dcd7..10e7e684db3c1 100644 --- a/paddle/pten/kernels/cpu/complex_kernel.cc +++ b/paddle/pten/kernels/cpu/complex_kernel.cc @@ -24,7 +24,7 @@ PT_REGISTER_CTX_KERNEL(conj, CPU, ALL_LAYOUT, - pten::Conj, + pten::ConjKernel, paddle::platform::complex, paddle::platform::complex, float, diff --git a/paddle/pten/kernels/cpu/dot_grad_kernel.cc b/paddle/pten/kernels/cpu/dot_grad_kernel.cc new file mode 100644 index 0000000000000..c9d5c35e134c8 --- /dev/null +++ b/paddle/pten/kernels/cpu/dot_grad_kernel.cc @@ -0,0 +1,32 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/pten/kernels/dot_grad_kernel.h" +#include "paddle/pten/kernels/impl/dot_grad_kernel_impl.h" + +#include "paddle/pten/backends/cpu/cpu_context.h" +#include "paddle/pten/core/kernel_registry.h" + +#include "paddle/fluid/platform/complex.h" + +PT_REGISTER_CTX_KERNEL(dot_grad, + CPU, + ALL_LAYOUT, + pten::DotGradKernel, + float, + double, + int, + int64_t, + paddle::platform::complex, + paddle::platform::complex) {} diff --git a/paddle/pten/kernels/cpu/dot_kernel.cc b/paddle/pten/kernels/cpu/dot_kernel.cc index 247ad1216a266..72e9e28907f90 100644 --- a/paddle/pten/kernels/cpu/dot_kernel.cc +++ b/paddle/pten/kernels/cpu/dot_kernel.cc @@ -23,10 +23,10 @@ namespace pten { template -void Dot(const Context& dev_ctx, - const DenseTensor& x, - const DenseTensor& y, - DenseTensor* out) { +void DotKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + DenseTensor* out) { auto const *x_ptr = x.data(), *x_ptr_ = &x_ptr[0]; auto const *y_ptr = y.data(), *y_ptr_ = &y_ptr[0]; auto* z = out->mutable_data(); @@ -52,7 +52,7 @@ using complex128 = ::paddle::platform::complex; PT_REGISTER_CTX_KERNEL(dot, CPU, ALL_LAYOUT, - pten::Dot, + pten::DotKernel, float, double, int, diff --git a/paddle/pten/kernels/cpu/matmul_grad_kernel.cc b/paddle/pten/kernels/cpu/matmul_grad_kernel.cc new file mode 100644 index 0000000000000..5a8abb6701b0e --- /dev/null +++ b/paddle/pten/kernels/cpu/matmul_grad_kernel.cc @@ -0,0 +1,47 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/pten/kernels/matmul_grad_kernel.h" + +#include "paddle/fluid/platform/complex.h" +#include "paddle/pten/core/kernel_registry.h" + +#include "paddle/pten/kernels/impl/matmul_grad_kernel_impl.h" + +PT_REGISTER_CTX_KERNEL(matmul_grad, + CPU, + ALL_LAYOUT, + pten::MatmulGradKernel, + float, + double, + paddle::platform::complex, + paddle::platform::complex) {} + +PT_REGISTER_CTX_KERNEL(matmul_double_grad, + CPU, + ALL_LAYOUT, + pten::MatmulDoubleGradKernel, + float, + double, + paddle::platform::complex, + paddle::platform::complex) {} + +PT_REGISTER_CTX_KERNEL(matmul_triple_grad, + CPU, + ALL_LAYOUT, + pten::MatmulTripleGradKernel, + float, + double, + paddle::platform::complex, + paddle::platform::complex) {} diff --git a/paddle/pten/kernels/dot_grad_kernel.h b/paddle/pten/kernels/dot_grad_kernel.h new file mode 100644 index 0000000000000..b0940e5b16a33 --- /dev/null +++ b/paddle/pten/kernels/dot_grad_kernel.h @@ -0,0 +1,56 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/pten/core/dense_tensor.h" + +namespace pten { + +template +void DotGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& dout, + DenseTensor* dx, + DenseTensor* dy); + +template +void DotDoubleGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& ddx, + const DenseTensor& ddy, + const DenseTensor& dout, + DenseTensor* dx, + DenseTensor* dy, + DenseTensor* ddout); + +template +void DotTripleGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& ddx, + const DenseTensor& ddy, + const DenseTensor& d_dx, + const DenseTensor& d_dy, + const DenseTensor& dout, + const DenseTensor& d_ddout, + DenseTensor* d_x, + DenseTensor* d_y, + DenseTensor* d_ddx, + DenseTensor* d_ddy, + DenseTensor* d_dout); + +} // namespace pten diff --git a/paddle/pten/kernels/dot_kernel.h b/paddle/pten/kernels/dot_kernel.h index 9924749cd2141..5ef660265333e 100644 --- a/paddle/pten/kernels/dot_kernel.h +++ b/paddle/pten/kernels/dot_kernel.h @@ -19,9 +19,9 @@ namespace pten { template -void Dot(const Context& dev_ctx, - const DenseTensor& x, - const DenseTensor& y, - DenseTensor* out); +void DotKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + DenseTensor* out); } // namespace pten diff --git a/paddle/pten/kernels/empty_kernel.cc b/paddle/pten/kernels/empty_kernel.cc index 94886806bccf3..2dd55a13e38e5 100644 --- a/paddle/pten/kernels/empty_kernel.cc +++ b/paddle/pten/kernels/empty_kernel.cc @@ -1,33 +1,34 @@ /* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ #include "paddle/pten/kernels/empty_kernel.h" #include "paddle/pten/backends/all_context.h" #include "paddle/pten/core/kernel_registry.h" +#include "paddle/fluid/platform/complex.h" + namespace pten { -template -void EmptyKernel(const ContextT& dev_ctx, +template +void EmptyKernel(const Context& dev_ctx, const ScalarArray& shape, DenseTensor* out) { out->Resize(paddle::framework::make_ddim(shape.GetData())); } -template -void EmptyLikeKernel(const ContextT& dev_ctx, DenseTensor* out) { +template +void EmptyLikeKernel(const Context& dev_ctx, DenseTensor* out) { out->mutable_data(); } @@ -37,44 +38,62 @@ PT_REGISTER_CTX_KERNEL(empty, CPU, ALL_LAYOUT, pten::EmptyKernel, - bool, - int, - int64_t, float, double, - paddle::platform::float16) {} + uint8_t, + int16_t, + int, + int64_t, + bool, + paddle::platform::float16, + paddle::platform::bfloat16, + paddle::platform::complex, + paddle::platform::complex) {} PT_REGISTER_CTX_KERNEL(empty_like, CPU, ALL_LAYOUT, pten::EmptyLikeKernel, - bool, - int, - int64_t, float, double, - paddle::platform::float16) {} + uint8_t, + int16_t, + int, + int64_t, + bool, + paddle::platform::float16, + paddle::platform::bfloat16, + paddle::platform::complex, + paddle::platform::complex) {} #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) PT_REGISTER_CTX_KERNEL(empty, GPU, ALL_LAYOUT, pten::EmptyKernel, - bool, - int, - int64_t, float, double, - paddle::platform::float16) {} + uint8_t, + int16_t, + int, + int64_t, + bool, + paddle::platform::float16, + paddle::platform::complex, + paddle::platform::complex) {} PT_REGISTER_CTX_KERNEL(empty_like, GPU, ALL_LAYOUT, pten::EmptyLikeKernel, - bool, - int, - int64_t, float, double, - paddle::platform::float16) {} + uint8_t, + int16_t, + int, + int64_t, + bool, + paddle::platform::float16, + paddle::platform::complex, + paddle::platform::complex) {} #endif diff --git a/paddle/pten/kernels/empty_kernel.h b/paddle/pten/kernels/empty_kernel.h index d71ee0b1266f2..d283ef5c1e41e 100644 --- a/paddle/pten/kernels/empty_kernel.h +++ b/paddle/pten/kernels/empty_kernel.h @@ -41,6 +41,14 @@ DenseTensor Empty(const Context& dev_ctx, DenseTensorMeta&& meta) { return dense_out; } +template +DenseTensor Empty(const Context& dev_ctx) { + return Empty(dev_ctx, + {paddle::experimental::CppTypeToDataType::Type(), + {-1}, + DataLayout::NCHW}); +} + template DenseTensor Empty(const Context& dev_ctx, const ScalarArray& shape, diff --git a/paddle/pten/kernels/gpu/complex_kernel.cu b/paddle/pten/kernels/gpu/complex_kernel.cu index 5a3c14de4036a..02f050f5bc838 100644 --- a/paddle/pten/kernels/gpu/complex_kernel.cu +++ b/paddle/pten/kernels/gpu/complex_kernel.cu @@ -24,7 +24,8 @@ PT_REGISTER_CTX_KERNEL(conj, GPU, ALL_LAYOUT, - pten::Conj, + pten::ConjKernel, + paddle::platform::float16, paddle::platform::complex, paddle::platform::complex, float, diff --git a/paddle/pten/kernels/gpu/dot_grad_kernel.cu b/paddle/pten/kernels/gpu/dot_grad_kernel.cu new file mode 100644 index 0000000000000..42af96f7c7265 --- /dev/null +++ b/paddle/pten/kernels/gpu/dot_grad_kernel.cu @@ -0,0 +1,32 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/pten/kernels/dot_grad_kernel.h" +#include "paddle/pten/kernels/impl/dot_grad_kernel_impl.h" + +#include "paddle/pten/backends/gpu/gpu_context.h" +#include "paddle/pten/core/kernel_registry.h" + +#include "paddle/fluid/platform/complex.h" + +PT_REGISTER_CTX_KERNEL(dot_grad, + GPU, + ALL_LAYOUT, + pten::DotGradKernel, + float, + double, + int, + int64_t, + paddle::platform::complex, + paddle::platform::complex) {} diff --git a/paddle/pten/kernels/gpu/dot_kernel.cu b/paddle/pten/kernels/gpu/dot_kernel.cu index 6b66d45b7dd48..1f9e7aa3f1cfd 100644 --- a/paddle/pten/kernels/gpu/dot_kernel.cu +++ b/paddle/pten/kernels/gpu/dot_kernel.cu @@ -25,10 +25,10 @@ namespace pten { template -void Dot(const Context& dev_ctx, - const DenseTensor& x, - const DenseTensor& y, - DenseTensor* out) { +void DotKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + DenseTensor* out) { out->mutable_data(); if (1 == out->dims().size()) { auto eigen_out = pten::EigenScalar::From(*out); @@ -55,7 +55,7 @@ using complex128 = ::paddle::platform::complex; PT_REGISTER_CTX_KERNEL(dot, GPU, ALL_LAYOUT, - pten::Dot, + pten::DotKernel, float, double, int, diff --git a/paddle/pten/kernels/gpu/matmul_grad_kernel.cu b/paddle/pten/kernels/gpu/matmul_grad_kernel.cu new file mode 100644 index 0000000000000..f20c3f82c9262 --- /dev/null +++ b/paddle/pten/kernels/gpu/matmul_grad_kernel.cu @@ -0,0 +1,50 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/pten/kernels/matmul_grad_kernel.h" + +#include "paddle/fluid/platform/complex.h" +#include "paddle/pten/core/kernel_registry.h" + +#include "paddle/pten/kernels/impl/matmul_grad_kernel_impl.h" + +PT_REGISTER_CTX_KERNEL(matmul_grad, + GPU, + ALL_LAYOUT, + pten::MatmulGradKernel, + float, + double, + paddle::platform::float16, + paddle::platform::complex, + paddle::platform::complex) {} + +PT_REGISTER_CTX_KERNEL(matmul_double_grad, + GPU, + ALL_LAYOUT, + pten::MatmulDoubleGradKernel, + float, + double, + paddle::platform::float16, + paddle::platform::complex, + paddle::platform::complex) {} + +PT_REGISTER_CTX_KERNEL(matmul_triple_grad, + GPU, + ALL_LAYOUT, + pten::MatmulTripleGradKernel, + float, + double, + paddle::platform::float16, + paddle::platform::complex, + paddle::platform::complex) {} diff --git a/paddle/pten/kernels/hybird/transpose.h b/paddle/pten/kernels/hybird/transpose.h index 459fed6b9fa04..17f52c74a1344 100644 --- a/paddle/pten/kernels/hybird/transpose.h +++ b/paddle/pten/kernels/hybird/transpose.h @@ -17,6 +17,9 @@ #include "paddle/fluid/framework/ddim.h" #include "paddle/pten/core/dense_tensor.h" +#include "paddle/fluid/operators/eigen/eigen_function.h" +#include "paddle/pten/kernels/hybird/eigen/common.h" + namespace pten { namespace math { @@ -30,5 +33,30 @@ struct TransposeNormal { const std::vector& axis); }; +template +struct Transpose { + void operator()(const DeviceContext& dev_ctx, + const DenseTensor& in, + DenseTensor* out, + const std::vector& axis) { + Eigen::array permute; + for (int i = 0; i < Rank; i++) { + permute[i] = axis[i]; + } + auto eigen_in = pten::EigenTensor::From(in); + auto eigen_out = pten::EigenTensor::From(*out); + auto* dev = dev_ctx.eigen_device(); + // use 32bit index to speed up computation + bool use_32bit_index = eigen_out.size() < Eigen::NumTraits::highest(); + bool is_gpu_place = paddle::platform::is_gpu_place(dev_ctx.GetPlace()); + if (use_32bit_index && is_gpu_place) { + To32BitIndex(eigen_out).device(*dev) = + To32BitIndex(eigen_in).shuffle(permute); + } else { + eigen_out.device(*dev) = eigen_in.shuffle(permute); + } + } +}; + } // namespace math } // namespace pten diff --git a/paddle/pten/kernels/impl/complex_kernel_impl.h b/paddle/pten/kernels/impl/complex_kernel_impl.h index 6f3a6049faa9a..e0c6825a78a53 100644 --- a/paddle/pten/kernels/impl/complex_kernel_impl.h +++ b/paddle/pten/kernels/impl/complex_kernel_impl.h @@ -21,12 +21,14 @@ namespace pten { template -void Conj(const Context& dev_ctx, const DenseTensor& x, DenseTensor* out) { +void ConjKernel(const Context& context, + const DenseTensor& x, + DenseTensor* out) { auto numel = x.numel(); auto* x_data = x.data(); auto* out_data = out->mutable_data(); - paddle::platform::ForRange for_range(dev_ctx, numel); + paddle::platform::ForRange for_range(context, numel); paddle::operators::math::ConjFunctor functor(x_data, numel, out_data); for_range(functor); } diff --git a/paddle/pten/kernels/impl/dot_grad_kernel_impl.h b/paddle/pten/kernels/impl/dot_grad_kernel_impl.h new file mode 100644 index 0000000000000..16c87bbab474a --- /dev/null +++ b/paddle/pten/kernels/impl/dot_grad_kernel_impl.h @@ -0,0 +1,919 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/pten/core/dense_tensor.h" +#include "paddle/pten/kernels/hybird/eigen/common.h" + +#include "paddle/pten/kernels/complex_kernel.h" + +#include "paddle/fluid/operators/eigen/eigen_function.h" +#include "paddle/fluid/operators/math/complex_functors.h" + +namespace pten { + +template +struct DotGradFunction { + void operator()(const DeviceContext& ctx, + const DenseTensor* tensor_x, + const DenseTensor* tensor_y, + const DenseTensor* tensor_dout, + DenseTensor* tensor_dx, + DenseTensor* tensor_dy); +}; + +template +struct DotGradFunction> { + void operator()(const DeviceContext& ctx, + const DenseTensor* tensor_x, + const DenseTensor* tensor_y, + const DenseTensor* tensor_dout, + DenseTensor* tensor_dx, + DenseTensor* tensor_dy) { +#if defined(__NVCC__) || defined(__HIPCC__) + if (1 == tensor_dout->dims().size()) { + auto dout = EigenVector::Flatten(*tensor_dout); + + if (tensor_dx) { + auto y = EigenVector::Flatten(*tensor_y); + auto& dev = *ctx.eigen_device(); + Eigen::DSizes size(tensor_dx->numel()); + + ConjKernel(ctx, *tensor_y, tensor_dx); + + auto dx = EigenVector::Flatten(*tensor_dx); + dx.device(dev) = dx * dout.broadcast(size); + } + + if (tensor_dy) { + auto x = EigenVector::Flatten(*tensor_x); + auto& dev = *ctx.eigen_device(); + Eigen::DSizes size(tensor_dy->numel()); + + ConjKernel(ctx, *tensor_x, tensor_dy); + + auto dy = EigenVector::Flatten(*tensor_dy); + dy.device(dev) = dy * dout.broadcast(size); + } + } else { + auto dout = EigenMatrix::From(*tensor_dout); + + if (tensor_dx) { + tensor_dx->mutable_data(); + auto y = EigenMatrix::From(*tensor_y); + auto& dev = *ctx.eigen_device(); + Eigen::DSizes size(1, tensor_dx->dims()[1]); + + ConjKernel(ctx, *tensor_y, tensor_dx); + + auto dx = EigenMatrix::From(*tensor_dx); + dx.device(dev) = dx * dout.broadcast(size); + } + + if (tensor_dy) { + tensor_dy->mutable_data(); + auto x = EigenMatrix::From(*tensor_x); + auto& dev = *ctx.eigen_device(); + Eigen::DSizes size(1, tensor_dy->dims()[1]); + + ConjKernel(ctx, *tensor_x, tensor_dy); + + auto dy = EigenMatrix::From(*tensor_dy); + dy.device(dev) = dy * dout.broadcast(size); + } + } +#else + const auto* data_dout = tensor_dout->data(); + + if (tensor_dx) { + auto* data_dx = tensor_dx->mutable_data(); + const auto* data_y = tensor_y->data(); + const DDim& dim = tensor_x->dims(); + size_t N = static_cast(paddle::framework::product(dim)); + + auto step = dim[dim.size() - 1]; + + int s = -1; + for (size_t i = 0; i < N; ++i) { + if (0 == i % step) ++s; + data_dx[i] = T(data_y[i].real, -data_y[i].imag) * data_dout[s]; + } + } + + if (tensor_dy) { + auto* data_dy = tensor_dy->mutable_data(); + const auto* data_x = tensor_x->data(); + const DDim& dim = tensor_y->dims(); + size_t N = static_cast(paddle::framework::product(dim)); + + auto step = dim[dim.size() - 1]; + + int s = -1; + for (size_t i = 0; i < N; ++i) { + if (0 == i % step) ++s; + data_dy[i] = T(data_x[i].real, -data_x[i].imag) * data_dout[s]; + } + } +#endif + } +}; + +template +struct DotGradFunction> { + void operator()(const DeviceContext& ctx, + const DenseTensor* tensor_x, + const DenseTensor* tensor_y, + const DenseTensor* tensor_dout, + DenseTensor* tensor_dx, + DenseTensor* tensor_dy) { +#if defined(__NVCC__) || defined(__HIPCC__) + if (1 == tensor_dout->dims().size()) { + auto dout = EigenVector::Flatten(*tensor_dout); + if (tensor_dx) { + auto y = EigenVector::Flatten(*tensor_y); + auto dx = EigenVector::Flatten(*tensor_dx); + auto& dev = *ctx.eigen_device(); + Eigen::DSizes size(tensor_dx->numel()); + dx.device(dev) = y * dout.broadcast(size); + } + + if (tensor_dy) { + auto x = EigenVector::Flatten(*tensor_x); + auto dy = EigenVector::Flatten(*tensor_dy); + auto& dev = *ctx.eigen_device(); + Eigen::DSizes size(tensor_dy->numel()); + dy.device(dev) = x * dout.broadcast(size); + } + } else { + auto dout = EigenMatrix::From(*tensor_dout); + + if (tensor_dx) { + tensor_dx->mutable_data(); + auto y = EigenMatrix::From(*tensor_y); + auto dx = EigenMatrix::From(*tensor_dx); + auto& dev = *ctx.eigen_device(); + Eigen::DSizes size(1, tensor_dx->dims()[1]); + dx.device(dev) = y * dout.broadcast(size); + } + + if (tensor_dy) { + tensor_dy->mutable_data(); + auto x = EigenMatrix::From(*tensor_x); + auto dy = EigenMatrix::From(*tensor_dy); + auto& dev = *ctx.eigen_device(); + Eigen::DSizes size(1, tensor_dy->dims()[1]); + dy.device(dev) = x * dout.broadcast(size); + } + } +#else + auto const *x = tensor_x->data(), *y = tensor_y->data(), + *dz = tensor_dout->data(); + auto&& d = tensor_x->dims(); + auto const N = tensor_x->numel(); + auto const B = d[d.size() - 1]; + + if (tensor_dx) { + auto* dx = tensor_dx->mutable_data(); + for (auto j = 0; j < N / B; ++j) { + auto const ss = dz[j]; + for (auto i = 0; i < B; ++i) *dx++ = *y++ * ss; + } + } + + if (tensor_dy) { + auto* dy = tensor_dy->mutable_data(); + for (auto j = 0; j < N / B; ++j) { + auto const ss = dz[j]; + for (auto i = 0; i < B; i++) *dy++ = *x++ * ss; + } + } +#endif + } +}; + +template +struct DotDoubleGradFunction { + void operator()(const DeviceContext& ctx, + const DenseTensor* tensor_x, + const DenseTensor* tensor_y, + const DenseTensor* tensor_dout, + const DenseTensor* tensor_ddx, + const DenseTensor* tensor_ddy, + DenseTensor* tensor_dx, + DenseTensor* tensor_dy, + DenseTensor* tensor_ddout); +}; + +template +struct DotDoubleGradFunction> { + void operator()(const DeviceContext& ctx, + const DenseTensor* tensor_x, + const DenseTensor* tensor_y, + const DenseTensor* tensor_dout, + const DenseTensor* tensor_ddx, + const DenseTensor* tensor_ddy, + DenseTensor* tensor_dx, + DenseTensor* tensor_dy, + DenseTensor* tensor_ddout) { +#if defined(__NVCC__) || defined(__HIPCC__) + if (1 == tensor_dout->dims().size()) { + DenseTensor tensor_dout_help; + auto& dev = *ctx.eigen_device(); + if (tensor_dx || tensor_dy) { + tensor_dout_help = Conj(ctx, *tensor_dout); + } + if (tensor_dx) { + auto ddy = EigenVector::Flatten(*tensor_ddy); + Eigen::DSizes size(tensor_ddy->numel()); + auto dx = EigenVector::Flatten(*tensor_dx); + auto dout = EigenVector::Flatten(tensor_dout_help); + dx.device(dev) = ddy * dout.broadcast(size); + } + + if (tensor_dy) { + auto ddx = EigenVector::Flatten(*tensor_ddx); + Eigen::DSizes size(tensor_ddx->numel()); + auto dy = EigenVector::Flatten(*tensor_dy); + auto dout = EigenVector::Flatten(tensor_dout_help); + dy.device(dev) = ddx * dout.broadcast(size); + } + + if (tensor_ddout) { + DenseTensor tensor_x_help = Conj(ctx, *tensor_x); + DenseTensor tensor_y_help = Conj(ctx, *tensor_y); + + auto x = EigenVector::Flatten(tensor_x_help); + auto y = EigenVector::Flatten(tensor_y_help); + auto ddx = EigenVector::Flatten(*tensor_ddx); + auto ddy = EigenVector::Flatten(*tensor_ddy); + auto ddout = EigenVector::Flatten(*tensor_ddout); + ddout.device(dev) = (x * ddy + y * ddx).sum(); + } + } +#else + const auto* data_dout = tensor_dout->data(); + + if (tensor_dx) { + auto* data_dx = tensor_dx->mutable_data(); + const auto* data_ddy = tensor_ddy->data(); + const DDim& dim = tensor_dx->dims(); + size_t N = static_cast(product(dim)); + + auto step = dim[dim.size() - 1]; + + int s = -1; + for (size_t i = 0; i < N; ++i) { + if (0 == i % step) ++s; + data_dx[i] = T(data_dout[s].real, -data_dout[s].imag) * data_ddy[i]; + } + } + + if (tensor_dy) { + auto* data_dy = tensor_dy->mutable_data(); + const auto* data_ddx = tensor_ddx->data(); + const DDim& dim = tensor_dy->dims(); + size_t N = static_cast(product(dim)); + + auto step = dim[dim.size() - 1]; + + int s = -1; + for (size_t i = 0; i < N; ++i) { + if (0 == i % step) ++s; + data_dy[i] = T(data_dout[s].real, -data_dout[s].imag) * data_ddx[i]; + } + } + + if (tensor_ddout) { + auto* data_ddout = tensor_ddout->mutable_data(); + auto* data_x = tensor_x->data(); + auto* data_y = tensor_y->data(); + auto* data_ddx = tensor_ddx->data(); + auto* data_ddy = tensor_ddy->data(); + + const DDim& dim = tensor_dy->dims(); + size_t N = static_cast(product(dim)); + auto step = dim[dim.size() - 1]; + int s = -1; + bool new_s = false; + + for (size_t i = 0; i < N; ++i) { + if (0 == i % step) { + ++s; + new_s = true; + } + if (new_s) { + data_ddout[s] = T(data_x[i].real, -data_x[i].imag) * data_ddy[i] + + T(data_y[i].real, -data_y[i].imag) * data_ddx[i]; + } else { + data_ddout[s] += T(data_x[i].real, -data_x[i].imag) * data_ddy[i] + + T(data_y[i].real, -data_y[i].imag) * data_ddx[i]; + } + new_s = false; + } + } +#endif + } +}; + +template +struct DotDoubleGradFunction> { + void operator()(const DeviceContext& ctx, + const DenseTensor* tensor_x, + const DenseTensor* tensor_y, + const DenseTensor* tensor_dout, + const DenseTensor* tensor_ddx, + const DenseTensor* tensor_ddy, + DenseTensor* tensor_dx, + DenseTensor* tensor_dy, + DenseTensor* tensor_ddout) { +#if defined(__NVCC__) || defined(__HIPCC__) + if (1 == tensor_dout->dims().size()) { + auto& dev = *ctx.eigen_device(); + auto dout = EigenVector::Flatten(*tensor_dout); + if (tensor_dx) { + tensor_dx->mutable_data(); + auto ddy = EigenVector::Flatten(*tensor_ddy); + Eigen::DSizes size(tensor_ddy->numel()); + auto dx = EigenVector::Flatten(*tensor_dx); + dx.device(dev) = ddy * dout.broadcast(size); + } + + if (tensor_dy) { + tensor_dy->mutable_data(); + auto ddx = EigenVector::Flatten(*tensor_ddx); + Eigen::DSizes size(tensor_ddx->numel()); + + auto dy = EigenVector::Flatten(*tensor_dy); + dy.device(dev) = ddx * dout.broadcast(size); + } + + if (tensor_ddout) { + tensor_ddout->mutable_data(); + auto x = EigenVector::Flatten(*tensor_x); + auto y = EigenVector::Flatten(*tensor_y); + auto ddx = EigenVector::Flatten(*tensor_ddx); + auto ddy = EigenVector::Flatten(*tensor_ddy); + auto ddout = EigenVector::Flatten(*tensor_ddout); + ddout.device(dev) = (x * ddy + y * ddx).sum(); + } + } +#else + const auto* data_dout = tensor_dout->data(); + + if (tensor_dx) { + auto* data_dx = tensor_dx->mutable_data(); + const auto* data_ddy = tensor_ddy->data(); + const DDim& dim = tensor_dx->dims(); + size_t N = static_cast(product(dim)); + + auto step = dim[dim.size() - 1]; + + int s = -1; + for (size_t i = 0; i < N; ++i) { + if (0 == i % step) ++s; + data_dx[i] = data_dout[s] * data_ddy[i]; + } + } + + if (tensor_dy) { + auto* data_dy = tensor_dy->mutable_data(); + const auto* data_ddx = tensor_ddx->data(); + const DDim& dim = tensor_dy->dims(); + size_t N = static_cast(product(dim)); + + auto step = dim[dim.size() - 1]; + + int s = -1; + for (size_t i = 0; i < N; ++i) { + if (0 == i % step) ++s; + data_dy[i] = data_dout[s] * data_ddx[i]; + } + } + + if (tensor_ddout) { + auto* data_ddout = tensor_ddout->mutable_data(); + auto* data_x = tensor_x->data(); + auto* data_y = tensor_y->data(); + auto* data_ddx = tensor_ddx->data(); + auto* data_ddy = tensor_ddy->data(); + + const DDim& dim = tensor_dy->dims(); + size_t N = static_cast(product(dim)); + auto step = dim[dim.size() - 1]; + int s = -1; + bool new_s = false; + + for (size_t i = 0; i < N; ++i) { + if (0 == i % step) { + ++s; + new_s = true; + } + if (new_s) { + data_ddout[s] = data_x[i] * data_ddy[i] + data_y[i] * data_ddx[i]; + } else { + data_ddout[s] += data_x[i] * data_ddy[i] + data_y[i] * data_ddx[i]; + } + new_s = false; + } + } +#endif + } +}; + +template +struct DotTripleGradFunction { + void operator()(const DeviceContext& ctx, + const DenseTensor* in_tensor_x, + const DenseTensor* in_tensor_y, + const DenseTensor* in_tensor_ddx, + const DenseTensor* in_tensor_ddy, + const DenseTensor* in_tensor_d_dx, + const DenseTensor* in_tensor_d_dy, + const DenseTensor* in_tensor_dout, + const DenseTensor* in_tensor_d_ddout, + DenseTensor* out_tensor_d_x, + DenseTensor* out_tensor_d_y, + DenseTensor* out_tensor_d_dout, + DenseTensor* out_tensor_d_ddx, + DenseTensor* out_tensor_d_ddy); +}; + +// TODO(wuweilong): enable this function when the unittests framewark for multi +// grad is ok (dtype: complex64 or complex128). +template +struct DotTripleGradFunction> { + void operator()(const DeviceContext& ctx, + const DenseTensor* in_tensor_x, + const DenseTensor* in_tensor_y, + const DenseTensor* in_tensor_ddx, + const DenseTensor* in_tensor_ddy, + const DenseTensor* in_tensor_d_dx, + const DenseTensor* in_tensor_d_dy, + const DenseTensor* in_tensor_dout, + const DenseTensor* in_tensor_d_ddout, + DenseTensor* out_tensor_d_x, + DenseTensor* out_tensor_d_y, + DenseTensor* out_tensor_d_dout, + DenseTensor* out_tensor_d_ddx, + DenseTensor* out_tensor_d_ddy) { +#if defined(__NVCC__) || defined(__HIPCC__) + if (1 == in_tensor_d_ddout->dims().size()) { + DenseTensor in_tensor_d_ddout_help; + auto& dev = *ctx.eigen_device(); + if (out_tensor_d_x || out_tensor_d_y) { + in_tensor_d_ddout_help = + Conj(ctx, *in_tensor_d_ddout); + } + if (out_tensor_d_x) { + auto ddy = EigenVector::Flatten(*in_tensor_ddy); + Eigen::DSizes size(in_tensor_ddy->numel()); + auto d_x = EigenVector::Flatten(*out_tensor_d_x); + auto d_ddout = EigenVector::Flatten(in_tensor_d_ddout_help); + d_x.device(dev) = ddy * d_ddout.broadcast(size); + } + + if (out_tensor_d_y) { + auto ddx = EigenVector::Flatten(*in_tensor_ddx); + Eigen::DSizes size(in_tensor_ddx->numel()); + auto d_y = EigenVector::Flatten(*out_tensor_d_y); + auto d_ddout = EigenVector::Flatten(in_tensor_d_ddout_help); + d_y.device(dev) = ddx * d_ddout.broadcast(size); + } + + if (out_tensor_d_dout) { + DenseTensor in_tensor_ddx_help = + Conj(ctx, *in_tensor_ddx); + DenseTensor in_tensor_ddy_help = + Conj(ctx, *in_tensor_ddy); + + auto ddx = EigenVector::Flatten(in_tensor_ddx_help); + auto ddy = EigenVector::Flatten(in_tensor_ddy_help); + auto d_dx = EigenVector::Flatten(*in_tensor_d_dx); + auto d_dy = EigenVector::Flatten(*in_tensor_d_dy); + auto d_dout = EigenVector::Flatten(*out_tensor_d_dout); + d_dout.device(dev) = (ddx * d_dy + ddy * d_dx).sum(); + } + + if (out_tensor_d_ddx) { + DenseTensor in_tensor_dout_help = + Conj(ctx, *in_tensor_dout); + DenseTensor in_tensor_y_help = + Conj(ctx, *in_tensor_y); + + auto dout = EigenVector::Flatten(in_tensor_dout_help); + auto y = EigenVector::Flatten(in_tensor_y_help); + auto d_ddout = EigenVector::Flatten(*in_tensor_d_ddout); + auto d_dy = EigenVector::Flatten(*in_tensor_d_dy); + auto d_ddx = EigenVector::Flatten(*out_tensor_d_ddx); + Eigen::DSizes size(in_tensor_y->numel()); + d_ddx.device(dev) = + (dout.broadcast(size) * d_dy + y * d_ddout.broadcast(size)); + } + + if (out_tensor_d_ddy) { + DenseTensor in_tensor_dout_help = + Conj(ctx, *in_tensor_dout); + DenseTensor in_tensor_x_help = + Conj(ctx, *in_tensor_x); + + auto dout = EigenVector::Flatten(in_tensor_dout_help); + auto x = EigenVector::Flatten(in_tensor_x_help); + auto d_ddout = EigenVector::Flatten(*in_tensor_d_ddout); + auto d_dx = EigenVector::Flatten(*in_tensor_d_dx); + auto d_ddy = EigenVector::Flatten(*out_tensor_d_ddy); + Eigen::DSizes size(in_tensor_x->numel()); + d_ddy.device(dev) = + (dout.broadcast(size) * d_dx + x * d_ddout.broadcast(size)); + } + } +#else + const auto* data_d_ddout = in_tensor_d_ddout->data(); + + if (out_tensor_d_x) { + auto* data_d_x = out_tensor_d_x->mutable_data(); + const auto* data_ddy = in_tensor_ddy->data(); + + const DDim& dim = out_tensor_d_x->dims(); + size_t N = static_cast(product(dim)); + auto step = dim[dim.size() - 1]; + int s = -1; + + for (size_t i = 0; i < N; ++i) { + if (0 == i % step) ++s; + data_d_x[i] = T(data_ddy[i].real, -data_ddy[i].imag) * data_d_ddout[s]; + } + } + + if (out_tensor_d_y) { + auto* data_d_y = out_tensor_d_y->mutable_data(); + const auto* data_ddx = in_tensor_ddx->data(); + + const DDim& dim = out_tensor_d_y->dims(); + size_t N = static_cast(product(dim)); + auto step = dim[dim.size() - 1]; + int s = -1; + + for (size_t i = 0; i < N; ++i) { + if (0 == i % step) ++s; + data_d_y[i] = T(data_ddx[i].real, -data_ddx[i].imag) * data_d_ddout[s]; + } + } + + if (out_tensor_d_dout) { + auto* data_d_dout = out_tensor_d_dout->mutable_data(); + auto* data_ddx = in_tensor_ddx->data(); + auto* data_ddy = in_tensor_ddy->data(); + auto* data_d_dx = in_tensor_d_dx->data(); + auto* data_d_dy = in_tensor_d_dy->data(); + + const DDim& dim = out_tensor_d_dout->dims(); + size_t N = static_cast(product(dim)); + auto step = dim[dim.size() - 1]; + int s = -1; + bool new_s = false; + + for (size_t i = 0; i < N; ++i) { + if (0 == i % step) { + ++s; + new_s = true; + } + if (new_s) { + data_d_dout[s] = + T(data_ddy[i].real, -data_ddy[i].imag) * data_d_dx[i] + + T(data_ddx[i].real, -data_ddx[i].imag) * data_d_dy[i]; + } else { + data_d_dout[s] += + T(data_ddy[i].real, -data_ddy[i].imag) * data_d_dx[i] + + T(data_ddx[i].real, -data_ddx[i].imag) * data_d_dy[i]; + } + new_s = false; + } + } + + if (out_tensor_d_ddx) { + auto* data_d_ddx = out_tensor_d_ddx->mutable_data(); + auto* data_dout = in_tensor_dout->data(); + auto* data_d_dy = in_tensor_d_dy->data(); + auto* data_y = in_tensor_y->data(); + auto* data_d_ddout = in_tensor_d_ddout->data(); + + const DDim& dim = out_tensor_d_ddx->dims(); + size_t N = static_cast(product(dim)); + auto step = dim[dim.size() - 1]; + int s = -1; + + for (size_t i = 0; i < N; ++i) { + if (0 == i % step) ++s; + data_d_ddx[i] = + T(data_dout[s].real, -data_dout[s].imag) * data_d_dy[i] + + T(data_y[i].real, -data_y[i].imag) * data_d_ddout[s]; + } + } + + if (out_tensor_d_ddy) { + auto* data_d_ddy = out_tensor_d_ddy->mutable_data(); + auto* data_dout = in_tensor_dout->data(); + auto* data_d_dx = in_tensor_d_dx->data(); + auto* data_x = in_tensor_x->data(); + auto* data_d_ddout = in_tensor_d_ddout->data(); + + const DDim& dim = out_tensor_d_ddy->dims(); + size_t N = static_cast(product(dim)); + auto step = dim[dim.size() - 1]; + int s = -1; + + for (size_t i = 0; i < N; ++i) { + if (0 == i % step) ++s; + data_d_ddy[i] = + T(data_dout[s].real, -data_dout[s].imag) * data_d_dx[i] + + T(data_x[i].real, -data_x[i].imag) * data_d_ddout[s]; + } + } +#endif + } +}; + +template +struct DotTripleGradFunction> { + void operator()(const DeviceContext& ctx, + const DenseTensor* in_tensor_x, + const DenseTensor* in_tensor_y, + const DenseTensor* in_tensor_ddx, + const DenseTensor* in_tensor_ddy, + const DenseTensor* in_tensor_d_dx, + const DenseTensor* in_tensor_d_dy, + const DenseTensor* in_tensor_dout, + const DenseTensor* in_tensor_d_ddout, + DenseTensor* out_tensor_d_x, + DenseTensor* out_tensor_d_y, + DenseTensor* out_tensor_d_dout, + DenseTensor* out_tensor_d_ddx, + DenseTensor* out_tensor_d_ddy) { +#if defined(__NVCC__) || defined(__HIPCC__) + if (1 == in_tensor_d_ddout->dims().size()) { + auto& dev = *ctx.eigen_device(); + auto d_ddout = EigenVector::Flatten(*in_tensor_d_ddout); + if (out_tensor_d_x) { + out_tensor_d_x->mutable_data(); + auto ddy = EigenVector::Flatten(*in_tensor_ddy); + Eigen::DSizes size(in_tensor_ddy->numel()); + auto d_x = EigenVector::Flatten(*out_tensor_d_x); + d_x.device(dev) = ddy * d_ddout.broadcast(size); + } + + if (out_tensor_d_y) { + out_tensor_d_y->mutable_data(); + auto ddx = EigenVector::Flatten(*in_tensor_ddx); + Eigen::DSizes size(in_tensor_ddx->numel()); + + auto d_y = EigenVector::Flatten(*out_tensor_d_y); + d_y.device(dev) = ddx * d_ddout.broadcast(size); + } + + if (out_tensor_d_dout) { + out_tensor_d_dout->mutable_data(); + auto ddx = EigenVector::Flatten(*in_tensor_ddx); + auto ddy = EigenVector::Flatten(*in_tensor_ddy); + auto d_dx = EigenVector::Flatten(*in_tensor_d_dx); + auto d_dy = EigenVector::Flatten(*in_tensor_d_dy); + auto d_dout = EigenVector::Flatten(*out_tensor_d_dout); + d_dout.device(dev) = (ddx * d_dy + ddy * d_dx).sum(); + } + + if (out_tensor_d_ddx) { + out_tensor_d_ddx->mutable_data(); + auto dout = EigenVector::Flatten(*in_tensor_dout); + auto y = EigenVector::Flatten(*in_tensor_y); + auto d_ddout = EigenVector::Flatten(*in_tensor_d_ddout); + auto d_dy = EigenVector::Flatten(*in_tensor_d_dy); + auto d_ddx = EigenVector::Flatten(*out_tensor_d_ddx); + Eigen::DSizes size(in_tensor_y->numel()); + d_ddx.device(dev) = + (dout.broadcast(size) * d_dy + y * d_ddout.broadcast(size)); + } + + if (out_tensor_d_ddy) { + out_tensor_d_ddy->mutable_data(); + auto dout = EigenVector::Flatten(*in_tensor_dout); + auto x = EigenVector::Flatten(*in_tensor_x); + auto d_ddout = EigenVector::Flatten(*in_tensor_d_ddout); + auto d_dx = EigenVector::Flatten(*in_tensor_d_dx); + auto d_ddy = EigenVector::Flatten(*out_tensor_d_ddy); + Eigen::DSizes size(in_tensor_x->numel()); + d_ddy.device(dev) = + (dout.broadcast(size) * d_dx + x * d_ddout.broadcast(size)); + } + } +#else + const auto* data_d_ddout = in_tensor_d_ddout->data(); + + if (out_tensor_d_x) { + auto* data_d_x = out_tensor_d_x->mutable_data(); + const auto* data_ddy = in_tensor_ddy->data(); + + const DDim& dim = out_tensor_d_x->dims(); + size_t N = static_cast(product(dim)); + auto step = dim[dim.size() - 1]; + int s = -1; + + for (size_t i = 0; i < N; ++i) { + if (0 == i % step) ++s; + data_d_x[i] = data_ddy[i] * data_d_ddout[s]; + } + } + + if (out_tensor_d_y) { + auto* data_d_y = out_tensor_d_y->mutable_data(); + const auto* data_ddx = in_tensor_ddx->data(); + + const DDim& dim = out_tensor_d_y->dims(); + size_t N = static_cast(product(dim)); + auto step = dim[dim.size() - 1]; + int s = -1; + + for (size_t i = 0; i < N; ++i) { + if (0 == i % step) ++s; + data_d_y[i] = data_ddx[i] * data_d_ddout[s]; + } + } + + if (out_tensor_d_dout) { + auto* data_d_dout = out_tensor_d_dout->mutable_data(); + auto* data_ddx = in_tensor_ddx->data(); + auto* data_ddy = in_tensor_ddy->data(); + auto* data_d_dx = in_tensor_d_dx->data(); + auto* data_d_dy = in_tensor_d_dy->data(); + + const DDim& dim = in_tensor_ddx->dims(); + size_t N = static_cast(product(dim)); + auto step = dim[dim.size() - 1]; + int s = -1; + bool new_s = false; + for (size_t i = 0; i < N; ++i) { + if (0 == i % step) { + ++s; + new_s = true; + } + if (new_s) { + data_d_dout[s] = + data_ddy[i] * data_d_dx[i] + data_ddx[i] * data_d_dy[i]; + } else { + data_d_dout[s] += + data_ddy[i] * data_d_dx[i] + data_ddx[i] * data_d_dy[i]; + } + new_s = false; + } + } + + if (out_tensor_d_ddx) { + auto* data_d_ddx = out_tensor_d_ddx->mutable_data(); + auto* data_dout = in_tensor_dout->data(); + auto* data_d_dy = in_tensor_d_dy->data(); + auto* data_y = in_tensor_y->data(); + auto* data_d_ddout = in_tensor_d_ddout->data(); + + const DDim& dim = out_tensor_d_ddx->dims(); + size_t N = static_cast(product(dim)); + auto step = dim[dim.size() - 1]; + int s = -1; + + for (size_t i = 0; i < N; ++i) { + if (0 == i % step) ++s; + data_d_ddx[i] = + data_dout[s] * data_d_dy[i] + data_y[i] * data_d_ddout[s]; + } + } + + if (out_tensor_d_ddy) { + auto* data_d_ddy = out_tensor_d_ddy->mutable_data(); + auto* data_dout = in_tensor_dout->data(); + auto* data_d_dx = in_tensor_d_dx->data(); + auto* data_x = in_tensor_x->data(); + auto* data_d_ddout = in_tensor_d_ddout->data(); + + const DDim& dim = out_tensor_d_ddy->dims(); + size_t N = static_cast(product(dim)); + auto step = dim[dim.size() - 1]; + int s = -1; + + for (size_t i = 0; i < N; ++i) { + if (0 == i % step) ++s; + data_d_ddy[i] = + data_dout[s] * data_d_dx[i] + data_x[i] * data_d_ddout[s]; + } + } +#endif + } +}; + +template +void DotGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& dout, + DenseTensor* dx, + DenseTensor* dy) { + if (dx) { + dx->mutable_data(); + } + if (dy) { + dy->mutable_data(); + } + DotGradFunction()(dev_ctx, &x, &y, &dout, dx, dy); +} + +template +void DotDoubleGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& ddx, + const DenseTensor& ddy, + const DenseTensor& dout, + DenseTensor* dx, + DenseTensor* dy, + DenseTensor* ddout) { + if (dx) { + dx->mutable_data(); + } + if (dy) { + dy->mutable_data(); + } + if (ddout) { + ddout->mutable_data(); + } + DotDoubleGradFunction()( + dev_ctx, &x, &y, &dout, ddx, ddy, dx, dy, ddout); +} + +template +void DotTripleGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& ddx, + const DenseTensor& ddy, + const DenseTensor& d_dx, + const DenseTensor& d_dy, + const DenseTensor& dout, + const DenseTensor& d_ddout, + DenseTensor* d_x, + DenseTensor* d_y, + DenseTensor* d_ddx, + DenseTensor* d_ddy, + DenseTensor* d_dout) { + if (d_x) { + d_x->mutable_data(); + } + if (d_y) { + d_y->mutable_data(); + } + if (d_ddx) { + d_ddx->mutable_data(); + } + if (d_ddy) { + d_ddy->mutable_data(); + } + if (d_dout) { + d_dout->mutable_data(); + } + + DotTripleGradFunction()(dev_ctx, + &x, + &y, + ddx, + ddy, + d_dx, + d_dy, + dout, + d_ddout, + d_x, + d_y, + d_dout, + d_ddx, + d_ddy); +} + +} // namespace pten diff --git a/paddle/pten/kernels/impl/matmul_grad_kernel_impl.h b/paddle/pten/kernels/impl/matmul_grad_kernel_impl.h new file mode 100644 index 0000000000000..802cc019d78c5 --- /dev/null +++ b/paddle/pten/kernels/impl/matmul_grad_kernel_impl.h @@ -0,0 +1,1742 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +// #include "paddle/pten/kernels/complex_kernel.h" +#include "paddle/pten/include/math.h" +#include "paddle/pten/kernels/empty_kernel.h" +#include "paddle/pten/kernels/impl/dot_grad_kernel_impl.h" +#include "paddle/pten/kernels/impl/matmul_kernel_impl.h" + +#include "paddle/pten/kernels/cpu/reduce.h" +#include "paddle/pten/kernels/funcs/reduce_functor.h" + +#include "paddle/pten/backends/cpu/cpu_context.h" +#include "paddle/pten/backends/gpu/gpu_context.h" + +#if defined(__NVCC__) || defined(__HIPCC__) +#include "paddle/pten/kernels/gpu/reduce.h" +#endif + +namespace pten { + +template +struct ReduceSumForMatmulGrad { + void operator()(const Context& dev_ctx, + const DenseTensor& input, + DenseTensor* output, + const std::vector& reduce_dims); +}; + +template +struct ReduceSumForMatmulGrad { + void operator()(const CPUContext& dev_ctx, + const DenseTensor& input, + DenseTensor* output, + const std::vector& reduce_dims) { + std::vector reduce_dims_tmp(reduce_dims.begin(), + reduce_dims.end()); + ReduceKernelImpl( + dev_ctx, input, output, reduce_dims_tmp, true, false); + } +}; + +#if defined(__NVCC__) || defined(__HIPCC__) +template +struct ReduceSumForMatmulGrad { + void operator()(const GPUContext& dev_ctx, + const DenseTensor& input, + DenseTensor* output, + const std::vector& reduce_dims) { + auto stream = dev_ctx.stream(); + kernels:: + TensorReduceFunctorImpl>( + input, output, kps::IdentityFunctor(), reduce_dims, stream); + } +}; +#endif + +// Reshape a rank-3 tensor from P x M x N to (P * M) x N. +// Identity op if the tensor is not of rank 3. +static DenseTensor FoldInitDims(const DenseTensor& input) { + DenseTensor output = input; + auto in_dims = input.dims(); + if (in_dims.size() == 3) { + output.Resize({in_dims[0] * in_dims[1], in_dims[2]}); + } + return output; +} + +// Reshape a rank-3 tensor from P x M x N to M x (P * N). +// (Warning: This requires transposing data and writes into new memory.) +// Identity op if the tensor is not of rank 3. +template +static DenseTensor FoldHeadAndLastDims(const Context& dev_ctx, + const DenseTensor& input) { + auto in_dims = input.dims(); + if (in_dims.size() != 3) { + return input; + } + DenseTensor output = EmptyLike(dev_ctx, input); + output.Resize({in_dims[1], in_dims[0], in_dims[2]}); + std::vector axis = {1, 0, 2}; + math::Transpose trans; + trans(dev_ctx, input, &output, axis); + output.Resize({in_dims[1], in_dims[0] * in_dims[2]}); + return output; +} + +template +void MatMul(const Context& dev_ctx, + const DenseTensor& a, + bool trans_a, + const DenseTensor& b, + bool trans_b, + DenseTensor* out, + bool flag = false) { + out->mutable_data(); + auto blas = paddle::operators::math::GetBlas(dev_ctx); + auto mat_dim_a = + paddle::operators::math::CreateMatrixDescriptor(a.dims(), 0, trans_a); + auto mat_dim_b = + paddle::operators::math::CreateMatrixDescriptor(b.dims(), 0, trans_b); + if (a.dims().size() == 3 && b.dims().size() <= 2) { + // the transpose_X must be false, if is true, the transpose cost much time + if (!trans_a) { + mat_dim_a.height_ *= mat_dim_a.batch_size_; + mat_dim_a.batch_size_ = 0; + } + } + blas.MatMul(a.data(), + mat_dim_a, + b.data(), + mat_dim_b, + static_cast(1), + out->mutable_data(), + static_cast(flag)); +} + +/** + * Get row matrix shape from a vector shape. If the rank of x_dim > 1, the + * original x_dim is returned. + */ +static DDim RowMatrixFromVector(const DDim& x_dim) { + if (x_dim.size() > 1) { + return x_dim; + } + return paddle::framework::make_ddim({1, x_dim[0]}); +} + +/** + * Get column matrix shape from a vector shape. If the ran of y_dim > 1, the + * original y_dim is returned. + */ +static DDim ColumnMatrixFromVector(const DDim& y_dim) { + if (y_dim.size() > 1) { + return y_dim; + } + return paddle::framework::make_ddim({y_dim[0], 1}); +} + +/** + * Reshape a tensor to 3-D or 2-D tensor by matrix descriptor. + * + * The shape would be [BatchSize, H, W] or [H, W]. + * If transposed, `H,W` will be swapped. + */ +static void ReshapeTensorIntoMatrixSequence( + DenseTensor* x, const paddle::operators::math::MatDescriptor& descriptor) { + int64_t h, w; + h = descriptor.height_; + w = descriptor.width_; + if (descriptor.trans_) { + std::swap(w, h); + } + if (descriptor.batch_size_) { + x->Resize({descriptor.batch_size_, h, w}); + } else { + x->Resize({h, w}); + } +} + +static void ReshapeXYOutIntoMatrixSequence(DenseTensor* x, + DenseTensor* y, + DenseTensor* out, + bool trans_x, + bool trans_y) { + auto x_dim = RowMatrixFromVector(x->dims()); + auto y_dim = ColumnMatrixFromVector(y->dims()); + auto mat_dim_x = + paddle::operators::math::CreateMatrixDescriptor(x_dim, 0, trans_x); + auto mat_dim_y = + paddle::operators::math::CreateMatrixDescriptor(y_dim, 0, trans_y); + if (mat_dim_x.batch_size_ == 0 && mat_dim_y.batch_size_ == 0) { + out->Resize({mat_dim_x.height_, mat_dim_y.width_}); + } else { + out->Resize({(std::max)(mat_dim_x.batch_size_, mat_dim_y.batch_size_), + mat_dim_x.height_, + mat_dim_y.width_}); + } + + ReshapeTensorIntoMatrixSequence(x, mat_dim_x); + ReshapeTensorIntoMatrixSequence(y, mat_dim_y); +} + +template +void CalcInputGrad(const Context& dev_ctx, + const DenseTensor& a, + bool trans_a, + bool is_fold_init_dims_a, + const DenseTensor& b, + bool trans_b, + bool is_fold_init_dims_b, + DenseTensor* out, + bool flag = false) { + if (out == nullptr) return; + bool need_combine = + (a.dims().size() == 3 || b.dims().size() == 3) && out->dims().size() == 2; + if (!need_combine) { + MatMul(dev_ctx, a, trans_a, b, trans_b, out, flag); + } else { + MatMul( + dev_ctx, + is_fold_init_dims_a ? FoldInitDims(a) + : FoldHeadAndLastDims(dev_ctx, a), + trans_a, + is_fold_init_dims_b ? FoldInitDims(b) + : FoldHeadAndLastDims(dev_ctx, b), + trans_b, + out, + flag); + } +} + +template +void MatmulGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& out_grad, + bool transpose_x, + bool transpose_y, + DenseTensor* dx, + DenseTensor* dy) { + // get dims + std::vector x_dims = vectorize(x.dims()); + std::vector y_dims = vectorize(y.dims()); + std::vector dout_dims = vectorize(out_grad.dims()); + + int x_ndim = x_dims.size(); + int y_ndim = y_dims.size(); + int ndim = dout_dims.size(); + + // Case1 : x's or y's dim = 1 + if (x_ndim == 1 && y_ndim == 1) { + if (dx) dx->mutable_data(); + if (dy) dy->mutable_data(); + if (out_grad.numel() == 1) { + DotGradFunction()(dev_ctx, &x, &y, &out_grad, dx, dy); + return; + } + } + + bool is_broadcast = true; + if (x_ndim <= 2 || y_ndim <= 2) { + is_broadcast = false; + } else if (x_ndim != y_ndim) { + is_broadcast = true; + } else { + is_broadcast = !std::equal( + x_dims.cbegin(), x_dims.cbegin() + x_ndim - 2, y_dims.cbegin()); + } + + // for complex + DenseTensor x_conj; + DenseTensor y_conj; + + // Case2: no broadcast or no batch size, it aims to speed and it is same as + // matmul in old version. + if (!is_broadcast) { + DenseTensor x_help = x; + DenseTensor y_help = y; + DenseTensor out_grad_help = out_grad; + ReshapeXYOutIntoMatrixSequence( + &x_help, &y_help, &out_grad_help, transpose_x, transpose_y); + + DDim dx_dims; + if (dx) { + dx_dims = dx->dims(); + if (dx_dims != x_help.dims()) { + dx->Resize(x_help.dims()); + } + + y_conj = Conj(dev_ctx, y_help); + } + + DDim dy_dims; + if (dy) { + dy_dims = dy->dims(); + if (dy_dims != y_help.dims()) { + dy->Resize(y_help.dims()); + } + + x_conj = Conj(dev_ctx, x_help); + } + + if (transpose_x && transpose_y) { + CalcInputGrad( + dev_ctx, y_conj, true, true, out_grad_help, true, false, dx); + CalcInputGrad( + dev_ctx, out_grad_help, true, true, x_conj, true, false, dy); + } else if (transpose_x) { + CalcInputGrad( + dev_ctx, y_conj, false, false, out_grad_help, true, false, dx); + CalcInputGrad( + dev_ctx, x_conj, false, false, out_grad_help, false, true, dy); + } else if (transpose_y) { + CalcInputGrad( + dev_ctx, out_grad_help, false, false, y_conj, false, true, dx); + CalcInputGrad( + dev_ctx, out_grad_help, true, true, x_conj, false, true, dy); + } else { + CalcInputGrad( + dev_ctx, out_grad_help, false, false, y_conj, true, false, dx); + CalcInputGrad( + dev_ctx, x_conj, true, true, out_grad_help, false, true, dy); + } + + if (dx) { + if (dx_dims != x_help.dims()) { + dx->Resize(dx_dims); + } + } + if (dy) { + if (dy_dims != y_help.dims()) { + dy->Resize(dy_dims); + } + } + } else { + // Case3: broadcast. It need cost much time to reduce sum for the + // broadcast and wastes the memory. + // So we should avoid the case in reality. + VLOG(3) << "It need cost much time to reduce sum for the broadcast and " + "wastes the memory. So we should avoid the case in reality"; + x_conj = Conj(dev_ctx, x); + y_conj = Conj(dev_ctx, y); + + DenseTensor dx_help = Empty(dev_ctx); + DenseTensor dy_help = Empty(dev_ctx); + + if (transpose_x) { + if (transpose_y) { + // X'Y': dA = Y'G', dB = G'X' + if (dx) + MatMulFunction(dev_ctx, + y_conj, + out_grad, + y_dims, + dout_dims, + &dx_help, + true, + true); + if (dy) + MatMulFunction(dev_ctx, + out_grad, + x_conj, + dout_dims, + x_dims, + &dy_help, + true, + true); + } else { + // X'Y: dX = YG', dY = XG + if (dx) + MatMulFunction(dev_ctx, + y_conj, + out_grad, + y_dims, + dout_dims, + &dx_help, + false, + true); + if (dy) + MatMulFunction(dev_ctx, + x_conj, + out_grad, + x_dims, + dout_dims, + &dy_help, + false, + false); + } + } else { + if (transpose_y) { + // XY': dX = GY, dY = G'X + if (dx) + MatMulFunction(dev_ctx, + out_grad, + y_conj, + dout_dims, + y_dims, + &dx_help, + false, + false); + if (dy) + MatMulFunction(dev_ctx, + out_grad, + x_conj, + dout_dims, + x_dims, + &dy_help, + true, + false); + } else { + // XY: dX = GY', dY = X'G + if (dx) + MatMulFunction(dev_ctx, + out_grad, + y_conj, + dout_dims, + y_dims, + &dx_help, + false, + true); + if (dy) + MatMulFunction(dev_ctx, + x_conj, + out_grad, + x_dims, + dout_dims, + &dy_help, + true, + false); + } + } + + // get help dims + const std::vector dx_help_dims = vectorize(dx_help.dims()); + const std::vector dy_help_dims = vectorize(dy_help.dims()); + + std::vector dx_broadcast_dims(ndim); + std::vector dy_broadcast_dims(ndim); + + std::fill( + dx_broadcast_dims.data(), dx_broadcast_dims.data() + ndim - x_ndim, 1); + std::fill( + dy_broadcast_dims.data(), dy_broadcast_dims.data() + ndim - y_ndim, 1); + std::copy(x_dims.data(), + x_dims.data() + x_ndim, + dx_broadcast_dims.data() + ndim - x_ndim); + std::copy(y_dims.data(), + y_dims.data() + y_ndim, + dy_broadcast_dims.data() + ndim - y_ndim); + + std::vector dx_reduce_dims; + std::vector dy_reduce_dims; + for (int idx = 0; idx <= ndim - 3; idx++) { + if (dx_help_dims[idx] != 1 && dx_broadcast_dims[idx] == 1) { + dx_reduce_dims.push_back(idx); + } + if (dy_help_dims[idx] != 1 && dy_broadcast_dims[idx] == 1) { + dy_reduce_dims.push_back(idx); + } + } + // reduce sum to get grad by ReduceSum + if (dx) { + if (dx_reduce_dims.empty()) { + *dx = std::move(dx_help); + } else { + ReduceSumForMatmulGrad()( + dev_ctx, dx_help, dx, dx_reduce_dims); + } + dx->Resize(x.dims()); + } + if (dy) { + if (dy_reduce_dims.empty()) { + *dy = std::move(dy_help); + } else { + ReduceSumForMatmulGrad()( + dev_ctx, dy_help, dy, dy_reduce_dims); + } + dy->Resize(y.dims()); + } + // Get the OutputGrad(out) + } +} + +template +void MatmulDoubleGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& dout, + paddle::optional ddx, + paddle::optional ddy, + bool transpose_x, + bool transpose_y, + DenseTensor* dx, + DenseTensor* dy, + DenseTensor* ddout) { + // Get dims from the input x, y, output_grad + std::vector x_dims = vectorize(x.dims()); + std::vector y_dims = vectorize(y.dims()); + std::vector dout_dims = vectorize(dout.dims()); + + int x_ndim = x_dims.size(); + int y_ndim = y_dims.size(); + int ndim = dout_dims.size(); + + // Case1 : x's or y's dim = 1 + if (x_ndim == 1 && y_ndim == 1) { + DotDoubleGradFunction()( + dev_ctx, &x, &y, &dout, ddx.get_ptr(), ddy.get_ptr(), dx, dy, ddout); + return; + } + + DenseTensor x_conj; + DenseTensor y_conj; + DenseTensor dout_conj; + + bool is_broadcast = true; + if (x_ndim <= 2 || y_ndim <= 2) { + is_broadcast = false; + } else if (x_ndim != y_ndim) { + is_broadcast = true; + } else { + is_broadcast = !std::equal( + x_dims.cbegin(), x_dims.cbegin() + x_ndim - 2, y_dims.cbegin()); + } + + if (!is_broadcast) { + // Case2: no broadcast or no batch size + DenseTensor x_help = x; + DenseTensor y_help = y; + DenseTensor dout_help = dout; + ReshapeXYOutIntoMatrixSequence( + &x_help, &y_help, &dout_help, transpose_x, transpose_y); + DDim dx_dims; + + if (dx) { + dx_dims = dx->dims(); + if (dx_dims != x_help.dims()) { + dx->Resize(x_help.dims()); + } + } + + DDim dy_dims; + if (dy) { + dy_dims = dy->dims(); + if (dy_dims != y_help.dims()) { + dy->Resize(y_help.dims()); + } + } + + DDim ddout_dims; + if (ddout) { + ddout_dims = ddout->dims(); + if (ddout_dims != dout_help.dims()) { + ddout->Resize(dout_help.dims()); + } + + x_conj = Conj(dev_ctx, x_help); + y_conj = Conj(dev_ctx, y_help); + } + + if (dx || dy) { + dout_conj = Conj(dev_ctx, dout_help); + } + + bool ddout_flag = false; + if (ddx) { + auto ddx_mat = ddx.get(); + if (ddx_mat.dims() != x_help.dims()) { + ddx_mat.Resize(x_help.dims()); + } + if (dy) { + if (transpose_x && transpose_y) { + // dy = dout' * ddx' + CalcInputGrad( + dev_ctx, dout_conj, true, true, ddx_mat, true, false, dy, false); + } else if (transpose_x) { + // dy = ddx * dout + CalcInputGrad(dev_ctx, + ddx_mat, + false, + false, + dout_conj, + false, + true, + dy, + false); + } else if (transpose_y) { + // dy = dout' * ddx + CalcInputGrad( + dev_ctx, dout_conj, true, true, ddx_mat, false, true, dy, false); + } else { + // dy = ddx' * dout + CalcInputGrad( + dev_ctx, ddx_mat, true, true, dout_conj, false, true, dy, false); + } + } + + if (ddout) { + CalcInputGrad(dev_ctx, + ddx_mat, + transpose_x, + true, + y_conj, + transpose_y, + false, + ddout, + ddout_flag); + ddout_flag = true; + } + } + + if (ddy) { + auto ddy_mat = ddy.get(); + if (ddy_mat.dims() != y_help.dims()) { + ddy_mat.Resize(y_help.dims()); + } + if (dx) { + if (transpose_x && transpose_y) { + // dx = ddy' * dout' + CalcInputGrad( + dev_ctx, ddy_mat, true, true, dout_conj, true, false, dx, false); + } else if (transpose_x) { + // dx = ddy * dout' + CalcInputGrad(dev_ctx, + ddy_mat, + false, + false, + dout_conj, + true, + false, + dx, + false); + } else if (transpose_y) { + // dx = dout * ddy + CalcInputGrad(dev_ctx, + dout_conj, + false, + false, + ddy_mat, + false, + true, + dx, + false); + } else { + // dx = dout * ddy' + CalcInputGrad(dev_ctx, + dout_conj, + false, + false, + ddy_mat, + true, + false, + dx, + false); + } + } + + if (ddout) { + CalcInputGrad(dev_ctx, + x_conj, + transpose_x, + true, + ddy_mat, + transpose_y, + false, + ddout, + ddout_flag); + } + } + + if (dx) { + if (dx_dims != x_help.dims()) { + dx->Resize(dx_dims); + } + } + + if (dy) { + if (dy_dims != y_help.dims()) { + dy->Resize(dy_dims); + } + } + + if (ddout) { + if (ddout_dims != dout_help.dims()) { + ddout->Resize(ddout_dims); + } + } + } else { + // Case3: broadcast. It need cost much time to reduce sum for the + // broadcast and wastes the memory. + // So we should avoid the case in reality. + VLOG(3) << "It need cost much time to reduce sum for the broadcast and " + "wastes the memory. So we should avoid the case in reality"; + if (dx || dy) { + dout_conj = Conj(dev_ctx, dout); + } + if (ddout) { + x_conj = Conj(dev_ctx, x); + y_conj = Conj(dev_ctx, y); + } + + DenseTensor dx_help = Empty(dev_ctx); + DenseTensor dy_help = Empty(dev_ctx); + + if (transpose_x) { + if (transpose_y) { + if (dx) { + MatMulFunction(dev_ctx, + ddy.get(), + dout_conj, + y_dims, + dout_dims, + &dx_help, + true, + true); + } + if (dy) { + MatMulFunction(dev_ctx, + dout_conj, + ddx.get(), + dout_dims, + x_dims, + &dy_help, + true, + true); + } + } else { + if (dx) + MatMulFunction(dev_ctx, + ddy.get(), + dout_conj, + y_dims, + dout_dims, + &dx_help, + false, + true); + if (dy) + MatMulFunction(dev_ctx, + ddx.get(), + dout_conj, + x_dims, + dout_dims, + &dy_help, + false, + false); + } + } else { + if (transpose_y) { + if (dx) { + MatMulFunction(dev_ctx, + dout_conj, + ddy.get(), + dout_dims, + y_dims, + &dx_help, + false, + false); + } + if (dy) { + MatMulFunction(dev_ctx, + dout_conj, + ddx.get(), + dout_dims, + x_dims, + &dy_help, + true, + false); + } + } else { + if (dx) { + MatMulFunction(dev_ctx, + dout_conj, + ddy.get(), + dout_dims, + y_dims, + &dx_help, + false, + true); + } + if (dy) { + MatMulFunction(dev_ctx, + ddx.get(), + dout_conj, + x_dims, + dout_dims, + &dy_help, + true, + false); + } + } + } + + // get help dims + const std::vector dx_help_dims = vectorize(dx_help.dims()); + const std::vector dy_help_dims = vectorize(dy_help.dims()); + + std::vector dx_broadcast_dims(ndim); + std::vector dy_broadcast_dims(ndim); + + std::fill( + dx_broadcast_dims.data(), dx_broadcast_dims.data() + ndim - x_ndim, 1); + std::fill( + dy_broadcast_dims.data(), dy_broadcast_dims.data() + ndim - y_ndim, 1); + std::copy(x_dims.data(), + x_dims.data() + x_ndim, + dx_broadcast_dims.data() + ndim - x_ndim); + std::copy(y_dims.data(), + y_dims.data() + y_ndim, + dy_broadcast_dims.data() + ndim - y_ndim); + + std::vector dx_reduce_dims; + std::vector dy_reduce_dims; + for (int idx = 0; idx <= ndim - 3; idx++) { + if (dx_help_dims[idx] != 1 && dx_broadcast_dims[idx] == 1) { + dx_reduce_dims.push_back(idx); + } + if (dy_help_dims[idx] != 1 && dy_broadcast_dims[idx] == 1) { + dy_reduce_dims.push_back(idx); + } + } + // Reduce sum to get grad by ReduceSum + if (dx) { + if (dx_reduce_dims.empty()) { + *dx = std::move(dx_help); + } else { + ReduceSumForMatmulGrad()( + dev_ctx, dx_help, dx, dx_reduce_dims); + } + dx->Resize(x.dims()); + } + if (dy) { + if (dy_reduce_dims.empty()) { + *dy = std::move(dy_help); + } else { + ReduceSumForMatmulGrad()( + dev_ctx, dy_help, dy, dy_reduce_dims); + } + dy->Resize(y.dims()); + } + + if (ddout) { + // Calculate the gradient of OutputGrad(Out) + MatMulFunction(dev_ctx, + ddx.get(), + y_conj, + x_dims, + y_dims, + ddout, + transpose_x, + transpose_y); + MatMulFunction(dev_ctx, + x_conj, + ddy.get(), + x_dims, + y_dims, + ddout, + transpose_x, + transpose_y, + true); + } + } +} + +template +void MatmulTripleGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& dout, + const DenseTensor& ddx, + const DenseTensor& ddy, + paddle::optional d_dx, + paddle::optional d_dy, + paddle::optional d_ddout, + bool transpose_x, + bool transpose_y, + DenseTensor* out_d_x, + DenseTensor* out_d_y, + DenseTensor* out_d_dout, + DenseTensor* out_d_ddx, + DenseTensor* out_d_ddy) { + // Get dims from the input x, y, output_grad + std::vector x_dims = vectorize(x.dims()); + std::vector y_dims = vectorize(y.dims()); + std::vector dout_dims = vectorize(dout.dims()); + + int x_ndim = x_dims.size(); + int y_ndim = y_dims.size(); + int ndim = dout_dims.size(); + + // Case1 : x's and y's dim = 1 + if (x_ndim == 1 && y_ndim == 1) { + VLOG(3) << "======== MatMulV2TripleGradKernel, Compute ====== Case 1"; + DotTripleGradFunction()(dev_ctx, + &x, + &y, + &ddx, + &ddy, + d_dx.get_ptr(), + d_dy.get_ptr(), + &dout, + d_ddout.get_ptr(), + out_d_x, + out_d_y, + out_d_dout, + out_d_ddx, + out_d_ddy); + return; + } + + DenseTensor x_conj; + DenseTensor y_conj; + DenseTensor dout_conj; + DenseTensor ddx_conj; + DenseTensor ddy_conj; + + bool is_broadcast = true; + if (x_ndim <= 2 || y_ndim <= 2) { + is_broadcast = false; + } else if (x_ndim != y_ndim) { + is_broadcast = true; + } else { + is_broadcast = !std::equal( + x_dims.cbegin(), x_dims.cbegin() + x_ndim - 2, y_dims.cbegin()); + } + + if (!is_broadcast) { + // Case2: no broadcast or no batch size + VLOG(3) << "======== MatMulV2TripleGradKernel, Compute ====== Case 2"; + DenseTensor x_help = x; + DenseTensor y_help = y; + DenseTensor dout_help = dout; + DenseTensor ddx_help = ddx; + DenseTensor ddy_help = ddy; + ReshapeXYOutIntoMatrixSequence( + &x_help, &y_help, &dout_help, transpose_x, transpose_y); + + if (ddx_help.dims() != x_help.dims()) { + ddx_help.Resize(x_help.dims()); + } + + if (ddy_help.dims() != y_help.dims()) { + ddy_help.Resize(y_help.dims()); + } + + DDim out_dx_dims; + if (out_d_x) { + out_dx_dims = out_d_x->dims(); + if (out_dx_dims != x_help.dims()) { + out_d_x->Resize(x_help.dims()); + } + } + + DDim out_dy_dims; + if (out_d_y) { + out_dy_dims = out_d_y->dims(); + if (out_dy_dims != y_help.dims()) { + out_d_y->Resize(y_help.dims()); + } + } + + DDim out_d_dout_dims; + if (out_d_dout) { + out_d_dout_dims = out_d_dout->dims(); + if (out_d_dout_dims != dout_help.dims()) { + out_d_dout->Resize(dout_help.dims()); + } + + ddx_conj = Conj(dev_ctx, ddx_help); + ddy_conj = Conj(dev_ctx, ddy_help); + } + + DDim out_d_ddx_dims; + if (out_d_ddx) { + out_d_ddx_dims = out_d_ddx->dims(); + if (out_d_ddx_dims != x_help.dims()) { + out_d_ddx->Resize(x_help.dims()); + } + } + + DDim out_d_ddy_dims; + if (out_d_ddy) { + out_d_ddy_dims = out_d_ddy->dims(); + if (out_d_ddy_dims != y_help.dims()) { + out_d_ddy->Resize(y_help.dims()); + } + } + + if (out_d_ddx || out_d_ddy) { + x_conj = Conj(dev_ctx, x_help); + y_conj = Conj(dev_ctx, y_help); + dout_conj = Conj(dev_ctx, dout_help); + } + + bool d_dout_flag = false; + bool d_ddx_flag = false; + bool d_ddy_flag = false; + + if (d_ddout) { + auto d_ddout_mat = d_ddout.get(); + if (d_ddout_mat.dims() != dout_help.dims()) { + d_ddout_mat.Resize(dout_help.dims()); + } + + if (out_d_y) { + if (transpose_x && transpose_y) { + // out_d_y = d_ddout' * ddx' + CalcInputGrad(dev_ctx, + d_ddout_mat, + true, + true, + ddx_conj, + true, + false, + out_d_y, + false); + } else if (transpose_x) { + // out_d_y = ddx * d_ddout + CalcInputGrad(dev_ctx, + ddx_conj, + false, + false, + d_ddout_mat, + false, + true, + out_d_y, + false); + } else if (transpose_y) { + // out_d_y = d_ddout' * ddx + CalcInputGrad(dev_ctx, + d_ddout_mat, + true, + true, + ddx_conj, + false, + true, + out_d_y, + false); + } else { + // out_d_y = ddx' * d_ddout + CalcInputGrad(dev_ctx, + ddx_conj, + true, + true, + d_ddout_mat, + false, + true, + out_d_y, + false); + } + } + if (out_d_x) { + if (transpose_x && transpose_y) { + // out_d_x = ddy' * d_ddout' + CalcInputGrad(dev_ctx, + ddy_conj, + true, + true, + d_ddout_mat, + true, + false, + out_d_x, + false); + } else if (transpose_x) { + // out_d_x = ddy * d_ddout' + CalcInputGrad(dev_ctx, + ddy_conj, + false, + false, + d_ddout_mat, + true, + false, + out_d_x, + false); + } else if (transpose_y) { + // out_d_x = d_ddout * ddy + CalcInputGrad(dev_ctx, + d_ddout_mat, + false, + false, + ddy_conj, + false, + true, + out_d_x, + false); + } else { + // out_d_x = d_ddout * ddy' + CalcInputGrad(dev_ctx, + d_ddout_mat, + false, + false, + ddy_conj, + true, + false, + out_d_x, + false); + } + } + + // equations: + // d_ddx = DOut * D_DY + Y * D_DDOut + // Let: d_ddx1 = Y * D_DDOut + // Let: d_ddx2 = DOut * D_DY + + // d_ddy = DOut * D_DX + X * D_DDOut + // Let: d_ddy1 = X * D_DDOut + // Let: d_ddy2 = DOut * D_DX + + // d_dout = DDY * D_DX + DDX * D_DY + // Let: d_dout1 = DDX * D_DY + // Let: d_dout2 = DDY * D_DX + + // compute d_ddx1 + if (out_d_ddx) { + if (transpose_x && transpose_y) { + // out_d_ddx1 = y' * d_ddout' + CalcInputGrad(dev_ctx, + y_conj, + true, + true, + d_ddout_mat, + true, + false, + out_d_ddx, + d_ddx_flag); + } else if (transpose_x) { + // out_d_ddx1 = y * d_ddout' + CalcInputGrad(dev_ctx, + y_conj, + false, + false, + d_ddout_mat, + true, + false, + out_d_ddx, + d_ddx_flag); + } else if (transpose_y) { + // out_d_ddx1 = d_ddout * y + CalcInputGrad(dev_ctx, + d_ddout_mat, + false, + false, + y_conj, + false, + true, + out_d_ddx, + d_ddx_flag); + } else { + // out_d_ddx1 = d_ddout * y' + CalcInputGrad(dev_ctx, + d_ddout_mat, + false, + false, + y_conj, + true, + false, + out_d_ddx, + d_ddx_flag); + } + d_ddx_flag = true; + } + + // compute d_ddy1 + if (out_d_ddy) { + if (transpose_x && transpose_y) { + // out_d_ddy1 = d_ddout' * x' + CalcInputGrad(dev_ctx, + d_ddout_mat, + true, + true, + x_conj, + true, + false, + out_d_ddy, + false); + } else if (transpose_x) { + // out_d_ddy1 = x * d_ddout + CalcInputGrad(dev_ctx, + x_conj, + false, + false, + d_ddout_mat, + false, + true, + out_d_ddy, + false); + } else if (transpose_y) { + // out_d_ddy1 = d_ddout' * x + CalcInputGrad(dev_ctx, + d_ddout_mat, + true, + true, + x_conj, + false, + true, + out_d_ddy, + false); + } else { + // out_d_ddy1 = x' * d_ddout + CalcInputGrad(dev_ctx, + x_conj, + true, + true, + d_ddout_mat, + false, + true, + out_d_ddy, + false); + } + d_ddy_flag = true; + } + } + + if (d_dy) { + auto d_dy_mat = d_dy.get(); + if (d_dy_mat.dims() != y_help.dims()) { + d_dy_mat.Resize(y_help.dims()); + } + + // compute d_dout1 + if (out_d_dout) { + CalcInputGrad(dev_ctx, + ddx_conj, + transpose_x, + true, + d_dy_mat, + transpose_y, + false, + out_d_dout, + d_dout_flag); + d_dout_flag = true; + } + + // compute d_ddx2 + if (out_d_ddx) { + if (transpose_x && transpose_y) { + // out_d_ddx2 = D_DY' * DOut' + CalcInputGrad(dev_ctx, + d_dy_mat, + true, + true, + dout_conj, + true, + false, + out_d_ddx, + d_ddx_flag); + } else if (transpose_x) { + // out_d_ddx2 = D_DY * Dout' + CalcInputGrad(dev_ctx, + d_dy_mat, + false, + false, + dout_conj, + true, + false, + out_d_ddx, + d_ddx_flag); + } else if (transpose_y) { + // out_d_ddx2 = Dout * D_DY + CalcInputGrad(dev_ctx, + dout_conj, + false, + false, + d_dy_mat, + false, + true, + out_d_ddx, + d_ddx_flag); + } else { + // out_d_ddx2 = Dout * D_DY' + CalcInputGrad(dev_ctx, + dout_conj, + false, + false, + d_dy_mat, + true, + false, + out_d_ddx, + d_ddx_flag); + } + } + } + + if (d_dx) { + auto d_dx_mat = d_dx.get(); + if (d_dx_mat.dims() != x_help.dims()) { + d_dx_mat.Resize(x_help.dims()); + } + + // compute d_dout2 + if (out_d_dout) { + CalcInputGrad(dev_ctx, + d_dx_mat, + transpose_x, + true, + ddy_conj, + transpose_y, + false, + out_d_dout, + d_dout_flag); + } + + // compute d_ddy2 + if (out_d_ddy) { + if (transpose_x && transpose_y) { + // out_d_ddy2 = dout' * d_dx' + CalcInputGrad(dev_ctx, + dout_conj, + true, + true, + d_dx_mat, + true, + false, + out_d_ddy, + d_ddy_flag); + } else if (transpose_x) { + // out_d_ddy2 = d_dx * dout + CalcInputGrad(dev_ctx, + d_dx_mat, + false, + false, + dout_conj, + false, + true, + out_d_ddy, + d_ddy_flag); + } else if (transpose_y) { + // out_d_ddy2 = dout' * d_dx + CalcInputGrad(dev_ctx, + dout_conj, + true, + true, + d_dx_mat, + false, + true, + out_d_ddy, + d_ddy_flag); + } else { + // out_d_ddy2 = d_dx' * dout + CalcInputGrad(dev_ctx, + d_dx_mat, + true, + true, + dout_conj, + false, + true, + out_d_ddy, + d_ddy_flag); + } + } + } + + if (out_d_x) { + if (out_dx_dims != x_help.dims()) { + out_d_x->Resize(out_dx_dims); + } + } + + if (out_d_y) { + if (out_dy_dims != y_help.dims()) { + out_d_y->Resize(out_dy_dims); + } + } + + if (out_d_dout) { + if (out_d_dout_dims != dout_help.dims()) { + out_d_dout->Resize(out_d_dout_dims); + } + } + + if (out_d_ddx) { + if (out_d_ddx_dims != x_help.dims()) { + out_d_ddx->Resize(out_d_ddx_dims); + } + } + + if (out_d_ddy) { + if (out_d_ddy_dims != y_help.dims()) { + out_d_ddy->Resize(out_d_ddy_dims); + } + } + } else { + // Case3: broadcast. It need cost much time to reduce sum for the + // broadcast and wastes the memory. + // So we should avoid the case in reality. + VLOG(3) << "======== MatMulV2TripleGradKernel, Compute ====== Case 3"; + VLOG(3) << "It need cost much time to reduce sum for the broadcast and " + "wastes the memory. So we should avoid the case in reality"; + + DenseTensor out_dx_help = Empty(dev_ctx); + DenseTensor out_dy_help = Empty(dev_ctx); + DenseTensor out_d_ddx_help = Empty(dev_ctx); + DenseTensor out_d_ddy_help = Empty(dev_ctx); + + if (out_d_dout) { + ddx_conj = Conj(dev_ctx, ddx); + ddy_conj = Conj(dev_ctx, ddy); + } + if (out_d_ddx || out_d_ddy) { + x_conj = Conj(dev_ctx, x); + y_conj = Conj(dev_ctx, y); + dout_conj = Conj(dev_ctx, dout); + } + + if (transpose_x) { + if (transpose_y) { + // dX = ddY' d_ddout’, dY = d_ddout’ ddX' + if (out_d_x) + MatMulFunction(dev_ctx, + ddy_conj, + d_ddout.get(), + y_dims, + dout_dims, + &out_dx_help, + true, + true); + if (out_d_y) + MatMulFunction(dev_ctx, + d_ddout.get(), + ddx_conj, + dout_dims, + x_dims, + &out_dy_help, + true, + true); + } else { + // dX = ddY d_ddout', dY = ddX d_ddout + if (out_d_x) + MatMulFunction(dev_ctx, + ddy_conj, + d_ddout.get(), + y_dims, + dout_dims, + &out_dx_help, + false, + true); + if (out_d_y) + MatMulFunction(dev_ctx, + ddx_conj, + d_ddout.get(), + x_dims, + dout_dims, + &out_dy_help, + false, + false); + } + } else { + if (transpose_y) { + // dX = d_ddout ddY, dY = d_ddout’ ddX + if (out_d_x) + MatMulFunction(dev_ctx, + d_ddout.get(), + ddy_conj, + dout_dims, + y_dims, + &out_dx_help, + false, + false); + if (out_d_y) + MatMulFunction(dev_ctx, + d_ddout.get(), + ddx_conj, + dout_dims, + x_dims, + &out_dy_help, + true, + false); + } else { + // dX = d_ddout ddY', dY = ddX' d_ddout + if (out_d_x) + MatMulFunction(dev_ctx, + d_ddout.get(), + ddy_conj, + dout_dims, + y_dims, + &out_dx_help, + false, + true); + if (out_d_y) + MatMulFunction(dev_ctx, + ddx_conj, + d_ddout.get(), + x_dims, + dout_dims, + &out_dy_help, + true, + false); + } + } + + // get help dims + const std::vector dx_help_dims = + vectorize(out_dx_help.dims()); + const std::vector dy_help_dims = + vectorize(out_dx_help.dims()); + + std::vector dx_broadcast_dims(ndim); + std::vector dy_broadcast_dims(ndim); + + std::fill( + dx_broadcast_dims.data(), dx_broadcast_dims.data() + ndim - x_ndim, 1); + std::fill( + dy_broadcast_dims.data(), dy_broadcast_dims.data() + ndim - y_ndim, 1); + std::copy(x_dims.data(), + x_dims.data() + x_ndim, + dx_broadcast_dims.data() + ndim - x_ndim); + std::copy(y_dims.data(), + y_dims.data() + y_ndim, + dy_broadcast_dims.data() + ndim - y_ndim); + + std::vector dx_reduce_dims; + std::vector dy_reduce_dims; + for (int idx = 0; idx <= ndim - 3; idx++) { + if (dx_help_dims[idx] != 1 && dx_broadcast_dims[idx] == 1) { + dx_reduce_dims.push_back(idx); + } + if (dy_help_dims[idx] != 1 && dy_broadcast_dims[idx] == 1) { + dy_reduce_dims.push_back(idx); + } + } + // Reduce sum to get grad by ReduceSum + if (out_d_x) { + if (dx_reduce_dims.empty()) { + *out_d_x = std::move(out_dx_help); + } else { + ReduceSumForMatmulGrad()( + dev_ctx, out_dx_help, out_d_x, dx_reduce_dims); + } + out_d_x->Resize(x.dims()); + } + + if (out_d_y) { + if (dy_reduce_dims.empty()) { + *out_d_y = std::move(out_dy_help); + } else { + ReduceSumForMatmulGrad()( + dev_ctx, out_dy_help, out_d_y, dy_reduce_dims); + } + out_d_y->Resize(y.dims()); + } + + // compute d_dout + if (out_d_dout) { + MatMulFunction(dev_ctx, + d_dx.get(), + ddy_conj, + x_dims, + y_dims, + out_d_dout, + transpose_x, + transpose_y); + MatMulFunction(dev_ctx, + ddx_conj, + d_dy.get(), + x_dims, + y_dims, + out_d_dout, + transpose_x, + transpose_y, + true); + } + // compute d_ddx + if (out_d_ddx) { + if (transpose_x && transpose_y) { + // out_d_ddx1 = y' * d_ddout' + MatMulFunction(dev_ctx, + y_conj, + d_ddout.get(), + y_dims, + dout_dims, + &out_d_ddx_help, + true, + true); + // out_d_ddx2 = D_DY' * DOut' + MatMulFunction(dev_ctx, + d_dy.get(), + dout_conj, + y_dims, + dout_dims, + &out_d_ddx_help, + true, + true, + true); + } else if (transpose_x) { + // out_d_ddx1 = y * d_ddout' + MatMulFunction(dev_ctx, + y_conj, + d_ddout.get(), + y_dims, + dout_dims, + &out_d_ddx_help, + false, + true); + // out_d_ddx2 = D_DY * Dout' + MatMulFunction(dev_ctx, + d_dy.get(), + dout_conj, + y_dims, + dout_dims, + &out_d_ddx_help, + false, + true, + true); + } else if (transpose_y) { + // out_d_ddx1 = d_ddout * y + MatMulFunction(dev_ctx, + d_ddout.get(), + y_conj, + dout_dims, + y_dims, + &out_d_ddx_help, + false, + false); + // out_d_ddx2 = Dout * D_DY + MatMulFunction(dev_ctx, + dout_conj, + d_dy.get(), + dout_dims, + y_dims, + &out_d_ddx_help, + false, + false, + true); + } else { + // out_d_ddx1 = d_ddout * y' + MatMulFunction(dev_ctx, + d_ddout.get(), + y_conj, + dout_dims, + y_dims, + &out_d_ddx_help, + false, + true); + // out_d_ddx2 = Dout * D_DY' + MatMulFunction(dev_ctx, + dout_conj, + d_dy.get(), + dout_dims, + y_dims, + &out_d_ddx_help, + false, + true, + true); + } + if (dx_reduce_dims.empty()) { + *out_d_ddx = std::move(out_d_ddx_help); + } else { + ReduceSumForMatmulGrad()( + dev_ctx, out_d_ddx_help, out_d_ddx, dx_reduce_dims); + } + out_d_ddx->Resize(x.dims()); + } + + // compute d_ddy + if (out_d_ddy) { + if (transpose_x && transpose_y) { + // out_d_ddy1 = d_ddout' * x' + MatMulFunction(dev_ctx, + d_ddout.get(), + x_conj, + dout_dims, + x_dims, + &out_d_ddy_help, + true, + true); + // out_d_ddy2 = dout' * d_dx' + MatMulFunction(dev_ctx, + dout_conj, + d_dx.get(), + dout_dims, + x_dims, + &out_d_ddy_help, + true, + true, + true); + } else if (transpose_x) { + // out_d_ddy1 = x * d_ddout + MatMulFunction(dev_ctx, + x_conj, + d_ddout.get(), + x_dims, + dout_dims, + &out_d_ddy_help, + false, + false); + // out_d_ddy2 = d_dx * dout + MatMulFunction(dev_ctx, + d_dx.get(), + dout_conj, + x_dims, + dout_dims, + &out_d_ddy_help, + false, + false, + true); + } else if (transpose_y) { + // out_d_ddy1 = d_ddout' * x + MatMulFunction(dev_ctx, + d_ddout.get(), + x_conj, + dout_dims, + x_dims, + &out_d_ddy_help, + true, + false); + // out_d_ddy2 = dout' * d_dx + MatMulFunction(dev_ctx, + dout_conj, + d_dx.get(), + dout_dims, + x_dims, + &out_d_ddy_help, + true, + false, + true); + } else { + // out_d_ddy1 = x' * d_ddout + MatMulFunction(dev_ctx, + x_conj, + d_ddout.get(), + x_dims, + dout_dims, + &out_d_ddy_help, + true, + false); + // out_d_ddy2 = d_dx' * dout + MatMulFunction(dev_ctx, + d_dx.get(), + dout_conj, + x_dims, + dout_dims, + &out_d_ddy_help, + true, + false, + true); + } + + if (dy_reduce_dims.empty()) { + *out_d_ddy = std::move(out_d_ddy_help); + } else { + ReduceSumForMatmulGrad()( + dev_ctx, out_d_ddy_help, out_d_ddy, dy_reduce_dims); + } + out_d_ddy->Resize(y.dims()); + } + } +} + +} // namespace pten diff --git a/paddle/pten/kernels/impl/matmul_kernel_impl.h b/paddle/pten/kernels/impl/matmul_kernel_impl.h index e50b2f0641a46..f5f69f327a69f 100644 --- a/paddle/pten/kernels/impl/matmul_kernel_impl.h +++ b/paddle/pten/kernels/impl/matmul_kernel_impl.h @@ -86,7 +86,7 @@ static void IndexIncreaseFromDims(const int ndim, } template -void MatMulFunction(const Context& context, +void MatMulFunction(const Context& dev_ctx, const DenseTensor& X, const DenseTensor& Y, const std::vector& x_dims, @@ -102,7 +102,7 @@ void MatMulFunction(const Context& context, const T* x_data = X.data(); const T* y_data = Y.data(); - auto blas = paddle::operators::math::GetBlas(context); + auto blas = paddle::operators::math::GetBlas(dev_ctx); if (x_ndim == 1 && y_ndim == 1) { const int M = X.numel(); @@ -117,6 +117,8 @@ void MatMulFunction(const Context& context, M, N)); VLOG(3) << "MatMul's case 1"; + Out->Resize({1}); + Out->mutable_data(); blas.GEMM(CblasNoTrans, CblasTrans, 1, @@ -471,7 +473,7 @@ void MatMulFunction(const Context& context, } template -void MatMulFunction(const Context& context, +void MatMulFunction(const Context& dev_ctx, const DenseTensor& X, const DenseTensor& Y, DenseTensor* Out, @@ -481,11 +483,11 @@ void MatMulFunction(const Context& context, const std::vector x_dims = vectorize(X.dims()); const std::vector y_dims = vectorize(Y.dims()); MatMulFunction( - context, X, Y, x_dims, y_dims, Out, trans_x, trans_y, flag); + dev_ctx, X, Y, x_dims, y_dims, Out, trans_x, trans_y, flag); } template -void MatmulKernel(const Context& context, +void MatmulKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& y, bool transpose_x, @@ -501,7 +503,7 @@ void MatmulKernel(const Context& context, paddle::platform::errors::InvalidArgument( "The Input(Y) dims size must not be equal 0," " but reviced dims size is 0. ")); - MatMulFunction(context, x, y, out, transpose_x, transpose_y); + MatMulFunction(dev_ctx, x, y, out, transpose_x, transpose_y); } } // namespace pten diff --git a/paddle/pten/kernels/matmul_grad_kernel.h b/paddle/pten/kernels/matmul_grad_kernel.h new file mode 100644 index 0000000000000..db485b79d2736 --- /dev/null +++ b/paddle/pten/kernels/matmul_grad_kernel.h @@ -0,0 +1,63 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/pten/core/dense_tensor.h" +#include "paddle/utils/optional.h" + +namespace pten { + +template +void MatmulGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& dout, + bool transpose_x, + bool transpose_y, + DenseTensor* dx, + DenseTensor* dy); + +template +void MatmulDoubleGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& dout, + paddle::optional ddx, + paddle::optional ddy, + bool transpose_x, + bool transpose_y, + DenseTensor* dx, + DenseTensor* dy, + DenseTensor* ddout); + +template +void MatmulTripleGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& dout, + const DenseTensor& ddx, + const DenseTensor& ddy, + paddle::optional d_dx, + paddle::optional d_dy, + paddle::optional d_ddout, + bool transpose_x, + bool transpose_y, + DenseTensor* out_d_x, + DenseTensor* out_d_y, + DenseTensor* out_d_dout, + DenseTensor* out_d_ddx, + DenseTensor* out_d_ddy); + +} // namespace pten diff --git a/paddle/pten/kernels/matmul_kernel.h b/paddle/pten/kernels/matmul_kernel.h index fb54a5301e61c..f9cb2c3801caa 100644 --- a/paddle/pten/kernels/matmul_kernel.h +++ b/paddle/pten/kernels/matmul_kernel.h @@ -14,14 +14,15 @@ #pragma once -#include "paddle/pten/api/lib/utils/storage.h" #include "paddle/pten/core/dense_tensor.h" #include "paddle/pten/infermeta/binary.h" +#include "paddle/pten/kernels/empty_kernel.h" + namespace pten { template -void MatmulKernel(const Context& context, +void MatmulKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& y, bool transpose_x, @@ -29,17 +30,14 @@ void MatmulKernel(const Context& context, DenseTensor* out); template -DenseTensor Matmul(const Context& context, +DenseTensor Matmul(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& y, bool transpose_x, bool transpose_y) { auto out_meta = MatmulInferMeta(x.meta(), y.meta(), transpose_x, transpose_y); - DenseTensor dense_out( - pten::make_intrusive( - context.GetPlace()), - std::move(out_meta)); - MatmulKernel(context, x, y, transpose_x, transpose_y, &dense_out); + auto dense_out = Empty(dev_ctx, std::move(out_meta)); + MatmulKernel(dev_ctx, x, y, transpose_x, transpose_y, &dense_out); return dense_out; }