diff --git a/paddle/fluid/pir/dialect/op_generator/decomp_interface_gen_op_list.py b/paddle/fluid/pir/dialect/op_generator/decomp_interface_gen_op_list.py index 334d410e7dab31..2068d0917e299f 100644 --- a/paddle/fluid/pir/dialect/op_generator/decomp_interface_gen_op_list.py +++ b/paddle/fluid/pir/dialect/op_generator/decomp_interface_gen_op_list.py @@ -18,8 +18,9 @@ # come into effect in generated file pd_op.h -decomp_interface_declare_gen_op_list = ['mean'] +# manual decomp interface declare are located in manual_op.h +decomp_interface_declare_gen_op_list = ["mean", "squeeze", "add_n"] # come into effect in generated file op_decomp.cc # manual decomp interface implementation are located in manual_op_decomp.cc -decomp_interface_implementation_gen_op_list = ["mean"] +decomp_interface_implementation_gen_op_list = ["mean", "squeeze", "add_n"] diff --git a/paddle/fluid/pir/dialect/operator/ir/manual_op.h b/paddle/fluid/pir/dialect/operator/ir/manual_op.h index 83ced4c1458fe1..cda6eb596d21eb 100644 --- a/paddle/fluid/pir/dialect/operator/ir/manual_op.h +++ b/paddle/fluid/pir/dialect/operator/ir/manual_op.h @@ -16,6 +16,7 @@ #include #include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/fluid/pir/dialect/operator/interface/decomp.h" #include "paddle/fluid/pir/dialect/operator/interface/infermeta.h" #include "paddle/fluid/pir/dialect/operator/interface/op_yaml_info.h" #include "paddle/fluid/pir/dialect/operator/interface/vjp.h" @@ -34,7 +35,8 @@ namespace dialect { class AddNOp : public pir::Op { + paddle::dialect::VjpInterface, + paddle::dialect::DecompInterface> { public: using Op::Op; static const char *name() { return "pd_op.add_n"; } @@ -55,6 +57,7 @@ class AddNOp : public pir::Op> &outputs, const std::vector> &out_grads, const std::vector> &stop_gradients); + static std::vector> Decomp(pir::Operation *op); }; class AddN_Op : public pir::Op 0 - else item_backwards[0] - ) - - -def extend_compat_info(apis, compats): - for api in apis: - attrs = api["attrs"] - for attr in attrs: - if op_gen_tests.is_scalar( - attr['typename'] - ) or op_gen_tests.is_intarray(attr['typename']): - attr["support_tensor"] = False - apis_dict = to_named_dict(apis) - for compat_item in compats: - fwd_op_name = compat_item["op"] - if fwd_op_name not in apis_dict: - continue - fwd_api = apis_dict[fwd_op_name] - backward_op_names = [] - while fwd_op_name is not None and fwd_op_name in apis_dict: - backward_op_names.append(apis_dict[fwd_op_name]['backward']) - fwd_op_name = apis_dict[fwd_op_name]['backward'] - backward_apis = [] - for backward_op_name in backward_op_names: - if backward_op_name in apis_dict: - backward_apis.append(apis_dict[backward_op_name]) - support_tensor_attrs_names = [] - compat_attrs_data_type = {} - if 'scalar' in compat_item and compat_item['op'] != "pow": - for attr_name, attr_info in compat_item['scalar'].items(): - if ( - 'support_tensor' in attr_info - and attr_info['support_tensor'] is True - or 'tensor_name' in attr_info - ): - support_tensor_attrs_names.append(attr_name) - if 'data_type' in attr_info: - compat_attrs_data_type.update( - {attr_name: attr_info['data_type']} - ) - if 'int_array' in compat_item: - for attr_name, attr_info in compat_item['int_array'].items(): - if ( - 'support_tensor' in attr_info - and attr_info['support_tensor'] is True - or 'tensor_name' in attr_info - or 'tensors_name' in attr_info - ): - support_tensor_attrs_names.append(attr_name) - if len(support_tensor_attrs_names) > 0: - for api in [fwd_api] + backward_apis: - attrs = api["attrs"] - for attr in attrs: - if attr['name'] in support_tensor_attrs_names: - attr['support_tensor'] = True - for api in [fwd_api] + backward_apis: - attrs = api["attrs"] - for attr in attrs: - if attr['name'] in compat_attrs_data_type: - attr['data_type'] = compat_attrs_data_type[attr['name']] - return apis - - def process_optional_output_info(apis): for api in apis: inputs_dict = to_named_dict(api['inputs']) diff --git a/paddle/fluid/primitive/codegen/templates/decomp/generated_decomp.j2 b/paddle/fluid/primitive/codegen/templates/decomp/generated_decomp.j2 index 601a59802d4312..087a8757e5e17f 100644 --- a/paddle/fluid/primitive/codegen/templates/decomp/generated_decomp.j2 +++ b/paddle/fluid/primitive/codegen/templates/decomp/generated_decomp.j2 @@ -60,6 +60,7 @@ std::vector> {{class_name}}::Decomp(pir::Operation* o for (size_t idx = 0; idx < combine_op_obj_{{item.name}}.inputs().size(); idx++) { {{item.name}}.emplace_back( std::make_shared(combine_op_obj_{{item.name}}.inputs()[idx])); + } {% endif %} {% endif %} {% endfor %} @@ -69,7 +70,47 @@ std::vector> {{class_name}}::Decomp(pir::Operation* o {% if attrs %} {% for item in attrs %} {% do attr_names.append(item.name) %} + + {% if item.typename == "Scalar" and item.support_tensor %} + + Tensor {{item.name}}_(std::make_shared(op_obj.{{item.name}}())); + + auto* {{item.name}}_define_op = + std::static_pointer_cast({{item.name}}_.impl()) + ->value() + .dyn_cast() + .owner(); + if ({{item.name}}_define_op->name() != "pd_op.full") { + PADDLE_THROW( + platform::errors::Unimplemented("We don't support dynamic tensors " + "attribute {{item.name}} for {{fwd_name}} decomposition " + "for now. ")); + } + Scalar {{item.name}} = {{item.name}}_define_op->attribute("value").dyn_cast().data(); + + + {% elif item.typename == "IntArray" and item.support_tensor %} + + Tensor {{item.name}}_(std::make_shared(op_obj.{{item.name}}())); + + auto* {{item.name}}_define_op = + std::static_pointer_cast({{item.name}}_.impl()) + ->value() + .dyn_cast() + .owner(); + if ({{item.name}}_define_op->name() != "pd_op.full_int_array") { + PADDLE_THROW( + platform::errors::Unimplemented("We don't support dynamic tensors " + "attribute {{item.name}} for {{fwd_name}} decomposition " + "for now. ")); + } + IntArray {{item.name}} = phi::IntArray( + paddle::dialect::GetInt64Vector({{item.name}}_define_op->attribute("value"))); + + {% else %} + {{item.typename}} {{item.name}} = op->attribute("{{item.name}}").dyn_cast<{{item.mapped_type}}>().data(); + {% endif %} {% endfor %} {% endif %} @@ -92,9 +133,14 @@ std::vector> {{class_name}}::Decomp(pir::Operation* o {% endfor %} std::tuple<{{common.sequence('', '', ', ', output_types)}}> op_res = paddle::primitive::details::{{fwd_name}}_decomp( {{common.args(input_names, attr_names)}}); - for (size_t i = 0; i < org_res.size(); ++i) { - res[i].push_back(std::static_pointer_cast(std::get(op_res).impl())->value().dyn_cast()); - } + {% for k in range(outputs|length) %} + {% if outputs[k].intermediate %} + pir::OpResult {{outputs[k].name}}; + res[{{k}}].push_back({{outputs[k].name}}); + {% else %} + res[{{k}}].push_back(std::static_pointer_cast(std::get<{{k}}>(op_res).impl())->value().dyn_cast()); + {% endif %} + {% endfor %} {% endif %} return res; diff --git a/paddle/fluid/primitive/composite/composite.h b/paddle/fluid/primitive/composite/composite.h index e0da626ef4c938..f1230c33d43f9b 100644 --- a/paddle/fluid/primitive/composite/composite.h +++ b/paddle/fluid/primitive/composite/composite.h @@ -62,6 +62,25 @@ Tensor mean_decomp(const Tensor& x, const IntArray& axis, bool keepdim) { } } +template +std::tuple squeeze_decomp(const Tensor& x, + const IntArray& axis) { + auto axis_ = process_dims(x, axis.GetData()); + auto out_shape = get_squeeze_dims(x, axis_); + Tensor out = reshape(x, out_shape); + Tensor xshape; + return std::make_tuple(out, xshape); +} + +template +Tensor add_n_decomp(const std::vector& x) { + Tensor res = x[0]; + for (size_t i = 1; i < x.size(); i++) { + res = add(res, x[i]); + } + return res; +} + } // namespace details } // namespace primitive diff --git a/paddle/fluid/primitive/utils/utils.h b/paddle/fluid/primitive/utils/utils.h index 73fa68a0bb937c..da6cd28bfa476b 100644 --- a/paddle/fluid/primitive/utils/utils.h +++ b/paddle/fluid/primitive/utils/utils.h @@ -55,6 +55,47 @@ static std::vector get_unsqueeze_dims( return result; } +// This fucction compute unsqueeze dims for reshape to replace unsqueeze. +static std::vector get_squeeze_dims(const Tensor& origin, + const std::vector& axis) { + auto origin_dims = origin.shape(); + auto total_shape_size = origin_dims.size(); + std::vector result; + for (size_t i = 0; i < total_shape_size; ++i) { + if (origin_dims[i] != 1) { + result.push_back(origin_dims[i]); + } else if (origin_dims[i] == 1 && + std::find(axis.begin(), axis.end(), int64_t(i)) == axis.end()) { + result.push_back(1); + } else { + continue; + } + } + return result; +} + +static std::vector process_dims(const Tensor& origin, + const std::vector& axis) { + auto origin_dims = origin.shape(); + auto total_shape_size = origin_dims.size(); + std::vector result; + auto axis_size = axis.size(); + if (axis_size == 0) { + for (size_t i = 0; i < total_shape_size; ++i) { + result.push_back(i); + } + } else { + for (size_t i = 0; i < axis_size; ++i) { + if (axis[i] < 0) { + result.push_back(axis[i] + total_shape_size); + } else { + result.push_back(axis[i]); + } + } + } + return result; +} + // These method don't need to be specified static phi::DDim get_reduce_dims_from_out(const phi::DDim& dout_dims, const phi::DDim& in_dims) { diff --git a/paddle/fluid/pybind/pir.cc b/paddle/fluid/pybind/pir.cc index 485c63e26a4357..b273f93b88454a 100644 --- a/paddle/fluid/pybind/pir.cc +++ b/paddle/fluid/pybind/pir.cc @@ -373,6 +373,17 @@ void BindOperation(py::module *m) { } return op_list; }) + .def("get_output_intermediate_value", + [](Operation &self) -> py::list { + py::list op_list; + paddle::dialect::OpYamlInfoInterface yaml_interface = + self.dyn_cast(); + auto outputs_info = std::get<2>(yaml_interface.GetOpInfo()); + for (auto &output_info : outputs_info) { + op_list.append(output_info.intermediate); + } + return op_list; + }) .def("get_input_grad_semantics", [](Operation &self) -> py::list { py::list op_list; diff --git a/python/paddle/decomposition/decomp.py b/python/paddle/decomposition/decomp.py index e89b5abc392211..af692982071e50 100644 --- a/python/paddle/decomposition/decomp.py +++ b/python/paddle/decomposition/decomp.py @@ -32,12 +32,20 @@ def _build_tensor_tuple(xs): return TypeError(f"Type {type(xs)} is not supported.") -def _analyse_decomp_results(orig_outs, decomp_outs): - assert len(orig_outs) == len(decomp_outs) +def _analyse_decomp_results(orig_outs, decomp_outs, op): + intermediate_values = op.get_output_intermediate_value() + assert len(orig_outs) == len(decomp_outs) == len(intermediate_values) res = [] - for org_item, new_item in zip(orig_outs, decomp_outs): + for org_item, new_item, value in zip( + orig_outs, decomp_outs, intermediate_values + ): if isinstance(org_item, pir.OpResult): - assert len(new_item) == 1 and isinstance(new_item[0], pir.OpResult) + if value: + assert new_item[0] is None + else: + assert len(new_item) == 1 and isinstance( + new_item[0], pir.OpResult + ) res.append(new_item[0]) else: res.append(new_item) @@ -256,7 +264,9 @@ def _decompose_subgraph(block, orig_vars, dst_vars, op_filter): orig_outs = op.results() if has_sink_decomp_rule: decomp_outs = call_decomp(op) - new_outs = _analyse_decomp_results(orig_outs, decomp_outs) + new_outs = _analyse_decomp_results( + orig_outs, decomp_outs, op + ) else: new_outs = _build_tensor_tuple(decom_rule(*input_args)) @@ -389,7 +399,9 @@ def decompose_fwd_op( pir.set_insertion_point(fwd_op) if has_sink_decomp_rule: decomp_outs = call_decomp(fwd_op) - new_outs = _analyse_decomp_results(orig_outs, decomp_outs) + new_outs = _analyse_decomp_results( + orig_outs, decomp_outs, fwd_op + ) else: new_outs = _build_tensor_tuple(decom_rule(*input_args)) diff --git a/test/ir/new_ir/test_ir_pybind.py b/test/ir/new_ir/test_ir_pybind.py index 6434b0eb65268a..d1a0e1de1f8788 100644 --- a/test/ir/new_ir/test_ir_pybind.py +++ b/test/ir/new_ir/test_ir_pybind.py @@ -32,6 +32,8 @@ def get_ir_program(): y_s = paddle.matmul(x_s, x_s) z_s = paddle.add(y_s, y_s) k_s = paddle.tanh(z_s) + q_s = paddle.unsqueeze(k_s, [2]) + newir_program = pir.translate_to_new_ir(main_program.desc) return newir_program @@ -51,10 +53,10 @@ def test_block(self): block = newir_program.global_block() ops = block.ops self.assertEqual( - len(ops), 4 + len(ops), 6 ) # pir program add "builtin.get_parameter" by default, so size is 4 - block.remove_op(ops[3]) - self.assertEqual(len(block.ops), 3) + block.remove_op(ops[5]) + self.assertEqual(len(block.ops), 5) def test_operation(self): newir_program = get_ir_program() @@ -64,7 +66,7 @@ def test_operation(self): tanh_op = newir_program.global_block().ops[3] parent_block = tanh_op.get_parent_block() parent_ops_num = len(parent_block.ops) - self.assertEqual(parent_ops_num, 4) + self.assertEqual(parent_ops_num, 6) self.assertEqual(tanh_op.num_results(), 1) self.assertEqual(len(matmul_op.get_input_names()), 2) self.assertEqual(len(matmul_op.get_attr_names()), 2) @@ -190,6 +192,12 @@ def test_results(self): results = matmul_op.results() self.assertEqual(len(results), 1) + def test_get_output_intermediate_value(self): + newir_program = get_ir_program() + unsqueeze_op = newir_program.global_block().ops[-1] + results = unsqueeze_op.get_output_intermediate_value() + self.assertEqual(results, [False, True]) + if __name__ == "__main__": unittest.main()