Skip to content

Commit

Permalink
Refactor op info parser (#54859)
Browse files Browse the repository at this point in the history
* add kernel dialect

* change DenseTensorTypeStorage to DenseTensorType

* add test case`

* add first pd_op to kernel dialect

* lower pd op to kernel dialect

* update

* update

* remove useless code

* add attrite print test

* fix bug

* update

* update

* update

* update

* polish code

* fix bug

* polish  code  and add python test

* add test

* fix test error

* add env flag

* fix bug

* revert test env

* change cc_test_old to cc_test

* fix build_static bug

* fix type test error

* udpate cmake

* disable test in windows

* update

* update

* fix bug

* split file

* fix conflict

* polish code and fix conflict

* support place transformer

* finish bug

* add gpu flags

* fix with cuda macro

* add fetch kernel

* support fetch var in new ir

* fix bug

* polish code

* change array equal to np.testing

* support feed in new ir

* update

* fix bug

* try to hack combine op

* add scope guard

* revert atan2 op

* add scope guard

* update

* polish code

* update

* refactor build kernel context

* fix unitest bug

* polish code

* use original order

* remove useless code

* polish code

* fix bug
  • Loading branch information
phlrain authored Jun 29, 2023
1 parent b94b3ac commit f18d538
Show file tree
Hide file tree
Showing 11 changed files with 268 additions and 282 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "paddle/fluid/framework/new_executor/interpreter/static_build.h"
#include "paddle/fluid/ir/dialect/pd_dialect.h"
#include "paddle/fluid/ir/interface/op_yaml_info.h"
#include "paddle/fluid/ir/interface/op_yaml_info_parser.h"
#include "paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.h"
#include "paddle/fluid/memory/stats.h"
#include "paddle/fluid/operators/controlflow/conditional_block_op_helper.h"
Expand Down Expand Up @@ -951,31 +952,27 @@ void BuildOpFuncList(
auto attr_map = (*it)->attributes();

auto op_name = attr_map.at("op_name").dyn_cast<::ir::StrAttribute>().data();
op_func_node.phi_op_name_ = op_name;

if (op_name == "builtin.combine" || op_name == "pd.feed") {
VLOG(6) << "skip process " << op_name;
continue;
}

op_func_node.phi_op_name_ = op_name;

::ir::OpInfo op_info = ctx->GetRegisteredOpInfo(op_name);

auto impl =
op_info.GetInterfaceImpl<paddle::dialect::OpYamlInfoInterface>();
auto yaml_info = impl->get_op_info_();

auto attr_info = std::get<1>(yaml_info);

op_func_node.infer_meta_interface_ =
op_info.GetInterfaceImpl<paddle::dialect::InferMetaInterface>();

VLOG(6) << "op name" << op_func_node.phi_op_name_;

dialect::OpYamlInfoParser op_yaml_info_parser(impl->get_op_info_());
::ir::BuildInferMetaContext((*it),
value_2_name_map,
scope,
yaml_info,
op_yaml_info_parser,
&(op_func_node.infer_meta_context_));

auto kernel_name =
Expand All @@ -996,7 +993,7 @@ void BuildOpFuncList(
::ir::BuildPhiKernelContext((*it),
value_2_name_map,
scope,
yaml_info,
op_yaml_info_parser,
&(op_func_node.kernel_context_),
&(op_func_node.input_index),
&(op_func_node.output_index));
Expand Down
16 changes: 5 additions & 11 deletions paddle/fluid/framework/new_executor/new_ir_interpreter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,9 @@ void NewIRInterpreter::RunImpl() {
// &&
// (sync_op_num_ == 0)) {
VLOG(4) << "Tracing Instruction List";

TraceInstructionList(vec_instruction_);

// } else {
// VLOG(4) << "Non-tracing";
// // For the program that only run once, it is no need to
Expand Down Expand Up @@ -938,15 +940,6 @@ void NewIRInterpreter::RunOperator(const Instruction& instr_node) {
}

void NewIRInterpreter::RunInstruction(const Instruction& instr_node) {
VLOG(5) << __func__ << " OP id:" << instr_node.Id()
<< " name:" << instr_node.OpBase()->Type() << " type:"
<< (instr_node.KernelType() == OpFuncType::kCpuSync
? "kCpuSync"
: (instr_node.KernelType() == OpFuncType::kGpuSync
? "kGpuSync"
: "kGpuAsync"))
<< " runs on " << platform::GetCurrentThreadName();

OperatorBase* op = nullptr;
if (instr_node.OpBaseValid()) {
op = instr_node.OpBase();
Expand Down Expand Up @@ -1377,8 +1370,9 @@ void NewIRInterpreter::TraceInstructionList(
}
}

for (size_t idx = 0; idx < trace_execute_order_.size(); idx++) {
auto instr_id = trace_execute_order_[idx];
// TODO(phlrain) use orignal order for now, use better dependecy
for (size_t instr_id = 0; instr_id < vec_instruction_.size(); ++instr_id) {
/// auto instr_id = trace_execute_order_[idx];
auto& instr_node = vec_instruction_.at(instr_id);

RunInstruction(instr_node);
Expand Down
31 changes: 28 additions & 3 deletions paddle/fluid/ir/interface/op_yaml_info_parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ OpYamlInfoParser::OpYamlInfoParser(const OpInfoTuple& op_info_tuple)
parse();
}

bool OpYamlInfoParser::IsTensorArrtibute(size_t index) const {
bool OpYamlInfoParser::IsTensorAttribute(size_t index) const {
PADDLE_ENFORCE_LT(
index,
InputInfo().size(),
Expand All @@ -48,6 +48,21 @@ const std::string& OpYamlInfoParser::AttrTypeName(
return it->second.type_name;
}

const std::string& OpYamlInfoParser::TensorAttrTypeName(
const std::string& name) const {
auto it = map_input_info_.find(name);

PADDLE_ENFORCE_NE(it,
map_input_info_.end(),
phi::errors::NotFound("Not found [%s] in input map", name));

PADDLE_ENFORCE_EQ(
it->second.is_mutable_attribute,
true,
phi::errors::PreconditionNotMet("[%s] MUST be a tensor attribute", name));
return it->second.type_name;
}

const std::vector<std::string>& OpYamlInfoParser::InferMetaTensorParams()
const {
return vec_infer_meta_tensor_params_;
Expand All @@ -62,6 +77,14 @@ const std::vector<std::string>& OpYamlInfoParser::KernelFnAttrParams() const {
return vec_kernel_fn_attr_params_;
}

const OpRunTimeInfo& OpYamlInfoParser::OpRuntimeInfo() const {
return std::get<3>(op_info_tuple_);
}

const std::map<std::string, int>& OpYamlInfoParser::Name2Id() const {
return map_name2id_;
}

void OpYamlInfoParser::parse() {
auto input_info = std::get<0>(op_info_tuple_);

Expand Down Expand Up @@ -91,15 +114,17 @@ void OpYamlInfoParser::parse() {
auto runtime_info = std::get<3>(op_info_tuple_);

for (auto& name : runtime_info.infer_meta_param) {
if (map_name2id_.count(name)) {
if (map_name2id_.count(name) &&
!map_input_info_[name].is_mutable_attribute) {
vec_infer_meta_tensor_params_.push_back(name);
} else {
vec_infer_meta_attr_params_.push_back(name);
}
}

for (auto& name : runtime_info.kernel_param) {
if (map_name2id_.count(name)) {
if (map_name2id_.count(name) &&
!map_input_info_[name].is_mutable_attribute) {
vec_kernel_fn_tensor_params_.push_back(name);
} else {
vec_kernel_fn_attr_params_.push_back(name);
Expand Down
7 changes: 5 additions & 2 deletions paddle/fluid/ir/interface/op_yaml_info_parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,23 +25,26 @@ class OpYamlInfoParser {

explicit OpYamlInfoParser(const OpInfoTuple& op_info_tuple);

bool IsTensorArrtibute(size_t index) const;
bool IsTensorAttribute(size_t index) const;
size_t InputTensorNumber() const;

const std::string& AttrTypeName(const std::string& name) const;
const std::string& TensorAttrTypeName(const std::string& name) const;

const std::vector<std::string>& InferMetaTensorParams() const;
const std::vector<std::string>& InferMetaAttrParams() const;
const std::vector<std::string>& KernelFnTensorParams() const;
const std::vector<std::string>& KernelFnAttrParams() const;
const OpRunTimeInfo& OpRuntimeInfo() const;
const std::map<std::string, int>& Name2Id() const;

private:
void parse();
inline const std::vector<OpInputInfo>& InputInfo() const {
return std::get<0>(op_info_tuple_);
}

const OpInfoTuple& op_info_tuple_;
OpInfoTuple op_info_tuple_;

std::map<std::string, int> map_name2id_;

Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/ir/pass/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@ file(GLOB PD_PASS_SRCS "*.cc")
cc_library(
pd_op_to_kernel_pass
SRCS ${PD_PASS_SRCS}
DEPS ir phi_utils)
DEPS ir phi_utils pd_interface)
Loading

0 comments on commit f18d538

Please sign in to comment.