Skip to content

Commit

Permalink
[XPU] support fp16 for c_embedding and c_embedding_grad (PaddlePaddle…
Browse files Browse the repository at this point in the history
…#58845)

* [XPU] support fp16 for c_embedding and c_embedding_grad

* bugfix

* minor

* add to kl3 oplist
  • Loading branch information
XiaociZhang authored Nov 10, 2023
1 parent 3d58705 commit aafbad4
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 20 deletions.
8 changes: 4 additions & 4 deletions paddle/phi/backends/xpu/xpu2_op_list.cc
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,10 @@ XPUOpMap& get_kl2_ops() {
XPUKernelSet({phi::DataType::FLOAT16, phi::DataType::FLOAT32})},
{"c_concat",
XPUKernelSet({phi::DataType::FLOAT16, phi::DataType::FLOAT32})},
{"c_embedding", XPUKernelSet({phi::DataType::FLOAT32})},
{"c_embedding_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"c_embedding",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"c_embedding_grad",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"c_identity",
XPUKernelSet({phi::DataType::FLOAT16,
phi::DataType::FLOAT32,
Expand Down Expand Up @@ -1107,8 +1109,6 @@ XPUOpMap& get_kl2_ops() {
XPUKernelSet({phi::DataType::FLOAT16,
phi::DataType::FLOAT32,
phi::DataType::INT32})},
{"c_embedding",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
};

return s_xpu2_kernels;
Expand Down
6 changes: 4 additions & 2 deletions paddle/phi/backends/xpu/xpu3_op_list.cc
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,10 @@ XPUOpMap& get_kl3_ops() {
XPUKernelSet({phi::DataType::FLOAT16, phi::DataType::FLOAT32})},
{"c_concat",
XPUKernelSet({phi::DataType::FLOAT16, phi::DataType::FLOAT32})},
{"c_embedding", XPUKernelSet({phi::DataType::FLOAT32})},
{"c_embedding_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"c_embedding",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"c_embedding_grad",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"c_identity",
XPUKernelSet({phi::DataType::FLOAT16,
phi::DataType::FLOAT32,
Expand Down
17 changes: 11 additions & 6 deletions paddle/phi/kernels/xpu/c_embedding_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ void CEmbeddingKernel(const Context& dev_ctx,
DenseTensor* out) {
const T* table_data = w.data<T>();
T* output_data = dev_ctx.template Alloc<T>(out);
using XPUType = typename XPUTypeTrait<T>::Type;

const int64_t height = w.dims()[0];
const int64_t width = w.dims()[1];
Expand All @@ -41,9 +42,9 @@ void CEmbeddingKernel(const Context& dev_ctx,
const auto& index_type = ids.dtype();
if (index_type == phi::DataType::INT32) {
int r = xpu::embedding(dev_ctx.x_context(),
table_data,
reinterpret_cast<const XPUType*>(table_data),
ids.data<int32_t>(),
output_data,
reinterpret_cast<XPUType*>(output_data),
height,
width,
ids.numel(),
Expand All @@ -52,9 +53,9 @@ void CEmbeddingKernel(const Context& dev_ctx,
PADDLE_ENFORCE_XDNN_SUCCESS(r, "embedding");
} else if (index_type == phi::DataType::INT64) {
int r = xpu::embedding(dev_ctx.x_context(),
table_data,
reinterpret_cast<const XPUType*>(table_data),
ids.data<int64_t>(),
output_data,
reinterpret_cast<XPUType*>(output_data),
height,
width,
ids.numel(),
Expand All @@ -68,5 +69,9 @@ void CEmbeddingKernel(const Context& dev_ctx,
}
} // namespace phi

PD_REGISTER_KERNEL(c_embedding, XPU, ALL_LAYOUT, phi::CEmbeddingKernel, float) {
}
PD_REGISTER_KERNEL(c_embedding,
XPU,
ALL_LAYOUT,
phi::CEmbeddingKernel,
float,
phi::dtype::float16) {}
23 changes: 15 additions & 8 deletions paddle/phi/kernels/xpu/c_embedding_kernel_grad.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ void CEmbeddingGradKernel(const Context& dev_ctx,
w_grad->Resize(w.dims());
dev_ctx.template Alloc(w_grad, w.dtype());
T* table_grad_data = static_cast<T*>(w_grad->data());
using XPUType = typename XPUTypeTrait<T>::Type;

size_t table_t_mem_size = w.numel() * phi::SizeOf(w_grad->dtype());
size_t table_grad_t_mem_size = w_grad->numel() * phi::SizeOf(w_grad->dtype());
Expand All @@ -40,8 +41,10 @@ void CEmbeddingGradKernel(const Context& dev_ctx,
<< ", table_grad_t memory_size:" << table_grad_t_mem_size
<< ", start_index:" << start_index;

int r = xpu::constant(
dev_ctx.x_context(), table_grad_data, w_grad->numel(), (T)0);
int r = xpu::constant(dev_ctx.x_context(),
reinterpret_cast<XPUType*>(table_grad_data),
w_grad->numel(),
(XPUType)0);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant");
const T* d_output_data = out_grad.data<T>();

Expand All @@ -51,19 +54,19 @@ void CEmbeddingGradKernel(const Context& dev_ctx,
const auto& index_type = ids.dtype();
if (index_type == phi::DataType::INT32) {
r = xpu::embedding_grad(dev_ctx.x_context(),
d_output_data,
reinterpret_cast<const XPUType*>(d_output_data),
ids.data<int32_t>(),
table_grad_data,
reinterpret_cast<XPUType*>(table_grad_data),
height,
width,
ids.numel(),
-1,
static_cast<int32_t>(start_index));
} else if (index_type == phi::DataType::INT64) {
r = xpu::embedding_grad(dev_ctx.x_context(),
d_output_data,
reinterpret_cast<const XPUType*>(d_output_data),
ids.data<int64_t>(),
table_grad_data,
reinterpret_cast<XPUType*>(table_grad_data),
height,
width,
ids.numel(),
Expand All @@ -78,5 +81,9 @@ void CEmbeddingGradKernel(const Context& dev_ctx,

} // namespace phi

PD_REGISTER_KERNEL(
c_embedding_grad, XPU, ALL_LAYOUT, phi::CEmbeddingGradKernel, float) {}
PD_REGISTER_KERNEL(c_embedding_grad,
XPU,
ALL_LAYOUT,
phi::CEmbeddingGradKernel,
float,
phi::dtype::float16) {}

0 comments on commit aafbad4

Please sign in to comment.