Skip to content

Commit

Permalink
merge elementwise kernels under legacy/kps (PaddlePaddle#57261)
Browse files Browse the repository at this point in the history
  • Loading branch information
tianhaodongbd authored Sep 15, 2023
1 parent bd5a814 commit c60bb47
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 214 deletions.
51 changes: 0 additions & 51 deletions paddle/phi/kernels/legacy/kps/elementwise_add_kernel.cu

This file was deleted.

52 changes: 0 additions & 52 deletions paddle/phi/kernels/legacy/kps/elementwise_divide_kernel.cu

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand All @@ -22,6 +22,17 @@

namespace phi {

DEFINE_CUDA_ELEMENTWISE_OP(Add)

// Create the definition of Divide
DEFINE_CUDA_ELEMENTWISE_OP(Divide)

// Create the definition of Multiply
DEFINE_CUDA_ELEMENTWISE_OP(Multiply)

// Create the definition of Subtract
DEFINE_CUDA_ELEMENTWISE_OP(Subtract)

template <typename T, typename Context>
void MaximumRawKernel(const Context& dev_ctx,
const DenseTensor& x,
Expand Down Expand Up @@ -115,6 +126,12 @@ void ElementwisePowRawKernel(const Context& dev_ctx,
} // namespace phi

#ifdef PADDLE_WITH_XPU_KP
PD_REGISTER_KERNEL(add_raw, KPS, ALL_LAYOUT, phi::AddRawKernel, float) {}
PD_REGISTER_KERNEL(divide_raw, KPS, ALL_LAYOUT, phi::DivideRawKernel, float) {}
PD_REGISTER_KERNEL(
multiply_raw, KPS, ALL_LAYOUT, phi::MultiplyRawKernel, float) {}
PD_REGISTER_KERNEL(
subtract_raw, KPS, ALL_LAYOUT, phi::SubtractRawKernel, float) {}
PD_REGISTER_KERNEL(maximum_raw, KPS, ALL_LAYOUT, phi::MaximumRawKernel, float) {
}
PD_REGISTER_KERNEL(minimum_raw, KPS, ALL_LAYOUT, phi::MinimumRawKernel, float) {
Expand All @@ -124,10 +141,67 @@ PD_REGISTER_KERNEL(
PD_REGISTER_KERNEL(
elementwise_pow_raw, KPS, ALL_LAYOUT, phi::ElementwisePowRawKernel, float) {
}

#else

using float16 = phi::dtype::float16;
using bfloat16 = phi::dtype::bfloat16;
using complex64 = ::phi::dtype::complex<float>;
using complex128 = ::phi::dtype::complex<double>;

PD_REGISTER_KERNEL(add_raw,
KPS,
ALL_LAYOUT,
phi::AddRawKernel,
float,
double,
int16_t,
int,
int64_t,
float16,
bfloat16,
complex64,
complex128) {}

PD_REGISTER_KERNEL(divide_raw,
KPS,
ALL_LAYOUT,
phi::DivideRawKernel,
float,
double,
int,
int64_t,
float16,
bfloat16,
complex64,
complex128) {}

PD_REGISTER_KERNEL(multiply_raw,
KPS,
ALL_LAYOUT,
phi::MultiplyRawKernel,
float,
double,
int,
int64_t,
bool,
float16,
complex64,
complex128,
bfloat16) {}

PD_REGISTER_KERNEL(subtract_raw,
KPS,
ALL_LAYOUT,
phi::SubtractRawKernel,
float,
double,
int16_t,
int,
int64_t,
float16,
bfloat16,
complex64,
complex128) {}

PD_REGISTER_KERNEL(maximum_raw,
KPS,
Expand Down Expand Up @@ -182,5 +256,4 @@ PD_REGISTER_KERNEL(elementwise_pow_raw,
float16,
int64_t,
bfloat16) {}

#endif
54 changes: 0 additions & 54 deletions paddle/phi/kernels/legacy/kps/elementwise_multiply_kernel.cu

This file was deleted.

54 changes: 0 additions & 54 deletions paddle/phi/kernels/legacy/kps/elementwise_subtract_kernel.cu

This file was deleted.

0 comments on commit c60bb47

Please sign in to comment.