Skip to content

Commit

Permalink
[Zero-Dim] fix batch_norm op infermeta bug (#47858)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhwesky2010 authored Nov 11, 2022
1 parent 17dffd1 commit 1854941
Show file tree
Hide file tree
Showing 8 changed files with 69 additions and 54 deletions.
3 changes: 3 additions & 0 deletions paddle/fluid/operators/batch_norm_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,9 @@ void BatchNormOp::InferShape(framework::InferShapeContext *ctx) const {
ctx->SetOutputDim("SavedMean", {C});
ctx->SetOutputDim("SavedVariance", {C});
ctx->ShareLoD("X", "Y");
if (ctx->HasInput("ReserveSpace")) {
ctx->SetOutputDim("ReserveSpace", {-1});
}
}

framework::OpKernelType BatchNormOp::GetExpectedKernelType(
Expand Down
13 changes: 7 additions & 6 deletions paddle/fluid/operators/common_infer_shape_functions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,13 @@ inline void GetBroadcastDimsArrays(const framework::DDim &x_dims,
platform::errors::InvalidArgument(
"Axis should be great than or equal to 0, but received axis is %d.",
axis));
PADDLE_ENFORCE_LE(axis,
max_dim,
platform::errors::InvalidArgument(
"Axis should be less than %d, but received axis is %d.",
max_dim,
axis));
PADDLE_ENFORCE_LE(
axis,
max_dim,
platform::errors::InvalidArgument(
"Axis should be less than or equal to %d, but received axis is %d.",
max_dim,
axis));
if (x_dims.size() > y_dims.size()) {
std::fill(y_dims_array, y_dims_array + axis, 1);
if (axis + y_dims.size() < max_dim) {
Expand Down
13 changes: 7 additions & 6 deletions paddle/fluid/operators/elementwise/elementwise_npu.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,12 +123,13 @@ void NpuElementWiseOpBroadcast(const platform::NPUDeviceContext& dev_ctx,
platform::errors::InvalidArgument(
"Axis should be great than or equal to 0, but received axis is %d.",
axis));
PADDLE_ENFORCE_LE(axis,
max_dim,
platform::errors::InvalidArgument(
"Axis should be less than %d, but received axis is %d.",
max_dim,
axis));
PADDLE_ENFORCE_LE(
axis,
max_dim,
platform::errors::InvalidArgument(
"Axis should be less than or equal to %d, but received axis is %d.",
max_dim,
axis));

for (int i = 0; i < x_dims.size(); ++i) {
dst_dims_vec[i + x_axis] =
Expand Down
3 changes: 3 additions & 0 deletions paddle/phi/infermeta/multiary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -640,6 +640,9 @@ void BatchNormInferMeta(const MetaTensor& x,
if (saved_variance) {
saved_variance->set_dims({C});
}
if (reserve_space) {
reserve_space->set_dims({-1});
}
y->share_lod(x);
y->set_dtype(x.dtype());
}
Expand Down
13 changes: 7 additions & 6 deletions paddle/phi/kernels/funcs/common_shape.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,13 @@ inline void GetBroadcastDimsArrays(const DDim &x_dims,
phi::errors::InvalidArgument(
"Axis should be great than or equal to 0, but received axis is %d.",
axis));
PADDLE_ENFORCE_LE(axis,
max_dim,
phi::errors::InvalidArgument(
"Axis should be less than %d, but received axis is %d.",
max_dim,
axis));
PADDLE_ENFORCE_LE(
axis,
max_dim,
phi::errors::InvalidArgument(
"Axis should be less than or equal to %d, but received axis is %d.",
max_dim,
axis));
if (x_dims.size() > y_dims.size()) {
std::fill(y_dims_array, y_dims_array + axis, 1);
if (axis + y_dims.size() < max_dim) {
Expand Down
26 changes: 14 additions & 12 deletions paddle/phi/kernels/funcs/elementwise_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -326,12 +326,13 @@ void CommonElementwiseBroadcastForward(const CPUContext &dev_ctx,
phi::errors::InvalidArgument(
"Axis should be great than or equal to 0, but received axis is %d.",
axis));
PADDLE_ENFORCE_LE(axis,
max_dim,
phi::errors::InvalidArgument(
"Axis should be less than %d, but received axis is %d.",
max_dim,
axis));
PADDLE_ENFORCE_LE(
axis,
max_dim,
phi::errors::InvalidArgument(
"Axis should be less than or equal to %d, but received axis is %d.",
max_dim,
axis));
std::vector<int> x_dims_array(max_dim);
std::vector<int> y_dims_array(max_dim);
std::vector<int> out_dims_array(max_dim);
Expand Down Expand Up @@ -394,12 +395,13 @@ void ElementwiseCompute(const CPUContext &dev_ctx,
errors::InvalidArgument(
"Axis should be great than or equal to 0, but received axis is %d.",
axis));
PADDLE_ENFORCE_LE(axis,
max_dim,
errors::InvalidArgument(
"Axis should be less than %d, but received axis is %d.",
max_dim,
axis));
PADDLE_ENFORCE_LE(
axis,
max_dim,
errors::InvalidArgument(
"Axis should be less than or equal to %d, but received axis is %d.",
max_dim,
axis));

int pre, n, post, is_run_common_broadcast, axis_trim = 0;
if (is_xsize_larger) {
Expand Down
26 changes: 14 additions & 12 deletions paddle/phi/kernels/funcs/elementwise_grad_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -287,12 +287,13 @@ void ElemwiseGradComputeWithBroadcast(const CPUContext &ctx,
errors::InvalidArgument(
"Axis should be great than or equal to 0, but received axis is %d.",
axis));
PADDLE_ENFORCE_LE(axis,
max_dim,
errors::InvalidArgument(
"Axis should be less than %d, but received axis is %d.",
max_dim,
axis));
PADDLE_ENFORCE_LE(
axis,
max_dim,
errors::InvalidArgument(
"Axis should be less than or equal to %d, but received axis is %d.",
max_dim,
axis));

int pre, n, post, is_run_common_broadcast, axis_trim = 0;
if (is_xsize_larger) {
Expand Down Expand Up @@ -1725,12 +1726,13 @@ void ElemwiseGradComputeWithBroadcast(const GPUContext &ctx,
errors::InvalidArgument(
"Axis should be great than or equal to 0, but received axis is %d.",
axis));
PADDLE_ENFORCE_LE(axis,
max_dim,
errors::InvalidArgument(
"Axis should be less than %d, but received axis is %d.",
max_dim,
axis));
PADDLE_ENFORCE_LE(
axis,
max_dim,
errors::InvalidArgument(
"Axis should be less than or equal to %d, but received axis is %d.",
max_dim,
axis));

int pre, n, post, is_run_common_broadcast, axis_trim = 0;
if (is_xsize_larger) {
Expand Down
26 changes: 14 additions & 12 deletions paddle/phi/kernels/xpu/elementwise.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,13 @@ void XPUElementwise(const XPUContext& dev_ctx,
errors::InvalidArgument(
"Axis should be great than or equal to 0, but received axis is %d.",
axis));
PADDLE_ENFORCE_LE(axis,
max_dim,
errors::InvalidArgument(
"Axis should be less than %d, but received axis is %d.",
max_dim,
axis));
PADDLE_ENFORCE_LE(
axis,
max_dim,
errors::InvalidArgument(
"Axis should be less than or equal to %d, but received axis is %d.",
max_dim,
axis));
std::vector<int> x_dims_vec(max_dim, 1);
std::vector<int> y_dims_vec(max_dim, 1);
if (x_dims.size() == max_dim) {
Expand Down Expand Up @@ -121,12 +122,13 @@ void XPUElementwiseGrad(const XPUContext& dev_ctx,
errors::InvalidArgument(
"Axis should be great than or equal to 0, but received axis is %d.",
axis));
PADDLE_ENFORCE_LE(axis,
max_dim,
errors::InvalidArgument(
"Axis should be less than %d, but received axis is %d.",
max_dim,
axis));
PADDLE_ENFORCE_LE(
axis,
max_dim,
errors::InvalidArgument(
"Axis should be less than or equal to %d, but received axis is %d.",
max_dim,
axis));
std::vector<int> x_dims_vec(max_dim, 1);
std::vector<int> y_dims_vec(max_dim, 1);
if (x_dims.size() == max_dim) {
Expand Down

0 comments on commit 1854941

Please sign in to comment.