Skip to content

Commit

Permalink
[PTen] Rename kernel register marco (#38861)
Browse files Browse the repository at this point in the history
* rename register marco

* fix error changing

* fix format error
  • Loading branch information
chenwhql authored Jan 13, 2022
1 parent dccdc71 commit 158bf13
Show file tree
Hide file tree
Showing 25 changed files with 636 additions and 1,193 deletions.
6 changes: 3 additions & 3 deletions cmake/pten_kernel.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@
function(kernel_declare TARGET_LIST)
foreach(kernel_path ${TARGET_LIST})
file(READ ${kernel_path} kernel_impl)
# TODO(chenweihang): rename PT_REGISTER_CTX_KERNEL to PT_REGISTER_KERNEL
# TODO(chenweihang): rename PT_REGISTER_KERNEL to PT_REGISTER_KERNEL
# NOTE(chenweihang): now we don't recommend to use digit in kernel name
string(REGEX MATCH "(PT_REGISTER_CTX_KERNEL|PT_REGISTER_GENERAL_KERNEL)\\([ \t\r\n]*[a-z0-9_]*," first_registry "${kernel_impl}")
string(REGEX MATCH "(PT_REGISTER_KERNEL|PT_REGISTER_GENERAL_KERNEL)\\([ \t\r\n]*[a-z0-9_]*," first_registry "${kernel_impl}")
if (NOT first_registry STREQUAL "")
# parse the first kernel name
string(REPLACE "PT_REGISTER_CTX_KERNEL(" "" kernel_name "${first_registry}")
string(REPLACE "PT_REGISTER_KERNEL(" "" kernel_name "${first_registry}")
string(REPLACE "PT_REGISTER_GENERAL_KERNEL(" "" kernel_name "${kernel_name}")
string(REPLACE "," "" kernel_name "${kernel_name}")
string(REGEX REPLACE "[ \t\r\n]+" "" kernel_name "${kernel_name}")
Expand Down
820 changes: 132 additions & 688 deletions paddle/pten/core/kernel_registry.h

Large diffs are not rendered by default.

30 changes: 15 additions & 15 deletions paddle/pten/kernels/cpu/cast_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,20 +58,20 @@ void CastKernel(const Context& dev_ctx,

} // namespace pten

PT_REGISTER_CTX_KERNEL(cast,
CPU,
ALL_LAYOUT,
pten::CastKernel,
float,
double,
int,
int64_t,
int16_t,
bool,
uint8_t,
paddle::platform::float16,
paddle::platform::bfloat16,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {
PT_REGISTER_KERNEL(cast,
CPU,
ALL_LAYOUT,
pten::CastKernel,
float,
double,
int,
int64_t,
int16_t,
bool,
uint8_t,
paddle::platform::float16,
paddle::platform::bfloat16,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED);
}
20 changes: 10 additions & 10 deletions paddle/pten/kernels/cpu/complex_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@
// See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/platform/complex.h"

PT_REGISTER_CTX_KERNEL(conj,
CPU,
ALL_LAYOUT,
pten::ConjKernel,
paddle::platform::complex<float>,
paddle::platform::complex<double>,
float,
double,
int,
int64_t) {}
PT_REGISTER_KERNEL(conj,
CPU,
ALL_LAYOUT,
pten::ConjKernel,
paddle::platform::complex<float>,
paddle::platform::complex<double>,
float,
double,
int,
int64_t) {}
20 changes: 10 additions & 10 deletions paddle/pten/kernels/cpu/dot_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,13 @@

#include "paddle/fluid/platform/complex.h"

PT_REGISTER_CTX_KERNEL(dot_grad,
CPU,
ALL_LAYOUT,
pten::DotGradKernel,
float,
double,
int,
int64_t,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {}
PT_REGISTER_KERNEL(dot_grad,
CPU,
ALL_LAYOUT,
pten::DotGradKernel,
float,
double,
int,
int64_t,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {}
20 changes: 10 additions & 10 deletions paddle/pten/kernels/cpu/dot_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,13 @@ void DotKernel(const Context& dev_ctx,
using complex64 = ::paddle::platform::complex<float>;
using complex128 = ::paddle::platform::complex<double>;

PT_REGISTER_CTX_KERNEL(dot,
CPU,
ALL_LAYOUT,
pten::DotKernel,
float,
double,
int,
int64_t,
complex64,
complex128) {}
PT_REGISTER_KERNEL(dot,
CPU,
ALL_LAYOUT,
pten::DotKernel,
float,
double,
int,
int64_t,
complex64,
complex128) {}
50 changes: 25 additions & 25 deletions paddle/pten/kernels/cpu/full_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,29 +18,29 @@ limitations under the License. */
#include "paddle/pten/core/kernel_registry.h"
#include "paddle/pten/kernels/impl/full_kernel_impl.h"

PT_REGISTER_CTX_KERNEL(full,
CPU,
ALL_LAYOUT,
pten::FullKernel,
float,
double,
uint8_t,
int16_t,
int,
int64_t,
bool,
paddle::platform::float16,
paddle::platform::bfloat16,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {}
PT_REGISTER_KERNEL(full,
CPU,
ALL_LAYOUT,
pten::FullKernel,
float,
double,
uint8_t,
int16_t,
int,
int64_t,
bool,
paddle::platform::float16,
paddle::platform::bfloat16,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {}

PT_REGISTER_CTX_KERNEL(full_like,
CPU,
ALL_LAYOUT,
pten::FullLikeKernel,
float,
double,
int,
int64_t,
bool,
paddle::platform::float16) {}
PT_REGISTER_KERNEL(full_like,
CPU,
ALL_LAYOUT,
pten::FullLikeKernel,
float,
double,
int,
int64_t,
bool,
paddle::platform::float16) {}
108 changes: 54 additions & 54 deletions paddle/pten/kernels/cpu/math_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -118,60 +118,60 @@ using complex128 = ::paddle::platform::complex<double>;

// NOTE(chenweihang): using bfloat16 will cause redefine with xpu bfloat16
// using bfloat16 = ::paddle::platform::bfloat16;
PT_REGISTER_CTX_KERNEL(
PT_REGISTER_KERNEL(
mean, CPU, ALL_LAYOUT, pten::MeanKernel, float, double, bool) {}
PT_REGISTER_CTX_KERNEL(add,
CPU,
ALL_LAYOUT,
pten::AddKernel,
float,
double,
int,
int64_t,
complex64,
complex128) {}
PT_REGISTER_CTX_KERNEL(subtract,
CPU,
ALL_LAYOUT,
pten::SubtractKernel,
float,
double,
int,
int64_t,
complex64,
complex128) {}
PT_REGISTER_CTX_KERNEL(divide,
CPU,
ALL_LAYOUT,
pten::DivideKernel,
float,
double,
int,
int64_t,
complex64,
complex128) {}
PT_REGISTER_CTX_KERNEL(multiply,
CPU,
ALL_LAYOUT,
pten::MultiplyKernel,
float,
double,
int,
int64_t,
bool,
complex64,
complex128) {}
PT_REGISTER_CTX_KERNEL(sum,
CPU,
ALL_LAYOUT,
pten::SumKernel,
bool,
float,
double,
paddle::platform::float16,
int,
int64_t,
complex64,
complex128) {
PT_REGISTER_KERNEL(add,
CPU,
ALL_LAYOUT,
pten::AddKernel,
float,
double,
int,
int64_t,
complex64,
complex128) {}
PT_REGISTER_KERNEL(subtract,
CPU,
ALL_LAYOUT,
pten::SubtractKernel,
float,
double,
int,
int64_t,
complex64,
complex128) {}
PT_REGISTER_KERNEL(divide,
CPU,
ALL_LAYOUT,
pten::DivideKernel,
float,
double,
int,
int64_t,
complex64,
complex128) {}
PT_REGISTER_KERNEL(multiply,
CPU,
ALL_LAYOUT,
pten::MultiplyKernel,
float,
double,
int,
int64_t,
bool,
complex64,
complex128) {}
PT_REGISTER_KERNEL(sum,
CPU,
ALL_LAYOUT,
pten::SumKernel,
bool,
float,
double,
paddle::platform::float16,
int,
int64_t,
complex64,
complex128) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED);
}
52 changes: 26 additions & 26 deletions paddle/pten/kernels/cpu/matmul_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,29 +19,29 @@ limitations under the License. */

#include "paddle/pten/kernels/impl/matmul_grad_kernel_impl.h"

PT_REGISTER_CTX_KERNEL(matmul_grad,
CPU,
ALL_LAYOUT,
pten::MatmulGradKernel,
float,
double,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {}

PT_REGISTER_CTX_KERNEL(matmul_double_grad,
CPU,
ALL_LAYOUT,
pten::MatmulDoubleGradKernel,
float,
double,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {}

PT_REGISTER_CTX_KERNEL(matmul_triple_grad,
CPU,
ALL_LAYOUT,
pten::MatmulTripleGradKernel,
float,
double,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {}
PT_REGISTER_KERNEL(matmul_grad,
CPU,
ALL_LAYOUT,
pten::MatmulGradKernel,
float,
double,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {}

PT_REGISTER_KERNEL(matmul_double_grad,
CPU,
ALL_LAYOUT,
pten::MatmulDoubleGradKernel,
float,
double,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {}

PT_REGISTER_KERNEL(matmul_triple_grad,
CPU,
ALL_LAYOUT,
pten::MatmulTripleGradKernel,
float,
double,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {}
16 changes: 8 additions & 8 deletions paddle/pten/kernels/cpu/matmul_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@ limitations under the License. */
#include "paddle/fluid/platform/complex.h"
#include "paddle/pten/kernels/impl/matmul_kernel_impl.h"

PT_REGISTER_CTX_KERNEL(matmul,
CPU,
ALL_LAYOUT,
pten::MatmulKernel,
float,
double,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {}
PT_REGISTER_KERNEL(matmul,
CPU,
ALL_LAYOUT,
pten::MatmulKernel,
float,
double,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {}
24 changes: 12 additions & 12 deletions paddle/pten/kernels/cpu/scale_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,15 @@ void ScaleKernel(const Context& dev_ctx,

} // namespace pten

PT_REGISTER_CTX_KERNEL(scale,
CPU,
ALL_LAYOUT,
pten::ScaleKernel,
float,
double,
paddle::platform::bfloat16,
uint8_t,
int8_t,
int16_t,
int,
int64_t) {}
PT_REGISTER_KERNEL(scale,
CPU,
ALL_LAYOUT,
pten::ScaleKernel,
float,
double,
paddle::platform::bfloat16,
uint8_t,
int8_t,
int16_t,
int,
int64_t) {}
Loading

0 comments on commit 158bf13

Please sign in to comment.