Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PIR]Support optional input and output for pir api #57492

Merged
merged 11 commits into from
Sep 22, 2023
199 changes: 167 additions & 32 deletions paddle/fluid/pir/dialect/op_generator/api_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

#include <vector>

#include "paddle/utils/optional.h"
#include "paddle/pir/core/value.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/place.h"
Expand All @@ -47,6 +48,7 @@
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
#include "paddle/pir/core/builder.h"
#include "paddle/pir/core/builtin_op.h"
#include "paddle/fluid/pir/dialect/operator/utils/utils.h"

{body}

Expand All @@ -67,14 +69,46 @@

API_IMPL_TEMPLATE = """
{ret_type} {api_name}({args}){{
{handle_optional_inputs}
{in_combine}
{compute_op}
{handle_optional_outputs}
{out_split}
{return_result}
}}

"""

OPTIONAL_VECTOR_VALUE_INPUT_TEMPLATE = """
paddle::optional<pir::Value> optional_{name};
if (!{name}) {{
optional_{name} = paddle::make_optional<pir::Value>(pir::Value());
}} else {{
auto optional_{name}_combine_op = APIBuilder::Instance().GetBuilder()->Build<pir::CombineOp>({name}.get());
optional_{name} = paddle::make_optional<pir::Value>(optional_{name}_combine_op.out());
}}"""

OPTIONAL_VALUE_INPUT_TEMPLATE = """
paddle::optional<pir::Value> optional_{name};
if (!{name}) {{
optional_{name} = paddle::make_optional<pir::Value>(pir::Value());
}} else {{
optional_{name} = {name};
}}"""

OPTIONAL_OPRESULT_OUTPUT_TEMPLATE = """
paddle::optional<pir::OpResult> optional_{name};
if (!IsEmptyOpResult({op_name}_op.result({index}))) {{
optional_{name} = paddle::make_optional<pir::OpResult>({op_name}_op.result({index}));
}}"""

OPTIONAL_VECTOR_OPRESULT_OUTPUT_TEMPLATE = """
paddle::optional<std::vector<pir::OpResult>> optional_{name};
if (!IsEmptyOpResult({op_name}_op.result({index}))) {{
auto optional_{name}_slice_op = APIBuilder::Instance().GetBuilder()->Build<pir::SplitOp>({op_name}_op.result({index}));
optional_{name} = paddle::make_optional<std::vector<pir::OpResult>>(optional_{name}_slice_op.outputs());
}}"""

COMBINE_OP_TEMPLATE = """
auto {op_name} = APIBuilder::Instance().GetBuilder()->Build<pir::CombineOp>({in_name});"""

Expand All @@ -88,23 +122,35 @@
VECTOR_TYPE = 'pir::VectorType'
INTARRAY_ATTRIBUTE = "paddle::dialect::IntArrayAttribute"

INPUT_TYPE_MAP = {
'paddle::dialect::DenseTensorType': 'pir::Value',
'paddle::dialect::SelectedRowsType': 'pir::Value',
'pir::VectorType<paddle::dialect::DenseTensorType>': 'std::vector<pir::Value>',
}
OPTIONAL_INPUT_TYPE_MAP = {
'paddle::dialect::DenseTensorType': 'paddle::optional<pir::Value>',
'paddle::dialect::SelectedRowsType': 'paddle::optional<pir::Value>',
'pir::VectorType<paddle::dialect::DenseTensorType>': 'paddle::optional<std::vector<pir::Value>>',
}
OUTPUT_TYPE_MAP = {
'paddle::dialect::DenseTensorType': 'pir::OpResult',
'paddle::dialect::SelectedRowsType': 'pir::OpResult',
'pir::VectorType<paddle::dialect::DenseTensorType>': 'std::vector<pir::OpResult>',
}
OPTIONAL_OUTPUT_TYPE_MAP = {
'paddle::dialect::DenseTensorType': 'paddle::optional<pir::OpResult>',
'paddle::dialect::SelectedRowsType': 'paddle::optional<pir::OpResult>',
'pir::VectorType<paddle::dialect::DenseTensorType>': 'paddle::optional<std::vector<pir::OpResult>>',
}


def get_op_class_name(op_name):
return to_pascal_case(op_name) + 'Op'


class CodeGen:
def __init__(self) -> None:
self._type_map = {
'paddle::dialect::DenseTensorType': 'pir::Value',
'paddle::dialect::SelectedRowsType': 'pir::Value',
'pir::VectorType<paddle::dialect::DenseTensorType>': 'std::vector<pir::Value>',
}
self._ret_type_map = {
'paddle::dialect::DenseTensorType': 'pir::OpResult',
'paddle::dialect::SelectedRowsType': 'pir::OpResult',
'pir::VectorType<paddle::dialect::DenseTensorType>': 'std::vector<pir::OpResult>',
}
pass

def _parse_yaml(self, op_yaml_files, op_compat_yaml_file):
op_compat_parser = OpCompatParser(op_compat_yaml_file)
Expand Down Expand Up @@ -133,16 +179,43 @@ def _need_skip(self, op_info, op_name):
op_info.infer_meta_func is None and op_name not in PD_MANUAL_OP_LIST
)

def _is_optional_input(self, op_info, input_name):
name_list = op_info.input_name_list
optional_list = op_info.input_optional_list
if (
input_name in name_list
and optional_list[name_list.index(input_name)] == 'true'
):
return True
return False

def _is_optinonal_output(self, op_info, output_name):
inplace_map = op_info.inplace_map
input_optional_list = op_info.input_optional_list
input_name_list = op_info.input_name_list
if inplace_map is None:
return False

if output_name in inplace_map.keys():
input_index = input_name_list.index(inplace_map[output_name])
if input_optional_list[input_index] == 'true':
return True
return False

# =====================================
# Gen declare functions
# =====================================
def _gen_api_inputs(self, op_info):
name_list = op_info.input_name_list
type_list = op_info.input_type_list
assert len(name_list) == len(type_list)
optional_list = op_info.input_optional_list
assert len(name_list) == len(type_list) == len(optional_list)
ret = []
for name, type in zip(name_list, type_list):
ret.append(f'{self._type_map[type]} {name}')
for name, type, optional in zip(name_list, type_list, optional_list):
if optional == 'true':
ret.append(f'{OPTIONAL_INPUT_TYPE_MAP[type]} {name}')
else:
ret.append(f'{INPUT_TYPE_MAP[type]} {name}')
return ', '.join(ret)

def _gen_api_attrs(
Expand Down Expand Up @@ -191,26 +264,31 @@ def _gen_api_args(
return (inputs + ', ' + attrs).strip(', ')

def _gen_ret_type(self, op_info):
name_list = op_info.output_name_list
type_list = op_info.output_type_list
intermediate_list = op_info.output_intermediate_list
assert len(type_list) == len(intermediate_list)
assert len(name_list) == len(type_list) == len(intermediate_list)

output_num = len(type_list) - intermediate_list.count('true')
if output_num > 1:
return 'std::tuple<{}>'.format(
', '.join(
[
self._ret_type_map[type]
for type, intermediate in zip(
type_list, intermediate_list
)
if intermediate == 'false'
]
)
)
ret = []
for name, type, intermediate in zip(
name_list, type_list, intermediate_list
):
if intermediate == 'true':
continue
if self._is_optinonal_output(op_info, name):
ret.append(OPTIONAL_OUTPUT_TYPE_MAP[type])
else:
ret.append(OUTPUT_TYPE_MAP[type])
return 'std::tuple<{}>'.format(', '.join(ret))
elif output_num == 1:
index = intermediate_list.index('false')
return self._ret_type_map[type_list[index]]
name = name_list[index]
if self._is_optinonal_output(op_info, name):
return OPTIONAL_OUTPUT_TYPE_MAP[type_list[index]]
else:
return OUTPUT_TYPE_MAP[type_list[index]]
elif output_num == 0:
return 'void'

Expand Down Expand Up @@ -255,14 +333,56 @@ def _gen_h_file(self, op_info_items, namespaces, h_file_path):
# =====================================
# Gen impl functions
# =====================================
def _gen_handle_optional_inputs(self, op_info):
name_list = op_info.input_name_list
optional_list = op_info.input_optional_list
type_list = op_info.input_type_list
assert len(name_list) == len(optional_list) == len(type_list)
ret = ''
for name, optional, type in zip(name_list, optional_list, type_list):
if optional == 'true':
if VECTOR_TYPE in type:
ret += OPTIONAL_VECTOR_VALUE_INPUT_TEMPLATE.format(
name=name
)
else:
ret += OPTIONAL_VALUE_INPUT_TEMPLATE.format(name=name)
return ret

def _gen_handle_optional_outputs(self, op_info, op_name):
name_list = op_info.output_name_list
type_list = op_info.output_type_list
intermediate_list = op_info.output_intermediate_list
ret = ''
for i, (name, type, intermediate) in enumerate(
zip(name_list, type_list, intermediate_list)
):
if intermediate == 'true':
continue
if self._is_optinonal_output(op_info, name):
if VECTOR_TYPE in type:
ret += OPTIONAL_VECTOR_OPRESULT_OUTPUT_TEMPLATE.format(
name=name,
op_name=op_name,
index=i,
)
else:
ret += OPTIONAL_OPRESULT_OUTPUT_TEMPLATE.format(
name=name,
op_name=op_name,
index=i,
)
return ret

def _gen_in_combine(self, op_info, is_mutable_attr, is_vector_mutable_attr):
name_list = op_info.input_name_list
type_list = op_info.input_type_list
assert len(name_list) == len(type_list)
optional_list = op_info.input_optional_list
assert len(name_list) == len(type_list) == len(optional_list)
combine_op = ''
combine_op_list = []
for name, type in zip(name_list, type_list):
if VECTOR_TYPE in type:
for name, type, optional in zip(name_list, type_list, optional_list):
if optional == 'false' and VECTOR_TYPE in type:
op_name = f'{name}_combine_op'
combine_op += COMBINE_OP_TEMPLATE.format(
op_name=op_name, in_name=name
Expand Down Expand Up @@ -305,7 +425,10 @@ def _gen_compute_op_args(

for input_name, combine_op in zip(name_list, in_combine_op_list):
if combine_op is None:
ret.append(input_name)
if self._is_optional_input(op_info, input_name):
ret.append(f'optional_{input_name}.get()')
else:
ret.append(input_name)
else:
ret.append(f'{combine_op}.out()')
if is_mutable_attr:
Expand Down Expand Up @@ -334,7 +457,13 @@ def _gen_out_split_and_ret_list(self, op_info, op_inst_name):
name_list = op_info.output_name_list
type_list = op_info.output_type_list
intermediate_list = op_info.output_intermediate_list
assert len(name_list) == len(type_list) == len(intermediate_list)
optional_list = op_info.output_optional_list
assert (
len(name_list)
== len(type_list)
== len(intermediate_list)
== len(optional_list)
)

split_op_str = ''
ret_list = []
Expand All @@ -343,7 +472,9 @@ def _gen_out_split_and_ret_list(self, op_info, op_inst_name):
):
if intermediate == 'true':
continue
if VECTOR_TYPE in type:
if self._is_optinonal_output(op_info, name):
ret_list.append(f'optional_{name}')
elif VECTOR_TYPE in type:
split_op_name = f'{name}_split_op'
split_op_str += SPLIT_OP_TEMPLATE.format(
op_name=split_op_name, in_name=f'{op_inst_name}.result({i})'
Expand Down Expand Up @@ -384,8 +515,12 @@ def _gen_one_impl(
args=self._gen_api_args(
op_info, False, is_mutable_attr, is_vector_mutable_attr
),
handle_optional_inputs=self._gen_handle_optional_inputs(op_info),
in_combine=in_combine,
compute_op=compute_op,
handle_optional_outputs=self._gen_handle_optional_outputs(
op_info, op_name
),
out_split=out_split,
return_result=self._gen_return_result(ret_list),
)
Expand Down
26 changes: 18 additions & 8 deletions paddle/fluid/pir/dialect/op_generator/python_c_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,14 +241,24 @@ def _gen_h_file(self, op_info_items, namespaces, h_file_path):
def _gen_inputs(self, op_info, op_name):
name_list = op_info.input_name_list
type_list = op_info.input_type_list
assert len(name_list) == len(type_list)
optional_list = op_info.input_optional_list
assert len(name_list) == len(type_list) == len(optional_list)
ret = ''
for i, (name, type) in enumerate(zip(name_list, type_list)):
cast_func = (
'CastPyArg2VectorOfValue'
if VECTOR_TYPE in type
else 'CastPyArg2OpResult'
)
for i, (name, type, optional) in enumerate(
zip(name_list, type_list, optional_list)
):
if optional == 'true':
cast_func = (
'CastPyArg2OptionalVectorOfValue'
if VECTOR_TYPE in type
else 'CastPyArg2OptionalValue'
)
else:
cast_func = (
'CastPyArg2VectorOfValue'
if VECTOR_TYPE in type
else 'CastPyArg2Value'
)
ret += INPUT_TEMPLATE.format(
name=name, index=i, cast_func=cast_func, api_name=op_name
)
Expand Down Expand Up @@ -327,7 +337,7 @@ def _gen_cast_attrs(self, op_info, op_name):
type='',
name_=name,
name=name,
cast_func='CastPyArg2OpResult',
cast_func='CastPyArg2Value',
api_name=op_name,
index=input_size + i,
)
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/pir/dialect/operator/ir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ cc_library(
cc_library(
pd_op_dialect_api
SRCS ${api_source_file} manual_api.cc
DEPS api_builder pd_op_dialect_op)
DEPS api_builder pd_op_dialect_op pd_op_dialect_utils)

target_include_directories(pd_op_dialect_api INTERFACE ${PD_DIALECT_BINARY_DIR})

Expand Down
4 changes: 4 additions & 0 deletions paddle/fluid/pir/dialect/operator/utils/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -197,5 +197,9 @@ VariantType GetAttributeData(const pir::Attribute& attr) {

bool IsLegacyOp(const std::string& name) { return LegacyOpList.count(name); }

bool IsEmptyOpResult(const pir::OpResult& op_result) {
return !op_result.impl() || op_result.type().isa<pir::Type>();
}

} // namespace dialect
} // namespace paddle
2 changes: 2 additions & 0 deletions paddle/fluid/pir/dialect/operator/utils/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -170,5 +170,7 @@ VariantType GetAttributeData(const pir::Attribute& attr);

bool IsLegacyOp(const std::string& name);

bool IsEmptyOpResult(const pir::OpResult& op_result);

} // namespace dialect
} // namespace paddle
Loading