Skip to content

Commit

Permalink
[cherry-pick 2.5][Zero-Dim] paddle.nanmedian/count_nonzero/logspace s…
Browse files Browse the repository at this point in the history
…upport 0D, add some 0D case (#54649)

* [Zero-Dim] add 0D test case (#54581)

* [Zero-Dim] paddle.nanmedian/nanquantile support 0D Tensor (#54500)

* [Zero-Dim] paddle.nanmedian support 0D Tensor

* fix CI
  • Loading branch information
zhwesky2010 authored Jun 14, 2023
1 parent cf64aa0 commit 35de47b
Show file tree
Hide file tree
Showing 14 changed files with 693 additions and 422 deletions.
38 changes: 19 additions & 19 deletions paddle/phi/infermeta/multiary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2162,32 +2162,32 @@ void LogspaceInferMeta(const MetaTensor& start,
MetaTensor* out) {
auto s_dims = start.dims();
PADDLE_ENFORCE_EQ(
(s_dims.size() == 1) && (s_dims[0] == 1),
true,
phi::errors::InvalidArgument("The shape of Input(Start) must be [1],"
"but received input shape is [%s].",
s_dims));
phi::product(s_dims),
1,
phi::errors::InvalidArgument("The size of Input(Start) must be 1,"
"but received input size is %s.",
phi::product(s_dims)));
auto e_dims = stop.dims();
PADDLE_ENFORCE_EQ(
(e_dims.size() == 1) && (e_dims[0] == 1),
phi::product(e_dims),
true,
phi::errors::InvalidArgument("The shape of Input(Stop) must be [1],"
"but received input shape is [%s].",
e_dims));
phi::errors::InvalidArgument("The size of Input(Stop) must be 1,"
"but received input size is %s.",
phi::product(e_dims)));
auto num_dims = number.dims();
PADDLE_ENFORCE_EQ(
(num_dims.size() == 1) && (num_dims[0] == 1),
phi::product(num_dims),
true,
phi::errors::InvalidArgument("The shape of Input(Num) must be [1],"
"but received input shape is [%s].",
num_dims));
phi::errors::InvalidArgument("The size of Input(Num) must be 1,"
"but received input size is %s.",
phi::product(num_dims)));
auto b_dims = base.dims();
PADDLE_ENFORCE_EQ(
(b_dims.size() == 1) && (b_dims[0] == 1),
true,
phi::errors::InvalidArgument("The shape of Input(Base) must be [1],"
"but received input shape is [%s].",
b_dims));
PADDLE_ENFORCE_EQ(phi::product(b_dims),
true,
phi::errors::InvalidArgument(
"The size of Input(Base) must be 1,"
"but received input size is phi::product(b_dims).",
phi::product(b_dims)));
out->set_dims(phi::make_ddim({-1}));
out->set_dtype(dtype);
}
Expand Down
50 changes: 30 additions & 20 deletions paddle/phi/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2260,37 +2260,47 @@ void NanmedianInferMeta(const MetaTensor& x,
for (int64_t i = 0; i < x_rank; i++) {
out_dim.push_back(1);
}
} else {
out_dim.push_back(1);
}
} else {
std::vector<int64_t> cleaned_axis;
std::vector<int64_t> formated_axis;
for (auto& axis : axis_list) {
if (x_rank == 0) {
PADDLE_ENFORCE_EQ(axis == 0 || axis == -1,
true,
phi::errors::InvalidArgument(
"When input 0D Tensor, each element of the axis "
"can only be -1, 0, None"));
} else {
PADDLE_ENFORCE_LT(axis,
x_rank,
errors::InvalidArgument(
"each element of the axis should be in the "
"range [ -dimension(X), dimension(X) ) "
"which dimesion = %d. But received axis = %d.",
x_rank,
axis));
PADDLE_ENFORCE_GE(axis,
-x_rank,
errors::InvalidArgument(
"each element of the axis should be in the "
"range [ -dimension(X), dimension(X) ) "
"which dimesion = %d. But received axis = %d.",
x_rank,
axis));
}
if (axis < 0) axis += x_rank;

PADDLE_ENFORCE_LT(
axis,
x_rank,
errors::InvalidArgument(
"Attr(axis) value should be in range [-R, R-1], R is "
"the rank of Input(X). But received axis: %d, R: %d. "
"Current Input(X)'s shape is=[%s].",
axis,
x_rank,
x_dim));

PADDLE_ENFORCE_EQ(
std::find(cleaned_axis.begin(), cleaned_axis.end(), axis),
cleaned_axis.end(),
std::find(formated_axis.begin(), formated_axis.end(), axis),
formated_axis.end(),
errors::InvalidArgument("Attr(axes) has duplicated elements: %d.",
static_cast<int>(axis)));

cleaned_axis.push_back(axis);
formated_axis.push_back(axis);
}

for (int64_t i = 0; i < x_rank; i++) {
if (std::find(cleaned_axis.begin(), cleaned_axis.end(), i) ==
cleaned_axis.end()) {
if (std::find(formated_axis.begin(), formated_axis.end(), i) ==
formated_axis.end()) {
out_dim.push_back(x_dim[i]);
} else if (keep_dim) {
out_dim.push_back(1);
Expand Down
73 changes: 35 additions & 38 deletions paddle/phi/kernels/cpu/nanmedian_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/impl/nanmedian_grad_kernel_impl.h"
#include "paddle/phi/kernels/funcs/nanmedian_utils.h"

namespace phi {

Expand All @@ -26,67 +26,64 @@ void CalcMedianGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& median_index,
const DenseTensor& out_grad,
const IntArray& axes UNUSED,
DenseTensor* x_grad,
T* x_grad_ptr) {
DenseTensor* x_grad) {
T* dx_data = dev_ctx.template Alloc<T>(x_grad);
if (!dx_data) return;

phi::funcs::SetConstant<Context, T> set_zero;
set_zero(dev_ctx, x_grad, static_cast<T>(0));
if (!x_grad_ptr) return;

const int64_t* m_ptr = median_index.data<int64_t>();
const T* out_grad_ptr = out_grad.data<T>();
const int64_t* m_data = median_index.data<int64_t>();
const T* dout_data = out_grad.data<T>();
int64_t numel = x.numel();
auto x_dim = x.dims();
int64_t rank = x_dim.size();
int64_t stride = x_dim[rank - 1];

int64_t pre_dim = numel / stride;

int64_t i = 0;
int64_t offset = 0;
T div_factor = static_cast<T>(2.0);
for (i = 0; i < pre_dim; i++) {
if (m_ptr[2 * i] >= 0) {
if (m_ptr[2 * i] == m_ptr[2 * i + 1]) {
x_grad_ptr[offset + m_ptr[2 * i]] = out_grad_ptr[i];
if (m_data[2 * i] >= 0) {
if (m_data[2 * i] == m_data[2 * i + 1]) {
dx_data[offset + m_data[2 * i]] = dout_data[i];
} else {
x_grad_ptr[offset + m_ptr[2 * i]] = out_grad_ptr[i] / div_factor;
x_grad_ptr[offset + m_ptr[2 * i + 1]] = out_grad_ptr[i] / div_factor;
dx_data[offset + m_data[2 * i]] = dout_data[i] / static_cast<T>(2.0);
dx_data[offset + m_data[2 * i + 1]] =
dout_data[i] / static_cast<T>(2.0);
}
}
offset += stride;
}
}

template <typename T, typename Context>
void BaseMedianGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& median_index,
const DenseTensor& out_grad,
const IntArray& axes,
DenseTensor* x_grad) {
auto rank = x.dims().size();
T* x_grad_ptr = dev_ctx.template Alloc<T>(x_grad);
if (axes.size() && (rank > 1)) {
DenseTensor tmp_x_grad(*x_grad);
CalcMedianGradKernel<T, Context>(
dev_ctx, x, median_index, out_grad, axes, &tmp_x_grad, x_grad_ptr);
PostprocessMedianGradKernel<T, Context>(dev_ctx, &tmp_x_grad, axes, x_grad);
} else {
CalcMedianGradKernel<T, Context>(
dev_ctx, x, median_index, out_grad, axes, x_grad, x_grad_ptr);
}
}

template <typename T, typename Context>
void NanmedianGradKernel(const Context& dev_ctx,
const DenseTensor& input,
const DenseTensor& x,
const DenseTensor& median_index,
const DenseTensor& out_grad,
const IntArray& axes,
bool keep_dim UNUSED,
bool keepdim UNUSED,
DenseTensor* x_grad) {
BaseMedianGradKernel<T, Context>(
dev_ctx, input, median_index, out_grad, axes, x_grad);
DenseTensor tmp_x;
auto rank = x.dims().size();
if ((axes.size() == 0) || rank <= 1) {
tmp_x = x;
tmp_x.Resize({x.numel()});
CalcMedianGradKernel<T, Context>(
dev_ctx, tmp_x, median_index, out_grad, x_grad);
} else {
funcs::PreprocessMedianKernel<T, Context>(dev_ctx, x, axes, &tmp_x);

DenseTensor tmp_x_grad;
tmp_x_grad.Resize(x_grad->dims());
CalcMedianGradKernel<T, Context>(
dev_ctx, tmp_x, median_index, out_grad, &tmp_x_grad);

dev_ctx.template Alloc<T>(x_grad);
funcs::PostprocessMedianGradKernel<T, Context>(
dev_ctx, &tmp_x_grad, axes, x_grad);
}
}

} // namespace phi
Expand Down
69 changes: 28 additions & 41 deletions paddle/phi/kernels/cpu/nanmedian_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/nanmedian_kernel_impl.h"
#include "paddle/phi/kernels/funcs/nanmedian_utils.h"
#include "paddle/phi/kernels/top_k_kernel.h"

namespace phi {
Expand All @@ -31,7 +31,6 @@ void CalcMedianFunc(const Context& dev_ctx,
int64_t pre_dim,
T* o_ptr,
int64_t* m_ptr) {
bool should_ignore_nan = ignore_nan;
DenseTensor sort_out;
DenseTensor sort_indices;
auto sort_dim = x.dims();
Expand All @@ -52,7 +51,7 @@ void CalcMedianFunc(const Context& dev_ctx,
int64_t offset = 0;
int64_t i = 0;
bool is_ori_odd = stride & 1;
if (should_ignore_nan) {
if (ignore_nan) {
for (i = 0; i < pre_dim; i++) {
offset = i * sort_k;
if (nan_counts[i] == stride) {
Expand Down Expand Up @@ -107,11 +106,11 @@ void CalcMedianFunc(const Context& dev_ctx,
template <typename T, typename Context>
void ProcessMedianKernel(const Context& dev_ctx,
const DenseTensor& x,
T* o_ptr,
int64_t* m_ptr,
bool ignore_nan) {
bool should_ignore_nan = ignore_nan;
const T* x_ptr = x.data<T>();
DenseTensor* out,
DenseTensor* median_index) {
const T* x_data = x.data<T>();
T* out_data = dev_ctx.template Alloc<T>(out);
int64_t* m_data = dev_ctx.template Alloc<int64_t>(median_index);

int64_t numel = x.numel();
auto x_dim = x.dims();
Expand All @@ -122,7 +121,8 @@ void ProcessMedianKernel(const Context& dev_ctx,

int64_t max_valid_num = 0;
std::vector<int64_t> nan_counts;
if (should_ignore_nan) {
bool ignore_nan = true;
if (ignore_nan) {
int64_t total_nan_num = 0;
std::vector<T> col_vec;
col_vec.reserve(stride);
Expand All @@ -133,7 +133,7 @@ void ProcessMedianKernel(const Context& dev_ctx,
for (int64_t i = 0; i < pre_dim; i++) {
col_vec.clear();
col_vec.insert(
col_vec.begin(), x_ptr + i * stride, x_ptr + (i + 1) * stride);
col_vec.begin(), x_data + i * stride, x_data + (i + 1) * stride);
nan_counts[i] =
std::count_if(col_vec.begin(), col_vec.end(), [&](const T& val) {
return std::isnan(static_cast<float>(val));
Expand All @@ -145,47 +145,25 @@ void ProcessMedianKernel(const Context& dev_ctx,
// all elems are nan
if (total_nan_num == numel) {
for (i = 0; i < pre_dim; i++) {
o_ptr[i] = x_ptr[0];
m_ptr[2 * i] = -1;
m_ptr[2 * i + 1] = -1;
out_data[i] = std::numeric_limits<T>::quiet_NaN();
m_data[2 * i] = -1;
m_data[2 * i + 1] = -1;
}
return;
}
should_ignore_nan = total_nan_num > 0;
ignore_nan = total_nan_num > 0;
}

int64_t sort_k = should_ignore_nan ? max_valid_num : ((stride >> 1) + 1);
int64_t sort_k = ignore_nan ? max_valid_num : ((stride >> 1) + 1);
CalcMedianFunc<T, Context>(dev_ctx,
x,
nan_counts,
should_ignore_nan,
ignore_nan,
sort_k,
stride,
pre_dim,
o_ptr,
m_ptr);
}

template <typename T, typename Context>
void BaseMedianKernel(const Context& dev_ctx,
const DenseTensor& input,
const IntArray& axes,
DenseTensor* out,
DenseTensor* median_index,
bool ignore_nan) {
DenseTensor x;
auto rank = input.dims().size();
if ((axes.size() == 0) || rank <= 1) {
x = input;
x.Resize({input.numel()});
} else {
PreprocessMedianKernel<T, Context>(dev_ctx, input, axes, &x);
}

T* o_ptr = dev_ctx.template Alloc<T>(out);
int64_t* m_ptr = dev_ctx.template Alloc<int64_t>(median_index);
ProcessMedianKernel<T, Context>(dev_ctx, x, o_ptr, m_ptr, ignore_nan);
out->Resize(out->dims());
out_data,
m_data);
}

template <typename T, typename Context>
Expand All @@ -195,7 +173,16 @@ void NanmedianKernel(const Context& dev_ctx,
bool keepdim UNUSED,
DenseTensor* out,
DenseTensor* median_index) {
BaseMedianKernel<T, Context>(dev_ctx, x, axes, out, median_index, true);
DenseTensor tmp_x;
auto rank = x.dims().size();
if ((axes.size() == 0) || rank <= 1) {
tmp_x = x;
tmp_x.Resize({x.numel()});
} else {
funcs::PreprocessMedianKernel<T, Context>(dev_ctx, x, axes, &tmp_x);
}

ProcessMedianKernel<T, Context>(dev_ctx, tmp_x, out, median_index);
}

} // namespace phi
Expand Down
Loading

0 comments on commit 35de47b

Please sign in to comment.