Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Prim][PIR] PIR Prim support intarray, scalar, combineop #58581

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@


# come into effect in generated file pd_op.h
decomp_interface_declare_gen_op_list = ['mean']
# manual decomp interface declare are located in manual_op.h
decomp_interface_declare_gen_op_list = ["mean", "squeeze", "add_n"]

# come into effect in generated file op_decomp.cc
# manual decomp interface implementation are located in manual_op_decomp.cc
decomp_interface_implementation_gen_op_list = ["mean"]
decomp_interface_implementation_gen_op_list = ["mean", "squeeze", "add_n"]
5 changes: 4 additions & 1 deletion paddle/fluid/pir/dialect/operator/ir/manual_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include <vector>

#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/pir/dialect/operator/interface/decomp.h"
#include "paddle/fluid/pir/dialect/operator/interface/infermeta.h"
#include "paddle/fluid/pir/dialect/operator/interface/op_yaml_info.h"
#include "paddle/fluid/pir/dialect/operator/interface/vjp.h"
Expand All @@ -34,7 +35,8 @@ namespace dialect {
class AddNOp : public pir::Op<AddNOp,
paddle::dialect::OpYamlInfoInterface,
paddle::dialect::InferMetaInterface,
paddle::dialect::VjpInterface> {
paddle::dialect::VjpInterface,
paddle::dialect::DecompInterface> {
public:
using Op::Op;
static const char *name() { return "pd_op.add_n"; }
Expand All @@ -55,6 +57,7 @@ class AddNOp : public pir::Op<AddNOp,
const std::vector<std::vector<pir::OpResult>> &outputs,
const std::vector<std::vector<pir::Value>> &out_grads,
const std::vector<std::vector<bool>> &stop_gradients);
static std::vector<std::vector<pir::OpResult>> Decomp(pir::Operation *op);
};

class AddN_Op : public pir::Op<AddN_Op,
Expand Down
75 changes: 1 addition & 74 deletions paddle/fluid/primitive/codegen/decomp_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
)
import filters as op_gen_filters
import tests_utils as op_gen_tests
from gen import extend_compat_info, filter_compat_info
from parse_utils import to_named_dict
from type_mapping import output_type_map

Expand Down Expand Up @@ -128,80 +129,6 @@ def save(content: str, path: pathlib.Path):
f.write(content)


def filter_compat_info(items):
for item in items:
item['op'] = item['op'].split('(')[0].strip()
if 'backward' in item:
item_backwards = item['backward'].split(',')
for idx, item_backward in enumerate(item_backwards):
item_backward = item_backward.split('(')[0].strip()
item_backwards[idx] = item_backward
item['backward'] = (
','.join(item_backwards)
if len(item_backwards) > 0
else item_backwards[0]
)


def extend_compat_info(apis, compats):
for api in apis:
attrs = api["attrs"]
for attr in attrs:
if op_gen_tests.is_scalar(
attr['typename']
) or op_gen_tests.is_intarray(attr['typename']):
attr["support_tensor"] = False
apis_dict = to_named_dict(apis)
for compat_item in compats:
fwd_op_name = compat_item["op"]
if fwd_op_name not in apis_dict:
continue
fwd_api = apis_dict[fwd_op_name]
backward_op_names = []
while fwd_op_name is not None and fwd_op_name in apis_dict:
backward_op_names.append(apis_dict[fwd_op_name]['backward'])
fwd_op_name = apis_dict[fwd_op_name]['backward']
backward_apis = []
for backward_op_name in backward_op_names:
if backward_op_name in apis_dict:
backward_apis.append(apis_dict[backward_op_name])
support_tensor_attrs_names = []
compat_attrs_data_type = {}
if 'scalar' in compat_item and compat_item['op'] != "pow":
for attr_name, attr_info in compat_item['scalar'].items():
if (
'support_tensor' in attr_info
and attr_info['support_tensor'] is True
or 'tensor_name' in attr_info
):
support_tensor_attrs_names.append(attr_name)
if 'data_type' in attr_info:
compat_attrs_data_type.update(
{attr_name: attr_info['data_type']}
)
if 'int_array' in compat_item:
for attr_name, attr_info in compat_item['int_array'].items():
if (
'support_tensor' in attr_info
and attr_info['support_tensor'] is True
or 'tensor_name' in attr_info
or 'tensors_name' in attr_info
):
support_tensor_attrs_names.append(attr_name)
if len(support_tensor_attrs_names) > 0:
for api in [fwd_api] + backward_apis:
attrs = api["attrs"]
for attr in attrs:
if attr['name'] in support_tensor_attrs_names:
attr['support_tensor'] = True
for api in [fwd_api] + backward_apis:
attrs = api["attrs"]
for attr in attrs:
if attr['name'] in compat_attrs_data_type:
attr['data_type'] = compat_attrs_data_type[attr['name']]
return apis


def process_optional_output_info(apis):
for api in apis:
inputs_dict = to_named_dict(api['inputs'])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ std::vector<std::vector<pir::OpResult>> {{class_name}}::Decomp(pir::Operation* o
for (size_t idx = 0; idx < combine_op_obj_{{item.name}}.inputs().size(); idx++) {
{{item.name}}.emplace_back(
std::make_shared<primitive::LazyTensor>(combine_op_obj_{{item.name}}.inputs()[idx]));
}
{% endif %}
{% endif %}
{% endfor %}
Expand All @@ -69,7 +70,47 @@ std::vector<std::vector<pir::OpResult>> {{class_name}}::Decomp(pir::Operation* o
{% if attrs %}
{% for item in attrs %}
{% do attr_names.append(item.name) %}

{% if item.typename == "Scalar" and item.support_tensor %}

Tensor {{item.name}}_(std::make_shared<primitive::LazyTensor>(op_obj.{{item.name}}()));

auto* {{item.name}}_define_op =
std::static_pointer_cast<primitive::LazyTensor>({{item.name}}_.impl())
->value()
.dyn_cast<pir::OpResult>()
.owner();
if ({{item.name}}_define_op->name() != "pd_op.full") {
PADDLE_THROW(
platform::errors::Unimplemented("We don't support dynamic tensors "
"attribute {{item.name}} for {{fwd_name}} decomposition "
"for now. "));
}
Scalar {{item.name}} = {{item.name}}_define_op->attribute("value").dyn_cast<paddle::dialect::ScalarAttribute>().data();


{% elif item.typename == "IntArray" and item.support_tensor %}

Tensor {{item.name}}_(std::make_shared<primitive::LazyTensor>(op_obj.{{item.name}}()));

auto* {{item.name}}_define_op =
std::static_pointer_cast<primitive::LazyTensor>({{item.name}}_.impl())
->value()
.dyn_cast<pir::OpResult>()
.owner();
if ({{item.name}}_define_op->name() != "pd_op.full_int_array") {
PADDLE_THROW(
platform::errors::Unimplemented("We don't support dynamic tensors "
"attribute {{item.name}} for {{fwd_name}} decomposition "
"for now. "));
}
IntArray {{item.name}} = phi::IntArray(
paddle::dialect::GetInt64Vector({{item.name}}_define_op->attribute("value")));

{% else %}

{{item.typename}} {{item.name}} = op->attribute("{{item.name}}").dyn_cast<{{item.mapped_type}}>().data();
{% endif %}
{% endfor %}
{% endif %}

Expand All @@ -92,9 +133,14 @@ std::vector<std::vector<pir::OpResult>> {{class_name}}::Decomp(pir::Operation* o
{% endfor %}
std::tuple<{{common.sequence('', '', ', ', output_types)}}> op_res = paddle::primitive::details::{{fwd_name}}_decomp<primitive::LazyTensor>(
{{common.args(input_names, attr_names)}});
for (size_t i = 0; i < org_res.size(); ++i) {
res[i].push_back(std::static_pointer_cast<primitive::LazyTensor>(std::get<i>(op_res).impl())->value().dyn_cast<pir::OpResult>());
}
{% for k in range(outputs|length) %}
{% if outputs[k].intermediate %}
pir::OpResult {{outputs[k].name}};
res[{{k}}].push_back({{outputs[k].name}});
{% else %}
res[{{k}}].push_back(std::static_pointer_cast<primitive::LazyTensor>(std::get<{{k}}>(op_res).impl())->value().dyn_cast<pir::OpResult>());
{% endif %}
{% endfor %}
{% endif %}

return res;
Expand Down
19 changes: 19 additions & 0 deletions paddle/fluid/primitive/composite/composite.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,25 @@ Tensor mean_decomp(const Tensor& x, const IntArray& axis, bool keepdim) {
}
}

template <typename T>
std::tuple<Tensor, Tensor> squeeze_decomp(const Tensor& x,
const IntArray& axis) {
auto axis_ = process_dims(x, axis.GetData());
auto out_shape = get_squeeze_dims(x, axis_);
Tensor out = reshape<T>(x, out_shape);
Tensor xshape;
return std::make_tuple(out, xshape);
}

template <typename T>
Tensor add_n_decomp(const std::vector<Tensor>& x) {
Tensor res = x[0];
for (size_t i = 1; i < x.size(); i++) {
res = add<T>(res, x[i]);
}
return res;
}

} // namespace details

} // namespace primitive
Expand Down
41 changes: 41 additions & 0 deletions paddle/fluid/primitive/utils/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,47 @@ static std::vector<int64_t> get_unsqueeze_dims(
return result;
}

// This fucction compute unsqueeze dims for reshape to replace unsqueeze.
static std::vector<int64_t> get_squeeze_dims(const Tensor& origin,
const std::vector<int64_t>& axis) {
auto origin_dims = origin.shape();
auto total_shape_size = origin_dims.size();
std::vector<int64_t> result;
for (size_t i = 0; i < total_shape_size; ++i) {
if (origin_dims[i] != 1) {
result.push_back(origin_dims[i]);
} else if (origin_dims[i] == 1 &&
std::find(axis.begin(), axis.end(), int64_t(i)) == axis.end()) {
result.push_back(1);
} else {
continue;
}
}
return result;
}

static std::vector<int64_t> process_dims(const Tensor& origin,
const std::vector<int64_t>& axis) {
auto origin_dims = origin.shape();
auto total_shape_size = origin_dims.size();
std::vector<int64_t> result;
auto axis_size = axis.size();
if (axis_size == 0) {
for (size_t i = 0; i < total_shape_size; ++i) {
result.push_back(i);
}
} else {
for (size_t i = 0; i < axis_size; ++i) {
if (axis[i] < 0) {
result.push_back(axis[i] + total_shape_size);
} else {
result.push_back(axis[i]);
}
}
}
return result;
}

// These method don't need to be specified
static phi::DDim get_reduce_dims_from_out(const phi::DDim& dout_dims,
const phi::DDim& in_dims) {
Expand Down
11 changes: 11 additions & 0 deletions paddle/fluid/pybind/pir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,17 @@ void BindOperation(py::module *m) {
}
return op_list;
})
.def("get_output_intermediate_value",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
.def("get_output_intermediate_value",
.def("get_output_value_intermediate_status",

换成这个名字更容易理解一点?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

或许叫get_output_intermediate_status?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice

[](Operation &self) -> py::list {
py::list op_list;
paddle::dialect::OpYamlInfoInterface yaml_interface =
self.dyn_cast<paddle::dialect::OpYamlInfoInterface>();
auto outputs_info = std::get<2>(yaml_interface.GetOpInfo());
for (auto &output_info : outputs_info) {
op_list.append(output_info.intermediate);
}
return op_list;
})
.def("get_input_grad_semantics",
[](Operation &self) -> py::list {
py::list op_list;
Expand Down
24 changes: 18 additions & 6 deletions python/paddle/decomposition/decomp.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,20 @@ def _build_tensor_tuple(xs):
return TypeError(f"Type {type(xs)} is not supported.")


def _analyse_decomp_results(orig_outs, decomp_outs):
assert len(orig_outs) == len(decomp_outs)
def _analyse_decomp_results(orig_outs, decomp_outs, op):
intermediate_values = op.get_output_intermediate_value()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

assert len(orig_outs) == len(decomp_outs) == len(intermediate_values)
res = []
for org_item, new_item in zip(orig_outs, decomp_outs):
for org_item, new_item, value in zip(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这儿value 改成output_intermediate_status吧 他就是要给状态的标识

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, thks

orig_outs, decomp_outs, intermediate_values
):
if isinstance(org_item, pir.OpResult):
assert len(new_item) == 1 and isinstance(new_item[0], pir.OpResult)
if value:
assert new_item[0] is None
else:
assert len(new_item) == 1 and isinstance(
new_item[0], pir.OpResult
)
res.append(new_item[0])
else:
res.append(new_item)
Expand Down Expand Up @@ -256,7 +264,9 @@ def _decompose_subgraph(block, orig_vars, dst_vars, op_filter):
orig_outs = op.results()
if has_sink_decomp_rule:
decomp_outs = call_decomp(op)
new_outs = _analyse_decomp_results(orig_outs, decomp_outs)
new_outs = _analyse_decomp_results(
orig_outs, decomp_outs, op
)
else:
new_outs = _build_tensor_tuple(decom_rule(*input_args))

Expand Down Expand Up @@ -389,7 +399,9 @@ def decompose_fwd_op(
pir.set_insertion_point(fwd_op)
if has_sink_decomp_rule:
decomp_outs = call_decomp(fwd_op)
new_outs = _analyse_decomp_results(orig_outs, decomp_outs)
new_outs = _analyse_decomp_results(
orig_outs, decomp_outs, fwd_op
)
else:
new_outs = _build_tensor_tuple(decom_rule(*input_args))

Expand Down
16 changes: 12 additions & 4 deletions test/ir/new_ir/test_ir_pybind.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ def get_ir_program():
y_s = paddle.matmul(x_s, x_s)
z_s = paddle.add(y_s, y_s)
k_s = paddle.tanh(z_s)
q_s = paddle.unsqueeze(k_s, [2])

newir_program = pir.translate_to_new_ir(main_program.desc)
return newir_program

Expand All @@ -51,10 +53,10 @@ def test_block(self):
block = newir_program.global_block()
ops = block.ops
self.assertEqual(
len(ops), 4
len(ops), 6
) # pir program add "builtin.get_parameter" by default, so size is 4
block.remove_op(ops[3])
self.assertEqual(len(block.ops), 3)
block.remove_op(ops[5])
self.assertEqual(len(block.ops), 5)

def test_operation(self):
newir_program = get_ir_program()
Expand All @@ -64,7 +66,7 @@ def test_operation(self):
tanh_op = newir_program.global_block().ops[3]
parent_block = tanh_op.get_parent_block()
parent_ops_num = len(parent_block.ops)
self.assertEqual(parent_ops_num, 4)
self.assertEqual(parent_ops_num, 6)
self.assertEqual(tanh_op.num_results(), 1)
self.assertEqual(len(matmul_op.get_input_names()), 2)
self.assertEqual(len(matmul_op.get_attr_names()), 2)
Expand Down Expand Up @@ -190,6 +192,12 @@ def test_results(self):
results = matmul_op.results()
self.assertEqual(len(results), 1)

def test_get_output_intermediate_value(self):
newir_program = get_ir_program()
unsqueeze_op = newir_program.global_block().ops[-1]
results = unsqueeze_op.get_output_intermediate_value()
self.assertEqual(results, [False, True])


if __name__ == "__main__":
unittest.main()