-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[XPU][PHI Kernels] add scatter_nd_add_grad kernel & bf16 support for slice OPs #58580
Changes from 5 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -34,13 +34,7 @@ void NonZeroKernel(const Context& dev_ctx, | |
int* true_num = RAII_GUARD.alloc_l3_or_gm<int32_t>(1); | ||
int true_num_cpu; | ||
int ret = xpu::nonzero_count(dev_ctx.x_context(), cond_data, true_num, numel); | ||
PADDLE_ENFORCE_EQ( | ||
ret, | ||
XPU_SUCCESS, | ||
phi::errors::External( | ||
"XPU nonzero_count kernel return wrong value[%d %s] in WhereIndex", | ||
ret, | ||
XPUAPIErrorMsg[ret])); | ||
PADDLE_ENFORCE_XDNN_SUCCESS(ret, "nonzero_count"); | ||
|
||
memory_utils::Copy(phi::CPUPlace(), | ||
static_cast<void*>(&true_num_cpu), | ||
|
@@ -58,17 +52,12 @@ void NonZeroKernel(const Context& dev_ctx, | |
auto condition_shape = phi::vectorize<int>(dims); | ||
ret = xpu::where( | ||
dev_ctx.x_context(), cond_data, out_data, condition_shape, true_num_cpu); | ||
PADDLE_ENFORCE_EQ(ret, | ||
XPU_SUCCESS, | ||
phi::errors::External( | ||
"XPU masked_select kernel return wrong value[%d %s]", | ||
ret, | ||
XPUAPIErrorMsg[ret])); | ||
PADDLE_ENFORCE_XDNN_SUCCESS(ret, "where"); | ||
} | ||
|
||
} // namespace phi | ||
|
||
PD_REGISTER_KERNEL( | ||
nonzero, XPU, ALL_LAYOUT, phi::NonZeroKernel, int, bool, float) { | ||
nonzero, XPU, ALL_LAYOUT, phi::NonZeroKernel, int, bool, float, int64_t) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里需要在op_list里注册数据类型吗 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nonzero在Op list里的名字是where_index,这个已经加了 |
||
kernel->OutputAt(0).SetDataType(phi::DataType::INT64); | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -50,4 +50,5 @@ void ProdKernel(const Context& dev_ctx, | |
|
||
} // namespace phi | ||
|
||
PD_REGISTER_KERNEL(prod, XPU, ALL_LAYOUT, phi::ProdKernel, float) {} | ||
PD_REGISTER_KERNEL( | ||
prod, XPU, ALL_LAYOUT, phi::ProdKernel, float, int, int64_t) {} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 同上 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. prod在Op list里的名字是reduce_prod,这个已经加了 |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "paddle/phi/kernels/scatter_nd_add_grad_kernel.h" | ||
|
||
#include "paddle/phi/backends/xpu/enforce_xpu.h" | ||
#include "paddle/phi/core/kernel_registry.h" | ||
|
||
namespace phi { | ||
template <typename T, typename Context> | ||
void ScatterNdAddGradKernel(const Context &ctx, | ||
const DenseTensor &index, | ||
const DenseTensor &updates, | ||
const DenseTensor &out_grad, | ||
DenseTensor *x_grad, | ||
DenseTensor *updates_grad) { | ||
using XPUT = typename XPUTypeTrait<T>::Type; | ||
int ret = xpu::SUCCESS; | ||
const T *out_grad_data = out_grad.data<T>(); | ||
if (x_grad) { | ||
auto *x_grad_data = ctx.template Alloc<T>(x_grad); | ||
ret = xpu::copy<XPUT>(ctx.x_context(), | ||
reinterpret_cast<const XPUT *>(out_grad_data), | ||
reinterpret_cast<XPUT *>(x_grad_data), | ||
out_grad.numel()); | ||
PADDLE_ENFORCE_XDNN_SUCCESS(ret, "copy"); | ||
} | ||
|
||
if (updates_grad) { | ||
auto *updates_grad_data = ctx.template Alloc<T>(updates_grad); | ||
if (updates_grad->numel() == 0) { | ||
return; | ||
} | ||
if (index.numel() == 0) { | ||
auto index_dims = index.dims(); | ||
auto index_dims_size = index_dims.size(); | ||
int64_t end_size = index_dims[index_dims_size - 1]; | ||
PADDLE_ENFORCE_EQ( | ||
end_size, | ||
0, | ||
errors::InvalidArgument( | ||
"Size of the last dim of the index tensor [%d] should be 0", | ||
end_size)); | ||
auto remain_dims = phi::slice_ddim(index_dims, 0, index_dims_size - 1); | ||
int64_t remain_numel = phi::product(remain_dims); | ||
int64_t updates_grad_numel = updates_grad->numel(); | ||
int64_t out_grad_numel = out_grad.numel(); | ||
PADDLE_ENFORCE_EQ( | ||
remain_numel * out_grad_numel, | ||
updates_grad_numel, | ||
errors::InvalidArgument("out_grad numel[%d] * remain numel[%d] " | ||
"should math updates_grad numel[%d]", | ||
out_grad_numel, | ||
remain_numel, | ||
updates_grad_numel)); | ||
ret = xpu::broadcast<XPUT>(ctx.x_context(), | ||
reinterpret_cast<const XPUT *>(out_grad_data), | ||
reinterpret_cast<XPUT *>(updates_grad_data), | ||
{1, out_grad_numel}, | ||
{remain_numel, out_grad_numel}); | ||
PADDLE_ENFORCE_XDNN_SUCCESS(ret, "broadcast"); | ||
return; | ||
} | ||
|
||
auto index_shape_vec = vectorize<int64_t>(index.dims()); | ||
if (index_shape_vec.size() == 1) { | ||
index_shape_vec.insert(index_shape_vec.begin(), 1); | ||
} | ||
auto out_grad_shape_vec = vectorize<int64_t>(out_grad.dims()); | ||
xpu::VectorParam<int64_t> out_grad_shape_param = { | ||
out_grad_shape_vec.data(), | ||
static_cast<int64_t>(out_grad_shape_vec.size()), | ||
nullptr}; | ||
|
||
if (index.dtype() == DataType::INT32) { | ||
ret = xpu::gather_nd<XPUT, int>( | ||
ctx.x_context(), | ||
reinterpret_cast<const XPUT *>(out_grad_data), | ||
index.data<int>(), | ||
reinterpret_cast<XPUT *>(updates_grad_data), | ||
out_grad_shape_param, | ||
index_shape_vec); | ||
} else { | ||
ret = xpu::gather_nd<XPUT, int64_t>( | ||
ctx.x_context(), | ||
reinterpret_cast<const XPUT *>(out_grad_data), | ||
index.data<int64_t>(), | ||
reinterpret_cast<XPUT *>(updates_grad_data), | ||
out_grad_shape_param, | ||
index_shape_vec); | ||
} | ||
PADDLE_ENFORCE_XDNN_SUCCESS(ret, "gather_nd"); | ||
} | ||
} | ||
} // namespace phi | ||
|
||
PD_REGISTER_KERNEL(scatter_nd_add_grad, | ||
XPU, | ||
ALL_LAYOUT, | ||
phi::ScatterNdAddGradKernel, | ||
float, | ||
phi::dtype::float16, | ||
int, | ||
int64_t) {} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -34,8 +34,11 @@ void ScatterNdAddKernel(const Context &ctx, | |
if (updates.numel() == 0) return; | ||
|
||
if (index.numel() == 0) { | ||
int loop_time = | ||
static_cast<int>(index.dims().size() == 0 ? 1 : index.dims()[0]); | ||
int64_t index_dims_size = index.dims().size(); | ||
int loop_time = static_cast<int>( | ||
index_dims_size == 0 ? 1 | ||
: phi::product(phi::slice_ddim( | ||
index.dims(), 0, index_dims_size - 1))); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. index tensor最后一维长度为0时,需要按前面所有维度索引updates数组并累加到output中,而不是只累加第一维 |
||
|
||
for (int i = 0; i < loop_time; i++) { | ||
r = xpu::broadcast_add<T>(ctx.x_context(), | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -120,4 +120,5 @@ PD_REGISTER_KERNEL(slice, | |
float, | ||
int, | ||
phi::dtype::float16, | ||
phi::dtype::bfloat16, | ||
int64_t) {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里感觉要修改一下,common_type_value满足条件之后,就算value里有inf和nan,is_out_range也是false
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个正常行为,用户可以paddle.full_like一个值为Nan的Tensor,这里的逻辑是如果传入的值在数据类型能表示的范围或者是Nan/inf都是合理的。这个参考了GPU实现
Paddle/paddle/phi/kernels/gpu/full_kernel.cu
Line 87 in 5d4320b