diff --git a/cmake/pten_kernel.cmake b/cmake/pten_kernel.cmake index f962c1332093ab..bc9fefb58f4527 100644 --- a/cmake/pten_kernel.cmake +++ b/cmake/pten_kernel.cmake @@ -16,12 +16,12 @@ function(kernel_declare TARGET_LIST) foreach(kernel_path ${TARGET_LIST}) file(READ ${kernel_path} kernel_impl) - # TODO(chenweihang): rename PT_REGISTER_CTX_KERNEL to PT_REGISTER_KERNEL + # TODO(chenweihang): rename PT_REGISTER_KERNEL to PT_REGISTER_KERNEL # NOTE(chenweihang): now we don't recommend to use digit in kernel name - string(REGEX MATCH "(PT_REGISTER_CTX_KERNEL|PT_REGISTER_GENERAL_KERNEL)\\([ \t\r\n]*[a-z0-9_]*," first_registry "${kernel_impl}") + string(REGEX MATCH "(PT_REGISTER_KERNEL|PT_REGISTER_GENERAL_KERNEL)\\([ \t\r\n]*[a-z0-9_]*," first_registry "${kernel_impl}") if (NOT first_registry STREQUAL "") # parse the first kernel name - string(REPLACE "PT_REGISTER_CTX_KERNEL(" "" kernel_name "${first_registry}") + string(REPLACE "PT_REGISTER_KERNEL(" "" kernel_name "${first_registry}") string(REPLACE "PT_REGISTER_GENERAL_KERNEL(" "" kernel_name "${kernel_name}") string(REPLACE "," "" kernel_name "${kernel_name}") string(REGEX REPLACE "[ \t\r\n]+" "" kernel_name "${kernel_name}") diff --git a/paddle/pten/core/kernel_registry.h b/paddle/pten/core/kernel_registry.h index f08ef4acfd9ce7..194ab52d25688a 100644 --- a/paddle/pten/core/kernel_registry.h +++ b/paddle/pten/core/kernel_registry.h @@ -213,20 +213,20 @@ struct KernelRegistrar { * pointer of the corresponding data type is automatically instantiated * during registration. * - * Note: `1TA` means `1 template argument` + * Note: `2TA` means `2 template argument` */ #define PT_REGISTER_KERNEL( \ kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, ...) \ PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \ pt_register_tp_kernel_ns_check_##kernel_name##_##backend##_##layout, \ "PT_REGISTER_KERNEL must be called in global namespace."); \ - _PT_REGISTER_1TA_KERNEL( \ + _PT_REGISTER_2TA_KERNEL( \ kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, __VA_ARGS__) #ifndef _WIN32 -#define _PT_REGISTER_1TA_KERNEL( \ +#define _PT_REGISTER_2TA_KERNEL( \ kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, ...) \ - PT_KERNEL_INSTANTIATION(meta_kernel_fn, cpp_dtype, __VA_ARGS__); \ + PT_KERNEL_INSTANTIATION(meta_kernel_fn, backend, cpp_dtype, __VA_ARGS__); \ static void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \ ::pten::Kernel*); \ PT_KERNEL_REGISTRAR_INIT( \ @@ -252,7 +252,7 @@ struct KernelRegistrar { * * And msvc can work without template instantiation */ -#define _PT_REGISTER_1TA_KERNEL( \ +#define _PT_REGISTER_2TA_KERNEL( \ kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, ...) \ static void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \ ::pten::Kernel*); \ @@ -268,60 +268,76 @@ struct KernelRegistrar { ::pten::Kernel* kernel) #endif -#define PT_KERNEL_INSTANTIATION(meta_kernel_fn, cpp_dtype, ...) \ - _PT_KERNEL_INSTANTIATION(PT_NARGS(cpp_dtype, __VA_ARGS__), \ - meta_kernel_fn, \ - cpp_dtype, \ +#define PT_KERNEL_INSTANTIATION(meta_kernel_fn, backend, cpp_dtype, ...) \ + _PT_KERNEL_INSTANTIATION(PT_NARGS(cpp_dtype, __VA_ARGS__), \ + meta_kernel_fn, \ + backend, \ + cpp_dtype, \ __VA_ARGS__) -#define _PT_KERNEL_INSTANTIATION(N, meta_kernel_fn, cpp_dtype, ...) \ - PT_CONCATENATE(_PT_KERNEL_INSTANTIATION_, N) \ - (meta_kernel_fn, cpp_dtype, __VA_ARGS__) +#define _PT_KERNEL_INSTANTIATION(N, meta_kernel_fn, backend, cpp_dtype, ...) \ + PT_CONCATENATE(_PT_KERNEL_INSTANTIATION_, N) \ + (meta_kernel_fn, backend, cpp_dtype, __VA_ARGS__) -#define _PT_KERNEL_INSTANTIATION_1(meta_kernel_fn, cpp_dtype, ...) \ - template decltype(meta_kernel_fn) meta_kernel_fn -#define _PT_KERNEL_INSTANTIATION_2(meta_kernel_fn, cpp_dtype, ...) \ - template decltype(meta_kernel_fn) meta_kernel_fn; \ - PT_EXPAND(_PT_KERNEL_INSTANTIATION_1(meta_kernel_fn, __VA_ARGS__)) -#define _PT_KERNEL_INSTANTIATION_3(meta_kernel_fn, cpp_dtype, ...) \ - template decltype(meta_kernel_fn) meta_kernel_fn; \ - PT_EXPAND(_PT_KERNEL_INSTANTIATION_2(meta_kernel_fn, __VA_ARGS__)) -#define _PT_KERNEL_INSTANTIATION_4(meta_kernel_fn, cpp_dtype, ...) \ - template decltype(meta_kernel_fn) meta_kernel_fn; \ - PT_EXPAND(_PT_KERNEL_INSTANTIATION_3(meta_kernel_fn, __VA_ARGS__)) -#define _PT_KERNEL_INSTANTIATION_5(meta_kernel_fn, cpp_dtype, ...) \ - template decltype(meta_kernel_fn) meta_kernel_fn; \ - PT_EXPAND(_PT_KERNEL_INSTANTIATION_4(meta_kernel_fn, __VA_ARGS__)) -#define _PT_KERNEL_INSTANTIATION_6(meta_kernel_fn, cpp_dtype, ...) \ - template decltype(meta_kernel_fn) meta_kernel_fn; \ - PT_EXPAND(_PT_KERNEL_INSTANTIATION_5(meta_kernel_fn, __VA_ARGS__)) -#define _PT_KERNEL_INSTANTIATION_7(meta_kernel_fn, cpp_dtype, ...) \ - template decltype(meta_kernel_fn) meta_kernel_fn; \ - PT_EXPAND(_PT_KERNEL_INSTANTIATION_6(meta_kernel_fn, __VA_ARGS__)) -#define _PT_KERNEL_INSTANTIATION_8(meta_kernel_fn, cpp_dtype, ...) \ - template decltype(meta_kernel_fn) meta_kernel_fn; \ - PT_EXPAND(_PT_KERNEL_INSTANTIATION_7(meta_kernel_fn, __VA_ARGS__)) -#define _PT_KERNEL_INSTANTIATION_9(meta_kernel_fn, cpp_dtype, ...) \ - template decltype(meta_kernel_fn) meta_kernel_fn; \ - PT_EXPAND(_PT_KERNEL_INSTANTIATION_8(meta_kernel_fn, __VA_ARGS__)) -#define _PT_KERNEL_INSTANTIATION_10(meta_kernel_fn, cpp_dtype, ...) \ - template decltype(meta_kernel_fn) meta_kernel_fn; \ - PT_EXPAND(_PT_KERNEL_INSTANTIATION_9(meta_kernel_fn, __VA_ARGS__)) -#define _PT_KERNEL_INSTANTIATION_11(meta_kernel_fn, cpp_dtype, ...) \ - template decltype(meta_kernel_fn) meta_kernel_fn; \ - PT_EXPAND(_PT_KERNEL_INSTANTIATION_10(meta_kernel_fn, __VA_ARGS__)) -#define _PT_KERNEL_INSTANTIATION_12(meta_kernel_fn, cpp_dtype, ...) \ - template decltype(meta_kernel_fn) meta_kernel_fn; \ - PT_EXPAND(_PT_KERNEL_INSTANTIATION_11(meta_kernel_fn, __VA_ARGS__)) -#define _PT_KERNEL_INSTANTIATION_13(meta_kernel_fn, cpp_dtype, ...) \ - template decltype(meta_kernel_fn) meta_kernel_fn; \ - PT_EXPAND(_PT_KERNEL_INSTANTIATION_12(meta_kernel_fn, __VA_ARGS__)) -#define _PT_KERNEL_INSTANTIATION_14(meta_kernel_fn, cpp_dtype, ...) \ - template decltype(meta_kernel_fn) meta_kernel_fn; \ - PT_EXPAND(_PT_KERNEL_INSTANTIATION_13(meta_kernel_fn, __VA_ARGS__)) -#define _PT_KERNEL_INSTANTIATION_15(meta_kernel_fn, cpp_dtype, ...) \ - template decltype(meta_kernel_fn) meta_kernel_fn; \ - PT_EXPAND(_PT_KERNEL_INSTANTIATION_14(meta_kernel_fn, __VA_ARGS__)) +#define _PT_KERNEL_INSTANTIATION_1(meta_kernel_fn, backend, cpp_dtype, ...) \ + template decltype(meta_kernel_fn) \ + meta_kernel_fn +#define _PT_KERNEL_INSTANTIATION_2(meta_kernel_fn, backend, cpp_dtype, ...) \ + template decltype(meta_kernel_fn) \ + meta_kernel_fn; \ + PT_EXPAND(_PT_KERNEL_INSTANTIATION_1(meta_kernel_fn, backend, __VA_ARGS__)) +#define _PT_KERNEL_INSTANTIATION_3(meta_kernel_fn, backend, cpp_dtype, ...) \ + template decltype(meta_kernel_fn) \ + meta_kernel_fn; \ + PT_EXPAND(_PT_KERNEL_INSTANTIATION_2(meta_kernel_fn, backend, __VA_ARGS__)) +#define _PT_KERNEL_INSTANTIATION_4(meta_kernel_fn, backend, cpp_dtype, ...) \ + template decltype(meta_kernel_fn) \ + meta_kernel_fn; \ + PT_EXPAND(_PT_KERNEL_INSTANTIATION_3(meta_kernel_fn, backend, __VA_ARGS__)) +#define _PT_KERNEL_INSTANTIATION_5(meta_kernel_fn, backend, cpp_dtype, ...) \ + template decltype(meta_kernel_fn) \ + meta_kernel_fn; \ + PT_EXPAND(_PT_KERNEL_INSTANTIATION_4(meta_kernel_fn, backend, __VA_ARGS__)) +#define _PT_KERNEL_INSTANTIATION_6(meta_kernel_fn, backend, cpp_dtype, ...) \ + template decltype(meta_kernel_fn) \ + meta_kernel_fn; \ + PT_EXPAND(_PT_KERNEL_INSTANTIATION_5(meta_kernel_fn, backend, __VA_ARGS__)) +#define _PT_KERNEL_INSTANTIATION_7(meta_kernel_fn, backend, cpp_dtype, ...) \ + template decltype(meta_kernel_fn) \ + meta_kernel_fn; \ + PT_EXPAND(_PT_KERNEL_INSTANTIATION_6(meta_kernel_fn, backend, __VA_ARGS__)) +#define _PT_KERNEL_INSTANTIATION_8(meta_kernel_fn, backend, cpp_dtype, ...) \ + template decltype(meta_kernel_fn) \ + meta_kernel_fn; \ + PT_EXPAND(_PT_KERNEL_INSTANTIATION_7(meta_kernel_fn, backend, __VA_ARGS__)) +#define _PT_KERNEL_INSTANTIATION_9(meta_kernel_fn, backend, cpp_dtype, ...) \ + template decltype(meta_kernel_fn) \ + meta_kernel_fn; \ + PT_EXPAND(_PT_KERNEL_INSTANTIATION_8(meta_kernel_fn, backend, __VA_ARGS__)) +#define _PT_KERNEL_INSTANTIATION_10(meta_kernel_fn, backend, cpp_dtype, ...) \ + template decltype(meta_kernel_fn) \ + meta_kernel_fn; \ + PT_EXPAND(_PT_KERNEL_INSTANTIATION_9(meta_kernel_fn, backend, __VA_ARGS__)) +#define _PT_KERNEL_INSTANTIATION_11(meta_kernel_fn, backend, cpp_dtype, ...) \ + template decltype(meta_kernel_fn) \ + meta_kernel_fn; \ + PT_EXPAND(_PT_KERNEL_INSTANTIATION_10(meta_kernel_fn, backend, __VA_ARGS__)) +#define _PT_KERNEL_INSTANTIATION_12(meta_kernel_fn, backend, cpp_dtype, ...) \ + template decltype(meta_kernel_fn) \ + meta_kernel_fn; \ + PT_EXPAND(_PT_KERNEL_INSTANTIATION_11(meta_kernel_fn, backend, __VA_ARGS__)) +#define _PT_KERNEL_INSTANTIATION_13(meta_kernel_fn, backend, cpp_dtype, ...) \ + template decltype(meta_kernel_fn) \ + meta_kernel_fn; \ + PT_EXPAND(_PT_KERNEL_INSTANTIATION_12(meta_kernel_fn, backend, __VA_ARGS__)) +#define _PT_KERNEL_INSTANTIATION_14(meta_kernel_fn, backend, cpp_dtype, ...) \ + template decltype(meta_kernel_fn) \ + meta_kernel_fn; \ + PT_EXPAND(_PT_KERNEL_INSTANTIATION_13(meta_kernel_fn, backend, __VA_ARGS__)) +#define _PT_KERNEL_INSTANTIATION_15(meta_kernel_fn, backend, cpp_dtype, ...) \ + template decltype(meta_kernel_fn) \ + meta_kernel_fn; \ + PT_EXPAND(_PT_KERNEL_INSTANTIATION_14(meta_kernel_fn, backend, __VA_ARGS__)) #define PT_KERNEL_REGISTRAR_INIT( \ kernel_name, backend, layout, args_def_fn, meta_kernel_fn, cpp_dtype, ...) \ @@ -373,10 +389,11 @@ struct KernelRegistrar { DATALAYOUT(layout), \ ::paddle::experimental::CppTypeToDataType::Type(), \ ::pten::KernelArgsParseFunctor)>::Parse, \ + &meta_kernel_fn)>::Parse, \ args_def_fn, \ - PT_KERNEL(meta_kernel_fn), \ - PT_VARIADIC_KERNEL(meta_kernel_fn)); \ + PT_KERNEL(meta_kernel_fn), \ + PT_VARIADIC_KERNEL( \ + meta_kernel_fn)); \ int TouchKernelSymbolFor_##kernel_name##_##backend##_##layout() { return 0; } #define _PT_KERNEL_REGISTRAR_INIT_2(kernel_name, \ backend, \ @@ -393,10 +410,11 @@ struct KernelRegistrar { DATALAYOUT(layout), \ ::paddle::experimental::CppTypeToDataType::Type(), \ ::pten::KernelArgsParseFunctor)>::Parse, \ + &meta_kernel_fn)>::Parse, \ args_def_fn, \ - PT_KERNEL(meta_kernel_fn), \ - PT_VARIADIC_KERNEL(meta_kernel_fn)); \ + PT_KERNEL(meta_kernel_fn), \ + PT_VARIADIC_KERNEL( \ + meta_kernel_fn)); \ PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_1(kernel_name, \ backend, \ layout, \ @@ -419,10 +437,11 @@ struct KernelRegistrar { DATALAYOUT(layout), \ ::paddle::experimental::CppTypeToDataType::Type(), \ ::pten::KernelArgsParseFunctor)>::Parse, \ + &meta_kernel_fn)>::Parse, \ args_def_fn, \ - PT_KERNEL(meta_kernel_fn), \ - PT_VARIADIC_KERNEL(meta_kernel_fn)); \ + PT_KERNEL(meta_kernel_fn), \ + PT_VARIADIC_KERNEL( \ + meta_kernel_fn)); \ PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_2(kernel_name, \ backend, \ layout, \ @@ -445,10 +464,11 @@ struct KernelRegistrar { DATALAYOUT(layout), \ ::paddle::experimental::CppTypeToDataType::Type(), \ ::pten::KernelArgsParseFunctor)>::Parse, \ + &meta_kernel_fn)>::Parse, \ args_def_fn, \ - PT_KERNEL(meta_kernel_fn), \ - PT_VARIADIC_KERNEL(meta_kernel_fn)); \ + PT_KERNEL(meta_kernel_fn), \ + PT_VARIADIC_KERNEL( \ + meta_kernel_fn)); \ PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_3(kernel_name, \ backend, \ layout, \ @@ -471,10 +491,11 @@ struct KernelRegistrar { DATALAYOUT(layout), \ ::paddle::experimental::CppTypeToDataType::Type(), \ ::pten::KernelArgsParseFunctor)>::Parse, \ + &meta_kernel_fn)>::Parse, \ args_def_fn, \ - PT_KERNEL(meta_kernel_fn), \ - PT_VARIADIC_KERNEL(meta_kernel_fn)); \ + PT_KERNEL(meta_kernel_fn), \ + PT_VARIADIC_KERNEL( \ + meta_kernel_fn)); \ PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_4(kernel_name, \ backend, \ layout, \ @@ -497,10 +518,11 @@ struct KernelRegistrar { DATALAYOUT(layout), \ ::paddle::experimental::CppTypeToDataType::Type(), \ ::pten::KernelArgsParseFunctor)>::Parse, \ + &meta_kernel_fn)>::Parse, \ args_def_fn, \ - PT_KERNEL(meta_kernel_fn), \ - PT_VARIADIC_KERNEL(meta_kernel_fn)); \ + PT_KERNEL(meta_kernel_fn), \ + PT_VARIADIC_KERNEL( \ + meta_kernel_fn)); \ PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_5(kernel_name, \ backend, \ layout, \ @@ -523,10 +545,11 @@ struct KernelRegistrar { DATALAYOUT(layout), \ ::paddle::experimental::CppTypeToDataType::Type(), \ ::pten::KernelArgsParseFunctor)>::Parse, \ + &meta_kernel_fn)>::Parse, \ args_def_fn, \ - PT_KERNEL(meta_kernel_fn), \ - PT_VARIADIC_KERNEL(meta_kernel_fn)); \ + PT_KERNEL(meta_kernel_fn), \ + PT_VARIADIC_KERNEL( \ + meta_kernel_fn)); \ PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_6(kernel_name, \ backend, \ layout, \ @@ -549,10 +572,11 @@ struct KernelRegistrar { DATALAYOUT(layout), \ ::paddle::experimental::CppTypeToDataType::Type(), \ ::pten::KernelArgsParseFunctor)>::Parse, \ + &meta_kernel_fn)>::Parse, \ args_def_fn, \ - PT_KERNEL(meta_kernel_fn), \ - PT_VARIADIC_KERNEL(meta_kernel_fn)); \ + PT_KERNEL(meta_kernel_fn), \ + PT_VARIADIC_KERNEL( \ + meta_kernel_fn)); \ PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_7(kernel_name, \ backend, \ layout, \ @@ -575,10 +599,11 @@ struct KernelRegistrar { DATALAYOUT(layout), \ ::paddle::experimental::CppTypeToDataType::Type(), \ ::pten::KernelArgsParseFunctor)>::Parse, \ + &meta_kernel_fn)>::Parse, \ args_def_fn, \ - PT_KERNEL(meta_kernel_fn), \ - PT_VARIADIC_KERNEL(meta_kernel_fn)); \ + PT_KERNEL(meta_kernel_fn), \ + PT_VARIADIC_KERNEL( \ + meta_kernel_fn)); \ PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_8(kernel_name, \ backend, \ layout, \ @@ -601,10 +626,11 @@ struct KernelRegistrar { DATALAYOUT(layout), \ ::paddle::experimental::CppTypeToDataType::Type(), \ ::pten::KernelArgsParseFunctor)>::Parse, \ + &meta_kernel_fn)>::Parse, \ args_def_fn, \ - PT_KERNEL(meta_kernel_fn), \ - PT_VARIADIC_KERNEL(meta_kernel_fn)); \ + PT_KERNEL(meta_kernel_fn), \ + PT_VARIADIC_KERNEL( \ + meta_kernel_fn)); \ PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_9(kernel_name, \ backend, \ layout, \ @@ -627,10 +653,11 @@ struct KernelRegistrar { DATALAYOUT(layout), \ ::paddle::experimental::CppTypeToDataType::Type(), \ ::pten::KernelArgsParseFunctor)>::Parse, \ + &meta_kernel_fn)>::Parse, \ args_def_fn, \ - PT_KERNEL(meta_kernel_fn), \ - PT_VARIADIC_KERNEL(meta_kernel_fn)); \ + PT_KERNEL(meta_kernel_fn), \ + PT_VARIADIC_KERNEL( \ + meta_kernel_fn)); \ PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_10(kernel_name, \ backend, \ layout, \ @@ -653,10 +680,11 @@ struct KernelRegistrar { DATALAYOUT(layout), \ ::paddle::experimental::CppTypeToDataType::Type(), \ ::pten::KernelArgsParseFunctor)>::Parse, \ + &meta_kernel_fn)>::Parse, \ args_def_fn, \ - PT_KERNEL(meta_kernel_fn), \ - PT_VARIADIC_KERNEL(meta_kernel_fn)); \ + PT_KERNEL(meta_kernel_fn), \ + PT_VARIADIC_KERNEL( \ + meta_kernel_fn)); \ PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_11(kernel_name, \ backend, \ layout, \ @@ -679,10 +707,11 @@ struct KernelRegistrar { DATALAYOUT(layout), \ ::paddle::experimental::CppTypeToDataType::Type(), \ ::pten::KernelArgsParseFunctor)>::Parse, \ + &meta_kernel_fn)>::Parse, \ args_def_fn, \ - PT_KERNEL(meta_kernel_fn), \ - PT_VARIADIC_KERNEL(meta_kernel_fn)); \ + PT_KERNEL(meta_kernel_fn), \ + PT_VARIADIC_KERNEL( \ + meta_kernel_fn)); \ PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_12(kernel_name, \ backend, \ layout, \ @@ -705,10 +734,11 @@ struct KernelRegistrar { DATALAYOUT(layout), \ ::paddle::experimental::CppTypeToDataType::Type(), \ ::pten::KernelArgsParseFunctor)>::Parse, \ + &meta_kernel_fn)>::Parse, \ args_def_fn, \ - PT_KERNEL(meta_kernel_fn), \ - PT_VARIADIC_KERNEL(meta_kernel_fn)); \ + PT_KERNEL(meta_kernel_fn), \ + PT_VARIADIC_KERNEL( \ + meta_kernel_fn)); \ PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_13(kernel_name, \ backend, \ layout, \ @@ -731,10 +761,11 @@ struct KernelRegistrar { DATALAYOUT(layout), \ ::paddle::experimental::CppTypeToDataType::Type(), \ ::pten::KernelArgsParseFunctor)>::Parse, \ + &meta_kernel_fn)>::Parse, \ args_def_fn, \ - PT_KERNEL(meta_kernel_fn), \ - PT_VARIADIC_KERNEL(meta_kernel_fn)); \ + PT_KERNEL(meta_kernel_fn), \ + PT_VARIADIC_KERNEL( \ + meta_kernel_fn)); \ PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_14(kernel_name, \ backend, \ layout, \ @@ -743,41 +774,6 @@ struct KernelRegistrar { meta_kernel_fn, \ __VA_ARGS__)) -/** PT_REGISTER_NO_TEMPLATE_KERNEL - * - * Basic Kernel register marco, used to register a no template argument kernel - * function, pass in the complete function pointe of the kernel, this - * registration macro will not do automatic template instantiation. - * - * Note: developer maybe register 2 kernel with same name, backend and diff - * layout, so the layout also need to be a part of symbol var name. If developer - * register 2 kernel with same name, backend, layout and diff dtype, he should - * use another register marco PT_REGISTER_KERNEL. - * - * TODO(chenweihang): remove this marco later - */ -#define PT_REGISTER_NO_TEMPLATE_KERNEL( \ - kernel_name, backend, layout, kernel_fn, dtype) \ - PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \ - pt_register_no_t_kernel_ns_check_##kernel_name##_##backend##_##layout, \ - "PT_REGISTER_NO_TEMPLATE_KERNEL must be called in global namespace."); \ - static void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \ - ::pten::Kernel*); \ - static const ::pten::KernelRegistrar \ - __reg_pt_kernel_##kernel_name##_##backend##_##layout( \ - #kernel_name, \ - BACKEND(backend), \ - DATALAYOUT(layout), \ - ::pten::KernelArgsParseFunctor::Parse, \ - &__PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout, \ - PT_KERNEL(kernel_fn), \ - PT_VARIADIC_KERNEL(kernel_fn)); \ - int TouchKernelSymbolFor_##kernel_name##_##backend##_##layout() { \ - return 0; \ - } \ - void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \ - ::pten::Kernel* kernel) - /** PT_REGISTER_GENERAL_KERNEL * * Basic Kernel register marco, used to register a instantiated kernel function @@ -832,558 +828,6 @@ struct KernelRegistrar { ::pten::Kernel* kernel) #endif -/** PT_REGISTER_CTX_KERNEL - * - * Used for kernel registration with device context and data type as - * template parameter. - */ -#define PT_REGISTER_CTX_KERNEL( \ - kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, ...) \ - PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \ - pt_register_tp_ctx_kernel_ns_check_##kernel_name##_##backend##_##layout, \ - "PT_REGISTER_CTX_KERNEL must be called in global namespace."); \ - _PT_REGISTER_2TA_KERNEL( \ - kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, __VA_ARGS__) - -#ifndef _WIN32 -#define _PT_REGISTER_2TA_KERNEL( \ - kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, ...) \ - PT_KERNEL_INSTANTIATION2(meta_kernel_fn, backend, cpp_dtype, __VA_ARGS__); \ - static void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \ - ::pten::Kernel*); \ - PT_KERNEL_REGISTRAR_INIT2( \ - kernel_name, \ - backend, \ - layout, \ - &__PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout, \ - meta_kernel_fn, \ - cpp_dtype, \ - __VA_ARGS__); \ - void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \ - ::pten::Kernel* kernel) -#else -#define _PT_REGISTER_2TA_KERNEL( \ - kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, ...) \ - static void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \ - ::pten::Kernel*); \ - PT_KERNEL_REGISTRAR_INIT2( \ - kernel_name, \ - backend, \ - layout, \ - &__PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout, \ - meta_kernel_fn, \ - cpp_dtype, \ - __VA_ARGS__); \ - void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \ - ::pten::Kernel* kernel) -#endif - -#define PT_KERNEL_INSTANTIATION2(meta_kernel_fn, backend, cpp_dtype, ...) \ - _PT_KERNEL_INSTANTIATION2(PT_NARGS(cpp_dtype, __VA_ARGS__), \ - meta_kernel_fn, \ - backend, \ - cpp_dtype, \ - __VA_ARGS__) - -#define _PT_KERNEL_INSTANTIATION2(N, meta_kernel_fn, backend, cpp_dtype, ...) \ - PT_CONCATENATE(_PT_KERNEL_INSTANTIATION2_, N) \ - (meta_kernel_fn, backend, cpp_dtype, __VA_ARGS__) - -#define _PT_KERNEL_INSTANTIATION2_1(meta_kernel_fn, backend, cpp_dtype, ...) \ - template decltype(meta_kernel_fn) \ - meta_kernel_fn -#define _PT_KERNEL_INSTANTIATION2_2(meta_kernel_fn, backend, cpp_dtype, ...) \ - template decltype(meta_kernel_fn) \ - meta_kernel_fn; \ - PT_EXPAND(_PT_KERNEL_INSTANTIATION2_1(meta_kernel_fn, backend, __VA_ARGS__)) -#define _PT_KERNEL_INSTANTIATION2_3(meta_kernel_fn, backend, cpp_dtype, ...) \ - template decltype(meta_kernel_fn) \ - meta_kernel_fn; \ - PT_EXPAND(_PT_KERNEL_INSTANTIATION2_2(meta_kernel_fn, backend, __VA_ARGS__)) -#define _PT_KERNEL_INSTANTIATION2_4(meta_kernel_fn, backend, cpp_dtype, ...) \ - template decltype(meta_kernel_fn) \ - meta_kernel_fn; \ - PT_EXPAND(_PT_KERNEL_INSTANTIATION2_3(meta_kernel_fn, backend, __VA_ARGS__)) -#define _PT_KERNEL_INSTANTIATION2_5(meta_kernel_fn, backend, cpp_dtype, ...) \ - template decltype(meta_kernel_fn) \ - meta_kernel_fn; \ - PT_EXPAND(_PT_KERNEL_INSTANTIATION2_4(meta_kernel_fn, backend, __VA_ARGS__)) -#define _PT_KERNEL_INSTANTIATION2_6(meta_kernel_fn, backend, cpp_dtype, ...) \ - template decltype(meta_kernel_fn) \ - meta_kernel_fn; \ - PT_EXPAND(_PT_KERNEL_INSTANTIATION2_5(meta_kernel_fn, backend, __VA_ARGS__)) -#define _PT_KERNEL_INSTANTIATION2_7(meta_kernel_fn, backend, cpp_dtype, ...) \ - template decltype(meta_kernel_fn) \ - meta_kernel_fn; \ - PT_EXPAND(_PT_KERNEL_INSTANTIATION2_6(meta_kernel_fn, backend, __VA_ARGS__)) -#define _PT_KERNEL_INSTANTIATION2_8(meta_kernel_fn, backend, cpp_dtype, ...) \ - template decltype(meta_kernel_fn) \ - meta_kernel_fn; \ - PT_EXPAND(_PT_KERNEL_INSTANTIATION2_7(meta_kernel_fn, backend, __VA_ARGS__)) -#define _PT_KERNEL_INSTANTIATION2_9(meta_kernel_fn, backend, cpp_dtype, ...) \ - template decltype(meta_kernel_fn) \ - meta_kernel_fn; \ - PT_EXPAND(_PT_KERNEL_INSTANTIATION2_8(meta_kernel_fn, backend, __VA_ARGS__)) -#define _PT_KERNEL_INSTANTIATION2_10(meta_kernel_fn, backend, cpp_dtype, ...) \ - template decltype(meta_kernel_fn) \ - meta_kernel_fn; \ - PT_EXPAND(_PT_KERNEL_INSTANTIATION2_9(meta_kernel_fn, backend, __VA_ARGS__)) -#define _PT_KERNEL_INSTANTIATION2_11(meta_kernel_fn, backend, cpp_dtype, ...) \ - template decltype(meta_kernel_fn) \ - meta_kernel_fn; \ - PT_EXPAND(_PT_KERNEL_INSTANTIATION2_10(meta_kernel_fn, backend, __VA_ARGS__)) -#define _PT_KERNEL_INSTANTIATION2_12(meta_kernel_fn, backend, cpp_dtype, ...) \ - template decltype(meta_kernel_fn) \ - meta_kernel_fn; \ - PT_EXPAND(_PT_KERNEL_INSTANTIATION2_11(meta_kernel_fn, backend, __VA_ARGS__)) -#define _PT_KERNEL_INSTANTIATION2_13(meta_kernel_fn, backend, cpp_dtype, ...) \ - template decltype(meta_kernel_fn) \ - meta_kernel_fn; \ - PT_EXPAND(_PT_KERNEL_INSTANTIATION2_12(meta_kernel_fn, backend, __VA_ARGS__)) -#define _PT_KERNEL_INSTANTIATION2_14(meta_kernel_fn, backend, cpp_dtype, ...) \ - template decltype(meta_kernel_fn) \ - meta_kernel_fn; \ - PT_EXPAND(_PT_KERNEL_INSTANTIATION2_13(meta_kernel_fn, backend, __VA_ARGS__)) -#define _PT_KERNEL_INSTANTIATION2_15(meta_kernel_fn, backend, cpp_dtype, ...) \ - template decltype(meta_kernel_fn) \ - meta_kernel_fn; \ - PT_EXPAND(_PT_KERNEL_INSTANTIATION2_14(meta_kernel_fn, backend, __VA_ARGS__)) - -#define PT_KERNEL_REGISTRAR_INIT2( \ - kernel_name, backend, layout, args_def_fn, meta_kernel_fn, cpp_dtype, ...) \ - _PT_KERNEL_REGISTRAR_INIT2(PT_NARGS(cpp_dtype, __VA_ARGS__), \ - kernel_name, \ - backend, \ - layout, \ - args_def_fn, \ - meta_kernel_fn, \ - cpp_dtype, \ - __VA_ARGS__) - -// clang-format off - -/* The =pre-commit always treats this macro into the wrong format, - and multi-line macros cannot be skipped with NOLINT.*/ -#define _PT_KERNEL_REGISTRAR_INIT2(N, \ - kernel_name, \ - backend, \ - layout, \ - args_def_fn, \ - meta_kernel_fn, \ - cpp_dtype, \ - ...) \ - PT_CONCATENATE(_PT_KERNEL_REGISTRAR_INIT2_, N) ( \ - kernel_name, \ - backend, \ - layout, \ - PT_ID, \ - args_def_fn, \ - meta_kernel_fn, \ - cpp_dtype, \ - __VA_ARGS__) - -// clang-format on - -#define _PT_KERNEL_REGISTRAR_INIT2_1(kernel_name, \ - backend, \ - layout, \ - registrar_id, \ - args_def_fn, \ - meta_kernel_fn, \ - cpp_dtype, \ - ...) \ - static const ::pten::KernelRegistrar PT_CONCATENATE( \ - __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ - #kernel_name, \ - BACKEND(backend), \ - DATALAYOUT(layout), \ - ::paddle::experimental::CppTypeToDataType::Type(), \ - ::pten::KernelArgsParseFunctor)>::Parse, \ - args_def_fn, \ - PT_KERNEL(meta_kernel_fn), \ - PT_VARIADIC_KERNEL( \ - meta_kernel_fn)); \ - int TouchKernelSymbolFor_##kernel_name##_##backend##_##layout() { return 0; } -#define _PT_KERNEL_REGISTRAR_INIT2_2(kernel_name, \ - backend, \ - layout, \ - registrar_id, \ - args_def_fn, \ - meta_kernel_fn, \ - cpp_dtype, \ - ...) \ - static const ::pten::KernelRegistrar PT_CONCATENATE( \ - __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ - #kernel_name, \ - BACKEND(backend), \ - DATALAYOUT(layout), \ - ::paddle::experimental::CppTypeToDataType::Type(), \ - ::pten::KernelArgsParseFunctor)>::Parse, \ - args_def_fn, \ - PT_KERNEL(meta_kernel_fn), \ - PT_VARIADIC_KERNEL( \ - meta_kernel_fn)); \ - PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT2_1(kernel_name, \ - backend, \ - layout, \ - PT_ID, \ - args_def_fn, \ - meta_kernel_fn, \ - __VA_ARGS__)) -#define _PT_KERNEL_REGISTRAR_INIT2_3(kernel_name, \ - backend, \ - layout, \ - registrar_id, \ - args_def_fn, \ - meta_kernel_fn, \ - cpp_dtype, \ - ...) \ - static const ::pten::KernelRegistrar PT_CONCATENATE( \ - __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ - #kernel_name, \ - BACKEND(backend), \ - DATALAYOUT(layout), \ - ::paddle::experimental::CppTypeToDataType::Type(), \ - ::pten::KernelArgsParseFunctor)>::Parse, \ - args_def_fn, \ - PT_KERNEL(meta_kernel_fn), \ - PT_VARIADIC_KERNEL( \ - meta_kernel_fn)); \ - PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT2_2(kernel_name, \ - backend, \ - layout, \ - PT_ID, \ - args_def_fn, \ - meta_kernel_fn, \ - __VA_ARGS__)) -#define _PT_KERNEL_REGISTRAR_INIT2_4(kernel_name, \ - backend, \ - layout, \ - registrar_id, \ - args_def_fn, \ - meta_kernel_fn, \ - cpp_dtype, \ - ...) \ - static const ::pten::KernelRegistrar PT_CONCATENATE( \ - __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ - #kernel_name, \ - BACKEND(backend), \ - DATALAYOUT(layout), \ - ::paddle::experimental::CppTypeToDataType::Type(), \ - ::pten::KernelArgsParseFunctor)>::Parse, \ - args_def_fn, \ - PT_KERNEL(meta_kernel_fn), \ - PT_VARIADIC_KERNEL( \ - meta_kernel_fn)); \ - PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT2_3(kernel_name, \ - backend, \ - layout, \ - PT_ID, \ - args_def_fn, \ - meta_kernel_fn, \ - __VA_ARGS__)) -#define _PT_KERNEL_REGISTRAR_INIT2_5(kernel_name, \ - backend, \ - layout, \ - registrar_id, \ - args_def_fn, \ - meta_kernel_fn, \ - cpp_dtype, \ - ...) \ - static const ::pten::KernelRegistrar PT_CONCATENATE( \ - __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ - #kernel_name, \ - BACKEND(backend), \ - DATALAYOUT(layout), \ - ::paddle::experimental::CppTypeToDataType::Type(), \ - ::pten::KernelArgsParseFunctor)>::Parse, \ - args_def_fn, \ - PT_KERNEL(meta_kernel_fn), \ - PT_VARIADIC_KERNEL( \ - meta_kernel_fn)); \ - PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT2_4(kernel_name, \ - backend, \ - layout, \ - PT_ID, \ - args_def_fn, \ - meta_kernel_fn, \ - __VA_ARGS__)) -#define _PT_KERNEL_REGISTRAR_INIT2_6(kernel_name, \ - backend, \ - layout, \ - registrar_id, \ - args_def_fn, \ - meta_kernel_fn, \ - cpp_dtype, \ - ...) \ - static const ::pten::KernelRegistrar PT_CONCATENATE( \ - __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ - #kernel_name, \ - BACKEND(backend), \ - DATALAYOUT(layout), \ - ::paddle::experimental::CppTypeToDataType::Type(), \ - ::pten::KernelArgsParseFunctor)>::Parse, \ - args_def_fn, \ - PT_KERNEL(meta_kernel_fn), \ - PT_VARIADIC_KERNEL( \ - meta_kernel_fn)); \ - PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT2_5(kernel_name, \ - backend, \ - layout, \ - PT_ID, \ - args_def_fn, \ - meta_kernel_fn, \ - __VA_ARGS__)) -#define _PT_KERNEL_REGISTRAR_INIT2_7(kernel_name, \ - backend, \ - layout, \ - registrar_id, \ - args_def_fn, \ - meta_kernel_fn, \ - cpp_dtype, \ - ...) \ - static const ::pten::KernelRegistrar PT_CONCATENATE( \ - __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ - #kernel_name, \ - BACKEND(backend), \ - DATALAYOUT(layout), \ - ::paddle::experimental::CppTypeToDataType::Type(), \ - ::pten::KernelArgsParseFunctor)>::Parse, \ - args_def_fn, \ - PT_KERNEL(meta_kernel_fn), \ - PT_VARIADIC_KERNEL( \ - meta_kernel_fn)); \ - PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT2_6(kernel_name, \ - backend, \ - layout, \ - PT_ID, \ - args_def_fn, \ - meta_kernel_fn, \ - __VA_ARGS__)) -#define _PT_KERNEL_REGISTRAR_INIT2_8(kernel_name, \ - backend, \ - layout, \ - registrar_id, \ - args_def_fn, \ - meta_kernel_fn, \ - cpp_dtype, \ - ...) \ - static const ::pten::KernelRegistrar PT_CONCATENATE( \ - __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ - #kernel_name, \ - BACKEND(backend), \ - DATALAYOUT(layout), \ - ::paddle::experimental::CppTypeToDataType::Type(), \ - ::pten::KernelArgsParseFunctor)>::Parse, \ - args_def_fn, \ - PT_KERNEL(meta_kernel_fn), \ - PT_VARIADIC_KERNEL( \ - meta_kernel_fn)); \ - PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT2_7(kernel_name, \ - backend, \ - layout, \ - PT_ID, \ - args_def_fn, \ - meta_kernel_fn, \ - __VA_ARGS__)) -#define _PT_KERNEL_REGISTRAR_INIT2_9(kernel_name, \ - backend, \ - layout, \ - registrar_id, \ - args_def_fn, \ - meta_kernel_fn, \ - cpp_dtype, \ - ...) \ - static const ::pten::KernelRegistrar PT_CONCATENATE( \ - __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ - #kernel_name, \ - BACKEND(backend), \ - DATALAYOUT(layout), \ - ::paddle::experimental::CppTypeToDataType::Type(), \ - ::pten::KernelArgsParseFunctor)>::Parse, \ - args_def_fn, \ - PT_KERNEL(meta_kernel_fn), \ - PT_VARIADIC_KERNEL( \ - meta_kernel_fn)); \ - PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT2_8(kernel_name, \ - backend, \ - layout, \ - PT_ID, \ - args_def_fn, \ - meta_kernel_fn, \ - __VA_ARGS__)) -#define _PT_KERNEL_REGISTRAR_INIT2_10(kernel_name, \ - backend, \ - layout, \ - registrar_id, \ - args_def_fn, \ - meta_kernel_fn, \ - cpp_dtype, \ - ...) \ - static const ::pten::KernelRegistrar PT_CONCATENATE( \ - __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ - #kernel_name, \ - BACKEND(backend), \ - DATALAYOUT(layout), \ - ::paddle::experimental::CppTypeToDataType::Type(), \ - ::pten::KernelArgsParseFunctor)>::Parse, \ - args_def_fn, \ - PT_KERNEL(meta_kernel_fn), \ - PT_VARIADIC_KERNEL( \ - meta_kernel_fn)); \ - PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT2_9(kernel_name, \ - backend, \ - layout, \ - PT_ID, \ - args_def_fn, \ - meta_kernel_fn, \ - __VA_ARGS__)) -#define _PT_KERNEL_REGISTRAR_INIT2_11(kernel_name, \ - backend, \ - layout, \ - registrar_id, \ - args_def_fn, \ - meta_kernel_fn, \ - cpp_dtype, \ - ...) \ - static const ::pten::KernelRegistrar PT_CONCATENATE( \ - __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ - #kernel_name, \ - BACKEND(backend), \ - DATALAYOUT(layout), \ - ::paddle::experimental::CppTypeToDataType::Type(), \ - ::pten::KernelArgsParseFunctor)>::Parse, \ - args_def_fn, \ - PT_KERNEL(meta_kernel_fn), \ - PT_VARIADIC_KERNEL( \ - meta_kernel_fn)); \ - PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT2_10(kernel_name, \ - backend, \ - layout, \ - PT_ID, \ - args_def_fn, \ - meta_kernel_fn, \ - __VA_ARGS__)) -#define _PT_KERNEL_REGISTRAR_INIT2_12(kernel_name, \ - backend, \ - layout, \ - registrar_id, \ - args_def_fn, \ - meta_kernel_fn, \ - cpp_dtype, \ - ...) \ - static const ::pten::KernelRegistrar PT_CONCATENATE( \ - __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ - #kernel_name, \ - BACKEND(backend), \ - DATALAYOUT(layout), \ - ::paddle::experimental::CppTypeToDataType::Type(), \ - ::pten::KernelArgsParseFunctor)>::Parse, \ - args_def_fn, \ - PT_KERNEL(meta_kernel_fn), \ - PT_VARIADIC_KERNEL( \ - meta_kernel_fn)); \ - PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT2_11(kernel_name, \ - backend, \ - layout, \ - PT_ID, \ - args_def_fn, \ - meta_kernel_fn, \ - __VA_ARGS__)) -#define _PT_KERNEL_REGISTRAR_INIT2_13(kernel_name, \ - backend, \ - layout, \ - registrar_id, \ - args_def_fn, \ - meta_kernel_fn, \ - cpp_dtype, \ - ...) \ - static const ::pten::KernelRegistrar PT_CONCATENATE( \ - __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ - #kernel_name, \ - BACKEND(backend), \ - DATALAYOUT(layout), \ - ::paddle::experimental::CppTypeToDataType::Type(), \ - ::pten::KernelArgsParseFunctor)>::Parse, \ - args_def_fn, \ - PT_KERNEL(meta_kernel_fn), \ - PT_VARIADIC_KERNEL( \ - meta_kernel_fn)); \ - PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT2_12(kernel_name, \ - backend, \ - layout, \ - PT_ID, \ - args_def_fn, \ - meta_kernel_fn, \ - __VA_ARGS__)) -#define _PT_KERNEL_REGISTRAR_INIT2_14(kernel_name, \ - backend, \ - layout, \ - registrar_id, \ - args_def_fn, \ - meta_kernel_fn, \ - cpp_dtype, \ - ...) \ - static const ::pten::KernelRegistrar PT_CONCATENATE( \ - __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ - #kernel_name, \ - BACKEND(backend), \ - DATALAYOUT(layout), \ - ::paddle::experimental::CppTypeToDataType::Type(), \ - ::pten::KernelArgsParseFunctor)>::Parse, \ - args_def_fn, \ - PT_KERNEL(meta_kernel_fn), \ - PT_VARIADIC_KERNEL( \ - meta_kernel_fn)); \ - PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT2_13(kernel_name, \ - backend, \ - layout, \ - PT_ID, \ - args_def_fn, \ - meta_kernel_fn, \ - __VA_ARGS__)) -#define _PT_KERNEL_REGISTRAR_INIT2_15(kernel_name, \ - backend, \ - layout, \ - registrar_id, \ - args_def_fn, \ - meta_kernel_fn, \ - cpp_dtype, \ - ...) \ - static const ::pten::KernelRegistrar PT_CONCATENATE( \ - __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ - #kernel_name, \ - BACKEND(backend), \ - DATALAYOUT(layout), \ - ::paddle::experimental::CppTypeToDataType::Type(), \ - ::pten::KernelArgsParseFunctor)>::Parse, \ - args_def_fn, \ - PT_KERNEL(meta_kernel_fn), \ - PT_VARIADIC_KERNEL( \ - meta_kernel_fn)); \ - PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT2_14(kernel_name, \ - backend, \ - layout, \ - PT_ID, \ - args_def_fn, \ - meta_kernel_fn, \ - __VA_ARGS__)) - /** PT_DECLARE_KERNEL * * Used to export the symbols of the file where the kernel is located, diff --git a/paddle/pten/kernels/cpu/cast_kernel.cc b/paddle/pten/kernels/cpu/cast_kernel.cc index c6736cdd1bcf03..a0006f49a2b383 100644 --- a/paddle/pten/kernels/cpu/cast_kernel.cc +++ b/paddle/pten/kernels/cpu/cast_kernel.cc @@ -58,20 +58,20 @@ void CastKernel(const Context& dev_ctx, } // namespace pten -PT_REGISTER_CTX_KERNEL(cast, - CPU, - ALL_LAYOUT, - pten::CastKernel, - float, - double, - int, - int64_t, - int16_t, - bool, - uint8_t, - paddle::platform::float16, - paddle::platform::bfloat16, - paddle::platform::complex, - paddle::platform::complex) { +PT_REGISTER_KERNEL(cast, + CPU, + ALL_LAYOUT, + pten::CastKernel, + float, + double, + int, + int64_t, + int16_t, + bool, + uint8_t, + paddle::platform::float16, + paddle::platform::bfloat16, + paddle::platform::complex, + paddle::platform::complex) { kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED); } diff --git a/paddle/pten/kernels/cpu/complex_kernel.cc b/paddle/pten/kernels/cpu/complex_kernel.cc index 10e7e684db3c1a..59a7577153a618 100644 --- a/paddle/pten/kernels/cpu/complex_kernel.cc +++ b/paddle/pten/kernels/cpu/complex_kernel.cc @@ -21,13 +21,13 @@ // See Note [ Why still include the fluid headers? ] #include "paddle/fluid/platform/complex.h" -PT_REGISTER_CTX_KERNEL(conj, - CPU, - ALL_LAYOUT, - pten::ConjKernel, - paddle::platform::complex, - paddle::platform::complex, - float, - double, - int, - int64_t) {} +PT_REGISTER_KERNEL(conj, + CPU, + ALL_LAYOUT, + pten::ConjKernel, + paddle::platform::complex, + paddle::platform::complex, + float, + double, + int, + int64_t) {} diff --git a/paddle/pten/kernels/cpu/dot_grad_kernel.cc b/paddle/pten/kernels/cpu/dot_grad_kernel.cc index c9d5c35e134c83..ed927f820f0e7e 100644 --- a/paddle/pten/kernels/cpu/dot_grad_kernel.cc +++ b/paddle/pten/kernels/cpu/dot_grad_kernel.cc @@ -20,13 +20,13 @@ #include "paddle/fluid/platform/complex.h" -PT_REGISTER_CTX_KERNEL(dot_grad, - CPU, - ALL_LAYOUT, - pten::DotGradKernel, - float, - double, - int, - int64_t, - paddle::platform::complex, - paddle::platform::complex) {} +PT_REGISTER_KERNEL(dot_grad, + CPU, + ALL_LAYOUT, + pten::DotGradKernel, + float, + double, + int, + int64_t, + paddle::platform::complex, + paddle::platform::complex) {} diff --git a/paddle/pten/kernels/cpu/dot_kernel.cc b/paddle/pten/kernels/cpu/dot_kernel.cc index 72e9e28907f909..0baf9ba0a8bdd3 100644 --- a/paddle/pten/kernels/cpu/dot_kernel.cc +++ b/paddle/pten/kernels/cpu/dot_kernel.cc @@ -49,13 +49,13 @@ void DotKernel(const Context& dev_ctx, using complex64 = ::paddle::platform::complex; using complex128 = ::paddle::platform::complex; -PT_REGISTER_CTX_KERNEL(dot, - CPU, - ALL_LAYOUT, - pten::DotKernel, - float, - double, - int, - int64_t, - complex64, - complex128) {} +PT_REGISTER_KERNEL(dot, + CPU, + ALL_LAYOUT, + pten::DotKernel, + float, + double, + int, + int64_t, + complex64, + complex128) {} diff --git a/paddle/pten/kernels/cpu/full_kernel.cc b/paddle/pten/kernels/cpu/full_kernel.cc index 1ae8001d79dc71..919471d86ac534 100644 --- a/paddle/pten/kernels/cpu/full_kernel.cc +++ b/paddle/pten/kernels/cpu/full_kernel.cc @@ -18,29 +18,29 @@ limitations under the License. */ #include "paddle/pten/core/kernel_registry.h" #include "paddle/pten/kernels/impl/full_kernel_impl.h" -PT_REGISTER_CTX_KERNEL(full, - CPU, - ALL_LAYOUT, - pten::FullKernel, - float, - double, - uint8_t, - int16_t, - int, - int64_t, - bool, - paddle::platform::float16, - paddle::platform::bfloat16, - paddle::platform::complex, - paddle::platform::complex) {} +PT_REGISTER_KERNEL(full, + CPU, + ALL_LAYOUT, + pten::FullKernel, + float, + double, + uint8_t, + int16_t, + int, + int64_t, + bool, + paddle::platform::float16, + paddle::platform::bfloat16, + paddle::platform::complex, + paddle::platform::complex) {} -PT_REGISTER_CTX_KERNEL(full_like, - CPU, - ALL_LAYOUT, - pten::FullLikeKernel, - float, - double, - int, - int64_t, - bool, - paddle::platform::float16) {} +PT_REGISTER_KERNEL(full_like, + CPU, + ALL_LAYOUT, + pten::FullLikeKernel, + float, + double, + int, + int64_t, + bool, + paddle::platform::float16) {} diff --git a/paddle/pten/kernels/cpu/math_kernel.cc b/paddle/pten/kernels/cpu/math_kernel.cc index be0d52355bce69..83388d0d9a80fd 100644 --- a/paddle/pten/kernels/cpu/math_kernel.cc +++ b/paddle/pten/kernels/cpu/math_kernel.cc @@ -118,60 +118,60 @@ using complex128 = ::paddle::platform::complex; // NOTE(chenweihang): using bfloat16 will cause redefine with xpu bfloat16 // using bfloat16 = ::paddle::platform::bfloat16; -PT_REGISTER_CTX_KERNEL( +PT_REGISTER_KERNEL( mean, CPU, ALL_LAYOUT, pten::MeanKernel, float, double, bool) {} -PT_REGISTER_CTX_KERNEL(add, - CPU, - ALL_LAYOUT, - pten::AddKernel, - float, - double, - int, - int64_t, - complex64, - complex128) {} -PT_REGISTER_CTX_KERNEL(subtract, - CPU, - ALL_LAYOUT, - pten::SubtractKernel, - float, - double, - int, - int64_t, - complex64, - complex128) {} -PT_REGISTER_CTX_KERNEL(divide, - CPU, - ALL_LAYOUT, - pten::DivideKernel, - float, - double, - int, - int64_t, - complex64, - complex128) {} -PT_REGISTER_CTX_KERNEL(multiply, - CPU, - ALL_LAYOUT, - pten::MultiplyKernel, - float, - double, - int, - int64_t, - bool, - complex64, - complex128) {} -PT_REGISTER_CTX_KERNEL(sum, - CPU, - ALL_LAYOUT, - pten::SumKernel, - bool, - float, - double, - paddle::platform::float16, - int, - int64_t, - complex64, - complex128) { +PT_REGISTER_KERNEL(add, + CPU, + ALL_LAYOUT, + pten::AddKernel, + float, + double, + int, + int64_t, + complex64, + complex128) {} +PT_REGISTER_KERNEL(subtract, + CPU, + ALL_LAYOUT, + pten::SubtractKernel, + float, + double, + int, + int64_t, + complex64, + complex128) {} +PT_REGISTER_KERNEL(divide, + CPU, + ALL_LAYOUT, + pten::DivideKernel, + float, + double, + int, + int64_t, + complex64, + complex128) {} +PT_REGISTER_KERNEL(multiply, + CPU, + ALL_LAYOUT, + pten::MultiplyKernel, + float, + double, + int, + int64_t, + bool, + complex64, + complex128) {} +PT_REGISTER_KERNEL(sum, + CPU, + ALL_LAYOUT, + pten::SumKernel, + bool, + float, + double, + paddle::platform::float16, + int, + int64_t, + complex64, + complex128) { kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED); } diff --git a/paddle/pten/kernels/cpu/matmul_grad_kernel.cc b/paddle/pten/kernels/cpu/matmul_grad_kernel.cc index 5a8abb6701b0e0..4738e21573194b 100644 --- a/paddle/pten/kernels/cpu/matmul_grad_kernel.cc +++ b/paddle/pten/kernels/cpu/matmul_grad_kernel.cc @@ -19,29 +19,29 @@ limitations under the License. */ #include "paddle/pten/kernels/impl/matmul_grad_kernel_impl.h" -PT_REGISTER_CTX_KERNEL(matmul_grad, - CPU, - ALL_LAYOUT, - pten::MatmulGradKernel, - float, - double, - paddle::platform::complex, - paddle::platform::complex) {} - -PT_REGISTER_CTX_KERNEL(matmul_double_grad, - CPU, - ALL_LAYOUT, - pten::MatmulDoubleGradKernel, - float, - double, - paddle::platform::complex, - paddle::platform::complex) {} - -PT_REGISTER_CTX_KERNEL(matmul_triple_grad, - CPU, - ALL_LAYOUT, - pten::MatmulTripleGradKernel, - float, - double, - paddle::platform::complex, - paddle::platform::complex) {} +PT_REGISTER_KERNEL(matmul_grad, + CPU, + ALL_LAYOUT, + pten::MatmulGradKernel, + float, + double, + paddle::platform::complex, + paddle::platform::complex) {} + +PT_REGISTER_KERNEL(matmul_double_grad, + CPU, + ALL_LAYOUT, + pten::MatmulDoubleGradKernel, + float, + double, + paddle::platform::complex, + paddle::platform::complex) {} + +PT_REGISTER_KERNEL(matmul_triple_grad, + CPU, + ALL_LAYOUT, + pten::MatmulTripleGradKernel, + float, + double, + paddle::platform::complex, + paddle::platform::complex) {} diff --git a/paddle/pten/kernels/cpu/matmul_kernel.cc b/paddle/pten/kernels/cpu/matmul_kernel.cc index edba402ec1d842..f749e9cb279792 100644 --- a/paddle/pten/kernels/cpu/matmul_kernel.cc +++ b/paddle/pten/kernels/cpu/matmul_kernel.cc @@ -20,11 +20,11 @@ limitations under the License. */ #include "paddle/fluid/platform/complex.h" #include "paddle/pten/kernels/impl/matmul_kernel_impl.h" -PT_REGISTER_CTX_KERNEL(matmul, - CPU, - ALL_LAYOUT, - pten::MatmulKernel, - float, - double, - paddle::platform::complex, - paddle::platform::complex) {} +PT_REGISTER_KERNEL(matmul, + CPU, + ALL_LAYOUT, + pten::MatmulKernel, + float, + double, + paddle::platform::complex, + paddle::platform::complex) {} diff --git a/paddle/pten/kernels/cpu/scale_kernel.cc b/paddle/pten/kernels/cpu/scale_kernel.cc index 0582fb87b4457d..7088bba01aa787 100644 --- a/paddle/pten/kernels/cpu/scale_kernel.cc +++ b/paddle/pten/kernels/cpu/scale_kernel.cc @@ -51,15 +51,15 @@ void ScaleKernel(const Context& dev_ctx, } // namespace pten -PT_REGISTER_CTX_KERNEL(scale, - CPU, - ALL_LAYOUT, - pten::ScaleKernel, - float, - double, - paddle::platform::bfloat16, - uint8_t, - int8_t, - int16_t, - int, - int64_t) {} +PT_REGISTER_KERNEL(scale, + CPU, + ALL_LAYOUT, + pten::ScaleKernel, + float, + double, + paddle::platform::bfloat16, + uint8_t, + int8_t, + int16_t, + int, + int64_t) {} diff --git a/paddle/pten/kernels/cpu/sign_kernel.cc b/paddle/pten/kernels/cpu/sign_kernel.cc index a7b62822d6e0fa..25fa2bb5fe4efb 100644 --- a/paddle/pten/kernels/cpu/sign_kernel.cc +++ b/paddle/pten/kernels/cpu/sign_kernel.cc @@ -21,5 +21,4 @@ limitations under the License. */ // See Note [ Why still include the fluid headers? ] #include "paddle/fluid/platform/bfloat16.h" -PT_REGISTER_CTX_KERNEL(sign, CPU, ALL_LAYOUT, pten::SignKernel, float, double) { -} +PT_REGISTER_KERNEL(sign, CPU, ALL_LAYOUT, pten::SignKernel, float, double) {} diff --git a/paddle/pten/kernels/empty_kernel.cc b/paddle/pten/kernels/empty_kernel.cc index 2dd55a13e38e54..eb67ed6655f479 100644 --- a/paddle/pten/kernels/empty_kernel.cc +++ b/paddle/pten/kernels/empty_kernel.cc @@ -34,66 +34,66 @@ void EmptyLikeKernel(const Context& dev_ctx, DenseTensor* out) { } // namespace pten -PT_REGISTER_CTX_KERNEL(empty, - CPU, - ALL_LAYOUT, - pten::EmptyKernel, - float, - double, - uint8_t, - int16_t, - int, - int64_t, - bool, - paddle::platform::float16, - paddle::platform::bfloat16, - paddle::platform::complex, - paddle::platform::complex) {} +PT_REGISTER_KERNEL(empty, + CPU, + ALL_LAYOUT, + pten::EmptyKernel, + float, + double, + uint8_t, + int16_t, + int, + int64_t, + bool, + paddle::platform::float16, + paddle::platform::bfloat16, + paddle::platform::complex, + paddle::platform::complex) {} -PT_REGISTER_CTX_KERNEL(empty_like, - CPU, - ALL_LAYOUT, - pten::EmptyLikeKernel, - float, - double, - uint8_t, - int16_t, - int, - int64_t, - bool, - paddle::platform::float16, - paddle::platform::bfloat16, - paddle::platform::complex, - paddle::platform::complex) {} +PT_REGISTER_KERNEL(empty_like, + CPU, + ALL_LAYOUT, + pten::EmptyLikeKernel, + float, + double, + uint8_t, + int16_t, + int, + int64_t, + bool, + paddle::platform::float16, + paddle::platform::bfloat16, + paddle::platform::complex, + paddle::platform::complex) {} #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) -PT_REGISTER_CTX_KERNEL(empty, - GPU, - ALL_LAYOUT, - pten::EmptyKernel, - float, - double, - uint8_t, - int16_t, - int, - int64_t, - bool, - paddle::platform::float16, - paddle::platform::complex, - paddle::platform::complex) {} +PT_REGISTER_KERNEL(empty, + GPU, + ALL_LAYOUT, + pten::EmptyKernel, + float, + double, + uint8_t, + int16_t, + int, + int64_t, + bool, + paddle::platform::float16, + paddle::platform::complex, + paddle::platform::complex) {} -PT_REGISTER_CTX_KERNEL(empty_like, - GPU, - ALL_LAYOUT, - pten::EmptyLikeKernel, - float, - double, - uint8_t, - int16_t, - int, - int64_t, - bool, - paddle::platform::float16, - paddle::platform::complex, - paddle::platform::complex) {} +PT_REGISTER_KERNEL(empty_like, + GPU, + ALL_LAYOUT, + pten::EmptyLikeKernel, + float, + double, + uint8_t, + int16_t, + int, + int64_t, + bool, + paddle::platform::float16, + paddle::platform::complex, + paddle::platform::complex) {} #endif diff --git a/paddle/pten/kernels/flatten_grad_kernel.cc b/paddle/pten/kernels/flatten_grad_kernel.cc index d6aea31748d6cf..45f3c6558d9c87 100644 --- a/paddle/pten/kernels/flatten_grad_kernel.cc +++ b/paddle/pten/kernels/flatten_grad_kernel.cc @@ -33,41 +33,41 @@ void FlattenGradKernel(const Context& dev_ctx, } // namespace pten -PT_REGISTER_CTX_KERNEL(flatten_grad, - CPU, - ALL_LAYOUT, - pten::FlattenGradKernel, - float, - double, - uint8_t, - int8_t, - int, - int64_t) {} +PT_REGISTER_KERNEL(flatten_grad, + CPU, + ALL_LAYOUT, + pten::FlattenGradKernel, + float, + double, + uint8_t, + int8_t, + int, + int64_t) {} #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) -PT_REGISTER_CTX_KERNEL(flatten_grad, - GPU, - ALL_LAYOUT, - pten::FlattenGradKernel, - float, - paddle::platform::float16, - double, - uint8_t, - int8_t, - int, - int64_t) {} +PT_REGISTER_KERNEL(flatten_grad, + GPU, + ALL_LAYOUT, + pten::FlattenGradKernel, + float, + paddle::platform::float16, + double, + uint8_t, + int8_t, + int, + int64_t) {} #endif #ifdef PADDLE_WITH_XPU -PT_REGISTER_CTX_KERNEL(flatten_grad, - XPU, - ALL_LAYOUT, - pten::FlattenGradKernel, - float, - paddle::platform::float16, - int8_t, - int, - int64_t) {} +PT_REGISTER_KERNEL(flatten_grad, + XPU, + ALL_LAYOUT, + pten::FlattenGradKernel, + float, + paddle::platform::float16, + int8_t, + int, + int64_t) {} #endif diff --git a/paddle/pten/kernels/flatten_kernel.cc b/paddle/pten/kernels/flatten_kernel.cc index b284d3690830f7..9201a8df9d166c 100644 --- a/paddle/pten/kernels/flatten_kernel.cc +++ b/paddle/pten/kernels/flatten_kernel.cc @@ -48,72 +48,72 @@ void FlattenWithXShape(const Context& dev_ctx, } // namespace pten -PT_REGISTER_CTX_KERNEL(flatten, - CPU, - ALL_LAYOUT, - pten::FlattenKernel, - float, - double, - uint8_t, - int8_t, - int, - int64_t) {} +PT_REGISTER_KERNEL(flatten, + CPU, + ALL_LAYOUT, + pten::FlattenKernel, + float, + double, + uint8_t, + int8_t, + int, + int64_t) {} -PT_REGISTER_CTX_KERNEL(flatten_with_xshape, - CPU, - ALL_LAYOUT, - pten::FlattenWithXShape, - float, - double, - uint8_t, - int8_t, - int, - int64_t) {} +PT_REGISTER_KERNEL(flatten_with_xshape, + CPU, + ALL_LAYOUT, + pten::FlattenWithXShape, + float, + double, + uint8_t, + int8_t, + int, + int64_t) {} #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) -PT_REGISTER_CTX_KERNEL(flatten, - GPU, - ALL_LAYOUT, - pten::FlattenKernel, - float, - paddle::platform::float16, - double, - uint8_t, - int8_t, - int, - int64_t) {} +PT_REGISTER_KERNEL(flatten, + GPU, + ALL_LAYOUT, + pten::FlattenKernel, + float, + paddle::platform::float16, + double, + uint8_t, + int8_t, + int, + int64_t) {} -PT_REGISTER_CTX_KERNEL(flatten_with_xshape, - GPU, - ALL_LAYOUT, - pten::FlattenWithXShape, - float, - paddle::platform::float16, - double, - uint8_t, - int8_t, - int, - int64_t) {} +PT_REGISTER_KERNEL(flatten_with_xshape, + GPU, + ALL_LAYOUT, + pten::FlattenWithXShape, + float, + paddle::platform::float16, + double, + uint8_t, + int8_t, + int, + int64_t) {} #endif #ifdef PADDLE_WITH_XPU -PT_REGISTER_CTX_KERNEL(flatten, - XPU, - ALL_LAYOUT, - pten::FlattenKernel, - float, - paddle::platform::float16, - int8_t, - int, - int64_t) {} +PT_REGISTER_KERNEL(flatten, + XPU, + ALL_LAYOUT, + pten::FlattenKernel, + float, + paddle::platform::float16, + int8_t, + int, + int64_t) {} -PT_REGISTER_CTX_KERNEL(flatten_with_xshape, - XPU, - ALL_LAYOUT, - pten::FlattenWithXShape, - float, - paddle::platform::float16, - int8_t, - int, - int64_t) {} +PT_REGISTER_KERNEL(flatten_with_xshape, + XPU, + ALL_LAYOUT, + pten::FlattenWithXShape, + float, + paddle::platform::float16, + int8_t, + int, + int64_t) {} #endif diff --git a/paddle/pten/kernels/gpu/cast_kernel.cu b/paddle/pten/kernels/gpu/cast_kernel.cu index 0bbe7a3a132d1c..2f91c94ba5f75e 100644 --- a/paddle/pten/kernels/gpu/cast_kernel.cu +++ b/paddle/pten/kernels/gpu/cast_kernel.cu @@ -60,24 +60,24 @@ void CastKernel(const Context& dev_ctx, } // namespace pten -#define PTEN_REGISTER_CAST_CUDA_BASE_TYPE(op_name, ...) \ - PT_REGISTER_CTX_KERNEL(cast, \ - GPU, \ - ALL_LAYOUT, \ - pten::CastKernel, \ - float, \ - double, \ - int, \ - int64_t, \ - int16_t, \ - bool, \ - uint8_t, \ - paddle::platform::float16, \ - paddle::platform::complex, \ - paddle::platform::complex, \ - ##__VA_ARGS__) { \ - kernel->OutputAt(0).SetDataType( \ - paddle::experimental::DataType::UNDEFINED); \ +#define PTEN_REGISTER_CAST_CUDA_BASE_TYPE(op_name, ...) \ + PT_REGISTER_KERNEL(cast, \ + GPU, \ + ALL_LAYOUT, \ + pten::CastKernel, \ + float, \ + double, \ + int, \ + int64_t, \ + int16_t, \ + bool, \ + uint8_t, \ + paddle::platform::float16, \ + paddle::platform::complex, \ + paddle::platform::complex, \ + ##__VA_ARGS__) { \ + kernel->OutputAt(0).SetDataType( \ + paddle::experimental::DataType::UNDEFINED); \ } #if !defined(PADDLE_WITH_HIP) diff --git a/paddle/pten/kernels/gpu/complex_kernel.cu b/paddle/pten/kernels/gpu/complex_kernel.cu index 02f050f5bc838b..1c82077793e0a6 100644 --- a/paddle/pten/kernels/gpu/complex_kernel.cu +++ b/paddle/pten/kernels/gpu/complex_kernel.cu @@ -21,14 +21,14 @@ // See Note [ Why still include the fluid headers? ] #include "paddle/fluid/platform/complex.h" -PT_REGISTER_CTX_KERNEL(conj, - GPU, - ALL_LAYOUT, - pten::ConjKernel, - paddle::platform::float16, - paddle::platform::complex, - paddle::platform::complex, - float, - double, - int, - int64_t) {} +PT_REGISTER_KERNEL(conj, + GPU, + ALL_LAYOUT, + pten::ConjKernel, + paddle::platform::float16, + paddle::platform::complex, + paddle::platform::complex, + float, + double, + int, + int64_t) {} diff --git a/paddle/pten/kernels/gpu/dot_grad_kernel.cu b/paddle/pten/kernels/gpu/dot_grad_kernel.cu index 42af96f7c7265d..4b0d7fed4c9fd3 100644 --- a/paddle/pten/kernels/gpu/dot_grad_kernel.cu +++ b/paddle/pten/kernels/gpu/dot_grad_kernel.cu @@ -20,13 +20,13 @@ limitations under the License. */ #include "paddle/fluid/platform/complex.h" -PT_REGISTER_CTX_KERNEL(dot_grad, - GPU, - ALL_LAYOUT, - pten::DotGradKernel, - float, - double, - int, - int64_t, - paddle::platform::complex, - paddle::platform::complex) {} +PT_REGISTER_KERNEL(dot_grad, + GPU, + ALL_LAYOUT, + pten::DotGradKernel, + float, + double, + int, + int64_t, + paddle::platform::complex, + paddle::platform::complex) {} diff --git a/paddle/pten/kernels/gpu/dot_kernel.cu b/paddle/pten/kernels/gpu/dot_kernel.cu index 08d8f83c408dea..18bab5c15a0585 100644 --- a/paddle/pten/kernels/gpu/dot_kernel.cu +++ b/paddle/pten/kernels/gpu/dot_kernel.cu @@ -52,13 +52,13 @@ void DotKernel(const Context& dev_ctx, using complex64 = ::paddle::platform::complex; using complex128 = ::paddle::platform::complex; -PT_REGISTER_CTX_KERNEL(dot, - GPU, - ALL_LAYOUT, - pten::DotKernel, - float, - double, - int, - int64_t, - complex64, - complex128) {} +PT_REGISTER_KERNEL(dot, + GPU, + ALL_LAYOUT, + pten::DotKernel, + float, + double, + int, + int64_t, + complex64, + complex128) {} diff --git a/paddle/pten/kernels/gpu/full_kernel.cu b/paddle/pten/kernels/gpu/full_kernel.cu index ae1f8529db3de8..2f6346daa888f3 100644 --- a/paddle/pten/kernels/gpu/full_kernel.cu +++ b/paddle/pten/kernels/gpu/full_kernel.cu @@ -18,28 +18,28 @@ limitations under the License. */ #include "paddle/pten/core/kernel_registry.h" #include "paddle/pten/kernels/impl/full_kernel_impl.h" -PT_REGISTER_CTX_KERNEL(full, - GPU, - ALL_LAYOUT, - pten::FullKernel, - float, - double, - uint8_t, - int16_t, - int, - int64_t, - bool, - paddle::platform::float16, - paddle::platform::complex, - paddle::platform::complex) {} +PT_REGISTER_KERNEL(full, + GPU, + ALL_LAYOUT, + pten::FullKernel, + float, + double, + uint8_t, + int16_t, + int, + int64_t, + bool, + paddle::platform::float16, + paddle::platform::complex, + paddle::platform::complex) {} -PT_REGISTER_CTX_KERNEL(full_like, - GPU, - ALL_LAYOUT, - pten::FullLikeKernel, - float, - double, - int, - int64_t, - bool, - paddle::platform::float16) {} +PT_REGISTER_KERNEL(full_like, + GPU, + ALL_LAYOUT, + pten::FullLikeKernel, + float, + double, + int, + int64_t, + bool, + paddle::platform::float16) {} diff --git a/paddle/pten/kernels/gpu/math_kernel.cu b/paddle/pten/kernels/gpu/math_kernel.cu index 557080638038d4..1fd085ab5fe409 100644 --- a/paddle/pten/kernels/gpu/math_kernel.cu +++ b/paddle/pten/kernels/gpu/math_kernel.cu @@ -110,64 +110,64 @@ using float16 = paddle::platform::float16; using complex64 = ::paddle::platform::complex; using complex128 = ::paddle::platform::complex; -PT_REGISTER_CTX_KERNEL( +PT_REGISTER_KERNEL( mean, GPU, ALL_LAYOUT, pten::MeanKernel, float, double, bool, float16) {} -PT_REGISTER_CTX_KERNEL(add, - GPU, - ALL_LAYOUT, - pten::AddKernel, - float, - double, - int, - int64_t, - float16, - complex64, - complex128) {} -PT_REGISTER_CTX_KERNEL(subtract, - GPU, - ALL_LAYOUT, - pten::SubtractKernel, - float, - double, - int, - int64_t, - float16, - complex64, - complex128) {} -PT_REGISTER_CTX_KERNEL(divide, - GPU, - ALL_LAYOUT, - pten::DivideKernel, - float, - double, - int, - int64_t, - float16, - complex64, - complex128) {} -PT_REGISTER_CTX_KERNEL(multiply, - GPU, - ALL_LAYOUT, - pten::MultiplyKernel, - float, - double, - int, - int64_t, - bool, - float16, - complex64, - complex128) {} -PT_REGISTER_CTX_KERNEL(sum, - GPU, - ALL_LAYOUT, - pten::SumKernel, - bool, - float, - double, - float16, - int, - int64_t, - complex64, - complex128) { +PT_REGISTER_KERNEL(add, + GPU, + ALL_LAYOUT, + pten::AddKernel, + float, + double, + int, + int64_t, + float16, + complex64, + complex128) {} +PT_REGISTER_KERNEL(subtract, + GPU, + ALL_LAYOUT, + pten::SubtractKernel, + float, + double, + int, + int64_t, + float16, + complex64, + complex128) {} +PT_REGISTER_KERNEL(divide, + GPU, + ALL_LAYOUT, + pten::DivideKernel, + float, + double, + int, + int64_t, + float16, + complex64, + complex128) {} +PT_REGISTER_KERNEL(multiply, + GPU, + ALL_LAYOUT, + pten::MultiplyKernel, + float, + double, + int, + int64_t, + bool, + float16, + complex64, + complex128) {} +PT_REGISTER_KERNEL(sum, + GPU, + ALL_LAYOUT, + pten::SumKernel, + bool, + float, + double, + float16, + int, + int64_t, + complex64, + complex128) { kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED); } diff --git a/paddle/pten/kernels/gpu/matmul_grad_kernel.cu b/paddle/pten/kernels/gpu/matmul_grad_kernel.cu index f20c3f82c92623..993b17f6b8ed0c 100644 --- a/paddle/pten/kernels/gpu/matmul_grad_kernel.cu +++ b/paddle/pten/kernels/gpu/matmul_grad_kernel.cu @@ -19,32 +19,32 @@ limitations under the License. */ #include "paddle/pten/kernels/impl/matmul_grad_kernel_impl.h" -PT_REGISTER_CTX_KERNEL(matmul_grad, - GPU, - ALL_LAYOUT, - pten::MatmulGradKernel, - float, - double, - paddle::platform::float16, - paddle::platform::complex, - paddle::platform::complex) {} - -PT_REGISTER_CTX_KERNEL(matmul_double_grad, - GPU, - ALL_LAYOUT, - pten::MatmulDoubleGradKernel, - float, - double, - paddle::platform::float16, - paddle::platform::complex, - paddle::platform::complex) {} - -PT_REGISTER_CTX_KERNEL(matmul_triple_grad, - GPU, - ALL_LAYOUT, - pten::MatmulTripleGradKernel, - float, - double, - paddle::platform::float16, - paddle::platform::complex, - paddle::platform::complex) {} +PT_REGISTER_KERNEL(matmul_grad, + GPU, + ALL_LAYOUT, + pten::MatmulGradKernel, + float, + double, + paddle::platform::float16, + paddle::platform::complex, + paddle::platform::complex) {} + +PT_REGISTER_KERNEL(matmul_double_grad, + GPU, + ALL_LAYOUT, + pten::MatmulDoubleGradKernel, + float, + double, + paddle::platform::float16, + paddle::platform::complex, + paddle::platform::complex) {} + +PT_REGISTER_KERNEL(matmul_triple_grad, + GPU, + ALL_LAYOUT, + pten::MatmulTripleGradKernel, + float, + double, + paddle::platform::float16, + paddle::platform::complex, + paddle::platform::complex) {} diff --git a/paddle/pten/kernels/gpu/matmul_kernel.cu b/paddle/pten/kernels/gpu/matmul_kernel.cu index debda455818a95..a3ab88913a3b6d 100644 --- a/paddle/pten/kernels/gpu/matmul_kernel.cu +++ b/paddle/pten/kernels/gpu/matmul_kernel.cu @@ -20,12 +20,12 @@ limitations under the License. */ #include "paddle/fluid/platform/complex.h" #include "paddle/pten/kernels/impl/matmul_kernel_impl.h" -PT_REGISTER_CTX_KERNEL(matmul, - GPU, - ALL_LAYOUT, - pten::MatmulKernel, - float, - double, - paddle::platform::float16, - paddle::platform::complex, - paddle::platform::complex) {} +PT_REGISTER_KERNEL(matmul, + GPU, + ALL_LAYOUT, + pten::MatmulKernel, + float, + double, + paddle::platform::float16, + paddle::platform::complex, + paddle::platform::complex) {} diff --git a/paddle/pten/kernels/gpu/scale_kernel.cu b/paddle/pten/kernels/gpu/scale_kernel.cu index ff7e2a6ed284c1..4d63701413cd67 100644 --- a/paddle/pten/kernels/gpu/scale_kernel.cu +++ b/paddle/pten/kernels/gpu/scale_kernel.cu @@ -64,15 +64,15 @@ void ScaleKernel(const ContextT& dev_ctx, } // namespace pten -PT_REGISTER_CTX_KERNEL(scale, - GPU, - ALL_LAYOUT, - pten::ScaleKernel, - float, - double, - paddle::platform::float16, - uint8_t, - int8_t, - int16_t, - int, - int64_t) {} +PT_REGISTER_KERNEL(scale, + GPU, + ALL_LAYOUT, + pten::ScaleKernel, + float, + double, + paddle::platform::float16, + uint8_t, + int8_t, + int16_t, + int, + int64_t) {} diff --git a/paddle/pten/kernels/gpu/sign_kernel.cu b/paddle/pten/kernels/gpu/sign_kernel.cu index e7eb7e46861c8d..16356507dc8ea9 100644 --- a/paddle/pten/kernels/gpu/sign_kernel.cu +++ b/paddle/pten/kernels/gpu/sign_kernel.cu @@ -23,5 +23,5 @@ limitations under the License. */ using float16 = paddle::platform::float16; -PT_REGISTER_CTX_KERNEL( +PT_REGISTER_KERNEL( sign, GPU, ALL_LAYOUT, pten::SignKernel, float, double, float16) {}