Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[XPU][PHI Kernels] add scatter_nd_add_grad kernel & bf16 support for slice OPs #58580

Merged
merged 6 commits into from
Nov 2, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions paddle/phi/backends/xpu/xpu2_op_list.cc
Original file line number Diff line number Diff line change
Expand Up @@ -649,7 +649,10 @@ XPUOpMap& get_kl2_ops() {
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"reduce_min_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"reduce_min", XPUKernelSet({phi::DataType::FLOAT32})},
{"reduce_prod", XPUKernelSet({phi::DataType::FLOAT32})},
{"reduce_prod",
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::INT32,
phi::DataType::INT64})},
{"reduce_sum_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"reduce_sum",
XPUKernelSet({phi::DataType::FLOAT16,
Expand Down Expand Up @@ -709,6 +712,11 @@ XPUOpMap& get_kl2_ops() {
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::INT32,
phi::DataType::INT64})},
{"scatter_nd_add_grad",
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::FLOAT16,
phi::DataType::INT32,
phi::DataType::INT64})},
{"sampling_id",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT64})},
{"set_value",
Expand Down Expand Up @@ -757,10 +765,12 @@ XPUOpMap& get_kl2_ops() {
{"slice_grad",
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::FLOAT16,
phi::DataType::BFLOAT16,
phi::DataType::INT32})},
{"slice",
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::FLOAT16,
phi::DataType::BFLOAT16,
phi::DataType::INT32,
phi::DataType::INT64})},
{"softmax",
Expand Down Expand Up @@ -854,12 +864,14 @@ XPUOpMap& get_kl2_ops() {
{"strided_slice",
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::FLOAT16,
phi::DataType::BFLOAT16,
phi::DataType::INT16,
phi::DataType::INT32,
phi::DataType::INT64})},
{"strided_slice_grad",
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::FLOAT16,
phi::DataType::BFLOAT16,
phi::DataType::INT16,
phi::DataType::INT32})},
{"sum", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
Expand Down Expand Up @@ -1008,7 +1020,8 @@ XPUOpMap& get_kl2_ops() {
{"where_index",
XPUKernelSet({phi::DataType::INT32,
phi::DataType::BOOL,
phi::DataType::FLOAT32})},
phi::DataType::FLOAT32,
phi::DataType::INT64})},
{"where_grad",
XPUKernelSet({phi::DataType::INT32,
phi::DataType::INT64,
Expand Down
17 changes: 15 additions & 2 deletions paddle/phi/backends/xpu/xpu3_op_list.cc
Original file line number Diff line number Diff line change
Expand Up @@ -627,7 +627,10 @@ XPUOpMap& get_kl3_ops() {
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"reduce_min_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"reduce_min", XPUKernelSet({phi::DataType::FLOAT32})},
{"reduce_prod", XPUKernelSet({phi::DataType::FLOAT32})},
{"reduce_prod",
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::INT32,
phi::DataType::INT64})},
{"reduce_sum_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"reduce_sum",
XPUKernelSet({phi::DataType::FLOAT16,
Expand Down Expand Up @@ -685,6 +688,11 @@ XPUOpMap& get_kl3_ops() {
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::INT32,
phi::DataType::INT64})},
{"scatter_nd_add_grad",
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::FLOAT16,
phi::DataType::INT32,
phi::DataType::INT64})},
{"sampling_id",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT64})},
{"set_value",
Expand Down Expand Up @@ -732,10 +740,12 @@ XPUOpMap& get_kl3_ops() {
{"slice_grad",
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::FLOAT16,
phi::DataType::BFLOAT16,
phi::DataType::INT32})},
{"slice",
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::FLOAT16,
phi::DataType::BFLOAT16,
phi::DataType::INT32,
phi::DataType::INT64})},
{"softmax",
Expand Down Expand Up @@ -830,12 +840,14 @@ XPUOpMap& get_kl3_ops() {
{"strided_slice",
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::FLOAT16,
phi::DataType::BFLOAT16,
phi::DataType::INT16,
phi::DataType::INT32,
phi::DataType::INT64})},
{"strided_slice_grad",
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::FLOAT16,
phi::DataType::BFLOAT16,
phi::DataType::INT16,
phi::DataType::INT32})},
{"sum", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
Expand Down Expand Up @@ -984,7 +996,8 @@ XPUOpMap& get_kl3_ops() {
{"where_index",
XPUKernelSet({phi::DataType::INT32,
phi::DataType::BOOL,
phi::DataType::FLOAT32})},
phi::DataType::FLOAT32,
phi::DataType::INT64})},
{"where_grad",
XPUKernelSet({phi::DataType::INT32,
phi::DataType::INT64,
Expand Down
24 changes: 12 additions & 12 deletions paddle/phi/kernels/xpu/full_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,20 @@ void FullLikeKernel(const Context& dev_ctx,
T>::type>::type;

auto common_type_value = static_cast<CommonType>(value);
bool is_out_range = true;
if (std::isinf(value) || std::isnan(value)) {
is_out_range = false;
}
if ((common_type_value >=
static_cast<CommonType>(std::numeric_limits<T>::lowest())) &&
(common_type_value <=
static_cast<CommonType>(std::numeric_limits<T>::max()))) {
is_out_range = false;
}
Comment on lines +66 to +75
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里感觉要修改一下,common_type_value满足条件之后,就算value里有inf和nan,is_out_range也是false

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个正常行为,用户可以paddle.full_like一个值为Nan的Tensor,这里的逻辑是如果传入的值在数据类型能表示的范围或者是Nan/inf都是合理的。这个参考了GPU实现

bool is_out_range = true;


PADDLE_ENFORCE_EQ(
(common_type_value >=
static_cast<CommonType>(std::numeric_limits<T>::lowest())) &&
(common_type_value <=
static_cast<CommonType>(std::numeric_limits<T>::max())),
true,
is_out_range,
false,
phi::errors::InvalidArgument(
"The filled value is out of range for target type, "
"current kernel type is %s, the range should between %f "
Expand All @@ -79,13 +86,6 @@ void FullLikeKernel(const Context& dev_ctx,
static_cast<CommonType>(std::numeric_limits<T>::max()),
static_cast<float>(value)));

PADDLE_ENFORCE_EQ(std::isnan(value),
false,
phi::errors::InvalidArgument("The filled value is NaN."));
PADDLE_ENFORCE_EQ(std::isinf(value),
false,
phi::errors::InvalidArgument("The filled value is Inf."));

auto out_data = reinterpret_cast<XPUInTDType*>(out->data<T>());
if (out->numel() > 0) {
int r = xpu::constant(dev_ctx.x_context(),
Expand Down
17 changes: 3 additions & 14 deletions paddle/phi/kernels/xpu/nonzero_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,7 @@ void NonZeroKernel(const Context& dev_ctx,
int* true_num = RAII_GUARD.alloc_l3_or_gm<int32_t>(1);
int true_num_cpu;
int ret = xpu::nonzero_count(dev_ctx.x_context(), cond_data, true_num, numel);
PADDLE_ENFORCE_EQ(
ret,
XPU_SUCCESS,
phi::errors::External(
"XPU nonzero_count kernel return wrong value[%d %s] in WhereIndex",
ret,
XPUAPIErrorMsg[ret]));
PADDLE_ENFORCE_XDNN_SUCCESS(ret, "nonzero_count");

memory_utils::Copy(phi::CPUPlace(),
static_cast<void*>(&true_num_cpu),
Expand All @@ -58,17 +52,12 @@ void NonZeroKernel(const Context& dev_ctx,
auto condition_shape = phi::vectorize<int>(dims);
ret = xpu::where(
dev_ctx.x_context(), cond_data, out_data, condition_shape, true_num_cpu);
PADDLE_ENFORCE_EQ(ret,
XPU_SUCCESS,
phi::errors::External(
"XPU masked_select kernel return wrong value[%d %s]",
ret,
XPUAPIErrorMsg[ret]));
PADDLE_ENFORCE_XDNN_SUCCESS(ret, "where");
}

} // namespace phi

PD_REGISTER_KERNEL(
nonzero, XPU, ALL_LAYOUT, phi::NonZeroKernel, int, bool, float) {
nonzero, XPU, ALL_LAYOUT, phi::NonZeroKernel, int, bool, float, int64_t) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里需要在op_list里注册数据类型吗

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nonzero在Op list里的名字是where_index,这个已经加了

kernel->OutputAt(0).SetDataType(phi::DataType::INT64);
}
3 changes: 2 additions & 1 deletion paddle/phi/kernels/xpu/prod_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,4 +50,5 @@ void ProdKernel(const Context& dev_ctx,

} // namespace phi

PD_REGISTER_KERNEL(prod, XPU, ALL_LAYOUT, phi::ProdKernel, float) {}
PD_REGISTER_KERNEL(
prod, XPU, ALL_LAYOUT, phi::ProdKernel, float, int, int64_t) {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

prod在Op list里的名字是reduce_prod,这个已经加了

115 changes: 115 additions & 0 deletions paddle/phi/kernels/xpu/scatter_nd_add_grad_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/scatter_nd_add_grad_kernel.h"

#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"

namespace phi {
template <typename T, typename Context>
void ScatterNdAddGradKernel(const Context &ctx,
const DenseTensor &index,
const DenseTensor &updates,
const DenseTensor &out_grad,
DenseTensor *x_grad,
DenseTensor *updates_grad) {
using XPUT = typename XPUTypeTrait<T>::Type;
int ret = xpu::SUCCESS;
const T *out_grad_data = out_grad.data<T>();
if (x_grad) {
auto *x_grad_data = ctx.template Alloc<T>(x_grad);
ret = xpu::copy<XPUT>(ctx.x_context(),
reinterpret_cast<const XPUT *>(out_grad_data),
reinterpret_cast<XPUT *>(x_grad_data),
out_grad.numel());
PADDLE_ENFORCE_XDNN_SUCCESS(ret, "copy");
}

if (updates_grad) {
auto *updates_grad_data = ctx.template Alloc<T>(updates_grad);
if (updates_grad->numel() == 0) {
return;
}
if (index.numel() == 0) {
auto index_dims = index.dims();
auto index_dims_size = index_dims.size();
int64_t end_size = index_dims[index_dims_size - 1];
PADDLE_ENFORCE_EQ(
end_size,
0,
errors::InvalidArgument(
"Size of the last dim of the index tensor [%d] should be 0",
end_size));
auto remain_dims = phi::slice_ddim(index_dims, 0, index_dims_size - 1);
int64_t remain_numel = phi::product(remain_dims);
int64_t updates_grad_numel = updates_grad->numel();
int64_t out_grad_numel = out_grad.numel();
PADDLE_ENFORCE_EQ(
remain_numel * out_grad_numel,
updates_grad_numel,
errors::InvalidArgument("out_grad numel[%d] * remain numel[%d] "
"should math updates_grad numel[%d]",
out_grad_numel,
remain_numel,
updates_grad_numel));
ret = xpu::broadcast<XPUT>(ctx.x_context(),
reinterpret_cast<const XPUT *>(out_grad_data),
reinterpret_cast<XPUT *>(updates_grad_data),
{1, out_grad_numel},
{remain_numel, out_grad_numel});
PADDLE_ENFORCE_XDNN_SUCCESS(ret, "broadcast");
return;
}

auto index_shape_vec = vectorize<int64_t>(index.dims());
if (index_shape_vec.size() == 1) {
index_shape_vec.insert(index_shape_vec.begin(), 1);
}
auto out_grad_shape_vec = vectorize<int64_t>(out_grad.dims());
xpu::VectorParam<int64_t> out_grad_shape_param = {
out_grad_shape_vec.data(),
static_cast<int64_t>(out_grad_shape_vec.size()),
nullptr};

if (index.dtype() == DataType::INT32) {
ret = xpu::gather_nd<XPUT, int>(
ctx.x_context(),
reinterpret_cast<const XPUT *>(out_grad_data),
index.data<int>(),
reinterpret_cast<XPUT *>(updates_grad_data),
out_grad_shape_param,
index_shape_vec);
} else {
ret = xpu::gather_nd<XPUT, int64_t>(
ctx.x_context(),
reinterpret_cast<const XPUT *>(out_grad_data),
index.data<int64_t>(),
reinterpret_cast<XPUT *>(updates_grad_data),
out_grad_shape_param,
index_shape_vec);
}
PADDLE_ENFORCE_XDNN_SUCCESS(ret, "gather_nd");
}
}
} // namespace phi

PD_REGISTER_KERNEL(scatter_nd_add_grad,
XPU,
ALL_LAYOUT,
phi::ScatterNdAddGradKernel,
float,
phi::dtype::float16,
int,
int64_t) {}
7 changes: 5 additions & 2 deletions paddle/phi/kernels/xpu/scatter_nd_add_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,11 @@ void ScatterNdAddKernel(const Context &ctx,
if (updates.numel() == 0) return;

if (index.numel() == 0) {
int loop_time =
static_cast<int>(index.dims().size() == 0 ? 1 : index.dims()[0]);
int64_t index_dims_size = index.dims().size();
int loop_time = static_cast<int>(
index_dims_size == 0 ? 1
: phi::product(phi::slice_ddim(
index.dims(), 0, index_dims_size - 1)));
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

index tensor最后一维长度为0时,需要按前面所有维度索引updates数组并累加到output中,而不是只累加第一维


for (int i = 0; i < loop_time; i++) {
r = xpu::broadcast_add<T>(ctx.x_context(),
Expand Down
3 changes: 2 additions & 1 deletion paddle/phi/kernels/xpu/slice_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -85,4 +85,5 @@ PD_REGISTER_KERNEL(slice_grad,
phi::SliceGradKernel,
float,
int,
phi::dtype::float16) {}
phi::dtype::float16,
phi::dtype::bfloat16) {}
1 change: 1 addition & 0 deletions paddle/phi/kernels/xpu/slice_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -120,4 +120,5 @@ PD_REGISTER_KERNEL(slice,
float,
int,
phi::dtype::float16,
phi::dtype::bfloat16,
int64_t) {}
3 changes: 2 additions & 1 deletion paddle/phi/kernels/xpu/stride_slice_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -163,4 +163,5 @@ PD_REGISTER_KERNEL(strided_slice_raw_grad,
int,
int16_t,
float,
phi::dtype::float16) {}
phi::dtype::float16,
phi::dtype::bfloat16) {}
3 changes: 2 additions & 1 deletion paddle/phi/kernels/xpu/stride_slice_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -171,4 +171,5 @@ PD_REGISTER_KERNEL(strided_slice_raw,
int16_t,
int64_t,
float,
phi::dtype::float16) {}
phi::dtype::float16,
phi::dtype::bfloat16) {}
17 changes: 17 additions & 0 deletions test/xpu/test_fill_any_op_xpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,23 @@ def test_backward(self):
)


class TestFillAnyLikeOpSpecialValue(unittest.TestCase):
def setUp(self):
self.special_values = [float("nan"), float("+inf"), float("-inf")]
self.dtypes = ["float32", "float16"]

def test_dygraph_api(self):
paddle.disable_static()
paddle.set_device("xpu")
for dtype in self.dtypes:
for value in self.special_values:
ref = paddle.empty([4, 4], dtype=dtype)
val_pd = paddle.full_like(ref, value, dtype=dtype)
val_np = np.full([4, 4], value, dtype=dtype)
np.testing.assert_equal(val_pd.numpy(), val_np)
paddle.enable_static()


support_types = get_xpu_op_support_types('fill_any')
for stype in support_types:
create_test_class(globals(), XPUTestFillAnyOp, stype)
Expand Down
Loading