Skip to content

Commit

Permalink
modify the bf16 process and fix the max op
Browse files Browse the repository at this point in the history
  • Loading branch information
Vvsmile committed Jun 21, 2023
1 parent 7bf3ed9 commit 1b5431b
Show file tree
Hide file tree
Showing 4 changed files with 125 additions and 87 deletions.
75 changes: 40 additions & 35 deletions paddle/phi/kernels/funcs/elementwise_grad_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ limitations under the License. */

#include "paddle/phi/backends/context_pool.h"
#include "paddle/phi/backends/gpu/gpu_info.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/common/memory_utils.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/funcs/common_shape.h"
Expand Down Expand Up @@ -114,41 +115,43 @@ static void ElemwiseGradBroadcast1CPU(const T *x,
DY_OP dy_op,
T *dx,
T *dy) {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;

if (is_xsize_larger) {
for (int i = 0; i < h; ++i) {
for (int j = 0; j < w; ++j) {
for (int j = 0; j < w; ++j) {
MPType sum_y = static_cast<MPType>(0);
for (int i = 0; i < h; ++i) {
int x_offset = i * w + j;
if (dx != nullptr) {
dx[x_offset] =
dx_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
}
if (dy != nullptr) {
T tmp = dy_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
if (i == 0) {
dy[j] = tmp;
} else {
dy[j] += tmp;
}
sum_y += static_cast<MPType>(
dy_op(x[x_offset], y[j], out[x_offset], dout[x_offset]));
}
}
if (dy != nullptr) {
dy[j] = static_cast<T>(sum_y);
}
}
} else { // x.dims < y.dims, broadcast for x.
for (int i = 0; i < h; ++i) {
for (int j = 0; j < w; ++j) {
} else {
for (int j = 0; j < w; ++j) {
MPType sum_x = static_cast<MPType>(0);
for (int i = 0; i < h; ++i) {
int y_offset = i * w + j;
if (dy != nullptr) {
dy[y_offset] =
dy_op(x[j], y[y_offset], out[y_offset], dout[y_offset]);
}
if (dx != nullptr) {
T tmp = dx_op(x[j], y[y_offset], out[y_offset], dout[y_offset]);
if (i == 0) {
dx[j] = tmp;
} else {
dx[j] += tmp;
}
sum_x += static_cast<MPType>(
dx_op(x[j], y[y_offset], out[y_offset], dout[y_offset]));
}
}
if (dx != nullptr) {
dx[j] = static_cast<T>(sum_x);
}
}
}
}
Expand All @@ -166,45 +169,47 @@ static void ElemwiseGradBroadcast2CPU(const T *x,
DY_OP dy_op,
T *dx,
T *dy) {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;

if (is_xsize_larger) {
for (int i = 0; i < pre; ++i) {
for (int j = 0; j < n; ++j) {
for (int j = 0; j < n; ++j) {
MPType sum_y = static_cast<MPType>(0);
for (int i = 0; i < pre; ++i) {
for (int k = 0; k < post; ++k) {
int x_offset = i * n * post + j * post + k;
if (dx != nullptr) {
dx[x_offset] =
dx_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
}
if (dy != nullptr) {
T tmp = dy_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
if (i == 0 && k == 0) {
dy[j] = tmp;
} else {
dy[j] += tmp;
}
sum_y += static_cast<MPType>(
dy_op(x[x_offset], y[j], out[x_offset], dout[x_offset]));
}
}
}
if (dy != nullptr) {
dy[j] = static_cast<T>(sum_y);
}
}
} else { // x.dims < y.dims, broadcast for x.
for (int i = 0; i < pre; ++i) {
for (int j = 0; j < n; ++j) {
} else {
for (int j = 0; j < n; ++j) {
MPType sum_x = static_cast<MPType>(0);
for (int i = 0; i < pre; ++i) {
for (int k = 0; k < post; ++k) {
int y_offset = i * n * post + j * post + k;
if (dy != nullptr) {
dy[y_offset] =
dy_op(x[j], y[y_offset], out[y_offset], dout[y_offset]);
}
if (dx != nullptr) {
T tmp = dx_op(x[j], y[y_offset], out[y_offset], dout[y_offset]);
if (i == 0 && k == 0) {
dx[j] = tmp;
} else {
dx[j] += tmp;
}
sum_x += static_cast<MPType>(
dx_op(x[j], y[y_offset], out[y_offset], dout[y_offset]));
}
}
}
if (dx != nullptr) {
dx[j] = static_cast<T>(sum_x);
}
}
}
}
Expand Down Expand Up @@ -397,7 +402,7 @@ void ElemwiseGradComputeNoBroadcast(const DeviceContext &dev_ctx,
for_range(ElemwiseGradNoBroadcast<T, DX_OP, DY_OP, Tout>{
x.data<T>(),
y.data<T>(),
out.data<Tout>(),
out.data<Tout>(), /* */
dout.data<Tout>(),
dx_op,
dy_op,
Expand Down
Loading

0 comments on commit 1b5431b

Please sign in to comment.