diff --git a/paddle/phi/backends/xpu/xpu2_op_list.cc b/paddle/phi/backends/xpu/xpu2_op_list.cc index ce02232291e03..c5d6c36ad26ef 100644 --- a/paddle/phi/backends/xpu/xpu2_op_list.cc +++ b/paddle/phi/backends/xpu/xpu2_op_list.cc @@ -794,7 +794,10 @@ XPUOpMap& get_kl2_ops() { phi::DataType::BFLOAT16, phi::DataType::INT32, phi::DataType::INT64})}, - {"sqrt", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"sqrt", + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::FLOAT16, + phi::DataType::BFLOAT16})}, {"sqrt_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"square_grad", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, @@ -911,24 +914,28 @@ XPUOpMap& get_kl2_ops() { {"transpose2_grad", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16, + phi::DataType::BFLOAT16, phi::DataType::INT64, phi::DataType::INT32, phi::DataType::BOOL})}, {"transpose2", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16, + phi::DataType::BFLOAT16, phi::DataType::INT64, phi::DataType::INT32, phi::DataType::BOOL})}, {"transpose_grad", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16, + phi::DataType::BFLOAT16, phi::DataType::INT64, phi::DataType::INT32, phi::DataType::BOOL})}, {"transpose", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16, + phi::DataType::BFLOAT16, phi::DataType::INT64, phi::DataType::INT32, phi::DataType::BOOL})}, diff --git a/paddle/phi/backends/xpu/xpu3_op_list.cc b/paddle/phi/backends/xpu/xpu3_op_list.cc index 7366ec185c33a..858f7189cae6d 100644 --- a/paddle/phi/backends/xpu/xpu3_op_list.cc +++ b/paddle/phi/backends/xpu/xpu3_op_list.cc @@ -768,7 +768,12 @@ XPUOpMap& get_kl3_ops() { phi::DataType::BFLOAT16, phi::DataType::INT32, phi::DataType::INT64})}, - {"sqrt", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"sqrt", + XPUKernelSet({ + phi::DataType::FLOAT32, + phi::DataType::FLOAT16, + phi::DataType::BFLOAT16, + })}, {"sqrt_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"square_grad", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, @@ -885,24 +890,28 @@ XPUOpMap& get_kl3_ops() { {"transpose2_grad", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16, + phi::DataType::BFLOAT16, phi::DataType::INT64, phi::DataType::INT32, phi::DataType::BOOL})}, {"transpose2", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16, + phi::DataType::BFLOAT16, phi::DataType::INT64, phi::DataType::INT32, phi::DataType::BOOL})}, {"transpose_grad", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16, + phi::DataType::BFLOAT16, phi::DataType::INT64, phi::DataType::INT32, phi::DataType::BOOL})}, {"transpose", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16, + phi::DataType::BFLOAT16, phi::DataType::INT64, phi::DataType::INT32, phi::DataType::BOOL})}, diff --git a/paddle/phi/kernels/xpu/activation_kernel.cc b/paddle/phi/kernels/xpu/activation_kernel.cc index efac9b30ae2eb..c9b1136793e5e 100644 --- a/paddle/phi/kernels/xpu/activation_kernel.cc +++ b/paddle/phi/kernels/xpu/activation_kernel.cc @@ -580,8 +580,13 @@ PD_REGISTER_KERNEL(leaky_relu, phi::LeakyReluKernel, float, phi::dtype::float16) {} -PD_REGISTER_KERNEL( - sqrt, XPU, ALL_LAYOUT, phi::SqrtKernel, float, phi::dtype::float16) {} +PD_REGISTER_KERNEL(sqrt, + XPU, + ALL_LAYOUT, + phi::SqrtKernel, + float, + phi::dtype::float16, + phi::dtype::bfloat16) {} PD_REGISTER_KERNEL( tanh, XPU, ALL_LAYOUT, phi::TanhKernel, float, phi::dtype::float16) {} diff --git a/paddle/phi/kernels/xpu/transpose_grad_kernel.cc b/paddle/phi/kernels/xpu/transpose_grad_kernel.cc index 043d2c8e3df5a..71b2187bddce1 100644 --- a/paddle/phi/kernels/xpu/transpose_grad_kernel.cc +++ b/paddle/phi/kernels/xpu/transpose_grad_kernel.cc @@ -65,6 +65,7 @@ PD_REGISTER_KERNEL(transpose_grad, phi::TransposeGradKernel, float, phi::dtype::float16, + phi::dtype::bfloat16, int64_t, int, bool) {} diff --git a/paddle/phi/kernels/xpu/transpose_kernel.cc b/paddle/phi/kernels/xpu/transpose_kernel.cc index 398a2281dcea8..dd985ddc7ebc5 100644 --- a/paddle/phi/kernels/xpu/transpose_kernel.cc +++ b/paddle/phi/kernels/xpu/transpose_kernel.cc @@ -60,6 +60,7 @@ PD_REGISTER_KERNEL(transpose, phi::TransposeKernel, float, phi::dtype::float16, + phi::dtype::bfloat16, int64_t, int, bool) {}