Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PHI]Unify Fluid and PHI kernel #49328

Merged
merged 16 commits into from
Feb 8, 2023
6 changes: 2 additions & 4 deletions cmake/operators.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,7 @@ endfunction()

function(find_phi_register FILENAME ADD_PATH PATTERN)
# set op_name to OUTPUT
set(options "")
set(oneValueArgs "")
set(multiValueArgs "")
file(READ ${FILENAME} CONTENT)

string(
REGEX
MATCH
Expand Down Expand Up @@ -402,6 +398,7 @@ function(op_library TARGET)
set(op_name "")
# Add PHI Kernel Registry Message
find_phi_register(${cc_src} ${pybind_file} "PD_REGISTER_KERNEL")
find_phi_register(${cc_src} ${pybind_file} "PD_REGISTER_STRUCT_KERNEL")
find_phi_register(${cc_src} ${pybind_file} "PD_REGISTER_GENERAL_KERNEL")
find_register(${cc_src} "REGISTER_OPERATOR" op_name)
if(NOT ${op_name} EQUAL "")
Expand Down Expand Up @@ -442,6 +439,7 @@ function(op_library TARGET)
set(op_name "")
# Add PHI Kernel Registry Message
find_phi_register(${cu_src} ${pybind_file} "PD_REGISTER_KERNEL")
find_phi_register(${cu_src} ${pybind_file} "PD_REGISTER_STRUCT_KERNEL")
find_phi_register(${cu_src} ${pybind_file} "PD_REGISTER_GENERAL_KERNEL")
find_register(${cu_src} "REGISTER_OP_CUDA_KERNEL" op_name)
if(NOT ${op_name} EQUAL "")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -820,7 +820,9 @@ bool BuildOpFuncList(const platform::Place& place,
}

// step 5. run kernel
if (run_phi_kernel) {
if (run_phi_kernel &&
op_func_node.phi_kernel_->GetKernelRegisteredType() ==
phi::KernelRegisteredType::FUNCTION) {
phi::KernelContext phi_kernel_context;
op_with_kernel->BuildPhiKernelContext(
runtime_context, dev_ctx, &phi_kernel_context);
Expand All @@ -831,6 +833,12 @@ bool BuildOpFuncList(const platform::Place& place,
op_with_kernel->PhiKernelSignature(),
&phi_kernel_context);
}
} else if (run_phi_kernel &&
op_func_node.phi_kernel_->GetKernelRegisteredType() ==
phi::KernelRegisteredType::STRUCTURE) {
ExecutionContext execution_context(
*op_with_kernel, *runtime_scope, *dev_ctx, runtime_context);
(*op_func_node.phi_kernel_)(&execution_context);
} else {
// the place of exec_ctx maybe has changed.
if (!skip_run) {
Expand Down
32 changes: 32 additions & 0 deletions paddle/fluid/framework/op_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ limitations under the License. */
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/shape_inference.h"
#include "paddle/phi/core/kernel_registry.h"

namespace paddle {
namespace framework {
Expand Down Expand Up @@ -484,5 +485,36 @@ struct OpKernelRegistrarFunctorEx<PlaceType,
USE_OP_KERNEL(op_type)
// clang-format on

template <typename StructureKernel>
struct StructKernelImpl {
static void Compute(phi::KernelContext* ctx) {
auto exe_ctx = static_cast<paddle::framework::ExecutionContext*>(ctx);
StructureKernel().Compute(*exe_ctx);
}
};

#define PHI_STRUCTURE_KERNEL(...) \
::paddle::framework::StructKernelImpl<__VA_ARGS__>::Compute
#define PHI_STRUCTURE_VARIADIC_KERNEL(...) nullptr
#define STRUCTURE_ARG_PARSE_FUNCTOR(...) nullptr

#define STRUCTURE_KERNEL_INSTANTIATION( \
meta_kernel_structure, cpp_dtype, context) \
template class meta_kernel_structure<cpp_dtype, context>;

#define PD_REGISTER_STRUCT_KERNEL( \
kernel_name, backend, layout, meta_kernel_structure, ...) \
_PD_REGISTER_KERNEL(::phi::RegType::INNER, \
kernel_name, \
backend, \
::phi::backend##Context, \
layout, \
meta_kernel_structure, \
STRUCTURE_KERNEL_INSTANTIATION, \
STRUCTURE_ARG_PARSE_FUNCTOR, \
PHI_STRUCTURE_KERNEL, \
PHI_STRUCTURE_VARIADIC_KERNEL, \
__VA_ARGS__)

} // namespace framework
} // namespace paddle
37 changes: 25 additions & 12 deletions paddle/fluid/framework/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1689,15 +1689,18 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
std::string phi_kernel_name;
if (phi::KernelFactory::Instance().HasCompatiblePhiKernel(type_)) {
if (kernel_signature_ == nullptr || phi_kernel_ == nullptr) {
kernel_signature_.reset(new phi::KernelSignature(
std::move(GetExpectedPhiKernelArgs(exe_ctx))));
VLOG(6) << *kernel_signature_.get();
if (phi::KernelFactory::Instance().HasStructuredKernel(type_)) {
kernel_signature_.reset(new phi::KernelSignature(type_.c_str()));
} else {
kernel_signature_.reset(new phi::KernelSignature(
std::move(GetExpectedPhiKernelArgs(exe_ctx))));
}

VLOG(6) << *kernel_signature_.get();
phi_kernel_name = kernel_signature_->name;
kernel_type_.reset(
new OpKernelType(std::move(InnerGetExpectedKernelType(exe_ctx))));
dev_ctx = pool.Get(kernel_type_->place_);

phi_kernel_name = kernel_signature_->name;
// NOTE(Liu-xiandong): The register kernel used KP have library_type[KP],
// But the default library_type is Plain, so we need to modify the
// library_type here, otherwise it can't work.
Expand Down Expand Up @@ -1753,7 +1756,6 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
}
} else {
phi_kernel_name = kernel_signature_->name;

// NOTE(jiahongyu): The registered MKLDNN kernel have library_type =
// LibraryType::kMKLDNN and data_layout_ = DataLayout::ONEDNN. But the default
// values are kPlain, so we need to modify the library_type and data_layout_
Expand Down Expand Up @@ -1939,7 +1941,8 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
platform::TracerEventType::OperatorInner,
1,
platform::EventRole::kInnerOp);
if (run_phi_kernel_) {
if (run_phi_kernel_ && phi_kernel_->GetKernelRegisteredType() ==
phi::KernelRegisteredType::FUNCTION) {
phi::KernelContext phi_kernel_context;
if (enable_cache_runtime_context_ && !need_prepare_phi_data_ &&
!need_prepare_data_) {
Expand Down Expand Up @@ -1977,6 +1980,11 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
BuildPhiKernelContext(*runtime_ctx, dev_ctx, &phi_kernel_context);
(*phi_kernel_)(&phi_kernel_context);
}
} else if (run_phi_kernel_ && phi_kernel_->GetKernelRegisteredType() ==
phi::KernelRegisteredType::STRUCTURE) {
ExecutionContext execution_context(
*this, exec_scope, *dev_ctx, *runtime_ctx);
(*phi_kernel_)(&execution_context);
} else {
(*kernel_func_)(
ExecutionContext(*this, exec_scope, *dev_ctx, *runtime_ctx));
Expand Down Expand Up @@ -2147,14 +2155,18 @@ OpKernelType OperatorWithKernel::InnerGetExpectedKernelType(

phi::KernelKey OperatorWithKernel::ChoosePhiKernel(
const ExecutionContext& ctx) const {
kernel_signature_.reset(
new phi::KernelSignature(std::move(GetExpectedPhiKernelArgs(ctx))));
std::string phi_kernel_name;
if (phi::KernelFactory::Instance().HasStructuredKernel(type_)) {
kernel_signature_.reset(new phi::KernelSignature(type_.c_str()));
} else {
kernel_signature_.reset(
new phi::KernelSignature(std::move(GetExpectedPhiKernelArgs(ctx))));
}
VLOG(6) << *kernel_signature_.get();

phi_kernel_name = kernel_signature_->name;
kernel_type_.reset(
new OpKernelType(std::move(InnerGetExpectedKernelType(ctx))));

auto phi_kernel_name = kernel_signature_->name;
auto phi_kernel_key = TransOpKernelTypeToPhiKernelKey(*kernel_type_.get());
phi_kernel_.reset(new phi::Kernel(phi::KernelFactory::Instance().SelectKernel(
phi_kernel_name, phi_kernel_key)));
Expand Down Expand Up @@ -2616,7 +2628,8 @@ Scope* OperatorWithKernel::PrepareData(
}
};

if (run_phi_kernel_) {
if (run_phi_kernel_ && phi_kernel_->GetKernelRegisteredType() ==
phi::KernelRegisteredType::FUNCTION) {
const auto& input_names = kernel_signature_->input_names;
const auto& input_defs = phi_kernel_->args_def().input_defs();
PADDLE_ENFORCE_EQ(input_names.size(),
Expand Down
3 changes: 2 additions & 1 deletion paddle/fluid/framework/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ limitations under the License. */

#include "paddle/phi/core/compat/arg_map_context.h"
#include "paddle/phi/core/compat/op_utils.h"
#include "paddle/phi/core/kernel_context.h"
#include "paddle/phi/core/kernel_factory.h"
#include "paddle/utils/flat_hash_map.h"

Expand Down Expand Up @@ -290,7 +291,7 @@ class OperatorBase {
const platform::Place& place) const = 0;
};

class ExecutionContext {
class ExecutionContext : public phi::KernelContext {
public:
ExecutionContext(const OperatorBase& op,
const Scope& scope,
Expand Down
48 changes: 32 additions & 16 deletions paddle/fluid/imperative/prepared_operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -273,17 +273,23 @@ PreparedOp PrepareImpl(
kernel_signature = (*arg_map_fn)(
framework::ExecutionArgumentMappingContext(dygraph_exe_ctx));
} else {
default_kernel_signature =
default_phi_kernel_sig_map.GetNullable(op.Type());
if (default_kernel_signature) {
if (phi::KernelFactory::Instance().HasStructuredKernel(op.Type())) {
has_phi_kernel = true;
kernel_signature = *default_kernel_signature;
kernel_signature = phi::KernelSignature(op.Type().c_str());
} else {
default_kernel_signature =
default_phi_kernel_sig_map.GetNullable(op.Type());
if (default_kernel_signature) {
has_phi_kernel = true;
kernel_signature = *default_kernel_signature;
}
}
}

if (has_phi_kernel) {
VLOG(6) << kernel_signature;
phi_kernel_name = kernel_signature.name;

// NOTE(Liu-xiandong): The register kernel used KP have library_type[KP],
// But the default library_type is Plain, so we need to modify the
// library_type here, otherwise it can't work.
Expand Down Expand Up @@ -648,6 +654,7 @@ static void PreparedOpRunPtImpl(
const phi::KernelSignature* default_kernel_signature,
const phi::KernelSignature& kernel_signature,
const phi::Kernel& phi_kernel,
const framework::RuntimeContext& ctx,
platform::DeviceContext* dev_ctx,
const NameVarMap<VarType>& ins,
const NameVarMap<VarType>& outs,
Expand Down Expand Up @@ -678,19 +685,25 @@ static void PreparedOpRunPtImpl(
1,
platform::EventRole::kInnerOp);

PreparePhiData<VarType>(phi_kernel, kernel_signature, ins);

phi::KernelContext phi_kernel_context;
BuildDygraphPhiKernelContext<VarType>(kernel_signature,
phi_kernel,
ins,
outs,
attrs,
default_attrs,
dev_ctx,
&phi_kernel_context);
if (phi_kernel.GetKernelRegisteredType() ==
phi::KernelRegisteredType::FUNCTION) {
PreparePhiData<VarType>(phi_kernel, kernel_signature, ins);
phi::KernelContext phi_kernel_context;
BuildDygraphPhiKernelContext<VarType>(kernel_signature,
phi_kernel,
ins,
outs,
attrs,
default_attrs,
dev_ctx,
&phi_kernel_context);

phi_kernel(&phi_kernel_context);
phi_kernel(&phi_kernel_context);
} else {
DygraphExecutionContext<VarType> exe_ctx(
op, empty_scope, *dev_ctx, ctx, ins, outs, attrs, default_attrs);
phi_kernel(&exe_ctx);
}
}

if (FLAGS_check_nan_inf) {
Expand Down Expand Up @@ -722,6 +735,7 @@ void PreparedOp::Run(const NameVarMap<VarBase>& ins,
default_kernel_signature_,
kernel_signature_,
phi_kernel_,
ctx_,
dev_ctx_,
ins,
outs,
Expand Down Expand Up @@ -753,6 +767,7 @@ void PreparedOp::Run(const NameVarMap<VariableWrapper>& ins,
default_kernel_signature_,
kernel_signature_,
phi_kernel_,
ctx_,
dev_ctx_,
ins,
outs,
Expand Down Expand Up @@ -784,6 +799,7 @@ void PreparedOp::Run(const NameVarMap<egr::EagerVariable>& ins,
default_kernel_signature_,
kernel_signature_,
phi_kernel_,
ctx_,
dev_ctx_,
ins,
outs,
Expand Down
8 changes: 6 additions & 2 deletions paddle/fluid/imperative/tracer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -530,8 +530,12 @@ phi::KernelSignature Tracer::GetExpectedKernelSignature(
"This op type:`%s` is not a OperatorWithKernel, only "
"OperatorWithKernel can get KernelSignature",
type));
return phi::KernelSignature(
std::move(opbase_with_kernel->GetExpectedPhiKernelArgs(dygraph_exe_ctx)));
if (phi::KernelFactory::Instance().HasStructuredKernel(type)) {
return phi::KernelSignature(op->Type().c_str());
} else {
return phi::KernelSignature(std::move(
opbase_with_kernel->GetExpectedPhiKernelArgs(dygraph_exe_ctx)));
}
}

} // namespace imperative
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/operators/custom_device_common_op_registry.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ limitations under the License. */
phi::RegType::INNER, \
#kernel_name, \
dev_type, \
DATALAYOUT(layout), \
DATA_LAYOUT(layout), \
::phi::KernelArgsParseFunctor<decltype(&kernel_fn)>::Parse, \
[](const phi::KernelKey& kernel_key, phi::Kernel* kernel) {}, \
PHI_KERNEL(kernel_fn), \
Expand Down
Loading