Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

xpu: bind op scatter_nd_add. add data type for transpose2, clip & assign_value #50825

Merged
merged 2 commits into from
Feb 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 33 additions & 7 deletions paddle/phi/backends/xpu/xpu2_op_list.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,11 @@ XPUOpMap& get_kl2_ops() {
phi::DataType::INT32,
phi::DataType::INT64,
phi::DataType::BOOL})},
{"assign_value", XPUKernelSet({phi::DataType::FLOAT32})},
{"assign_value",
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::INT32,
phi::DataType::INT64,
phi::DataType::BOOL})},
{"atan", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"atan_grad",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
Expand Down Expand Up @@ -109,10 +113,15 @@ XPUOpMap& get_kl2_ops() {
phi::DataType::INT32})},
{"check_finite_and_unscale",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"clip", XPUKernelSet({phi::DataType::FLOAT32})},
{"clip",
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::INT64,
phi::DataType::INT32})},
{"clip_by_norm", XPUKernelSet({phi::DataType::FLOAT32})},
{"clip_grad",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::INT32})},
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::INT64,
phi::DataType::INT32})},
{"coalesce_tensor",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"concat_grad",
Expand Down Expand Up @@ -524,6 +533,7 @@ XPUOpMap& get_kl2_ops() {
XPUKernelSet({phi::DataType::INT64,
phi::DataType::INT32,
phi::DataType::FLOAT32})},
{"scatter_nd_add", XPUKernelSet({phi::DataType::FLOAT32})},
{"sampling_id",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT64})},
{"set_value",
Expand Down Expand Up @@ -656,13 +666,29 @@ XPUOpMap& get_kl2_ops() {
phi::DataType::FLOAT32})},
{"tile_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"transpose2_grad",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::FLOAT16,
phi::DataType::INT64,
phi::DataType::INT32,
phi::DataType::BOOL})},
{"transpose2",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::FLOAT16,
phi::DataType::INT64,
phi::DataType::INT32,
phi::DataType::BOOL})},
{"transpose_grad",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::FLOAT16,
phi::DataType::INT64,
phi::DataType::INT32,
phi::DataType::BOOL})},
{"transpose",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::FLOAT16,
phi::DataType::INT64,
phi::DataType::INT32,
phi::DataType::BOOL})},
{"truncated_gaussian_random", XPUKernelSet({phi::DataType::FLOAT32})},
{"top_k", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"top_k_v2", XPUKernelSet({phi::DataType::FLOAT32})},
Expand Down
6 changes: 3 additions & 3 deletions paddle/phi/kernels/xpu/clip_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,11 @@ void ClipGradKernel(const Context& ctx,
reinterpret_cast<const XPUDataType*>(out_grad.data<T>()),
reinterpret_cast<XPUDataType*>(x_grad->data<T>()),
x.numel(),
min.to<T>(),
max.to<T>());
min.to<XPUDataType>(),
max.to<XPUDataType>());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "clip_grad");
}
} // namespace phi

PD_REGISTER_KERNEL(
clip_grad, XPU, ALL_LAYOUT, phi::ClipGradKernel, float, int) {}
clip_grad, XPU, ALL_LAYOUT, phi::ClipGradKernel, float, int64_t, int) {}
7 changes: 4 additions & 3 deletions paddle/phi/kernels/xpu/clip_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ void ClipKernel(const Context& dev_ctx,
x_data,
out_data,
x.numel(),
min.to<float>(),
max.to<float>());
min.to<XPUDataType>(),
max.to<XPUDataType>());

PADDLE_ENFORCE_EQ(r,
XPU_SUCCESS,
Expand All @@ -46,4 +46,5 @@ void ClipKernel(const Context& dev_ctx,

} // namespace phi

PD_REGISTER_KERNEL(clip, XPU, ALL_LAYOUT, phi::ClipKernel, float) {}
PD_REGISTER_KERNEL(
clip, XPU, ALL_LAYOUT, phi::ClipKernel, float, int64_t, int) {}
3 changes: 1 addition & 2 deletions paddle/phi/kernels/xpu/gather_nd_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,6 @@ void GatherNdGradKernel(const Context &ctx,
phi::DataType::INT32,
phi::DataType::INT64));

int index_size =
static_cast<int>(index.dims().size() == 0 ? 1 : index.dims()[0]);
auto x_shape = phi::vectorize<int64_t>(x_grad->dims());
auto index_shape = phi::vectorize<int64_t>(index.dims());
if (index_shape.size() == 1) {
Expand All @@ -70,6 +68,7 @@ void GatherNdGradKernel(const Context &ctx,
DenseTensor index_cpu(index.type());
phi::Copy(ctx, index, phi::CPUPlace(), false, &index_cpu);

int index_size = static_cast<int>(index.numel());
if (index_type == phi::DataType::INT32) {
auto index_data = const_cast<int *>(index.data<int>());
xpu::VectorParam<int> index_vec{
Expand Down
108 changes: 108 additions & 0 deletions paddle/phi/kernels/xpu/scatter_nd_add_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/scatter_nd_add_kernel.h"

#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"

namespace phi {
template <typename T, typename Context>
void ScatterNdAddKernel(const Context &ctx,
const DenseTensor &x,
const DenseTensor &index,
const DenseTensor &updates,
DenseTensor *out) {
const T *x_ptr = x.data<T>();
const T *updates_ptr = updates.data<T>();

T *out_ptr = ctx.template Alloc<T>(out);
int r = xpu::copy(ctx.x_context(), x_ptr, out_ptr, x.numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "copy");

if (updates.numel() == 0) return;

if (index.numel() == 0) {
int loop_time =
static_cast<int>(index.dims().size() == 0 ? 1 : index.dims()[0]);

for (int i = 0; i < loop_time; i++) {
// xpu::add only support float or float16 template typename
// now, register this op only with float type
r = xpu::add<T>(ctx.x_context(),
updates_ptr + out->numel() * i,
out_ptr,
out_ptr,
out->numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "copy");
}
return;
}

const phi::DataType index_type = index.dtype();
bool index_type_match =
index_type == phi::DataType::INT32 || index_type == phi::DataType::INT64;
PADDLE_ENFORCE_EQ(index_type_match,
true,
phi::errors::InvalidArgument(
"Index holds the wrong type, it holds [%s], but "
"desires to be [%s] or [%s].",
index_type,
phi::DataType::INT32,
phi::DataType::INT64));

auto x_shape = phi::vectorize<int64_t>(x.dims());
auto index_shape = phi::vectorize<int64_t>(index.dims());
if (index_shape.size() == 1) {
index_shape.insert(index_shape.begin(), 1);
}
xpu::VectorParam<int64_t> x_vec = {
x_shape.data(), static_cast<int>(x_shape.size()), nullptr};

DenseTensor index_cpu(index.type());
phi::Copy(ctx, index, phi::CPUPlace(), false, &index_cpu);

int index_size = static_cast<int>(index.numel());

if (index_type == phi::DataType::INT32) {
xpu::VectorParam<int> index_vec{index_cpu.data<int>(), index_size, nullptr};

r = xpu::scatter_nd<T, int>(ctx.x_context(),
nullptr,
updates_ptr,
out_ptr,
index_vec,
x_vec,
index_shape,
false);
} else {
xpu::VectorParam<int64_t> index_vec{
index_cpu.data<int64_t>(), index_size, nullptr};

r = xpu::scatter_nd<T, int64_t>(ctx.x_context(),
nullptr,
updates_ptr,
out_ptr,
index_vec,
x_vec,
index_shape,
false);
}

PADDLE_ENFORCE_XDNN_SUCCESS(r, "scatter_nd_add");
}
} // namespace phi

PD_REGISTER_KERNEL(
scatter_nd_add, XPU, ALL_LAYOUT, phi::ScatterNdAddKernel, float) {}
5 changes: 4 additions & 1 deletion paddle/phi/kernels/xpu/transpose_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,4 +58,7 @@ PD_REGISTER_KERNEL(transpose_grad,
ALL_LAYOUT,
phi::TransposeGradKernel,
float,
phi::dtype::float16) {}
phi::dtype::float16,
int64_t,
int,
bool) {}
5 changes: 4 additions & 1 deletion paddle/phi/kernels/xpu/transpose_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,7 @@ PD_REGISTER_KERNEL(transpose,
ALL_LAYOUT,
phi::TransposeKernel,
float,
phi::dtype::float16) {}
phi::dtype::float16,
int64_t,
int,
bool) {}
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def setUp(self):
self.outputs = {"Out": self.value}

def init_data(self):
self.value = np.random.random(size=(2, 5)).astype(self.dtype)
self.value = np.random.random(size=(2, 5)).astype(np.float32)
self.attrs["fp32_values"] = [float(v) for v in self.value.flat]

def test_forward(self):
Expand Down
12 changes: 12 additions & 0 deletions python/paddle/fluid/tests/unittests/xpu/test_gather_nd_op_xpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,18 @@ def init_data(self):
self.inp = np.array([1, 2]).astype("int64")
self.output = self.xnp[tuple(self.inp.T)]

class XPUTestGatherNdOpMultiDimIndex1(XPUTestGatherNdBase):
def init_data(self):
self.xnp = np.random.uniform(0, 100, (10, 10)).astype(self.in_type)
self.inp = np.array([2, 2]).astype("int32")
self.output = self.xnp[tuple(self.inp.T)]

class XPUTestGatherNdOpMultiDimIndex2(XPUTestGatherNdBase):
def init_data(self):
self.xnp = np.random.uniform(0, 100, (10, 10)).astype(self.in_type)
self.inp = np.array([2, 2]).astype("int64")
self.output = self.xnp[tuple(self.inp.T)]


support_types = get_xpu_op_support_types('gather_nd')
for stype in support_types:
Expand Down
Loading