Skip to content

Commit

Permalink
swish, swish_grad add bfloat16 datatype support for XPU (PaddlePaddle…
Browse files Browse the repository at this point in the history
  • Loading branch information
ykkk2333 authored Nov 2, 2023
1 parent bf4de60 commit e30b891
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 7 deletions.
2 changes: 1 addition & 1 deletion cmake/external/xpu.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ set(XPU_XFT_LIB_NAME "libxft.so")
set(XPU_XPTI_LIB_NAME "libxpti.so")

if(NOT DEFINED XPU_BASE_DATE)
set(XPU_BASE_DATE "20231023")
set(XPU_BASE_DATE "20231025")
endif()
set(XPU_XCCL_BASE_VERSION "1.0.53.6")
if(NOT DEFINED XPU_XFT_BASE_VERSION)
Expand Down
3 changes: 2 additions & 1 deletion paddle/phi/backends/xpu/xpu2_op_list.cc
Original file line number Diff line number Diff line change
Expand Up @@ -864,7 +864,8 @@ XPUOpMap& get_kl2_ops() {
phi::DataType::INT32})},
{"sum", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"swish", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"swish_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"swish_grad",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"take_along_axis",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"tanh_grad",
Expand Down
10 changes: 8 additions & 2 deletions paddle/phi/backends/xpu/xpu3_op_list.cc
Original file line number Diff line number Diff line change
Expand Up @@ -839,8 +839,14 @@ XPUOpMap& get_kl3_ops() {
phi::DataType::INT16,
phi::DataType::INT32})},
{"sum", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"swish", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"swish_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"swish",
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::FLOAT16,
phi::DataType::BFLOAT16})},
{"swish_grad",
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::FLOAT16,
phi::DataType::BFLOAT16})},
{"take_along_axis",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"tanh_grad",
Expand Down
9 changes: 8 additions & 1 deletion paddle/phi/kernels/xpu/activation_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -700,6 +700,14 @@ PD_REGISTER_KERNEL(square_grad,
float,
phi::dtype::float16) {}

PD_REGISTER_KERNEL(swish_grad,
XPU,
ALL_LAYOUT,
phi::SwishGradKernel,
float,
phi::dtype::float16,
phi::dtype::bfloat16) {}

PD_REGISTER_ACTIVATION_GRAD_KERNEL(exp_grad, ExpGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(log_grad, LogGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(leaky_relu_grad, LeakyReluGradKernel)
Expand All @@ -710,7 +718,6 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(relu6_grad, Relu6GradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(sigmoid_grad, SigmoidGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(sqrt_grad, SqrtGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(mish_grad, MishGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(swish_grad, SwishGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(softplus_grad, SoftplusGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(sin_grad, SinGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(cos_grad, CosGradKernel)
Expand Down
9 changes: 7 additions & 2 deletions paddle/phi/kernels/xpu/activation_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -566,8 +566,13 @@ PD_REGISTER_KERNEL(
elu, XPU, ALL_LAYOUT, phi::EluKernel, float, phi::dtype::float16) {}
PD_REGISTER_KERNEL(
sigmoid, XPU, ALL_LAYOUT, phi::SigmoidKernel, float, phi::dtype::float16) {}
PD_REGISTER_KERNEL(
swish, XPU, ALL_LAYOUT, phi::SwishKernel, float, phi::dtype::float16) {}
PD_REGISTER_KERNEL(swish,
XPU,
ALL_LAYOUT,
phi::SwishKernel,
float,
phi::dtype::float16,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(hardsigmoid,
XPU,
ALL_LAYOUT,
Expand Down

0 comments on commit e30b891

Please sign in to comment.