Skip to content

Commit

Permalink
[Zero-Dim] add 0D Tensor UT case for XPU and expand kernel support 0D (
Browse files Browse the repository at this point in the history
…#53555)

* [Zero-Dim] add 0D Tensor UT case for XPU

* fix comment

* remove some unnecessary UT
  • Loading branch information
zhwesky2010 authored May 9, 2023
1 parent a37ef76 commit e588f2d
Show file tree
Hide file tree
Showing 3 changed files with 1,317 additions and 161 deletions.
3 changes: 2 additions & 1 deletion paddle/phi/kernels/xpu/expand_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ void ExpandGradKernel(const Context& ctx,

// Two zero
if (out_grad_dims.size() == 0 && in_grad_dims.size() == 0) {
return;
out_grad_dims = {1};
in_grad_dims = {1};
}

int r = xpu::expand_grad<XPUType>(
Expand Down
17 changes: 4 additions & 13 deletions paddle/phi/kernels/xpu/expand_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -94,26 +94,17 @@ void ExpandKernel(const Context& ctx,
shape_size,
rank));

if (shape_size == 0) {
phi::DDim out_dims = phi::make_ddim(final_expand_shape);
out->Resize(out_dims);
ctx.template Alloc<T>(out);

int r = xpu::copy<XPUType>(ctx.x_context(),
reinterpret_cast<const XPUType*>(x.data<T>()),
reinterpret_cast<XPUType*>(out->data<T>()),
x.numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "copy");
return;
}
DDim out_dims = phi::make_ddim(final_expand_shape);
out->Resize(out_dims);
ctx.template Alloc<T>(out);
auto& x_shape = vec_in_dims;
auto out_shape = phi::vectorize<int>(out_dims);
if (shape_size == 0) {
x_shape = {1};
out_shape = {1};
}

int r = XPU_SUCCESS;

if (std::is_same<T, bool>::value) {
auto x_data = reinterpret_cast<const int8_t*>(x.data<T>());
auto out_data = reinterpret_cast<int8_t*>(out->data<T>());
Expand Down
Loading

0 comments on commit e588f2d

Please sign in to comment.