Skip to content

Commit

Permalink
[Prim][PIR] PIR Prim support intarray, scalar, combineop (PaddlePaddl…
Browse files Browse the repository at this point in the history
…e#58581)

* fix intarray codegen

* fix code

* remove unused code

* reset code

* support scalar

* fix code

* support combineop case
  • Loading branch information
cyber-pioneer authored Nov 3, 2023
1 parent 808239e commit b6abb66
Show file tree
Hide file tree
Showing 9 changed files with 158 additions and 90 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@


# come into effect in generated file pd_op.h
decomp_interface_declare_gen_op_list = ['mean']
# manual decomp interface declare are located in manual_op.h
decomp_interface_declare_gen_op_list = ["mean", "squeeze", "add_n"]

# come into effect in generated file op_decomp.cc
# manual decomp interface implementation are located in manual_op_decomp.cc
decomp_interface_implementation_gen_op_list = ["mean"]
decomp_interface_implementation_gen_op_list = ["mean", "squeeze", "add_n"]
5 changes: 4 additions & 1 deletion paddle/fluid/pir/dialect/operator/ir/manual_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include <vector>

#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/pir/dialect/operator/interface/decomp.h"
#include "paddle/fluid/pir/dialect/operator/interface/infermeta.h"
#include "paddle/fluid/pir/dialect/operator/interface/op_yaml_info.h"
#include "paddle/fluid/pir/dialect/operator/interface/vjp.h"
Expand All @@ -34,7 +35,8 @@ namespace dialect {
class AddNOp : public pir::Op<AddNOp,
paddle::dialect::OpYamlInfoInterface,
paddle::dialect::InferMetaInterface,
paddle::dialect::VjpInterface> {
paddle::dialect::VjpInterface,
paddle::dialect::DecompInterface> {
public:
using Op::Op;
static const char *name() { return "pd_op.add_n"; }
Expand All @@ -55,6 +57,7 @@ class AddNOp : public pir::Op<AddNOp,
const std::vector<std::vector<pir::OpResult>> &outputs,
const std::vector<std::vector<pir::Value>> &out_grads,
const std::vector<std::vector<bool>> &stop_gradients);
static std::vector<std::vector<pir::OpResult>> Decomp(pir::Operation *op);
};

class AddN_Op : public pir::Op<AddN_Op,
Expand Down
75 changes: 1 addition & 74 deletions paddle/fluid/primitive/codegen/decomp_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
)
import filters as op_gen_filters
import tests_utils as op_gen_tests
from gen import extend_compat_info, filter_compat_info
from parse_utils import to_named_dict
from type_mapping import output_type_map

Expand Down Expand Up @@ -128,80 +129,6 @@ def save(content: str, path: pathlib.Path):
f.write(content)


def filter_compat_info(items):
for item in items:
item['op'] = item['op'].split('(')[0].strip()
if 'backward' in item:
item_backwards = item['backward'].split(',')
for idx, item_backward in enumerate(item_backwards):
item_backward = item_backward.split('(')[0].strip()
item_backwards[idx] = item_backward
item['backward'] = (
','.join(item_backwards)
if len(item_backwards) > 0
else item_backwards[0]
)


def extend_compat_info(apis, compats):
for api in apis:
attrs = api["attrs"]
for attr in attrs:
if op_gen_tests.is_scalar(
attr['typename']
) or op_gen_tests.is_intarray(attr['typename']):
attr["support_tensor"] = False
apis_dict = to_named_dict(apis)
for compat_item in compats:
fwd_op_name = compat_item["op"]
if fwd_op_name not in apis_dict:
continue
fwd_api = apis_dict[fwd_op_name]
backward_op_names = []
while fwd_op_name is not None and fwd_op_name in apis_dict:
backward_op_names.append(apis_dict[fwd_op_name]['backward'])
fwd_op_name = apis_dict[fwd_op_name]['backward']
backward_apis = []
for backward_op_name in backward_op_names:
if backward_op_name in apis_dict:
backward_apis.append(apis_dict[backward_op_name])
support_tensor_attrs_names = []
compat_attrs_data_type = {}
if 'scalar' in compat_item and compat_item['op'] != "pow":
for attr_name, attr_info in compat_item['scalar'].items():
if (
'support_tensor' in attr_info
and attr_info['support_tensor'] is True
or 'tensor_name' in attr_info
):
support_tensor_attrs_names.append(attr_name)
if 'data_type' in attr_info:
compat_attrs_data_type.update(
{attr_name: attr_info['data_type']}
)
if 'int_array' in compat_item:
for attr_name, attr_info in compat_item['int_array'].items():
if (
'support_tensor' in attr_info
and attr_info['support_tensor'] is True
or 'tensor_name' in attr_info
or 'tensors_name' in attr_info
):
support_tensor_attrs_names.append(attr_name)
if len(support_tensor_attrs_names) > 0:
for api in [fwd_api] + backward_apis:
attrs = api["attrs"]
for attr in attrs:
if attr['name'] in support_tensor_attrs_names:
attr['support_tensor'] = True
for api in [fwd_api] + backward_apis:
attrs = api["attrs"]
for attr in attrs:
if attr['name'] in compat_attrs_data_type:
attr['data_type'] = compat_attrs_data_type[attr['name']]
return apis


def process_optional_output_info(apis):
for api in apis:
inputs_dict = to_named_dict(api['inputs'])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ std::vector<std::vector<pir::OpResult>> {{class_name}}::Decomp(pir::Operation* o
for (size_t idx = 0; idx < combine_op_obj_{{item.name}}.inputs().size(); idx++) {
{{item.name}}.emplace_back(
std::make_shared<primitive::LazyTensor>(combine_op_obj_{{item.name}}.inputs()[idx]));
}
{% endif %}
{% endif %}
{% endfor %}
Expand All @@ -69,7 +70,47 @@ std::vector<std::vector<pir::OpResult>> {{class_name}}::Decomp(pir::Operation* o
{% if attrs %}
{% for item in attrs %}
{% do attr_names.append(item.name) %}

{% if item.typename == "Scalar" and item.support_tensor %}

Tensor {{item.name}}_(std::make_shared<primitive::LazyTensor>(op_obj.{{item.name}}()));

auto* {{item.name}}_define_op =
std::static_pointer_cast<primitive::LazyTensor>({{item.name}}_.impl())
->value()
.dyn_cast<pir::OpResult>()
.owner();
if ({{item.name}}_define_op->name() != "pd_op.full") {
PADDLE_THROW(
platform::errors::Unimplemented("We don't support dynamic tensors "
"attribute {{item.name}} for {{fwd_name}} decomposition "
"for now. "));
}
Scalar {{item.name}} = {{item.name}}_define_op->attribute("value").dyn_cast<paddle::dialect::ScalarAttribute>().data();


{% elif item.typename == "IntArray" and item.support_tensor %}

Tensor {{item.name}}_(std::make_shared<primitive::LazyTensor>(op_obj.{{item.name}}()));

auto* {{item.name}}_define_op =
std::static_pointer_cast<primitive::LazyTensor>({{item.name}}_.impl())
->value()
.dyn_cast<pir::OpResult>()
.owner();
if ({{item.name}}_define_op->name() != "pd_op.full_int_array") {
PADDLE_THROW(
platform::errors::Unimplemented("We don't support dynamic tensors "
"attribute {{item.name}} for {{fwd_name}} decomposition "
"for now. "));
}
IntArray {{item.name}} = phi::IntArray(
paddle::dialect::GetInt64Vector({{item.name}}_define_op->attribute("value")));

{% else %}

{{item.typename}} {{item.name}} = op->attribute("{{item.name}}").dyn_cast<{{item.mapped_type}}>().data();
{% endif %}
{% endfor %}
{% endif %}

Expand All @@ -92,9 +133,14 @@ std::vector<std::vector<pir::OpResult>> {{class_name}}::Decomp(pir::Operation* o
{% endfor %}
std::tuple<{{common.sequence('', '', ', ', output_types)}}> op_res = paddle::primitive::details::{{fwd_name}}_decomp<primitive::LazyTensor>(
{{common.args(input_names, attr_names)}});
for (size_t i = 0; i < org_res.size(); ++i) {
res[i].push_back(std::static_pointer_cast<primitive::LazyTensor>(std::get<i>(op_res).impl())->value().dyn_cast<pir::OpResult>());
}
{% for k in range(outputs|length) %}
{% if outputs[k].intermediate %}
pir::OpResult {{outputs[k].name}};
res[{{k}}].push_back({{outputs[k].name}});
{% else %}
res[{{k}}].push_back(std::static_pointer_cast<primitive::LazyTensor>(std::get<{{k}}>(op_res).impl())->value().dyn_cast<pir::OpResult>());
{% endif %}
{% endfor %}
{% endif %}

return res;
Expand Down
19 changes: 19 additions & 0 deletions paddle/fluid/primitive/composite/composite.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,25 @@ Tensor mean_decomp(const Tensor& x, const IntArray& axis, bool keepdim) {
}
}

template <typename T>
std::tuple<Tensor, Tensor> squeeze_decomp(const Tensor& x,
const IntArray& axis) {
auto axis_ = process_dims(x, axis.GetData());
auto out_shape = get_squeeze_dims(x, axis_);
Tensor out = reshape<T>(x, out_shape);
Tensor xshape;
return std::make_tuple(out, xshape);
}

template <typename T>
Tensor add_n_decomp(const std::vector<Tensor>& x) {
Tensor res = x[0];
for (size_t i = 1; i < x.size(); i++) {
res = add<T>(res, x[i]);
}
return res;
}

} // namespace details

} // namespace primitive
Expand Down
41 changes: 41 additions & 0 deletions paddle/fluid/primitive/utils/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,47 @@ static std::vector<int64_t> get_unsqueeze_dims(
return result;
}

// This fucction compute unsqueeze dims for reshape to replace unsqueeze.
static std::vector<int64_t> get_squeeze_dims(const Tensor& origin,
const std::vector<int64_t>& axis) {
auto origin_dims = origin.shape();
auto total_shape_size = origin_dims.size();
std::vector<int64_t> result;
for (size_t i = 0; i < total_shape_size; ++i) {
if (origin_dims[i] != 1) {
result.push_back(origin_dims[i]);
} else if (origin_dims[i] == 1 &&
std::find(axis.begin(), axis.end(), int64_t(i)) == axis.end()) {
result.push_back(1);
} else {
continue;
}
}
return result;
}

static std::vector<int64_t> process_dims(const Tensor& origin,
const std::vector<int64_t>& axis) {
auto origin_dims = origin.shape();
auto total_shape_size = origin_dims.size();
std::vector<int64_t> result;
auto axis_size = axis.size();
if (axis_size == 0) {
for (size_t i = 0; i < total_shape_size; ++i) {
result.push_back(i);
}
} else {
for (size_t i = 0; i < axis_size; ++i) {
if (axis[i] < 0) {
result.push_back(axis[i] + total_shape_size);
} else {
result.push_back(axis[i]);
}
}
}
return result;
}

// These method don't need to be specified
static phi::DDim get_reduce_dims_from_out(const phi::DDim& dout_dims,
const phi::DDim& in_dims) {
Expand Down
11 changes: 11 additions & 0 deletions paddle/fluid/pybind/pir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,17 @@ void BindOperation(py::module *m) {
}
return op_list;
})
.def("get_output_intermediate_value",
[](Operation &self) -> py::list {
py::list op_list;
paddle::dialect::OpYamlInfoInterface yaml_interface =
self.dyn_cast<paddle::dialect::OpYamlInfoInterface>();
auto outputs_info = std::get<2>(yaml_interface.GetOpInfo());
for (auto &output_info : outputs_info) {
op_list.append(output_info.intermediate);
}
return op_list;
})
.def("get_input_grad_semantics",
[](Operation &self) -> py::list {
py::list op_list;
Expand Down
24 changes: 18 additions & 6 deletions python/paddle/decomposition/decomp.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,20 @@ def _build_tensor_tuple(xs):
return TypeError(f"Type {type(xs)} is not supported.")


def _analyse_decomp_results(orig_outs, decomp_outs):
assert len(orig_outs) == len(decomp_outs)
def _analyse_decomp_results(orig_outs, decomp_outs, op):
intermediate_values = op.get_output_intermediate_value()
assert len(orig_outs) == len(decomp_outs) == len(intermediate_values)
res = []
for org_item, new_item in zip(orig_outs, decomp_outs):
for org_item, new_item, value in zip(
orig_outs, decomp_outs, intermediate_values
):
if isinstance(org_item, pir.OpResult):
assert len(new_item) == 1 and isinstance(new_item[0], pir.OpResult)
if value:
assert new_item[0] is None
else:
assert len(new_item) == 1 and isinstance(
new_item[0], pir.OpResult
)
res.append(new_item[0])
else:
res.append(new_item)
Expand Down Expand Up @@ -256,7 +264,9 @@ def _decompose_subgraph(block, orig_vars, dst_vars, op_filter):
orig_outs = op.results()
if has_sink_decomp_rule:
decomp_outs = call_decomp(op)
new_outs = _analyse_decomp_results(orig_outs, decomp_outs)
new_outs = _analyse_decomp_results(
orig_outs, decomp_outs, op
)
else:
new_outs = _build_tensor_tuple(decom_rule(*input_args))

Expand Down Expand Up @@ -389,7 +399,9 @@ def decompose_fwd_op(
pir.set_insertion_point(fwd_op)
if has_sink_decomp_rule:
decomp_outs = call_decomp(fwd_op)
new_outs = _analyse_decomp_results(orig_outs, decomp_outs)
new_outs = _analyse_decomp_results(
orig_outs, decomp_outs, fwd_op
)
else:
new_outs = _build_tensor_tuple(decom_rule(*input_args))

Expand Down
16 changes: 12 additions & 4 deletions test/ir/new_ir/test_ir_pybind.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ def get_ir_program():
y_s = paddle.matmul(x_s, x_s)
z_s = paddle.add(y_s, y_s)
k_s = paddle.tanh(z_s)
q_s = paddle.unsqueeze(k_s, [2])

newir_program = pir.translate_to_new_ir(main_program.desc)
return newir_program

Expand All @@ -51,10 +53,10 @@ def test_block(self):
block = newir_program.global_block()
ops = block.ops
self.assertEqual(
len(ops), 4
len(ops), 6
) # pir program add "builtin.get_parameter" by default, so size is 4
block.remove_op(ops[3])
self.assertEqual(len(block.ops), 3)
block.remove_op(ops[5])
self.assertEqual(len(block.ops), 5)

def test_operation(self):
newir_program = get_ir_program()
Expand All @@ -64,7 +66,7 @@ def test_operation(self):
tanh_op = newir_program.global_block().ops[3]
parent_block = tanh_op.get_parent_block()
parent_ops_num = len(parent_block.ops)
self.assertEqual(parent_ops_num, 4)
self.assertEqual(parent_ops_num, 6)
self.assertEqual(tanh_op.num_results(), 1)
self.assertEqual(len(matmul_op.get_input_names()), 2)
self.assertEqual(len(matmul_op.get_attr_names()), 2)
Expand Down Expand Up @@ -190,6 +192,12 @@ def test_results(self):
results = matmul_op.results()
self.assertEqual(len(results), 1)

def test_get_output_intermediate_value(self):
newir_program = get_ir_program()
unsqueeze_op = newir_program.global_block().ops[-1]
results = unsqueeze_op.get_output_intermediate_value()
self.assertEqual(results, [False, True])


if __name__ == "__main__":
unittest.main()

0 comments on commit b6abb66

Please sign in to comment.