Skip to content

Commit

Permalink
[IR] Support inplace execute logic for NewIrInterpreter (#55210)
Browse files Browse the repository at this point in the history
* add inplace interface

* support inplace

* refine code

* fix bug

* fix bug

* refien code
  • Loading branch information
zhangbo9674 authored Jul 10, 2023
1 parent 5f00305 commit e8cba1c
Show file tree
Hide file tree
Showing 14 changed files with 305 additions and 84 deletions.
1 change: 1 addition & 0 deletions paddle/fluid/ir/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
add_subdirectory(interface)
add_subdirectory(trait)
add_subdirectory(dialect)
add_subdirectory(transforms)
add_subdirectory(phi_kernel_adaptor)
2 changes: 1 addition & 1 deletion paddle/fluid/ir/dialect/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -52,5 +52,5 @@ file(GLOB PD_DIALECT_SRCS "*.cc")
cc_library(
pd_dialect
SRCS ${PD_DIALECT_SRCS} ${op_source_file}
DEPS framework_proto phi phi_utils pd_interface ir)
DEPS framework_proto phi phi_utils pd_interface pd_trait ir)
target_include_directories(pd_dialect PRIVATE ${PD_DIALECT_BINARY_DIR})
5 changes: 5 additions & 0 deletions paddle/fluid/ir/dialect/op_generator/op_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
#include "paddle/fluid/ir/dialect/op_yaml_info_util.h"
#include "paddle/fluid/ir/interface/op_yaml_info.h"
#include "paddle/fluid/ir/interface/infermeta.h"
#include "paddle/fluid/ir/trait/inplace.h"
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/phi/core/infermeta_utils.h"
Expand Down Expand Up @@ -708,6 +709,10 @@ def OpGenerator(
op_interfaces_str = ""
if len(op_interfaces) > 0:
op_interfaces_str = "," + ",".join(op_interfaces)

if op_name[-1] == "_":
op_traits += ["InplaceTrait"]

op_traits_str = ""
if len(op_traits) > 0:
op_traits_str = "," + ",".join(op_traits)
Expand Down
40 changes: 31 additions & 9 deletions paddle/fluid/ir/interface/op_yaml_info_parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,30 @@ const OpRunTimeInfo& OpYamlInfoParser::OpRuntimeInfo() const {
return std::get<3>(op_info_tuple_);
}

const std::map<std::string, int>& OpYamlInfoParser::Name2Id() const {
return name2id_;
const std::map<std::string, int>& OpYamlInfoParser::InputName2Id() const {
return input_name2id_;
}

bool OpYamlInfoParser::HasInplace(const std::string& out_name) const {
auto inplace_info = std::get<3>(op_info_tuple_).inplace;
for (size_t i = 0; i < inplace_info.size(); i++) {
if (out_name == inplace_info[i].first) {
return true;
}
}
return false;
}

const std::string& OpYamlInfoParser::InplaceName(
const std::string& out_name) const {
auto inplace_info = std::get<3>(op_info_tuple_).inplace;
for (size_t i = 0; i < inplace_info.size(); i++) {
if (out_name == inplace_info[i].first) {
return inplace_info[i].second;
}
}
PADDLE_THROW(phi::errors::PreconditionNotMet(
"Can not find inplace input of [%s].", out_name));
}

void OpYamlInfoParser::parse() {
Expand All @@ -94,38 +116,38 @@ void OpYamlInfoParser::parse() {
int start_index = 0;

for (size_t i = 0; i < input_info.size(); ++i) {
name2id_[input_info[i].name] = start_index++;

input_name2id_[input_info[i].name] = start_index++;
input_name_list_.push_back(input_info[i].name);
input_info_[input_info[i].name] = input_info[i];
if (!input_info[i].is_mutable_attribute) {
input_tensor_number_++;
}

input_info_[input_info[i].name] = input_info[i];
}

auto attribute_info = std::get<1>(op_info_tuple_);
for (size_t i = 0; i < attribute_info.size(); ++i) {
attribute_name_list_.push_back(attribute_info[i].name);
attr_info_[attribute_info[i].name] = attribute_info[i];
}

auto output_info = std::get<2>(op_info_tuple_);

for (size_t i = 0; i < output_info.size(); ++i) {
output_name_list_.push_back(output_info[i].name);
output_info_[output_info[i].name] = output_info[i];
}

auto runtime_info = std::get<3>(op_info_tuple_);

for (auto& name : runtime_info.infer_meta_param) {
if (name2id_.count(name) && !input_info_[name].is_mutable_attribute) {
if (input_name2id_.count(name) && !input_info_[name].is_mutable_attribute) {
infer_meta_tensor_params_.push_back(name);
} else {
infer_meta_attr_params_.push_back(name);
}
}

for (auto& name : runtime_info.kernel_param) {
if (name2id_.count(name) && !input_info_[name].is_mutable_attribute) {
if (input_name2id_.count(name) && !input_info_[name].is_mutable_attribute) {
kernel_fn_tensor_params_.push_back(name);
} else {
kernel_fn_attr_params_.push_back(name);
Expand Down
31 changes: 26 additions & 5 deletions paddle/fluid/ir/interface/op_yaml_info_parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,21 @@ class OpYamlInfoParser {
const std::vector<std::string>& TensorParams(bool is_kernel = false) const;
const std::vector<std::string>& AttrParams(bool is_kernel = false) const;
const OpRunTimeInfo& OpRuntimeInfo() const;
const std::map<std::string, int>& Name2Id() const;
const std::map<std::string, int>& InputName2Id() const;

const std::vector<std::string>& InputNames() const {
return input_name_list_;
}
const std::vector<std::string>& AttributeNames() const {
return attribute_name_list_;
}
const std::vector<std::string>& OutputNames() const {
return output_name_list_;
}

bool HasInplace(const std::string& out_name) const;

const std::string& InplaceName(const std::string& out_name) const;

private:
void parse();
Expand All @@ -44,18 +58,25 @@ class OpYamlInfoParser {

OpInfoTuple op_info_tuple_;

std::map<std::string, int> name2id_;

// input info
std::map<std::string, int> input_name2id_;
std::vector<std::string> input_name_list_;
std::map<std::string, OpInputInfo> input_info_;
int input_tensor_number_{0};

// attribute info
std::vector<std::string> attribute_name_list_;
std::map<std::string, OpAttributeInfo> attr_info_;

// output info
std::vector<std::string> output_name_list_;
std::map<std::string, OpOutputInfo> output_info_;

// runtime info
std::vector<std::string> infer_meta_tensor_params_;
std::vector<std::string> infer_meta_attr_params_;
std::vector<std::string> kernel_fn_tensor_params_;
std::vector<std::string> kernel_fn_attr_params_;

int input_tensor_number_{0};
};

} // namespace dialect
Expand Down
Loading

0 comments on commit e8cba1c

Please sign in to comment.