diff --git a/paddle/phi/backends/xpu/xpu2_op_list.cc b/paddle/phi/backends/xpu/xpu2_op_list.cc index 420937bdcfd250..4c54fdb25579e5 100644 --- a/paddle/phi/backends/xpu/xpu2_op_list.cc +++ b/paddle/phi/backends/xpu/xpu2_op_list.cc @@ -649,7 +649,10 @@ XPUOpMap& get_kl2_ops() { XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"reduce_min_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"reduce_min", XPUKernelSet({phi::DataType::FLOAT32})}, - {"reduce_prod", XPUKernelSet({phi::DataType::FLOAT32})}, + {"reduce_prod", + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::INT32, + phi::DataType::INT64})}, {"reduce_sum_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"reduce_sum", XPUKernelSet({phi::DataType::FLOAT16, @@ -709,6 +712,11 @@ XPUOpMap& get_kl2_ops() { XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::INT32, phi::DataType::INT64})}, + {"scatter_nd_add_grad", + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::FLOAT16, + phi::DataType::INT32, + phi::DataType::INT64})}, {"sampling_id", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT64})}, {"set_value", @@ -757,10 +765,12 @@ XPUOpMap& get_kl2_ops() { {"slice_grad", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16, + phi::DataType::BFLOAT16, phi::DataType::INT32})}, {"slice", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16, + phi::DataType::BFLOAT16, phi::DataType::INT32, phi::DataType::INT64})}, {"softmax", @@ -854,12 +864,14 @@ XPUOpMap& get_kl2_ops() { {"strided_slice", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16, + phi::DataType::BFLOAT16, phi::DataType::INT16, phi::DataType::INT32, phi::DataType::INT64})}, {"strided_slice_grad", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16, + phi::DataType::BFLOAT16, phi::DataType::INT16, phi::DataType::INT32})}, {"sum", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, @@ -1008,7 +1020,8 @@ XPUOpMap& get_kl2_ops() { {"where_index", XPUKernelSet({phi::DataType::INT32, phi::DataType::BOOL, - phi::DataType::FLOAT32})}, + phi::DataType::FLOAT32, + phi::DataType::INT64})}, {"where_grad", XPUKernelSet({phi::DataType::INT32, phi::DataType::INT64, diff --git a/paddle/phi/backends/xpu/xpu3_op_list.cc b/paddle/phi/backends/xpu/xpu3_op_list.cc index 435de8330cc59e..8dae7c55aede82 100644 --- a/paddle/phi/backends/xpu/xpu3_op_list.cc +++ b/paddle/phi/backends/xpu/xpu3_op_list.cc @@ -627,7 +627,10 @@ XPUOpMap& get_kl3_ops() { XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"reduce_min_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"reduce_min", XPUKernelSet({phi::DataType::FLOAT32})}, - {"reduce_prod", XPUKernelSet({phi::DataType::FLOAT32})}, + {"reduce_prod", + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::INT32, + phi::DataType::INT64})}, {"reduce_sum_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"reduce_sum", XPUKernelSet({phi::DataType::FLOAT16, @@ -685,6 +688,11 @@ XPUOpMap& get_kl3_ops() { XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::INT32, phi::DataType::INT64})}, + {"scatter_nd_add_grad", + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::FLOAT16, + phi::DataType::INT32, + phi::DataType::INT64})}, {"sampling_id", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT64})}, {"set_value", @@ -732,10 +740,12 @@ XPUOpMap& get_kl3_ops() { {"slice_grad", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16, + phi::DataType::BFLOAT16, phi::DataType::INT32})}, {"slice", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16, + phi::DataType::BFLOAT16, phi::DataType::INT32, phi::DataType::INT64})}, {"softmax", @@ -830,12 +840,14 @@ XPUOpMap& get_kl3_ops() { {"strided_slice", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16, + phi::DataType::BFLOAT16, phi::DataType::INT16, phi::DataType::INT32, phi::DataType::INT64})}, {"strided_slice_grad", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16, + phi::DataType::BFLOAT16, phi::DataType::INT16, phi::DataType::INT32})}, {"sum", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, @@ -984,7 +996,8 @@ XPUOpMap& get_kl3_ops() { {"where_index", XPUKernelSet({phi::DataType::INT32, phi::DataType::BOOL, - phi::DataType::FLOAT32})}, + phi::DataType::FLOAT32, + phi::DataType::INT64})}, {"where_grad", XPUKernelSet({phi::DataType::INT32, phi::DataType::INT64, diff --git a/paddle/phi/kernels/xpu/full_kernel.cc b/paddle/phi/kernels/xpu/full_kernel.cc index 4adccbb4be813c..906078629f4884 100644 --- a/paddle/phi/kernels/xpu/full_kernel.cc +++ b/paddle/phi/kernels/xpu/full_kernel.cc @@ -63,13 +63,20 @@ void FullLikeKernel(const Context& dev_ctx, T>::type>::type; auto common_type_value = static_cast(value); + bool is_out_range = true; + if (std::isinf(value) || std::isnan(value)) { + is_out_range = false; + } + if ((common_type_value >= + static_cast(std::numeric_limits::lowest())) && + (common_type_value <= + static_cast(std::numeric_limits::max()))) { + is_out_range = false; + } PADDLE_ENFORCE_EQ( - (common_type_value >= - static_cast(std::numeric_limits::lowest())) && - (common_type_value <= - static_cast(std::numeric_limits::max())), - true, + is_out_range, + false, phi::errors::InvalidArgument( "The filled value is out of range for target type, " "current kernel type is %s, the range should between %f " @@ -79,13 +86,6 @@ void FullLikeKernel(const Context& dev_ctx, static_cast(std::numeric_limits::max()), static_cast(value))); - PADDLE_ENFORCE_EQ(std::isnan(value), - false, - phi::errors::InvalidArgument("The filled value is NaN.")); - PADDLE_ENFORCE_EQ(std::isinf(value), - false, - phi::errors::InvalidArgument("The filled value is Inf.")); - auto out_data = reinterpret_cast(out->data()); if (out->numel() > 0) { int r = xpu::constant(dev_ctx.x_context(), diff --git a/paddle/phi/kernels/xpu/nonzero_kernel.cc b/paddle/phi/kernels/xpu/nonzero_kernel.cc index edfdb1e6dfe8b5..fe241965fb5c69 100644 --- a/paddle/phi/kernels/xpu/nonzero_kernel.cc +++ b/paddle/phi/kernels/xpu/nonzero_kernel.cc @@ -14,8 +14,7 @@ #include "paddle/phi/kernels/nonzero_kernel.h" -#include "paddle/phi/backends/xpu/xpu_context.h" -#include "paddle/phi/backends/xpu/xpu_header.h" +#include "paddle/phi/backends/xpu/enforce_xpu.h" #include "paddle/phi/common/memory_utils.h" #include "paddle/phi/core/kernel_registry.h" @@ -34,13 +33,7 @@ void NonZeroKernel(const Context& dev_ctx, int* true_num = RAII_GUARD.alloc_l3_or_gm(1); int true_num_cpu; int ret = xpu::nonzero_count(dev_ctx.x_context(), cond_data, true_num, numel); - PADDLE_ENFORCE_EQ( - ret, - XPU_SUCCESS, - phi::errors::External( - "XPU nonzero_count kernel return wrong value[%d %s] in WhereIndex", - ret, - XPUAPIErrorMsg[ret])); + PADDLE_ENFORCE_XDNN_SUCCESS(ret, "nonzero_count"); memory_utils::Copy(phi::CPUPlace(), static_cast(&true_num_cpu), @@ -58,17 +51,12 @@ void NonZeroKernel(const Context& dev_ctx, auto condition_shape = phi::vectorize(dims); ret = xpu::where( dev_ctx.x_context(), cond_data, out_data, condition_shape, true_num_cpu); - PADDLE_ENFORCE_EQ(ret, - XPU_SUCCESS, - phi::errors::External( - "XPU masked_select kernel return wrong value[%d %s]", - ret, - XPUAPIErrorMsg[ret])); + PADDLE_ENFORCE_XDNN_SUCCESS(ret, "where"); } } // namespace phi PD_REGISTER_KERNEL( - nonzero, XPU, ALL_LAYOUT, phi::NonZeroKernel, int, bool, float) { + nonzero, XPU, ALL_LAYOUT, phi::NonZeroKernel, int, bool, float, int64_t) { kernel->OutputAt(0).SetDataType(phi::DataType::INT64); } diff --git a/paddle/phi/kernels/xpu/prod_kernel.cc b/paddle/phi/kernels/xpu/prod_kernel.cc index 12f32959edb317..74e58ee63a7cad 100644 --- a/paddle/phi/kernels/xpu/prod_kernel.cc +++ b/paddle/phi/kernels/xpu/prod_kernel.cc @@ -50,4 +50,5 @@ void ProdKernel(const Context& dev_ctx, } // namespace phi -PD_REGISTER_KERNEL(prod, XPU, ALL_LAYOUT, phi::ProdKernel, float) {} +PD_REGISTER_KERNEL( + prod, XPU, ALL_LAYOUT, phi::ProdKernel, float, int, int64_t) {} diff --git a/paddle/phi/kernels/xpu/scatter_nd_add_grad_kernel.cc b/paddle/phi/kernels/xpu/scatter_nd_add_grad_kernel.cc new file mode 100644 index 00000000000000..a0fd86fcc3208d --- /dev/null +++ b/paddle/phi/kernels/xpu/scatter_nd_add_grad_kernel.cc @@ -0,0 +1,115 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/scatter_nd_add_grad_kernel.h" + +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { +template +void ScatterNdAddGradKernel(const Context &ctx, + const DenseTensor &index, + const DenseTensor &updates, + const DenseTensor &out_grad, + DenseTensor *x_grad, + DenseTensor *updates_grad) { + using XPUT = typename XPUTypeTrait::Type; + int ret = xpu::SUCCESS; + const T *out_grad_data = out_grad.data(); + if (x_grad) { + auto *x_grad_data = ctx.template Alloc(x_grad); + ret = xpu::copy(ctx.x_context(), + reinterpret_cast(out_grad_data), + reinterpret_cast(x_grad_data), + out_grad.numel()); + PADDLE_ENFORCE_XDNN_SUCCESS(ret, "copy"); + } + + if (updates_grad) { + auto *updates_grad_data = ctx.template Alloc(updates_grad); + if (updates_grad->numel() == 0) { + return; + } + if (index.numel() == 0) { + auto index_dims = index.dims(); + auto index_dims_size = index_dims.size(); + int64_t end_size = index_dims[index_dims_size - 1]; + PADDLE_ENFORCE_EQ( + end_size, + 0, + errors::InvalidArgument( + "Size of the last dim of the index tensor [%d] should be 0", + end_size)); + auto remain_dims = phi::slice_ddim(index_dims, 0, index_dims_size - 1); + int64_t remain_numel = phi::product(remain_dims); + int64_t updates_grad_numel = updates_grad->numel(); + int64_t out_grad_numel = out_grad.numel(); + PADDLE_ENFORCE_EQ( + remain_numel * out_grad_numel, + updates_grad_numel, + errors::InvalidArgument("out_grad numel[%d] * remain numel[%d] " + "should math updates_grad numel[%d]", + out_grad_numel, + remain_numel, + updates_grad_numel)); + ret = xpu::broadcast(ctx.x_context(), + reinterpret_cast(out_grad_data), + reinterpret_cast(updates_grad_data), + {1, out_grad_numel}, + {remain_numel, out_grad_numel}); + PADDLE_ENFORCE_XDNN_SUCCESS(ret, "broadcast"); + return; + } + + auto index_shape_vec = vectorize(index.dims()); + if (index_shape_vec.size() == 1) { + index_shape_vec.insert(index_shape_vec.begin(), 1); + } + auto out_grad_shape_vec = vectorize(out_grad.dims()); + xpu::VectorParam out_grad_shape_param = { + out_grad_shape_vec.data(), + static_cast(out_grad_shape_vec.size()), + nullptr}; + + if (index.dtype() == DataType::INT32) { + ret = xpu::gather_nd( + ctx.x_context(), + reinterpret_cast(out_grad_data), + index.data(), + reinterpret_cast(updates_grad_data), + out_grad_shape_param, + index_shape_vec); + } else { + ret = xpu::gather_nd( + ctx.x_context(), + reinterpret_cast(out_grad_data), + index.data(), + reinterpret_cast(updates_grad_data), + out_grad_shape_param, + index_shape_vec); + } + PADDLE_ENFORCE_XDNN_SUCCESS(ret, "gather_nd"); + } +} +} // namespace phi + +PD_REGISTER_KERNEL(scatter_nd_add_grad, + XPU, + ALL_LAYOUT, + phi::ScatterNdAddGradKernel, + float, + phi::dtype::float16, + int, + int64_t) {} diff --git a/paddle/phi/kernels/xpu/scatter_nd_add_kernel.cc b/paddle/phi/kernels/xpu/scatter_nd_add_kernel.cc index c760a2d0166c9d..69e40994eb92de 100644 --- a/paddle/phi/kernels/xpu/scatter_nd_add_kernel.cc +++ b/paddle/phi/kernels/xpu/scatter_nd_add_kernel.cc @@ -34,8 +34,11 @@ void ScatterNdAddKernel(const Context &ctx, if (updates.numel() == 0) return; if (index.numel() == 0) { - int loop_time = - static_cast(index.dims().size() == 0 ? 1 : index.dims()[0]); + int64_t index_dims_size = index.dims().size(); + int loop_time = static_cast( + index_dims_size == 0 ? 1 + : phi::product(phi::slice_ddim( + index.dims(), 0, index_dims_size - 1))); for (int i = 0; i < loop_time; i++) { r = xpu::broadcast_add(ctx.x_context(), diff --git a/paddle/phi/kernels/xpu/slice_grad_kernel.cc b/paddle/phi/kernels/xpu/slice_grad_kernel.cc index 3e054f3d8f3424..ff5a49610fc549 100644 --- a/paddle/phi/kernels/xpu/slice_grad_kernel.cc +++ b/paddle/phi/kernels/xpu/slice_grad_kernel.cc @@ -85,4 +85,5 @@ PD_REGISTER_KERNEL(slice_grad, phi::SliceGradKernel, float, int, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/xpu/slice_kernel.cc b/paddle/phi/kernels/xpu/slice_kernel.cc index a9bdf477d7e134..d3c114db2411bb 100644 --- a/paddle/phi/kernels/xpu/slice_kernel.cc +++ b/paddle/phi/kernels/xpu/slice_kernel.cc @@ -120,4 +120,5 @@ PD_REGISTER_KERNEL(slice, float, int, phi::dtype::float16, + phi::dtype::bfloat16, int64_t) {} diff --git a/paddle/phi/kernels/xpu/stride_slice_grad_kernel.cc b/paddle/phi/kernels/xpu/stride_slice_grad_kernel.cc index fbc7a0bf6abcb8..709eeaac49546b 100644 --- a/paddle/phi/kernels/xpu/stride_slice_grad_kernel.cc +++ b/paddle/phi/kernels/xpu/stride_slice_grad_kernel.cc @@ -163,4 +163,5 @@ PD_REGISTER_KERNEL(strided_slice_raw_grad, int, int16_t, float, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/xpu/stride_slice_kernel.cc b/paddle/phi/kernels/xpu/stride_slice_kernel.cc index a2de8c2c8ffc10..2f026bae02fe45 100644 --- a/paddle/phi/kernels/xpu/stride_slice_kernel.cc +++ b/paddle/phi/kernels/xpu/stride_slice_kernel.cc @@ -171,4 +171,5 @@ PD_REGISTER_KERNEL(strided_slice_raw, int16_t, int64_t, float, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/test/xpu/test_fill_any_op_xpu.py b/test/xpu/test_fill_any_op_xpu.py index 2d71f78e05c341..22e493be70b07b 100644 --- a/test/xpu/test_fill_any_op_xpu.py +++ b/test/xpu/test_fill_any_op_xpu.py @@ -111,6 +111,23 @@ def test_backward(self): ) +class TestFillAnyLikeOpSpecialValue(unittest.TestCase): + def setUp(self): + self.special_values = [float("nan"), float("+inf"), float("-inf")] + self.dtypes = ["float32", "float16"] + + def test_dygraph_api(self): + paddle.disable_static() + paddle.set_device("xpu") + for dtype in self.dtypes: + for value in self.special_values: + ref = paddle.empty([4, 4], dtype=dtype) + val_pd = paddle.full_like(ref, value, dtype=dtype) + val_np = np.full([4, 4], value, dtype=dtype) + np.testing.assert_equal(val_pd.numpy(), val_np) + paddle.enable_static() + + support_types = get_xpu_op_support_types('fill_any') for stype in support_types: create_test_class(globals(), XPUTestFillAnyOp, stype) diff --git a/test/xpu/test_scatter_nd_add_op_xpu.py b/test/xpu/test_scatter_nd_add_op_xpu.py index f303cd9ce51503..6efb4fec3b0f7f 100644 --- a/test/xpu/test_scatter_nd_add_op_xpu.py +++ b/test/xpu/test_scatter_nd_add_op_xpu.py @@ -91,6 +91,9 @@ def setUp(self): def test_check_output(self): self.check_output_with_place(self.place) + def test_check_grad(self): + self.check_grad_with_place(self.place, ['X', 'Updates'], 'Out') + def init_data(self): self.x_np = np.random.random([100]).astype(self.dtype) self.index_np = np.random.randint(0, 100, [100, 1]).astype("int32") @@ -103,8 +106,10 @@ def infer_dtype_from_inputs_outputs(self, inputs, outputs): class TestScatterNdAddWithEmptyIndex(TestScatterNdAdd): def init_data(self): self.x_np = np.random.random((10, 10)).astype(self.dtype) - self.index_np = np.array([[], []]).astype("int32") - self.updates_np = np.random.random((2, 10, 10)).astype(self.dtype) + self.index_np = np.array([[[], []], [[], []]]).astype("int32") + self.updates_np = np.random.random((2, 2, 10, 10)).astype( + self.dtype + ) class TestScatterNdAddOpWithHighRankSame(TestScatterNdAdd): def init_data(self): @@ -138,6 +143,13 @@ def init_data(self): update_shape = judge_update_shape(self.x_np, self.index_np) self.updates_np = np.random.rand(*update_shape).astype(self.dtype) + class TestScatterNdAddWithZeroDimUpdates(TestScatterNdAdd): + def init_data(self): + shape = (10,) + self.x_np = np.random.rand(*shape).astype(self.dtype) + self.index_np = np.random.randint(0, 10, [1]).astype("int32") + self.updates_np = np.array(np.random.rand()).astype(self.dtype) + support_types = get_xpu_op_support_types('scatter_nd_add') for stype in support_types: diff --git a/test/xpu/test_scatter_op_xpu.py b/test/xpu/test_scatter_op_xpu.py index c8b627fce82e5c..7ff92985b34b24 100644 --- a/test/xpu/test_scatter_op_xpu.py +++ b/test/xpu/test_scatter_op_xpu.py @@ -145,7 +145,7 @@ def test_check_output(self): self.check_output_with_place(self.place) def test_check_grad(self): - self.check_grad_with_place(self.place, ['X'], 'Out') + self.check_grad_with_place(self.place, ['X', 'Updates'], 'Out') support_types = get_xpu_op_support_types('scatter')