Skip to content

Commit

Permalink
[XPU]add fp16 kernels (#54410)
Browse files Browse the repository at this point in the history
  • Loading branch information
wz1qqx authored Jun 8, 2023
1 parent 168fac1 commit fd9c555
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 10 deletions.
6 changes: 4 additions & 2 deletions paddle/phi/backends/xpu/xpu2_op_list.cc
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ XPUOpMap& get_kl2_ops() {
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"clip",
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::FLOAT16,
phi::DataType::INT64,
phi::DataType::INT32})},
{"clip_by_norm", XPUKernelSet({phi::DataType::FLOAT32})},
Expand Down Expand Up @@ -188,7 +189,8 @@ XPUOpMap& get_kl2_ops() {
{"deformable_conv_v1_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"deformable_conv_v1", XPUKernelSet({phi::DataType::FLOAT32})},
{"depthwise_conv2d_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"depthwise_conv2d", XPUKernelSet({phi::DataType::FLOAT32})},
{"depthwise_conv2d",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"depthwise_conv2d_transpose_grad",
XPUKernelSet({phi::DataType::FLOAT32})},
{"depthwise_conv2d_transpose",
Expand Down Expand Up @@ -599,7 +601,7 @@ XPUOpMap& get_kl2_ops() {
phi::DataType::INT32,
phi::DataType::INT8,
phi::DataType::FLOAT32})},
{"relu6", XPUKernelSet({phi::DataType::FLOAT32})},
{"relu6", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"relu6_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"relu_grad",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
Expand Down
3 changes: 2 additions & 1 deletion paddle/phi/kernels/activation_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ PD_REGISTER_KERNEL(swish,
#endif

#if defined PADDLE_WITH_XPU
PD_REGISTER_KERNEL(relu6, XPU, ALL_LAYOUT, phi::Relu6Kernel, float) {}
PD_REGISTER_KERNEL(
relu6, XPU, ALL_LAYOUT, phi::Relu6Kernel, float, phi::dtype::float16) {}
PD_REGISTER_KERNEL(
swish, XPU, ALL_LAYOUT, phi::SwishKernel, float, phi::dtype::float16) {}
#endif
Expand Down
8 changes: 7 additions & 1 deletion paddle/phi/kernels/xpu/activation_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -572,6 +572,13 @@ PD_REGISTER_KERNEL(
PD_REGISTER_KERNEL(
log, XPU, ALL_LAYOUT, phi::LogKernel, float, phi::dtype::float16) {}

PD_REGISTER_KERNEL(relu6_raw,
XPU,
ALL_LAYOUT,
phi::Relu6RawKernel,
float,
phi::dtype::float16) {}

#define PD_REGISTER_ACTIVATION_KERNEL(name, func) \
PD_REGISTER_KERNEL(name, XPU, ALL_LAYOUT, phi::func, float) {}

Expand All @@ -581,7 +588,6 @@ PD_REGISTER_ACTIVATION_KERNEL(hardswish, HardSwishKernel)
PD_REGISTER_ACTIVATION_KERNEL(mish, MishKernel)
PD_REGISTER_ACTIVATION_KERNEL(pow, PowKernel)
PD_REGISTER_ACTIVATION_KERNEL(reciprocal, ReciprocalKernel)
PD_REGISTER_ACTIVATION_KERNEL(relu6_raw, Relu6RawKernel)
PD_REGISTER_ACTIVATION_KERNEL(softplus, SoftplusKernel)
PD_REGISTER_ACTIVATION_KERNEL(sin, SinKernel)
PD_REGISTER_ACTIVATION_KERNEL(cos, CosKernel)
17 changes: 13 additions & 4 deletions paddle/phi/kernels/xpu/clip_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
// limitations under the License.

#include "paddle/phi/kernels/clip_kernel.h"

#include "glog/logging.h"

#include "paddle/phi/backends/xpu/xpu_context.h"
#include "paddle/phi/backends/xpu/xpu_header.h"
#include "paddle/phi/core/kernel_registry.h"
Expand All @@ -33,8 +36,8 @@ void ClipKernel(const Context& dev_ctx,
x_data,
out_data,
x.numel(),
min.to<XPUDataType>(),
max.to<XPUDataType>());
static_cast<XPUDataType>(min.to<T>()),
static_cast<XPUDataType>(max.to<T>()));

PADDLE_ENFORCE_EQ(r,
XPU_SUCCESS,
Expand All @@ -46,5 +49,11 @@ void ClipKernel(const Context& dev_ctx,

} // namespace phi

PD_REGISTER_KERNEL(
clip, XPU, ALL_LAYOUT, phi::ClipKernel, float, int64_t, int) {}
PD_REGISTER_KERNEL(clip,
XPU,
ALL_LAYOUT,
phi::ClipKernel,
float,
phi::dtype::float16,
int64_t,
int) {}
8 changes: 6 additions & 2 deletions paddle/phi/kernels/xpu/conv_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,11 @@ void Conv3DKernel(const Context& dev_ctx,

PD_REGISTER_KERNEL(
conv2d, XPU, ALL_LAYOUT, phi::ConvKernel, float, phi::dtype::float16) {}
PD_REGISTER_KERNEL(
depthwise_conv2d, XPU, ALL_LAYOUT, phi::DepthwiseConvKernel, float) {}
PD_REGISTER_KERNEL(depthwise_conv2d,
XPU,
ALL_LAYOUT,
phi::DepthwiseConvKernel,
float,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(
conv3d, XPU, ALL_LAYOUT, phi::Conv3DKernel, float, phi::dtype::float16) {}

0 comments on commit fd9c555

Please sign in to comment.