Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NewIR]new ir support yaml backend config #56570

Merged
merged 5 commits into from
Aug 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions paddle/fluid/ir/dialect/op_generator/op_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,7 @@ class {op_name} : public ir::Op<{op_name}{interfaces}{traits}> {{
std::vector<paddle::dialect::OpInputInfo> inputs = {{ {inputs} }};
std::vector<paddle::dialect::OpAttributeInfo> attributes = {{ {attributes} }};
std::vector<paddle::dialect::OpOutputInfo> outputs = {{ {outputs} }};
paddle::dialect::OpRunTimeInfo run_time_info = paddle::dialect::OpRunTimeInfo("{infer_meta_func}", {{"{infer_meta_param}"}}, {{"{kernel_func}"}}, {{"{kernel_param}"}}, {{{kernel_key_dtype}}}, {{{inplace}}}, {{{view}}});

paddle::dialect::OpRunTimeInfo run_time_info = paddle::dialect::OpRunTimeInfo("{infer_meta_func}", {{"{infer_meta_param}"}}, {{"{kernel_func}"}}, {{"{kernel_param}"}}, {{{kernel_key_dtype}}}, {{{kernel_key_backend}}}, {{{inplace}}}, {{{view}}});
return std::make_tuple(inputs, attributes, outputs, run_time_info, "{origin_op_name}");
}}
"""
Expand Down Expand Up @@ -1013,6 +1012,7 @@ def OpGenerator(
kernel_func_str = ""
kernel_param_str = ""
kernel_key_dtype = ""
kernel_key_backend = ""
if op_kernel_map is not None:
kernel_func_str = '", "'.join(op_kernel_map['func'])
kernel_param_str = '", "'.join(op_kernel_map['param'])
Expand All @@ -1022,6 +1022,12 @@ def OpGenerator(
)
if kernel_key_dtype != "":
kernel_key_dtype = '"' + kernel_key_dtype + '"'
if 'backend' in op_kernel_map and op_kernel_map['backend']:
kernel_key_backend = '", "'.join(
op_kernel_map['backend']['candidates']
)
if kernel_key_backend != "":
kernel_key_backend = '"' + kernel_key_backend + '"'

inplace_str = ""
view_str = ""
Expand All @@ -1045,6 +1051,7 @@ def OpGenerator(
kernel_func=kernel_func_str,
kernel_param=kernel_param_str,
kernel_key_dtype=kernel_key_dtype,
kernel_key_backend=kernel_key_backend,
inplace=inplace_str,
view=view_str,
origin_op_name=op_info.op_yaml_item['name'],
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ OpInfoTuple AddNOp::GetOpInfo() {
std::vector<paddle::dialect::OpOutputInfo> outputs = {
OpOutputInfo("out", "paddle::dialect::DenseTensorType", false, false)};
paddle::dialect::OpRunTimeInfo run_time_info =
OpRunTimeInfo("", {""}, {""}, {""}, {""}, {}, {});
OpRunTimeInfo("", {""}, {""}, {""}, {""}, {}, {}, {});

return std::make_tuple(inputs, attributes, outputs, run_time_info, "add_n");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,20 +77,23 @@ struct OpRunTimeInfo {
std::vector<std::string> kernel_func;
std::vector<std::string> kernel_param;
std::vector<std::string> kernel_key_dtype;
std::vector<std::string> kernel_key_backend;
std::vector<std::pair<std::string, std::string>> inplace;
std::vector<std::pair<std::string, std::string>> view;
OpRunTimeInfo(const std::string& infer_meta_func,
const std::vector<std::string>& infer_meta_param,
const std::vector<std::string>& kernel_func,
const std::vector<std::string>& kernel_param,
const std::vector<std::string>& dtype,
const std::vector<std::string>& backend,
const std::vector<std::pair<std::string, std::string>>& inplace,
const std::vector<std::pair<std::string, std::string>>& view)
: infer_meta_func(infer_meta_func),
infer_meta_param(infer_meta_param),
kernel_func(kernel_func),
kernel_param(kernel_param),
kernel_key_dtype(dtype),
kernel_key_backend(backend),
inplace(inplace),
view(view) {}
};
Expand Down
209 changes: 149 additions & 60 deletions paddle/fluid/ir/transforms/pd_op_to_kernel_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,150 @@ ir::OpResult AddPlaceTransferOp(ir::OpResult in,
}
}

phi::DataType GetKernelDataTypeByYamlInfo(
const ir::Operation* op,
const std::unordered_map<ir::Value, ir::OpResult>& map_value_pair,
const dialect::OpYamlInfoParser* op_info_parser) {
auto& attr_map = op->attributes();
auto& data_type_info = op_info_parser->OpRuntimeInfo().kernel_key_dtype;
phi::DataType kernel_data_type = phi::DataType::UNDEFINED;

for (size_t i = 0; i < data_type_info.size(); ++i) {
auto slot_name = data_type_info[i];
auto& input_map = op_info_parser->InputName2Id();

auto find_it = Str2PhiDataType.find(slot_name);
if (find_it != Str2PhiDataType.end()) {
kernel_data_type = find_it->second;
} else if (input_map.count(slot_name)) {
// parse from input
int in_index = input_map.at(slot_name);
auto type = map_value_pair.at(op->operand_source(in_index)).type();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里为什么一定要用 opresult 获取 type?


if (type.isa<paddle::dialect::AllocatedDenseTensorType>()) {
kernel_data_type = TransToPhiDataType(
type.dyn_cast<paddle::dialect::AllocatedDenseTensorType>().dtype());
} else if (type.isa<ir::VectorType>()) {
auto vec_data = type.dyn_cast<ir::VectorType>().data();
if (vec_data.empty()) {
kernel_data_type = phi::DataType::UNDEFINED;
} else {
if (vec_data[0].isa<paddle::dialect::AllocatedDenseTensorType>()) {
kernel_data_type = TransToPhiDataType(
vec_data[0]
.dyn_cast<paddle::dialect::AllocatedDenseTensorType>()
.dtype());
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"Only support DenseTensorType in vector"));
}
}
} else if (type.isa<paddle::dialect::AllocatedSelectedRowsType>()) {
kernel_data_type = TransToPhiDataType(
type.dyn_cast<paddle::dialect::AllocatedSelectedRowsType>()
.dtype());
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"Only support DenseTensorType, SelectedRows, VectorType"));
}

} else {
PADDLE_ENFORCE_EQ(attr_map.count(slot_name),
true,
phi::errors::PreconditionNotMet(
"[%s] MUST in attribute map", slot_name));

auto attr_type = op_info_parser->AttrTypeName(slot_name);
PADDLE_ENFORCE_EQ(attr_type,
"paddle::dialect::DataTypeAttribute",
phi::errors::PreconditionNotMet(
"Type of [%s] should be DataType", slot_name));
kernel_data_type = attr_map.at(slot_name)
.dyn_cast<paddle::dialect::DataTypeAttribute>()
.data();
}

if (kernel_data_type != phi::DataType::UNDEFINED) {
// In yaml definition, data type have an order
// like: data_type : dtype > x
// Should break when found a defined data type
break;
}
}

return kernel_data_type;
}

phi::Backend GetKernelBackendByYamlInfo(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GetKernelBackendByYamlInfo 和 GetKernelDataTypeByYamlInfo 的内容几乎是一样的,后续可以合并成一个

const ir::Operation* op,
const std::unordered_map<ir::Value, ir::OpResult>& map_value_pair,
const dialect::OpYamlInfoParser* op_info_parser) {
auto& attr_map = op->attributes();
auto& backend_info = op_info_parser->OpRuntimeInfo().kernel_key_backend;
phi::Backend kernel_backend = phi::Backend::UNDEFINED;
for (size_t i = 0; i < backend_info.size(); ++i) {
auto slot_name = backend_info[i];
auto& input_map = op_info_parser->InputName2Id();

if (input_map.count(slot_name)) {
// parse from input
int in_index = input_map.at(slot_name);
auto type = map_value_pair.at(op->operand_source(in_index)).type();

if (type.isa<paddle::dialect::AllocatedDenseTensorType>()) {
kernel_backend = paddle::experimental::ParseBackend(
type.dyn_cast<paddle::dialect::AllocatedDenseTensorType>().place());
} else if (type.isa<ir::VectorType>()) {
auto vec_data = type.dyn_cast<ir::VectorType>().data();
if (vec_data.empty()) {
kernel_backend = phi::Backend::UNDEFINED;
} else {
if (vec_data[0].isa<paddle::dialect::AllocatedDenseTensorType>()) {
kernel_backend = paddle::experimental::ParseBackend(
vec_data[0]
.dyn_cast<paddle::dialect::AllocatedDenseTensorType>()
.place());
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"Only support DenseTensorType in vector"));
}
}
} else if (type.isa<paddle::dialect::AllocatedSelectedRowsType>()) {
kernel_backend = paddle::experimental::ParseBackend(
type.dyn_cast<paddle::dialect::AllocatedSelectedRowsType>()
.place());
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"Only support DenseTensorType, SelectedRows, VectorType"));
}

} else {
PADDLE_ENFORCE_EQ(attr_map.count(slot_name),
true,
phi::errors::PreconditionNotMet(
"[%s] MUST in attribute map", slot_name));

auto attr_type = op_info_parser->AttrTypeName(slot_name);
PADDLE_ENFORCE_EQ(attr_type,
"paddle::dialect::PlaceAttribute",
phi::errors::PreconditionNotMet(
"Type of [%s] should be DataType", slot_name));
kernel_backend = paddle::experimental::ParseBackend(
attr_map.at(slot_name)
.dyn_cast<paddle::dialect::PlaceAttribute>()
.data());
}
if (kernel_backend != phi::Backend::UNDEFINED) {
// In yaml definition, backend have an order
// like: backend : place > x
// Should break when found a defined data type
break;
}
}

return kernel_backend;
}

phi::KernelKey GetKernelKey(
ir::Operation* op,
const phi::Place& place,
Expand Down Expand Up @@ -245,66 +389,11 @@ phi::KernelKey GetKernelKey(
// only suppurt non vector input for now
int tensor_input_number = op_info_parser->InputTensorNumber();

auto attr_map = op->attributes();
auto& data_type_info = op_info_parser->OpRuntimeInfo().kernel_key_dtype;

if (!data_type_info.empty()) {
// only support single input and attribute
auto slot_name = data_type_info[0];
auto& input_map = op_info_parser->InputName2Id();

auto find_it = Str2PhiDataType.find(slot_name);
if (find_it != Str2PhiDataType.end()) {
kernel_data_type = find_it->second;
} else if (input_map.count(slot_name)) {
// parse from input
int in_index = input_map.at(slot_name);
auto type = map_value_pair.at(op->operand_source(in_index)).type();

if (type.isa<paddle::dialect::AllocatedDenseTensorType>()) {
kernel_data_type = TransToPhiDataType(
type.dyn_cast<paddle::dialect::AllocatedDenseTensorType>()
.dtype());
} else if (type.isa<ir::VectorType>()) {
auto vec_data = type.dyn_cast<ir::VectorType>().data();
if (vec_data.empty()) {
kernel_data_type = phi::DataType::UNDEFINED;
} else {
if (vec_data[0].isa<paddle::dialect::AllocatedDenseTensorType>()) {
kernel_data_type = TransToPhiDataType(
vec_data[0]
.dyn_cast<paddle::dialect::AllocatedDenseTensorType>()
.dtype());
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"Only support DenseTensorType in vector"));
}
}
} else if (type.isa<paddle::dialect::AllocatedSelectedRowsType>()) {
kernel_data_type = TransToPhiDataType(
type.dyn_cast<paddle::dialect::AllocatedSelectedRowsType>()
.dtype());
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"Only support DenseTensorType, SelectedRows, VectorType"));
}

} else {
PADDLE_ENFORCE_EQ(attr_map.count(slot_name),
true,
phi::errors::PreconditionNotMet(
"[%s] MUST in attribute map", slot_name));

auto attr_type = op_info_parser->AttrTypeName(slot_name);
PADDLE_ENFORCE_EQ(attr_type,
"paddle::dialect::DataTypeAttribute",
phi::errors::PreconditionNotMet(
"Type of [%s] should be DataType", slot_name));
kernel_data_type = attr_map.at(slot_name)
.dyn_cast<paddle::dialect::DataTypeAttribute>()
.data();
}
}
// get datatype info
kernel_data_type =
GetKernelDataTypeByYamlInfo(op, map_value_pair, op_info_parser);
kernel_backend =
GetKernelBackendByYamlInfo(op, map_value_pair, op_info_parser);

// parse all the input tensor
if (tensor_input_number == 0 || op->name() == "pd.full_") {
Expand Down
1 change: 1 addition & 0 deletions test/cpp/ir/pattern_rewrite/pattern_rewrite_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,7 @@ OpInfoTuple Conv2dFusionOpTest::GetOpInfo() {
"user_workspace_size"},
{"input"},
{},
{},
{});

return std::make_tuple(
Expand Down