Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Prim][PIR] PIR Prim support intarray, scalar, combineop #58581

Merged

Conversation

cyber-pioneer
Copy link
Contributor

@cyber-pioneer cyber-pioneer commented Nov 1, 2023

PR types

New features

PR changes

Others

Description

Pcard-66975

  1. Sink composite rules of squeeze and add_n into cpp.
  2. PIR decomp forward rule support intarray, scalar, combineop.
  3. Add api get_output_intermediate_value to get PIR op output intermediate values.
std::vector<std::vector<pir::OpResult>> GatherOp::Decomp(pir::Operation* op) {
  GatherOp op_obj = op->dyn_cast<GatherOp>();
  (void)op_obj;

  VLOG(4) << "Decomp Prepare inputs of gather";

  Tensor x(std::make_shared<primitive::LazyTensor>(op_obj.x()));
  Tensor index(std::make_shared<primitive::LazyTensor>(op_obj.index()));

  VLOG(4) << "Decomp prepare attributes of gather";



  Tensor axis_(std::make_shared<primitive::LazyTensor>(op_obj.axis()));

  auto* axis_define_op =
      std::static_pointer_cast<primitive::LazyTensor>(axis_.impl())
          ->value()
          .dyn_cast<pir::OpResult>()
          .owner();
  if (axis_define_op->name() != "pd_op.full") {
    PADDLE_THROW(
        platform::errors::Unimplemented("We don't support dynamic tensors "
                                        "attribute axis for gather decomposition "
                                        "for now. "));
  }
  Scalar axis = axis_define_op->attribute("value").dyn_cast<paddle::dialect::ScalarAttribute>().data();



  VLOG(4) << "Decomp prepare call gather's decomp interface";

  auto org_res = op->results();
  std::vector<std::vector<pir::OpResult>> res(org_res.size());

  Tensor op_res = paddle::primitive::details::gather_decomp<primitive::LazyTensor>(x, index, axis);
  res[0].push_back(
    std::static_pointer_cast<primitive::LazyTensor>(op_res.impl())
        ->value()
        .dyn_cast<pir::OpResult>());
  return res;

}

std::vector<std::vector<pir::OpResult>> SqueezeOp::Decomp(pir::Operation* op) {
  SqueezeOp op_obj = op->dyn_cast<SqueezeOp>();
  (void)op_obj;

  VLOG(4) << "Decomp Prepare inputs of squeeze";

  Tensor x(std::make_shared<primitive::LazyTensor>(op_obj.x()));

  VLOG(4) << "Decomp prepare attributes of squeeze";



  Tensor axis_(std::make_shared<primitive::LazyTensor>(op_obj.axis()));

  auto* axis_define_op =
      std::static_pointer_cast<primitive::LazyTensor>(axis_.impl())
          ->value()
          .dyn_cast<pir::OpResult>()
          .owner();
  if (axis_define_op->name() != "pd_op.full_int_array") {
    PADDLE_THROW(
        platform::errors::Unimplemented("We don't support dynamic tensors "
                                        "attribute axis for squeeze decomposition "
                                        "for now. "));
  }
  IntArray axis = phi::IntArray(
      paddle::dialect::GetInt64Vector(axis_define_op->attribute("value")));


  VLOG(4) << "Decomp prepare call squeeze's decomp interface";

  auto org_res = op->results();
  std::vector<std::vector<pir::OpResult>> res(org_res.size());

  std::tuple<Tensor, Tensor> op_res = paddle::primitive::details::squeeze_decomp<primitive::LazyTensor>(
          x, axis);
  res[0].push_back(std::static_pointer_cast<primitive::LazyTensor>(std::get<0>(op_res).impl())->value().dyn_cast<pir::OpResult>());
  pir::OpResult xshape;
  res[1].push_back(xshape);

  return res;

}

std::vector<std::vector<pir::OpResult>> AddNOp::Decomp(pir::Operation* op) {
  AddNOp op_obj = op->dyn_cast<AddNOp>();
  (void)op_obj;

  VLOG(4) << "Decomp Prepare inputs of add_n";

  pir::CombineOp combine_op_obj_inputs =
    op_obj.inputs().dyn_cast<pir::OpResult>().owner()->dyn_cast<pir::CombineOp>();
  std::vector<Tensor> inputs;
  for (size_t idx = 0; idx < combine_op_obj_inputs.inputs().size(); idx++) {
      inputs.emplace_back(
          std::make_shared<primitive::LazyTensor>(combine_op_obj_inputs.inputs()[idx]));
  }

  VLOG(4) << "Decomp prepare attributes of add_n";


  VLOG(4) << "Decomp prepare call add_n's decomp interface";

  auto org_res = op->results();
  std::vector<std::vector<pir::OpResult>> res(org_res.size());

  Tensor op_res = paddle::primitive::details::add_n_decomp<primitive::LazyTensor>(inputs);
  res[0].push_back(
    std::static_pointer_cast<primitive::LazyTensor>(op_res.impl())
        ->value()
        .dyn_cast<pir::OpResult>());
  return res;

}

@paddle-bot paddle-bot bot added the contributor External developers label Nov 1, 2023
@cyber-pioneer cyber-pioneer changed the title fix intarray codegen [Prim][PIR] PIR Prim support intarray, scalar, combineop Nov 2, 2023
@@ -351,6 +351,17 @@ void BindOperation(py::module *m) {
}
return op_list;
})
.def("get_output_intermediate_value",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
.def("get_output_intermediate_value",
.def("get_output_value_intermediate_status",

换成这个名字更容易理解一点?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

或许叫get_output_intermediate_status?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice

def _analyse_decomp_results(orig_outs, decomp_outs):
assert len(orig_outs) == len(decomp_outs)
def _analyse_decomp_results(orig_outs, decomp_outs, op):
intermediate_values = op.get_output_intermediate_value()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

res = []
for org_item, new_item in zip(orig_outs, decomp_outs):
for org_item, new_item, value in zip(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这儿value 改成output_intermediate_status吧 他就是要给状态的标识

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, thks

@cyber-pioneer cyber-pioneer merged commit 5cb08af into PaddlePaddle:develop Nov 3, 2023
28 checks passed
@paddle-bot paddle-bot bot removed the contributor External developers label Nov 3, 2023
zeroRains pushed a commit to zeroRains/Paddle that referenced this pull request Nov 8, 2023
…e#58581)

* fix intarray codegen

* fix code

* remove unused code

* reset code

* support scalar

* fix code

* support combineop case
danleifeng pushed a commit to danleifeng/Paddle that referenced this pull request Nov 14, 2023
…e#58581)

* fix intarray codegen

* fix code

* remove unused code

* reset code

* support scalar

* fix code

* support combineop case
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants