Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

No.1 complex kernel support for addmm #56480

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions paddle/phi/kernels/cpu/addmm_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,11 @@ limitations under the License. */
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/addmm_grad_kernel_impl.h"

PD_REGISTER_KERNEL(
addmm_grad, CPU, ALL_LAYOUT, phi::AddmmGradKernel, float, double) {}
PD_REGISTER_KERNEL(addmm_grad,
CPU,
ALL_LAYOUT,
phi::AddmmGradKernel,
float,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
9 changes: 8 additions & 1 deletion paddle/phi/kernels/cpu/addmm_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,11 @@ limitations under the License. */
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/addmm_kernel_impl.h"

PD_REGISTER_KERNEL(addmm, CPU, ALL_LAYOUT, phi::AddmmKernel, float, double) {}
PD_REGISTER_KERNEL(addmm,
CPU,
ALL_LAYOUT,
phi::AddmmKernel,
float,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
4 changes: 3 additions & 1 deletion paddle/phi/kernels/gpu/addmm_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,6 @@ PD_REGISTER_KERNEL(addmm_grad,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
4 changes: 3 additions & 1 deletion paddle/phi/kernels/gpu/addmm_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,6 @@ PD_REGISTER_KERNEL(addmm,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
163 changes: 152 additions & 11 deletions paddle/phi/kernels/impl/addmm_grad_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,17 +55,144 @@ using PhiEigenTensor = EigenTensor<T, D, MajorType, IndexType>;
using Array1 = Eigen::DSizes<Eigen::DenseIndex, 1>;
using Array2 = Eigen::DSizes<Eigen::DenseIndex, 2>;

template <typename T, typename Context>
void AddmmGradKernel(const Context& dev_ctx,
const DenseTensor& input,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& out_grad,
float alpha,
float beta,
DenseTensor* input_grad,
DenseTensor* x_grad,
DenseTensor* y_grad) {
template <
typename T,
typename Context,
std::enable_if_t<std::is_same<T, phi::dtype::complex<float>>::value ||
std::is_same<T, phi::dtype::complex<double>>::value,
bool> = true>
void AddmmGradExtKernel(const Context& dev_ctx,
const DenseTensor& input,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& out_grad,
float alpha,
float beta,
DenseTensor* input_grad,
DenseTensor* x_grad,
DenseTensor* y_grad) {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;

auto in_dims = input.dims();
if (input.dims().size() == 1) {
in_dims = {1, input.dims()[0]};
input_grad->Resize(in_dims);
}
int total_elems = 0;

VLOG(3) << "alpha: " << alpha << " beta: " << beta;

if (input_grad != nullptr) {
input_grad->set_lod(out_grad.lod());
}
if (x_grad != nullptr) {
x_grad->set_lod(x.lod());
}
if (y_grad != nullptr) {
y_grad->set_lod(y.lod());
}

auto blas = funcs::GetBlas<Context, T>(dev_ctx);
// auto mt_blas = funcs::GetBlas<Context, MPType>(dev_ctx);
if (input_grad) {
dev_ctx.template Alloc<T>(input_grad);
total_elems = in_dims[0] * in_dims[1];
auto& place = *dev_ctx.eigen_device();
auto eigen_dout = PhiEigenTensor<T, 2>::From(out_grad);
auto eigen_dinput = PhiEigenTensor<T, 2>::From(*input_grad);

bool row_compress = in_dims[0] != out_grad.dims()[0];
bool col_compress = in_dims[1] != out_grad.dims()[1];
auto eigen_dinput_shape =
Array2(input_grad->dims()[0], input_grad->dims()[1]);

if (row_compress && col_compress) {
{
eigen_dinput.device(place) = eigen_dout.template cast<MPType>()
.sum()
.eval()
.reshape(eigen_dinput_shape)
.template cast<T>();
}
} else if (row_compress) {
{
eigen_dinput.device(place) = eigen_dout.template cast<MPType>()
.sum(Array1(0))
.eval()
.reshape(eigen_dinput_shape)
.template cast<T>();
}
} else if (col_compress) {
{
eigen_dinput.device(place) = eigen_dout.template cast<MPType>()
.sum(Array1(1))
.eval()
.reshape(eigen_dinput_shape)
.template cast<T>();
}
} else {
{
phi::funcs::ForRange<Context> for_range(dev_ctx, total_elems);
CopyOrScaleFunctor<T> functor(
1, out_grad.data<T>(), input_grad->data<T>(), total_elems);
for_range(functor);
}
}

// The SCAL does not support the float16, bfloat16
{
phi::funcs::ForRange<Context> for_range(dev_ctx, total_elems);
CopyOrScaleFunctor<T> functor(
beta, input_grad->data<T>(), input_grad->data<T>(), total_elems);
for_range(functor);
}

if (input.dims().size() == 1) {
input_grad->Resize(input.dims());
}
}
if (x_grad) {
dev_ctx.template Alloc<T>(x_grad);
total_elems = x.dims()[0] * x.dims()[1];
// x_grad = out_grad * y'. x_grad: M x K, out_grad : M x N, y : K x N
blas.MatMul(out_grad, false, y, true, x_grad);
{
phi::funcs::ForRange<Context> for_range(dev_ctx, total_elems);
CopyOrScaleFunctor<T> functor(
alpha, x_grad->data<T>(), x_grad->data<T>(), total_elems);
for_range(functor);
}
}
if (y_grad) {
dev_ctx.template Alloc<T>(y_grad);
total_elems = x.dims()[1] * y.dims()[1];
// y_grad = x' * out_grad. y_grad K x N, out_grad : M x N, x : M x K
blas.MatMul(x, true, out_grad, false, y_grad);
{
phi::funcs::ForRange<Context> for_range(dev_ctx, total_elems);
CopyOrScaleFunctor<T> functor(
alpha, y_grad->data<T>(), y_grad->data<T>(), total_elems);
for_range(functor);
}
}
}

template <
typename T,
typename Context,
std::enable_if_t<!std::is_same<T, phi::dtype::complex<float>>::value &&
!std::is_same<T, phi::dtype::complex<double>>::value,
bool> = true>
void AddmmGradExtKernel(const Context& dev_ctx,
const DenseTensor& input,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& out_grad,
float alpha,
float beta,
DenseTensor* input_grad,
DenseTensor* x_grad,
DenseTensor* y_grad) {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
bool is_float16_or_bfloat16 = false;
if (std::is_same<T, phi::dtype::float16>::value ||
Expand Down Expand Up @@ -196,4 +323,18 @@ void AddmmGradKernel(const Context& dev_ctx,
}
}

template <typename T, typename Context>
void AddmmGradKernel(const Context& dev_ctx,
const DenseTensor& input,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& out_grad,
float alpha,
float beta,
DenseTensor* input_grad,
DenseTensor* x_grad,
DenseTensor* y_grad) {
AddmmGradExtKernel<T, Context>(
dev_ctx, input, x, y, out_grad, alpha, beta, input_grad, x_grad, y_grad);
}
} // namespace phi
9 changes: 3 additions & 6 deletions paddle/phi/kernels/impl/addmm_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,19 +114,16 @@ void AddmmKernel(const Context& dev_ctx,

T t_alpha = static_cast<T>(alpha);
T t_beta = static_cast<T>(beta);
blas.GEMM(false,
false,
blas.GEMM(CblasNoTrans,
CblasNoTrans,
x_dims[0],
y_dims[1],
x_dims[1],
t_alpha,
x.data<T>(),
x_dims[1],
y.data<T>(),
y_dims[1],
t_beta,
out->data<T>(),
y_dims[1]);
out->data<T>());
}

} // namespace phi
24 changes: 24 additions & 0 deletions test/legacy_test/test_addmm_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,20 @@ def setUp(self):
+ np.dot(self.inputs['X'], self.inputs['Y'])
}

if self.dtype == np.complex64 or self.dtype == np.complex128:
self.inputs['Input'] += 1j * np.random.random((100, 1)).astype(
self.dtype
)
self.inputs['X'] += 1j * np.random.random((100, 10)).astype(
self.dtype
)
self.inputs['Y'] += 1j * np.random.random((10, 20)).astype(
self.dtype
)
self.outputs['Out'] = self.inputs['Input'] + np.dot(
self.inputs['X'], self.inputs['Y']
)

def init_dtype_type(self):
self.dtype = np.float64

Expand Down Expand Up @@ -333,6 +347,16 @@ def test_api_with_dygraph(self):
)


class TestAddMMOp6(TestAddMMOp):
def init_dtype_type(self):
self.dtype = np.complex64


class TestAddMMOp7(TestAddMMOp):
def init_dtype_type(self):
self.dtype = np.complex128


class TestAddMMAPI(unittest.TestCase):
def test_api_error(self):
data_x = np.ones((2, 2)).astype(np.float32)
Expand Down