Skip to content

Commit

Permalink
add use_pt_kernel Flags to control whether to use pt kernel (#13)
Browse files Browse the repository at this point in the history
* add use_pt_kernel Flags to control whether to use pt kernel

* change the default value to true for cheking pt kernels
  • Loading branch information
MingMingShangTian authored Oct 12, 2021
1 parent 9b33270 commit aa6ed57
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 2 deletions.
4 changes: 3 additions & 1 deletion paddle/fluid/framework/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ DECLARE_bool(check_nan_inf);
DECLARE_bool(enable_unused_var_check);
PADDLE_DEFINE_EXPORTED_int32(inner_op_parallelism, 0,
"number of threads for inner op");
DECLARE_bool(use_pt_kernel);

namespace paddle {
namespace framework {
Expand Down Expand Up @@ -1155,7 +1156,8 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
// phase

// VLOG(1) << "Pt KernelFactory: " << pt::KernelFactory::Instance();
if (pt::KernelFactory::Instance().ContainsKernel(type_.c_str())) {
if (FLAGS_use_pt_kernel &&
pt::KernelFactory::Instance().ContainsKernel(type_.c_str())) {
if (pt_kernel_key_.get() == nullptr || pt_kernel_.get() == nullptr) {
ChoosePtKernel(*runtime_ctx, *dev_ctx);
}
Expand Down
4 changes: 3 additions & 1 deletion paddle/fluid/imperative/prepared_operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "paddle/fluid/platform/xpu/xpu_op_list.h"
#endif
DECLARE_bool(check_nan_inf);
DECLARE_bool(use_pt_kernel);

namespace paddle {
namespace imperative {
Expand Down Expand Up @@ -205,7 +206,8 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
#endif

// 1. get expected kernel key
if (pt::KernelFactory::Instance().ContainsKernel(op.Type().c_str())) {
if (FLAGS_use_pt_kernel &&
pt::KernelFactory::Instance().ContainsKernel(op.Type().c_str())) {
auto kernel_name =
ConstructPtKernelName<VarType>(op.Type(), (*op.Info().proto_), ins);
auto inputs = BuildInputMap<VarType>(ins);
Expand Down
14 changes: 14 additions & 0 deletions paddle/fluid/platform/flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -673,3 +673,17 @@ PADDLE_DEFINE_EXPORTED_int32(get_host_by_name_time, 120,
PADDLE_DEFINE_EXPORTED_bool(
apply_pass_to_program, false,
"It controls whether to apply IR pass to program when using Fleet APIs");

/**
* Pt kernel related FLAG
* Name: FLAGS_use_pt_kernel
* Since Version: 2.2.0
* Value Range: bool, default=false
* Example: FLAGS_use_pt_kernel=true would use the pt kernel to compute in the
* Op.
* Note:
*/
// TODO(chentianyu03): change default value to false before merge into develop
// branch
PADDLE_DEFINE_EXPORTED_bool(use_pt_kernel, true,
"It controls whether to use pt kernel");

0 comments on commit aa6ed57

Please sign in to comment.