Skip to content

Commit

Permalink
[Eager] fix lerp grad kernel logic (#44705)
Browse files Browse the repository at this point in the history
  • Loading branch information
veyron95 authored Jul 28, 2022
1 parent e9b9201 commit bd813d3
Showing 1 changed file with 21 additions and 8 deletions.
29 changes: 21 additions & 8 deletions paddle/phi/kernels/impl/lerp_grad_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,22 @@ static void LerpGradFunction(const Context& ctx,
auto* dy = y_grad;

auto dout_dims = dout.dims();
auto dx_dims = phi::funcs::ExtendDims2Rank(dx->dims(), D);
auto dy_dims = phi::funcs::ExtendDims2Rank(dy->dims(), D);
DDim dx_dims;
DDim dy_dims;

auto w_dims = phi::funcs::ExtendDims2Rank(w.dims(), D);
Eigen::DSizes<int, D> dx_bcast_dims;
Eigen::DSizes<int, D> dy_bcast_dims;
Eigen::DSizes<int, D> w_bcast_dims;
phi::funcs::GetBroadcastDims<D>(dx_dims, dout_dims, &dx_bcast_dims);
phi::funcs::GetBroadcastDims<D>(dy_dims, dout_dims, &dy_bcast_dims);

if (dx) {
dx_dims = phi::funcs::ExtendDims2Rank(dx->dims(), D);
phi::funcs::GetBroadcastDims<D>(dx_dims, dout_dims, &dx_bcast_dims);
}
if (dy) {
dy_dims = phi::funcs::ExtendDims2Rank(dy->dims(), D);
phi::funcs::GetBroadcastDims<D>(dy_dims, dout_dims, &dy_bcast_dims);
}
phi::funcs::GetBroadcastDims<D>(w_dims, dout_dims, &w_bcast_dims);

auto eigen_w = phi::EigenTensor<T, D>::From(w, w_dims);
Expand All @@ -50,11 +58,16 @@ static void LerpGradFunction(const Context& ctx,
Eigen::DSizes<int, D * 2> dx_reshape_dims;
Eigen::DSizes<int, D * 2> dy_reshape_dims;
Eigen::DSizes<int, D> reduce_dims;

for (int i = 0; i < dout_dims.size(); ++i) {
dx_reshape_dims[2 * i] = dx_bcast_dims[i];
dx_reshape_dims[2 * i + 1] = dx_dims[i];
dy_reshape_dims[2 * i] = dy_bcast_dims[i];
dy_reshape_dims[2 * i + 1] = dy_dims[i];
if (dx) {
dx_reshape_dims[2 * i] = dx_bcast_dims[i];
dx_reshape_dims[2 * i + 1] = dx_dims[i];
}
if (dy) {
dy_reshape_dims[2 * i] = dy_bcast_dims[i];
dy_reshape_dims[2 * i + 1] = dy_dims[i];
}
reduce_dims[i] = 2 * i;
}

Expand Down

0 comments on commit bd813d3

Please sign in to comment.