Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[IR] IR attribute printer and support mutable attribute #54369

Merged
merged 60 commits into from
Jun 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
60 commits
Select commit Hold shift + click to select a range
2b5cc89
add vector type support for program translator
kangguangli May 23, 2023
489b2cc
polish
kangguangli May 23, 2023
324d897
support basic attribute type
kangguangli May 23, 2023
f388e79
Merge branch 'develop' of github.com:PaddlePaddle/Paddle into add_vec…
kangguangli May 23, 2023
f9e17d6
resolve conflicts
kangguangli May 23, 2023
9cd463f
add verify for combine/slice and unittests
kangguangli May 24, 2023
6d17079
polish
kangguangli May 24, 2023
b654cc2
Merge branch 'add_vector_type_support_for_program_translator' into su…
kangguangli May 24, 2023
b45df9c
support more type in attribute translator
kangguangli May 24, 2023
5940ac4
modify by reviews
kangguangli May 24, 2023
b3b1856
Merge branch 'develop' of github.com:PaddlePaddle/Paddle into support…
kangguangli May 24, 2023
b328cbb
fix merge mistakes
kangguangli May 24, 2023
640c27f
Merge branch 'develop' of github.com:PaddlePaddle/Paddle into support…
kangguangli May 25, 2023
e35a288
refine code
zhangbo9674 May 26, 2023
1ba19dc
refine code
zhangbo9674 May 26, 2023
6c2ac5d
add interface
zhangbo9674 May 26, 2023
8473bce
Merge branch 'develop' of github.com:PaddlePaddle/Paddle into support…
kangguangli May 29, 2023
dd2060b
Merge pull/54130/head
kangguangli May 29, 2023
6967eba
fix: op name normalization
kangguangli May 29, 2023
df0b5cd
fix typo
kangguangli May 29, 2023
f1c02dd
refactor input translator
kangguangli May 30, 2023
a053a11
Merge branch 'develop' of github.com:PaddlePaddle/Paddle into support…
kangguangli May 30, 2023
0854418
fix merge conflicts
kangguangli May 30, 2023
38f7453
fix op normalizer bug
kangguangli May 30, 2023
e8c6234
refactor attribute translator
kangguangli May 30, 2023
261fe97
fix bug
kangguangli May 30, 2023
6d57a5f
refactor output translator
kangguangli May 31, 2023
0717f11
fix typo
kangguangli May 31, 2023
c9d7c63
fix
kangguangli May 31, 2023
edfad35
Merge branch 'develop' of github.com:PaddlePaddle/Paddle into support…
kangguangli Jun 1, 2023
2fb54d8
fix approval error
kangguangli Jun 1, 2023
e945a41
fix coverage
kangguangli Jun 1, 2023
c9358f4
fix op_compat parser
kangguangli Jun 1, 2023
30540e7
Merge branch 'develop' of github.com:PaddlePaddle/Paddle into support…
kangguangli Jun 2, 2023
b785d35
fix merge conflicts
kangguangli Jun 2, 2023
a0f21e6
fix merge conflicts
kangguangli Jun 2, 2023
c2ede6f
fix merge conflicts
kangguangli Jun 2, 2023
2605b6e
Merge branch 'develop' of github.com:PaddlePaddle/Paddle into support…
kangguangli Jun 2, 2023
7f97345
fix merge conflicts
kangguangli Jun 2, 2023
5f81713
fix merge conflicts
kangguangli Jun 2, 2023
f7bd3ab
refactor scalar attribute
kangguangli Jun 5, 2023
cf97d27
Merge branch 'develop' of github.com:PaddlePaddle/Paddle into ir_attr…
kangguangli Jun 5, 2023
a937468
Merge branch 'refactor_scalar_attribute' into ir_attribute_printer
kangguangli Jun 5, 2023
4a54f06
draft
kangguangli Jun 5, 2023
17edac4
fix
kangguangli Jun 5, 2023
dc150e9
fix op build
kangguangli Jun 5, 2023
e885d45
fix op build
kangguangli Jun 5, 2023
ee632fb
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
kangguangli Jun 5, 2023
d85b306
Merge commit 'refs/pull/54340/head' of github.com:PaddlePaddle/Paddle…
kangguangli Jun 5, 2023
34faeb9
temporarily save
kangguangli Jun 5, 2023
63c7dd7
adpat mutable attribute
kangguangli Jun 6, 2023
85167e2
refine op_comat_gen process
kangguangli Jun 6, 2023
cf8097f
Merge branch 'develop' of github.com:PaddlePaddle/Paddle into ir_attr…
kangguangli Jun 7, 2023
fcccde6
Merge branch 'develop' of github.com:PaddlePaddle/Paddle into ir_attr…
kangguangli Jun 7, 2023
623c2a3
fix merge conflicts
kangguangli Jun 7, 2023
c733235
fix merge conflicts
kangguangli Jun 7, 2023
e1c16e7
fix merge conflicts
kangguangli Jun 7, 2023
ed21d27
Merge branch 'develop' of github.com:PaddlePaddle/Paddle into ir_attr…
kangguangli Jun 7, 2023
84636fe
complete dialect attribute printer and refine ir_throw
kangguangli Jun 7, 2023
d731403
polish code
kangguangli Jun 7, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 19 additions & 3 deletions paddle/fluid/ir/dialect/op_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -1193,8 +1193,8 @@ def OpGenerator(

# generate get op info funciton: inputs
inputs_info_str = ""
input_info_list = []
if len(op_input_name_list) > 0:
input_info_list = []
for idx in range(len(op_input_name_list)):
input_info_list.append(
CONSTRUCT_INPUT_INFO_TEMPLATE.format(
Expand All @@ -1204,7 +1204,19 @@ def OpGenerator(
no_need_buffer=op_input_no_need_buffer_list[idx],
)
)
inputs_info_str = ", ".join(input_info_list)

# add mutable attribute as input
if len(op_mutable_attribute_name_list) > 0:
for idx in range(len(op_mutable_attribute_name_list)):
input_info_list.append(
CONSTRUCT_INPUT_INFO_TEMPLATE.format(
name=op_mutable_attribute_name_list[idx],
typename=op_mutable_attribute_type_list[idx],
optional='false',
no_need_buffer='false',
)
)
inputs_info_str = ", ".join(input_info_list)

# generate get op info funciton: outputs
outputs_info_str = ""
Expand All @@ -1223,12 +1235,16 @@ def OpGenerator(

# generate get op info funciton: attributes
attribute_info_str = ""
op_mutable_attribute_name_set = set(op_mutable_attribute_name_list)
if len(op_attribute_name_list) > 0:
attribute_info_list = []
for idx in range(len(op_attribute_name_list)):
attribute_name = op_attribute_name_list[idx]
if attribute_name in op_mutable_attribute_name_set:
continue
attribute_info_list.append(
CONSTRUCT_ATTRIBUTE_INFO_TEMPLATE.format(
name=op_attribute_name_list[idx],
name=attribute_name,
typename=op_attribute_type_list[idx],
data_type=op_attribute_data_type_list[idx],
)
Expand Down
26 changes: 25 additions & 1 deletion paddle/fluid/ir/dialect/pd_dialect.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
#include "paddle/fluid/ir/dialect/pd_type_storage.h"
#include "paddle/fluid/ir/dialect/utils.h"
#include "paddle/ir/core/dialect_interface.h"
#include "paddle/ir/core/utils.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/dense_tensor.h"

namespace paddle {
Expand Down Expand Up @@ -107,7 +109,7 @@ void PaddleDialect::initialize() {
RegisterInterfaces<ParameterConvertInterface>();
}

void PaddleDialect::PrintType(ir::Type type, std::ostream &os) {
void PaddleDialect::PrintType(ir::Type type, std::ostream &os) const {
DenseTensorType tensor_type = type.dyn_cast<DenseTensorType>();

os << "tensor<";
Expand All @@ -119,5 +121,27 @@ void PaddleDialect::PrintType(ir::Type type, std::ostream &os) {
os << ">";
}

void PaddleDialect::PrintAttribute(ir::Attribute attr, std::ostream &os) const {
if (auto int_array_attr = attr.dyn_cast<IntArrayAttribute>()) {
phi::IntArray data = int_array_attr.data();
os << "IntArray[";
const auto &inner_data = data.GetData();
ir::PrintInterleave(
inner_data.begin(),
inner_data.end(),
[&os](int64_t i) { os << i; },
[&os]() { os << ","; });
os << "]";
} else if (auto data_type_attr = attr.dyn_cast<DataTypeAttribute>()) {
os << data_type_attr.data();
} else if (auto place_type_attr = attr.dyn_cast<PlaceAttribute>()) {
os << place_type_attr.data();
} else if (auto data_layout_attr = attr.dyn_cast<DataLayoutAttribute>()) {
os << data_layout_attr.data();
} else {
os << "<#AttrNotImplemented>";
}
}

} // namespace dialect
} // namespace paddle
3 changes: 2 additions & 1 deletion paddle/fluid/ir/dialect/pd_dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ class PaddleDialect : public ir::Dialect {

static const char* name() { return "pd"; }

void PrintType(ir::Type type, std::ostream& os);
void PrintType(ir::Type type, std::ostream& os) const;
void PrintAttribute(ir::Attribute type, std::ostream& os) const;

private:
void initialize();
Expand Down
4 changes: 3 additions & 1 deletion paddle/fluid/ir_adaptor/translator/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@ set(PD_PROGRAM_TRANSLATOR_BINARY_DIR
set(op_gen_file ${PD_PROGRAM_TRANSLATOR_SOURCE_DIR}/op_compat_gen.py)
set(op_compat_yaml_file ${PADDLE_SOURCE_DIR}/paddle/phi/api/yaml/op_compat.yaml)
set(op_compat_source_file ${PD_PROGRAM_TRANSLATOR_SOURCE_DIR}/op_compat_info.cc)
set(op_compat_templat_file
${PD_PROGRAM_TRANSLATOR_SOURCE_DIR}/op_compat_info.cc.j2)

add_custom_command(
OUTPUT ${op_compat_source_file}
COMMAND ${PYTHON_EXECUTABLE} ${op_gen_file} --op_compat_yaml_file
${op_compat_yaml_file} --output_source_file ${op_compat_source_file}
DEPENDS ${op_gen_file} ${op_compat_yaml_file}
DEPENDS ${op_gen_file} ${op_compat_yaml_file} ${op_compat_templat_file}
VERBATIM)

file(GLOB PD_PROGRAM_TRANSLATOR_SRCS "*.cc")
Expand Down
36 changes: 33 additions & 3 deletions paddle/fluid/ir_adaptor/translator/op_compat_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import argparse
from pathlib import Path
from typing import Dict
from typing import Dict, List, Set

import yaml
from jinja2 import Environment, FileSystemLoader, StrictUndefined
Expand Down Expand Up @@ -46,8 +46,11 @@ def to_phi_and_fluid_op_name(op_item):

with open(op_compat_yaml_file, "r") as f:
op_compat_infos = yaml.safe_load(f)
op_name_mappings = {}
op_arg_name_mappings = {}
op_name_mappings: Dict[str, str] = {}
op_arg_name_mappings: Dict[str, Dict[str, str]] = {}
op_mutable_attribues: Dict[str, Set[str]] = {}
op_mutable_attribute_infos: Dict[str, Dict[str, List[str]]] = {}

for op_compat_item in op_compat_infos:

def insert_new_mappings(op_name_str: str) -> str:
Expand All @@ -64,6 +67,23 @@ def insert_new_arg_mappings(op_name: str, arg_mapping: Dict[str, str]):
op_arg_name_mappings[op_name] = {}
op_arg_name_mappings[op_name].update(arg_mapping)

def insert_new_mutable_attributes(
op_name: str, mutable_attribute_infos: Dict[str, Dict[str, str]]
):
op_mutable_attribues[op_name] = set()
op_mutable_attribute_infos[op_name] = {}
for (
attribute_name,
mutable_attribute_info,
) in mutable_attribute_infos.items():
op_mutable_attribues[op_name].add(attribute_name)
op_mutable_attribute_infos[op_name][attribute_name] = []
for k, v in mutable_attribute_info.items():
if k == 'tensor_name' or k == 'tensors_name':
op_mutable_attribute_infos[op_name][
attribute_name
].append(v)

_, legacy_name = insert_new_mappings(op_compat_item["op"])
legacy_backward_op_names = []
if "backward" in op_compat_item:
Expand All @@ -88,6 +108,14 @@ def insert_new_arg_mappings(op_name: str, arg_mapping: Dict[str, str]):
for backward_op in legacy_backward_op_names:
insert_new_arg_mappings(backward_op, op_compat_item["outputs"])

if "int_array" in op_compat_item:
insert_new_mutable_attributes(
legacy_name, op_compat_item["int_array"]
)

if "scalar" in op_compat_item:
insert_new_mutable_attributes(legacy_name, op_compat_item["scalar"])

# special op mappings
op_name_mappings["fetch_v2"] = "fetch"

Expand All @@ -96,6 +124,8 @@ def insert_new_arg_mappings(op_name: str, arg_mapping: Dict[str, str]):
op_compat_definition = op_name_normailzer_template.render(
op_name_pairs=op_name_mappings,
op_arg_name_pairs=op_arg_name_mappings,
op_mutable_attributes=op_mutable_attribues,
op_mutable_attribute_infos=op_mutable_attribute_infos,
)
f.write(op_compat_definition)

Expand Down
31 changes: 31 additions & 0 deletions paddle/fluid/ir_adaptor/translator/op_compat_info.cc.j2
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,37 @@ OpNameNormalizer::OpNameNormalizer() {
},
{% endfor %}
};
op_mutable_attributes = {
{% for op_name, mutable_attributes in op_mutable_attributes.items() %}
{
"{{op_name}}",
{
{% for attribute_name in mutable_attributes %}
"{{attribute_name}}",
{% endfor %}
},
},
{% endfor %}
};
op_mutable_attribute_infos = {
{% for op_name, mutable_attribute_infos in op_mutable_attribute_infos.items() %}
{
"{{op_name}}",
{
{% for attribute_name, attribute_info in mutable_attribute_infos.items() %}
{
"{{attribute_name}}",
{
{% for candidate_var_name in attribute_info %}
"{{candidate_var_name}}",
{% endfor %}
},
},
{% endfor %}
},
},
{% endfor %}
};
}

} // namespace translator
Expand Down
24 changes: 24 additions & 0 deletions paddle/fluid/ir_adaptor/translator/op_compat_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include <functional>
#include <string>
#include <unordered_map>
#include <unordered_set>

#include "glog/logging.h"

Expand All @@ -25,13 +26,21 @@
namespace paddle {
namespace translator {

using MutableAttributeInfo = std::vector<std::string>;

class OpNameNormalizer {
private:
OpNameNormalizer(); // Disallow instantiation outside of the class.
std::unordered_map<std::string, std::string> op_name_mappings;
std::unordered_map<std::string, std::unordered_map<std::string, std::string>>
op_arg_name_mappings;

std::unordered_map<std::string,
std::unordered_map<std::string, MutableAttributeInfo>>
op_mutable_attribute_infos;
std::unordered_map<std::string, std::unordered_set<std::string>>
op_mutable_attributes;

public:
OpNameNormalizer(const OpNameNormalizer&) = delete;
OpNameNormalizer& operator=(const OpNameNormalizer&) = delete;
Expand All @@ -50,6 +59,21 @@ class OpNameNormalizer {
return op_name_mappings.at(op_type);
}

bool HasMutableAttribute(const std::string& op_type) {
return (op_mutable_attributes.find(op_type) != op_mutable_attributes.end());
}

const std::unordered_set<std::string>* GetMutableAttributes(
const std::string& op_type) {
if (!HasMutableAttribute(op_type)) return nullptr;
return &op_mutable_attributes.at(op_type);
}

const MutableAttributeInfo& GetMutableAttributeInfos(
const std::string& op_type, const std::string& arg_name) {
return op_mutable_attribute_infos.at(op_type).at(arg_name);
}

std::string GetLegacyArgName(const std::string& op_type,
const std::string& arg_name) {
bool is_grad_op = (op_type.find("grad") != std::string::npos);
Expand Down
Loading