Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Zero-Dim] Support 0D for kron #49847

Merged
merged 9 commits into from
Feb 18, 2023
130 changes: 66 additions & 64 deletions paddle/phi/kernels/impl/kron_grad_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,15 @@ namespace phi {

template <typename T>
struct KronGradElemFunctor {
KronGradElemFunctor(const T* dout,
const T* A,
const T* B,
T* dout_a,
T* dout_b,
const int64_t* stride_dout,
const int64_t* stride_a,
const int64_t* stride_b,
const int64_t* shape_b,
KronGradElemFunctor(const T *dout,
const T *A,
const T *B,
T *dout_a,
T *dout_b,
const int64_t *stride_dout,
const int64_t *stride_a,
const int64_t *stride_b,
const int64_t *shape_b,
const int64_t numel_a,
const int64_t numel_b,
const int ndims)
Expand Down Expand Up @@ -69,31 +69,31 @@ struct KronGradElemFunctor {
}

private:
const T* dout_;
const T* A_;
const T* B_;
T* dout_a_;
T* dout_b_;
const int64_t* stride_dout_;
const int64_t* stride_a_;
const int64_t* stride_b_;
const int64_t* shape_b_;
const T *dout_;
const T *A_;
const T *B_;
T *dout_a_;
T *dout_b_;
const int64_t *stride_dout_;
const int64_t *stride_a_;
const int64_t *stride_b_;
const int64_t *shape_b_;
const int64_t numel_a_;
const int64_t numel_b_;
const int ndims_;
};

template <typename T>
struct KronGradElemFunctor<dtype::complex<T>> {
KronGradElemFunctor(const dtype::complex<T>* dout,
const dtype::complex<T>* A,
const dtype::complex<T>* B,
dtype::complex<T>* dout_a,
dtype::complex<T>* dout_b,
const int64_t* stride_dout,
const int64_t* stride_a,
const int64_t* stride_b,
const int64_t* shape_b,
KronGradElemFunctor(const dtype::complex<T> *dout,
const dtype::complex<T> *A,
const dtype::complex<T> *B,
dtype::complex<T> *dout_a,
dtype::complex<T> *dout_b,
const int64_t *stride_dout,
const int64_t *stride_a,
const int64_t *stride_b,
const int64_t *shape_b,
const int64_t numel_a,
const int64_t numel_b,
const int ndims)
Expand Down Expand Up @@ -136,45 +136,47 @@ struct KronGradElemFunctor<dtype::complex<T>> {
}

private:
const dtype::complex<T>* dout_;
const dtype::complex<T>* A_;
const dtype::complex<T>* B_;
dtype::complex<T>* dout_a_;
dtype::complex<T>* dout_b_;
const int64_t* stride_dout_;
const int64_t* stride_a_;
const int64_t* stride_b_;
const int64_t* shape_b_;
const dtype::complex<T> *dout_;
const dtype::complex<T> *A_;
const dtype::complex<T> *B_;
dtype::complex<T> *dout_a_;
dtype::complex<T> *dout_b_;
const int64_t *stride_dout_;
const int64_t *stride_a_;
const int64_t *stride_b_;
const int64_t *shape_b_;
const int64_t numel_a_;
const int64_t numel_b_;
const int ndims_;
};

template <typename Context, typename T>
struct KronGradOpFunctor {
void operator()(const Context& dev_ctx,
const DenseTensor& dout,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* dx,
DenseTensor* dy) {
void operator()(const Context &dev_ctx,
const DenseTensor &dout,
const DenseTensor &x,
const DenseTensor &y,
DenseTensor *dx,
DenseTensor *dy) {
int ndims = dout.dims().size();
int64_t numel = dout.numel();
int64_t numel_x = x.numel();
int64_t numel_y = y.numel();

const phi::DDim& dim_x = x.dims();
const phi::DDim& dim_y = y.dims();
const phi::DDim& dim_dout = dout.dims();
const phi::DDim &dim_x = x.dims();
const phi::DDim &dim_y = y.dims();
const phi::DDim &dim_dout = dout.dims();
const phi::DDim stride_x =
dim_x.size() == 0 ? phi::DDim(dim_x) : phi::stride(dim_x);
const phi::DDim stride_y =
dim_y.size() == 0 ? phi::DDim(dim_y) : phi::stride(dim_y);
const phi::DDim stride_dout =
dim_dout.size() == 0 ? phi::DDim(dim_dout) : phi::stride(dim_dout);

const phi::DDim stride_x = phi::stride(dim_x);
const phi::DDim stride_y = phi::stride(dim_y);
const phi::DDim stride_dout = phi::stride(dim_dout);

const int64_t* p_stride_x = nullptr;
const int64_t* p_stride_y = nullptr;
const int64_t* p_stride_dout = nullptr;
const int64_t* p_shape_y = nullptr;
const int64_t *p_stride_x = nullptr;
const int64_t *p_stride_y = nullptr;
const int64_t *p_stride_dout = nullptr;
const int64_t *p_shape_y = nullptr;
#if defined(__NVCC__) || defined(__HIPCC__)
thrust::device_vector<int64_t> d_stride_x(ndims);
thrust::device_vector<int64_t> d_stride_y(ndims);
Expand All @@ -199,14 +201,14 @@ struct KronGradOpFunctor {
// dout_x: dout * kron(ones(X), Y) re-aranged in shape (numel_x, numel_y)
// dout_y: dout * kron(X, ones(Y)) re-aranged in shaoe (numel_y, numel_x)
DenseTensor dout_x;
T* p_dout_x = nullptr;
T *p_dout_x = nullptr;
if (dx) {
dout_x.Resize({numel_x, numel_y});
dev_ctx.template Alloc<T>(&dout_x);
p_dout_x = dout_x.data<T>();
}
DenseTensor dout_y;
T* p_dout_y = nullptr;
T *p_dout_y = nullptr;
if (dy) {
dout_y.Resize({numel_y, numel_x});
dev_ctx.template Alloc<T>(&dout_y);
Expand Down Expand Up @@ -240,7 +242,7 @@ struct KronGradOpFunctor {
dev_ctx, dout_y, dy, kps::IdentityFunctor<T>(), {1});
}
#else
auto* place = dev_ctx.eigen_device();
auto *place = dev_ctx.eigen_device();
Eigen::array<int, 1> reduce_dim = {1};
if (dx) {
auto eigen_dout_x = EigenMatrix<T>::Reshape(dout_x, 1);
Expand All @@ -257,12 +259,12 @@ struct KronGradOpFunctor {
};

template <typename T, typename Context>
void KronGradKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& out_grad,
DenseTensor* x_grad,
DenseTensor* y_grad) {
void KronGradKernel(const Context &ctx,
const DenseTensor &x,
const DenseTensor &y,
const DenseTensor &out_grad,
DenseTensor *x_grad,
DenseTensor *y_grad) {
if (x_grad) {
ctx.template Alloc<T>(x_grad);
}
Expand All @@ -274,8 +276,8 @@ void KronGradKernel(const Context& ctx,
DenseTensor xx = UnsqueezeTo(x, ndims);
DenseTensor yy = UnsqueezeTo(y, ndims);

DenseTensor* pdxx = nullptr;
DenseTensor* pdyy = nullptr;
DenseTensor *pdxx = nullptr;
DenseTensor *pdyy = nullptr;
DenseTensor dxx;
DenseTensor dyy;
if (x_grad) {
Expand Down
63 changes: 33 additions & 30 deletions paddle/phi/kernels/impl/kron_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@

namespace phi {

inline DenseTensor UnsqueezeTo(const DenseTensor& src, int ndims) {
const phi::DDim& shape = src.dims();
inline DenseTensor UnsqueezeTo(const DenseTensor &src, int ndims) {
const phi::DDim &shape = src.dims();
int rank = shape.size();
DenseTensor res;
res.ShareDataWith(src);
Expand All @@ -52,13 +52,13 @@ inline DenseTensor UnsqueezeTo(const DenseTensor& src, int ndims) {

template <typename T>
struct KronElemFunctor {
KronElemFunctor(const T* a,
const T* b,
T* out,
const int64_t* shape_b,
const int64_t* stride_a,
const int64_t* stride_b,
const int64_t* stride_out,
KronElemFunctor(const T *a,
const T *b,
T *out,
const int64_t *shape_b,
const int64_t *stride_a,
const int64_t *stride_b,
const int64_t *stride_out,
int ndims)
: a_(a),
b_(b),
Expand Down Expand Up @@ -86,31 +86,34 @@ struct KronElemFunctor {
}

private:
const T* a_;
const T* b_;
T* out_;
const int64_t* shape_b_;
const int64_t* stride_a_;
const int64_t* stride_b_;
const int64_t* stride_out_;
const T *a_;
const T *b_;
T *out_;
const int64_t *shape_b_;
const int64_t *stride_a_;
const int64_t *stride_b_;
const int64_t *stride_out_;
const int ndims_;
};

template <typename Context, typename T>
struct KronOpFunctor {
void operator()(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out) {
void operator()(const Context &dev_ctx,
const DenseTensor &x,
const DenseTensor &y,
DenseTensor *out) {
int ndims = out->dims().size();
int64_t numel = out->numel();

const phi::DDim& dim_x = x.dims();
const phi::DDim& dim_y = y.dims();
const phi::DDim& dim_out = out->dims();
const phi::DDim stride_x = phi::stride(dim_x);
const phi::DDim stride_y = phi::stride(dim_y);
const phi::DDim stride_out = phi::stride(dim_out);
const phi::DDim &dim_x = x.dims();
const phi::DDim &dim_y = y.dims();
const phi::DDim &dim_out = out->dims();
const phi::DDim stride_x =
dim_x.size() == 0 ? phi::DDim(dim_x) : phi::stride(dim_x);
const phi::DDim stride_y =
dim_y.size() == 0 ? phi::DDim(dim_y) : phi::stride(dim_y);
const phi::DDim stride_out =
dim_out.size() == 0 ? phi::DDim(dim_out) : phi::stride(dim_out);

const int64_t *p_stride_x = nullptr, *p_stride_y = nullptr,
*p_stride_out = nullptr, *p_shape_y = nullptr;
Expand Down Expand Up @@ -150,10 +153,10 @@ struct KronOpFunctor {
};

template <typename T, typename Context>
void KronKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out) {
void KronKernel(const Context &ctx,
const DenseTensor &x,
const DenseTensor &y,
DenseTensor *out) {
ctx.template Alloc<T>(out);

int ndims = out->dims().size();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,7 @@ def test_static_reduce(self):
paddle.fmax,
paddle.fmin,
paddle.complex,
paddle.kron,
]

binary_int_api_list = [
Expand Down