Skip to content

Commit

Permalink
[XPU] sqrt, transpose, transpose_grad support bf16 type (#58419)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangyk0314 authored Oct 27, 2023
1 parent 59b137d commit 5c7bd35
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 4 deletions.
9 changes: 8 additions & 1 deletion paddle/phi/backends/xpu/xpu2_op_list.cc
Original file line number Diff line number Diff line change
Expand Up @@ -794,7 +794,10 @@ XPUOpMap& get_kl2_ops() {
phi::DataType::BFLOAT16,
phi::DataType::INT32,
phi::DataType::INT64})},
{"sqrt", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"sqrt",
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::FLOAT16,
phi::DataType::BFLOAT16})},
{"sqrt_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"square_grad",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
Expand Down Expand Up @@ -911,24 +914,28 @@ XPUOpMap& get_kl2_ops() {
{"transpose2_grad",
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::FLOAT16,
phi::DataType::BFLOAT16,
phi::DataType::INT64,
phi::DataType::INT32,
phi::DataType::BOOL})},
{"transpose2",
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::FLOAT16,
phi::DataType::BFLOAT16,
phi::DataType::INT64,
phi::DataType::INT32,
phi::DataType::BOOL})},
{"transpose_grad",
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::FLOAT16,
phi::DataType::BFLOAT16,
phi::DataType::INT64,
phi::DataType::INT32,
phi::DataType::BOOL})},
{"transpose",
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::FLOAT16,
phi::DataType::BFLOAT16,
phi::DataType::INT64,
phi::DataType::INT32,
phi::DataType::BOOL})},
Expand Down
11 changes: 10 additions & 1 deletion paddle/phi/backends/xpu/xpu3_op_list.cc
Original file line number Diff line number Diff line change
Expand Up @@ -768,7 +768,12 @@ XPUOpMap& get_kl3_ops() {
phi::DataType::BFLOAT16,
phi::DataType::INT32,
phi::DataType::INT64})},
{"sqrt", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"sqrt",
XPUKernelSet({
phi::DataType::FLOAT32,
phi::DataType::FLOAT16,
phi::DataType::BFLOAT16,
})},
{"sqrt_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"square_grad",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
Expand Down Expand Up @@ -885,24 +890,28 @@ XPUOpMap& get_kl3_ops() {
{"transpose2_grad",
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::FLOAT16,
phi::DataType::BFLOAT16,
phi::DataType::INT64,
phi::DataType::INT32,
phi::DataType::BOOL})},
{"transpose2",
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::FLOAT16,
phi::DataType::BFLOAT16,
phi::DataType::INT64,
phi::DataType::INT32,
phi::DataType::BOOL})},
{"transpose_grad",
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::FLOAT16,
phi::DataType::BFLOAT16,
phi::DataType::INT64,
phi::DataType::INT32,
phi::DataType::BOOL})},
{"transpose",
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::FLOAT16,
phi::DataType::BFLOAT16,
phi::DataType::INT64,
phi::DataType::INT32,
phi::DataType::BOOL})},
Expand Down
9 changes: 7 additions & 2 deletions paddle/phi/kernels/xpu/activation_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -580,8 +580,13 @@ PD_REGISTER_KERNEL(leaky_relu,
phi::LeakyReluKernel,
float,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(
sqrt, XPU, ALL_LAYOUT, phi::SqrtKernel, float, phi::dtype::float16) {}
PD_REGISTER_KERNEL(sqrt,
XPU,
ALL_LAYOUT,
phi::SqrtKernel,
float,
phi::dtype::float16,
phi::dtype::bfloat16) {}

PD_REGISTER_KERNEL(
tanh, XPU, ALL_LAYOUT, phi::TanhKernel, float, phi::dtype::float16) {}
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/kernels/xpu/transpose_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ PD_REGISTER_KERNEL(transpose_grad,
phi::TransposeGradKernel,
float,
phi::dtype::float16,
phi::dtype::bfloat16,
int64_t,
int,
bool) {}
1 change: 1 addition & 0 deletions paddle/phi/kernels/xpu/transpose_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ PD_REGISTER_KERNEL(transpose,
phi::TransposeKernel,
float,
phi::dtype::float16,
phi::dtype::bfloat16,
int64_t,
int,
bool) {}

0 comments on commit 5c7bd35

Please sign in to comment.