diff --git a/paddle/fluid/operators/matrix_power_op.cc b/paddle/fluid/operators/matrix_power_op.cc index c65af3129f3646..cdf204628b638f 100644 --- a/paddle/fluid/operators/matrix_power_op.cc +++ b/paddle/fluid/operators/matrix_power_op.cc @@ -12,7 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/operators/matrix_power_op.h" +#include +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/tensor_util.h" namespace paddle { namespace operators { @@ -119,13 +122,3 @@ REGISTER_OPERATOR(matrix_power, ops::MatrixPowerOp, ops::MatrixPowerOpMaker, ops::MatrixPowerGradOpMaker); REGISTER_OPERATOR(matrix_power_grad, ops::MatrixPowerGradOp); - -REGISTER_OP_CPU_KERNEL( - matrix_power, - ops::MatrixPowerKernel, - ops::MatrixPowerKernel); - -REGISTER_OP_CPU_KERNEL( - matrix_power_grad, - ops::MatrixPowerGradKernel, - ops::MatrixPowerGradKernel); diff --git a/paddle/fluid/operators/matrix_power_op.cu b/paddle/fluid/operators/matrix_power_op.cu deleted file mode 100644 index d972e9499dc884..00000000000000 --- a/paddle/fluid/operators/matrix_power_op.cu +++ /dev/null @@ -1,27 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/matrix_power_op.h" - -namespace ops = paddle::operators; -namespace plf = paddle::platform; - -REGISTER_OP_CUDA_KERNEL(matrix_power, - ops::MatrixPowerKernel, - ops::MatrixPowerKernel); - -REGISTER_OP_CUDA_KERNEL( - matrix_power_grad, - ops::MatrixPowerGradKernel, - ops::MatrixPowerGradKernel); diff --git a/paddle/fluid/operators/matrix_power_op.h b/paddle/fluid/operators/matrix_power_op.h deleted file mode 100644 index 8eb9c58513df62..00000000000000 --- a/paddle/fluid/operators/matrix_power_op.h +++ /dev/null @@ -1,277 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include -#include -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/tensor_util.h" -#include "paddle/fluid/platform/for_range.h" -#include "paddle/phi/kernels/funcs/blas/blas.h" -#include "paddle/phi/kernels/funcs/matrix_inverse.h" - -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; - -template -struct IdentityMatrixFunctor { - IdentityMatrixFunctor(const int m, T* output) : m_(m), output_(output) {} - - HOSTDEVICE void operator()(size_t index) const { - const int row = index / m_ % m_; - const int col = index % m_; - output_[index] = col == row ? static_cast(1) : static_cast(0); - } - - const int m_; - T* output_; -}; - -template -void MatrixPowerFunction(const Tensor* X, const int n, Tensor* Out, - const paddle::framework::ExecutionContext& ctx) { - const auto& x_dims = X->dims(); - const int x_ndim = x_dims.size(); - T* out_data = Out->mutable_data(ctx.GetPlace()); - - auto& dev_ctx = ctx.template device_context(); - platform::ForRange for_range(dev_ctx, X->numel()); - - if (n == 0) { - // Out = Identity Matrix - IdentityMatrixFunctor functor(x_dims[x_ndim - 1], out_data); - for_range(functor); - return; - } - - auto blas = phi::funcs::GetBlas(dev_ctx); - - Tensor new_x = ctx.AllocateTmpTensor(X->dims(), dev_ctx); - int new_n = n; - if (n > 0) { - // newX = X - framework::TensorCopy(*X, ctx.GetPlace(), dev_ctx, &new_x); - } else { - // newX = X^{-1}, n = -n - phi::funcs::MatrixInverseFunctor mat_inv; - mat_inv(dev_ctx, *X, &new_x); - new_n = -n; - } - - if (new_n == 1) { - framework::TensorCopy(new_x, ctx.GetPlace(), dev_ctx, Out); - return; - } - - auto no_trans_desc = phi::funcs::CreateMatrixDescriptor(x_dims, 0, false); - - if (new_n == 2) { - // Out = newX * newX - Out->mutable_data(ctx.GetPlace()); - blas.MatMul(new_x, no_trans_desc, new_x, no_trans_desc, static_cast(1), - Out, static_cast(0)); - return; - } else if (new_n == 3) { - // Out = (newX * newX) * newX - // Note: C[i] matrices in MatMul must not overlap, i.e. the individual - // gemm operations must be computable independently; otherwise, - // undefined behavior is expected. - Tensor temp = ctx.AllocateTmpTensor(X->dims(), dev_ctx); - blas.MatMul(new_x, no_trans_desc, new_x, no_trans_desc, static_cast(1), - &temp, static_cast(0)); - blas.MatMul(temp, no_trans_desc, new_x, no_trans_desc, static_cast(1), - Out, static_cast(0)); - return; - } else if (new_n == 4) { - // Out = (newX * newX) * (newX * newX) - Tensor temp = ctx.AllocateTmpTensor(X->dims(), dev_ctx); - blas.MatMul(new_x, no_trans_desc, new_x, no_trans_desc, static_cast(1), - &temp, static_cast(0)); - blas.MatMul(temp, no_trans_desc, temp, no_trans_desc, static_cast(1), - Out, static_cast(0)); - return; - } - - // Calculate Out = newX^{n} for abs(n) > 4 with time complexity as O(logN) - int bit = 0; - Tensor z = Tensor(X->dtype()); - bool out_inited = false; - Tensor temp_out = ctx.AllocateTmpTensor(X->dims(), dev_ctx); - Tensor temp_z = ctx.AllocateTmpTensor(X->dims(), dev_ctx); - while (new_n > 0) { - bit = new_n & 0x1; - new_n >>= 1; - if (z.IsInitialized()) { - blas.MatMul(z, no_trans_desc, z, no_trans_desc, static_cast(1), - &temp_z, static_cast(0)); - framework::TensorCopy(temp_z, ctx.GetPlace(), dev_ctx, &z); - } else { - z = ctx.AllocateTmpTensor(X->dims(), dev_ctx); - framework::TensorCopy(new_x, ctx.GetPlace(), dev_ctx, &z); - } - if (bit == 1) { - if (out_inited == true) { - blas.MatMul(*Out, no_trans_desc, z, no_trans_desc, static_cast(1), - &temp_out, static_cast(0)); - framework::TensorCopy(temp_out, ctx.GetPlace(), dev_ctx, Out); - } else { - framework::TensorCopy(z, ctx.GetPlace(), dev_ctx, Out); - out_inited = true; - } - } - } - return; -} - -template -class MatrixPowerKernel : public framework::OpKernel { - public: - void Compute(const paddle::framework::ExecutionContext& ctx) const override { - const Tensor* X = ctx.Input("X"); - Tensor* Out = ctx.Output("Out"); - int n = ctx.Attr("n"); - - const auto& x_dims = X->dims(); - const int x_ndim = x_dims.size(); - PADDLE_ENFORCE_EQ( - x_dims[x_ndim - 2], x_dims[x_ndim - 1], - platform::errors::InvalidArgument( - "The inner-most 2 dimensions of Input(X) should be equal." - "X's shape[-2] = %d and shape[-1] = %d.", - x_dims[x_ndim - 2], x_dims[x_ndim - 1])); - - MatrixPowerFunction(X, n, Out, ctx); - } -}; - -template -void MatrixPowerGradFunction(const Tensor* X, const Tensor* Out, - const Tensor* dOut, const int n, Tensor* dX, - const paddle::framework::ExecutionContext& ctx) { - dX->mutable_data(ctx.GetPlace()); - const auto& x_dims = X->dims(); - - auto& dev_ctx = ctx.template device_context(); - auto blas = phi::funcs::GetBlas(dev_ctx); - - if (n == 0) { - // \nabla X = O - phi::funcs::SetConstant zero; - zero(dev_ctx, dX, static_cast(0)); - return; - } else if (n == 1) { - // \nabla X = \nabla Out - framework::TensorCopy(*dOut, ctx.GetPlace(), dev_ctx, dX); - return; - } - - auto trans_desc = phi::funcs::CreateMatrixDescriptor(x_dims, 0, true); - auto no_trans_desc = phi::funcs::CreateMatrixDescriptor(x_dims, 0, false); - - if (n == -1) { - // \nabla X = Out^{T} * \nabla Out * Out^{T} - Tensor temp_dx = - ctx.AllocateTmpTensor(X->dims(), dev_ctx); - blas.MatMul(*Out, trans_desc, *dOut, no_trans_desc, static_cast(-1), - &temp_dx, static_cast(0)); - blas.MatMul(temp_dx, no_trans_desc, *Out, trans_desc, static_cast(1), dX, - static_cast(0)); - return; - } - - Tensor new_x = ctx.AllocateTmpTensor(X->dims(), dev_ctx); - int new_n = n; - if (n > 0) { - // newX = X - framework::TensorCopy(*X, ctx.GetPlace(), dev_ctx, &new_x); - } else { - // newX = X^{-1}, n = -n - phi::funcs::MatrixInverseFunctor mat_inv; - mat_inv(dev_ctx, *X, &new_x); - new_n = -n; - } - - // Use chain rule blow to compute \nabla newX^{n} - // First, Get newX^{0}, newX^{1}, ..., newX^{n - 1}, - // Note that newX^{0} can be omitted - std::vector> tensor_list(new_n - 1); - tensor_list[0] = std::make_shared(new_x); - int index = 1; - while (index < new_n - 1) { - tensor_list[index] = std::make_shared( - ctx.AllocateTmpTensor(X->dims(), dev_ctx)); - blas.MatMul(*tensor_list[index - 1], no_trans_desc, new_x, no_trans_desc, - static_cast(1), tensor_list[index].get(), static_cast(0)); - index++; - } - - // Second, \nabla newX = \sum_{i = 0}^{n - 1} (newX^{T}^{i} - // * \nabla Out - // * (newX^{T}^{n - i - 1}) - Tensor dx_new = ctx.AllocateTmpTensor(X->dims(), dev_ctx); - blas.MatMul(*tensor_list[new_n - 2], trans_desc, *dOut, no_trans_desc, - static_cast(1), &dx_new, static_cast(0)); - Tensor da_an_minus1 = - ctx.AllocateTmpTensor(X->dims(), dev_ctx); - blas.MatMul(*dOut, no_trans_desc, *tensor_list[new_n - 2], trans_desc, - static_cast(1), &da_an_minus1, static_cast(0)); - blas.AXPY(X->numel(), static_cast(1), da_an_minus1.data(), - dx_new.data()); - int start = 0; - while (start < new_n - 2) { - Tensor a_da = ctx.AllocateTmpTensor(X->dims(), dev_ctx); - Tensor a_da_a = ctx.AllocateTmpTensor(X->dims(), dev_ctx); - blas.MatMul(*tensor_list[start], trans_desc, *dOut, no_trans_desc, - static_cast(1), &a_da, static_cast(0)); - blas.MatMul(a_da, no_trans_desc, *tensor_list[new_n - 3 - start], - trans_desc, static_cast(1), &a_da_a, static_cast(0)); - blas.AXPY(X->numel(), static_cast(1), a_da_a.data(), - dx_new.data()); - start++; - } - - if (n > 0) { - // \nabla X = \nabla newX - framework::TensorCopy(dx_new, ctx.GetPlace(), dev_ctx, dX); - } else { - // \nabla X = newX^{T} * \nabla newX * newX^{T} - Tensor temp_dx = - ctx.AllocateTmpTensor(X->dims(), dev_ctx); - blas.MatMul(new_x, trans_desc, dx_new, no_trans_desc, static_cast(-1), - &temp_dx, static_cast(0)); - blas.MatMul(temp_dx, no_trans_desc, new_x, trans_desc, static_cast(1), - dX, static_cast(0)); - } - return; -} - -template -class MatrixPowerGradKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - const Tensor* X = ctx.Input("X"); - const Tensor* Out = ctx.Input("Out"); - const Tensor* dOut = ctx.Input(framework::GradVarName("Out")); - const int n = ctx.Attr("n"); - Tensor* dX = ctx.Output(framework::GradVarName("X")); - - MatrixPowerGradFunction(X, Out, dOut, n, dX, ctx); - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/phi/kernels/CMakeLists.txt b/paddle/phi/kernels/CMakeLists.txt index 9b4b14bf51ed96..093cb6549797d1 100644 --- a/paddle/phi/kernels/CMakeLists.txt +++ b/paddle/phi/kernels/CMakeLists.txt @@ -27,7 +27,7 @@ kernel_library(full_kernel DEPS ${COMMON_KERNEL_DEPS} empty_kernel) # Some kernels depend on some targets that are not commonly used. # These targets are not suitable for common dependencies. # In this case, you need to manually generate them here. -set(MANUAL_BUILD_KERNELS math_kernel softmax_kernel softmax_grad_kernel triangular_solve_grad_kernel maxout_kernel maxout_grad_kernel put_along_axis_kernel put_along_axis_grad_kernel take_along_axis_kernel take_along_axis_grad_kernel eigh_kernel segment_pool_kernel segment_pool_grad_kernel) +set(MANUAL_BUILD_KERNELS math_kernel softmax_kernel softmax_grad_kernel triangular_solve_grad_kernel maxout_kernel maxout_grad_kernel put_along_axis_kernel put_along_axis_grad_kernel take_along_axis_kernel take_along_axis_grad_kernel eigh_kernel segment_pool_kernel segment_pool_grad_kernel matrix_power_kernel matrix_power_grad_kernel) kernel_library(math_kernel DEPS ${COMMON_KERNEL_DEPS} cast_kernel copy_kernel) kernel_library(softmax_kernel DEPS ${COMMON_KERNEL_DEPS} softmax) kernel_library(softmax_grad_kernel DEPS ${COMMON_KERNEL_DEPS} softmax) @@ -38,6 +38,8 @@ kernel_library(put_along_axis_kernel DEPS ${COMMON_KERNEL_DEPS} gather_scatter_k kernel_library(put_along_axis_grad_kernel DEPS ${COMMON_KERNEL_DEPS} gather_scatter_kernel) kernel_library(take_along_axis_kernel DEPS ${COMMON_KERNEL_DEPS} gather_scatter_kernel) kernel_library(take_along_axis_grad_kernel DEPS ${COMMON_KERNEL_DEPS} gather_scatter_kernel) +kernel_library(matrix_power_kernel DEPS ${COMMON_KERNEL_DEPS} matrix_inverse) +kernel_library(matrix_power_grad_kernel DEPS ${COMMON_KERNEL_DEPS} matrix_inverse) kernel_library(eigh_kernel DEPS ${COMMON_KERNEL_DEPS} lapack_function) kernel_library(segment_pool_kernel DEPS ${COMMON_KERNEL_DEPS} segment_pooling) kernel_library(segment_pool_grad_kernel DEPS ${COMMON_KERNEL_DEPS} segment_pooling) diff --git a/paddle/phi/kernels/cpu/matrix_power_grad_kernel.cc b/paddle/phi/kernels/cpu/matrix_power_grad_kernel.cc new file mode 100644 index 00000000000000..ae3b4d2b45582b --- /dev/null +++ b/paddle/phi/kernels/cpu/matrix_power_grad_kernel.cc @@ -0,0 +1,26 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/kernels/matrix_power_grad_kernel.h" +#include "paddle/phi/kernels/impl/matrix_power_grad_kernel_impl.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +PD_REGISTER_KERNEL(matrix_power_grad, + CPU, + ALL_LAYOUT, + phi::MatrixPowerGradKernel, + float, + double) {} diff --git a/paddle/phi/kernels/cpu/matrix_power_kernel.cc b/paddle/phi/kernels/cpu/matrix_power_kernel.cc new file mode 100644 index 00000000000000..f40e1e616f5262 --- /dev/null +++ b/paddle/phi/kernels/cpu/matrix_power_kernel.cc @@ -0,0 +1,22 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/kernels/matrix_power_kernel.h" +#include "paddle/phi/kernels/impl/matrix_power_kernel_impl.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +PD_REGISTER_KERNEL( + matrix_power, CPU, ALL_LAYOUT, phi::MatrixPowerKernel, float, double) {} diff --git a/paddle/phi/kernels/gpu/matrix_power_grad_kernel.cu b/paddle/phi/kernels/gpu/matrix_power_grad_kernel.cu new file mode 100644 index 00000000000000..25a9de8f8bed42 --- /dev/null +++ b/paddle/phi/kernels/gpu/matrix_power_grad_kernel.cu @@ -0,0 +1,26 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/kernels/matrix_power_grad_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/matrix_power_grad_kernel_impl.h" + +PD_REGISTER_KERNEL(matrix_power_grad, + GPU, + ALL_LAYOUT, + phi::MatrixPowerGradKernel, + float, + double) {} diff --git a/paddle/phi/kernels/gpu/matrix_power_kernel.cu b/paddle/phi/kernels/gpu/matrix_power_kernel.cu new file mode 100644 index 00000000000000..d7ae7d8a3f745c --- /dev/null +++ b/paddle/phi/kernels/gpu/matrix_power_kernel.cu @@ -0,0 +1,22 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/kernels/matrix_power_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/matrix_power_kernel_impl.h" + +PD_REGISTER_KERNEL( + matrix_power, GPU, ALL_LAYOUT, phi::MatrixPowerKernel, float, double) {} diff --git a/paddle/phi/kernels/impl/matrix_power_grad_kernel_impl.h b/paddle/phi/kernels/impl/matrix_power_grad_kernel_impl.h new file mode 100644 index 00000000000000..e797b27071caca --- /dev/null +++ b/paddle/phi/kernels/impl/matrix_power_grad_kernel_impl.h @@ -0,0 +1,200 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" +#include "paddle/phi/kernels/funcs/matrix_inverse.h" + +namespace phi { + +template +void MatrixPowerGradFunction(const DenseTensor* X, + const DenseTensor* Out, + const DenseTensor* dOut, + const int n, + DenseTensor* dX, + const Context& ctx) { + ctx.template Alloc(dX); + const auto& x_dims = X->dims(); + + auto blas = phi::funcs::GetBlas(ctx); + + if (n == 0) { + // \nabla X = O + phi::funcs::SetConstant zero; + zero(ctx, dX, static_cast(0)); + return; + } else if (n == 1) { + // \nabla X = \nabla Out + paddle::framework::TensorCopy(*dOut, ctx.GetPlace(), ctx, dX); + return; + } + + auto trans_desc = phi::funcs::CreateMatrixDescriptor(x_dims, 0, true); + auto no_trans_desc = phi::funcs::CreateMatrixDescriptor(x_dims, 0, false); + + if (n == -1) { + // \nabla X = Out^{T} * \nabla Out * Out^{T} + DenseTensor temp_dx; + temp_dx.Resize(X->dims()); + ctx.template Alloc(&temp_dx); + blas.MatMul(*Out, + trans_desc, + *dOut, + no_trans_desc, + static_cast(-1), + &temp_dx, + static_cast(0)); + blas.MatMul(temp_dx, + no_trans_desc, + *Out, + trans_desc, + static_cast(1), + dX, + static_cast(0)); + return; + } + + DenseTensor new_x; + new_x.Resize(X->dims()); + ctx.template Alloc(&new_x); + int new_n = n; + if (n > 0) { + // newX = X + paddle::framework::TensorCopy(*X, ctx.GetPlace(), ctx, &new_x); + } else { + // newX = X^{-1}, n = -n + phi::funcs::MatrixInverseFunctor mat_inv; + mat_inv(ctx, *X, &new_x); + new_n = -n; + } + + // Use chain rule blow to compute \nabla newX^{n} + // First, Get newX^{0}, newX^{1}, ..., newX^{n - 1}, + // Note that newX^{0} can be omitted + std::vector> tensor_list(new_n - 1); + tensor_list[0] = std::make_shared(new_x); + int index = 1; + while (index < new_n - 1) { + DenseTensor tensor_list_index; + tensor_list_index.Resize(X->dims()); + ctx.template Alloc(&tensor_list_index); + tensor_list[index] = std::make_shared(tensor_list_index); + + blas.MatMul(*tensor_list[index - 1], + no_trans_desc, + new_x, + no_trans_desc, + static_cast(1), + tensor_list[index].get(), + static_cast(0)); + index++; + } + + // Second, \nabla newX = \sum_{i = 0}^{n - 1} (newX^{T}^{i} + // * \nabla Out + // * (newX^{T}^{n - i - 1}) + DenseTensor dx_new; + dx_new.Resize(X->dims()); + ctx.template Alloc(&dx_new); + blas.MatMul(*tensor_list[new_n - 2], + trans_desc, + *dOut, + no_trans_desc, + static_cast(1), + &dx_new, + static_cast(0)); + DenseTensor da_an_minus1; + da_an_minus1.Resize(X->dims()); + ctx.template Alloc(&da_an_minus1); + blas.MatMul(*dOut, + no_trans_desc, + *tensor_list[new_n - 2], + trans_desc, + static_cast(1), + &da_an_minus1, + static_cast(0)); + blas.AXPY( + X->numel(), static_cast(1), da_an_minus1.data(), dx_new.data()); + int start = 0; + while (start < new_n - 2) { + DenseTensor a_da; + a_da.Resize(X->dims()); + ctx.template Alloc(&a_da); + DenseTensor a_da_a; + a_da_a.Resize(X->dims()); + ctx.template Alloc(&a_da_a); + blas.MatMul(*tensor_list[start], + trans_desc, + *dOut, + no_trans_desc, + static_cast(1), + &a_da, + static_cast(0)); + blas.MatMul(a_da, + no_trans_desc, + *tensor_list[new_n - 3 - start], + trans_desc, + static_cast(1), + &a_da_a, + static_cast(0)); + blas.AXPY( + X->numel(), static_cast(1), a_da_a.data(), dx_new.data()); + start++; + } + + if (n > 0) { + // \nabla X = \nabla newX + paddle::framework::TensorCopy(dx_new, ctx.GetPlace(), ctx, dX); + } else { + // \nabla X = newX^{T} * \nabla newX * newX^{T} + DenseTensor temp_dx; + temp_dx.Resize(X->dims()); + ctx.template Alloc(&temp_dx); + blas.MatMul(new_x, + trans_desc, + dx_new, + no_trans_desc, + static_cast(-1), + &temp_dx, + static_cast(0)); + blas.MatMul(temp_dx, + no_trans_desc, + new_x, + trans_desc, + static_cast(1), + dX, + static_cast(0)); + } + return; +} + +template +void MatrixPowerGradKernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& out, + const DenseTensor& out_grad, + int n, + DenseTensor* x_grad) { + auto X = &x; + auto Out = &out; + auto dOut = &out_grad; + auto dX = x_grad; + + MatrixPowerGradFunction(X, Out, dOut, n, dX, ctx); +} + +} // namespace phi diff --git a/paddle/phi/kernels/impl/matrix_power_kernel_impl.h b/paddle/phi/kernels/impl/matrix_power_kernel_impl.h new file mode 100644 index 00000000000000..ccc5e8757e8766 --- /dev/null +++ b/paddle/phi/kernels/impl/matrix_power_kernel_impl.h @@ -0,0 +1,203 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" +#include "paddle/phi/kernels/funcs/for_range.h" +#include "paddle/phi/kernels/funcs/matrix_inverse.h" + +namespace phi { + +template +struct IdentityMatrixFunctor { + IdentityMatrixFunctor(const int m, T* output) : m_(m), output_(output) {} + + HOSTDEVICE void operator()(size_t index) const { + const int row = index / m_ % m_; + const int col = index % m_; + output_[index] = col == row ? static_cast(1) : static_cast(0); + } + + const int m_; + T* output_; +}; + +template +void MatrixPowerFunction(const DenseTensor* X, + const int n, + DenseTensor* Out, + const Context& ctx) { + const auto& x_dims = X->dims(); + const int x_ndim = x_dims.size(); + T* out_data = ctx.template Alloc(Out); + + phi::funcs::ForRange for_range(ctx, X->numel()); + + if (n == 0) { + // Out = Identity Matrix + IdentityMatrixFunctor functor(x_dims[x_ndim - 1], out_data); + for_range(functor); + return; + } + + auto blas = phi::funcs::GetBlas(ctx); + + DenseTensor new_x; + new_x.Resize(X->dims()); + ctx.template Alloc(&new_x); + int new_n = n; + if (n > 0) { + // newX = X + paddle::framework::TensorCopy(*X, ctx.GetPlace(), ctx, &new_x); + } else { + // newX = X^{-1}, n = -n + phi::funcs::MatrixInverseFunctor mat_inv; + mat_inv(ctx, *X, &new_x); + new_n = -n; + } + + if (new_n == 1) { + paddle::framework::TensorCopy(new_x, ctx.GetPlace(), ctx, Out); + return; + } + + auto no_trans_desc = phi::funcs::CreateMatrixDescriptor(x_dims, 0, false); + + if (new_n == 2) { + // Out = newX * newX + ctx.template Alloc(Out); + blas.MatMul(new_x, + no_trans_desc, + new_x, + no_trans_desc, + static_cast(1), + Out, + static_cast(0)); + return; + } else if (new_n == 3) { + // Out = (newX * newX) * newX + // Note: C[i] matrices in MatMul must not overlap, i.e. the individual + // gemm operations must be computable independently; otherwise, + // undefined behavior is expected. + DenseTensor temp; + temp.Resize(X->dims()); + ctx.template Alloc(&temp); + blas.MatMul(new_x, + no_trans_desc, + new_x, + no_trans_desc, + static_cast(1), + &temp, + static_cast(0)); + blas.MatMul(temp, + no_trans_desc, + new_x, + no_trans_desc, + static_cast(1), + Out, + static_cast(0)); + return; + } else if (new_n == 4) { + // Out = (newX * newX) * (newX * newX) + DenseTensor temp; + temp.Resize(X->dims()); + ctx.template Alloc(&temp); + blas.MatMul(new_x, + no_trans_desc, + new_x, + no_trans_desc, + static_cast(1), + &temp, + static_cast(0)); + blas.MatMul(temp, + no_trans_desc, + temp, + no_trans_desc, + static_cast(1), + Out, + static_cast(0)); + return; + } + + // Calculate Out = newX^{n} for abs(n) > 4 with time complexity as O(logN) + int bit = 0; + DenseTensor z = DenseTensor(X->dtype()); + bool out_inited = false; + DenseTensor temp_out; + temp_out.Resize(X->dims()); + ctx.template Alloc(&temp_out); + DenseTensor temp_z; + temp_z.Resize(X->dims()); + ctx.template Alloc(&temp_z); + while (new_n > 0) { + bit = new_n & 0x1; + new_n >>= 1; + if (z.IsInitialized()) { + blas.MatMul(z, + no_trans_desc, + z, + no_trans_desc, + static_cast(1), + &temp_z, + static_cast(0)); + paddle::framework::TensorCopy(temp_z, ctx.GetPlace(), ctx, &z); + } else { + z.Resize(X->dims()); + ctx.template Alloc(&z); + paddle::framework::TensorCopy(new_x, ctx.GetPlace(), ctx, &z); + } + if (bit == 1) { + if (out_inited == true) { + blas.MatMul(*Out, + no_trans_desc, + z, + no_trans_desc, + static_cast(1), + &temp_out, + static_cast(0)); + paddle::framework::TensorCopy(temp_out, ctx.GetPlace(), ctx, Out); + } else { + paddle::framework::TensorCopy(z, ctx.GetPlace(), ctx, Out); + out_inited = true; + } + } + } + return; +} + +template +void MatrixPowerKernel(const Context& ctx, + const DenseTensor& x, + int n, + DenseTensor* out) { + const DenseTensor* X = &x; + auto Out = out; + + const auto& x_dims = X->dims(); + const int x_ndim = x_dims.size(); + PADDLE_ENFORCE_EQ( + x_dims[x_ndim - 2], + x_dims[x_ndim - 1], + errors::InvalidArgument( + "The inner-most 2 dimensions of Input(X) should be equal." + "X's shape[-2] = %d and shape[-1] = %d.", + x_dims[x_ndim - 2], + x_dims[x_ndim - 1])); + + MatrixPowerFunction(X, n, Out, ctx); +} + +} // namespace phi diff --git a/paddle/phi/kernels/matrix_power_grad_kernel.h b/paddle/phi/kernels/matrix_power_grad_kernel.h new file mode 100644 index 00000000000000..4f70cf6e34d491 --- /dev/null +++ b/paddle/phi/kernels/matrix_power_grad_kernel.h @@ -0,0 +1,29 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void MatrixPowerGradKernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& out, + const DenseTensor& out_grad, + int n, + DenseTensor* x_grad); + +} // namespace phi diff --git a/paddle/phi/kernels/matrix_power_kernel.h b/paddle/phi/kernels/matrix_power_kernel.h new file mode 100644 index 00000000000000..39a1bc85e3fe77 --- /dev/null +++ b/paddle/phi/kernels/matrix_power_kernel.h @@ -0,0 +1,27 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void MatrixPowerKernel(const Context& ctx, + const DenseTensor& x, + int n, + DenseTensor* out); + +} // namespace phi diff --git a/paddle/phi/ops/compat/matrix_power_sig.cc b/paddle/phi/ops/compat/matrix_power_sig.cc new file mode 100644 index 00000000000000..4c9ad4e74ab460 --- /dev/null +++ b/paddle/phi/ops/compat/matrix_power_sig.cc @@ -0,0 +1,30 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature MatrixPowerGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature("matrix_power_grad", + {"X", "Out", GradVarName("Out")}, + {"n"}, + {GradVarName("X")}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(matrix_power_grad, + phi::MatrixPowerGradOpArgumentMapping);