diff --git a/paddle/cinn/hlir/dialect/operator/ir/manual_op.cc b/paddle/cinn/hlir/dialect/operator/ir/manual_op.cc index 3a4ebb63679f3..3bb572250032f 100644 --- a/paddle/cinn/hlir/dialect/operator/ir/manual_op.cc +++ b/paddle/cinn/hlir/dialect/operator/ir/manual_op.cc @@ -15,16 +15,17 @@ #include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h" #include +#include "glog/logging.h" #include "paddle/pir/core/builtin_type.h" +#include "paddle/pir/core/enforce.h" #include "paddle/pir/core/op_base.h" +#include "paddle/pir/dialect/control_flow/ir/cf_ops.h" namespace cinn { namespace dialect { const char *GroupOp::attributes_name[GroupOp::attributes_num] = {"group_info"}; -// TODO(Aurlius84): Need to figure out how to rebuild relation info of ops outer -// GroupOp void GroupOp::Build(pir::Builder &builder, pir::OperationArgument &argument, const std::vector &output_types) { @@ -32,6 +33,20 @@ void GroupOp::Build(pir::Builder &builder, argument.output_types = output_types; } +void GroupOp::Build(pir::Builder &builder, // NOLINT + pir::OperationArgument &argument, // NOLINT + std::unique_ptr &&block) { + VLOG(4) << "Start build GroupOp"; + if (block && !block->empty()) { + IR_ENFORCE(block->back()->isa()); + auto *op = block->back(); + for (size_t i = 0; i < op->num_operands(); ++i) { + argument.AddOutput(op->operand(i).type()); + } + } + argument.AddRegion()->push_back(block.release()); +} + pir::Block *GroupOp::block() { pir::Region ®ion = (*this)->region(0); if (region.empty()) region.emplace_back(); diff --git a/paddle/cinn/hlir/dialect/operator/ir/manual_op.h b/paddle/cinn/hlir/dialect/operator/ir/manual_op.h index 39d433790be78..ba116d52a98c0 100644 --- a/paddle/cinn/hlir/dialect/operator/ir/manual_op.h +++ b/paddle/cinn/hlir/dialect/operator/ir/manual_op.h @@ -33,6 +33,10 @@ class GroupOp : public pir::Op { pir::OperationArgument &argument, // NOLINT const std::vector &output_types); + static void Build(pir::Builder &builder, // NOLINT + pir::OperationArgument &argument, // NOLINT + std::unique_ptr &&block); + pir::Block *block(); std::vector ops(); diff --git a/paddle/fluid/pir/transforms/build_cinn_pass.cc b/paddle/fluid/pir/transforms/build_cinn_pass.cc index 245e26cabad7e..2d8ea8ff0f501 100644 --- a/paddle/fluid/pir/transforms/build_cinn_pass.cc +++ b/paddle/fluid/pir/transforms/build_cinn_pass.cc @@ -551,7 +551,6 @@ void ReplaceWithGroupOp(pir::Block* block, // step 1: Ensure the insert point and create GroupOp here. auto* laste_input_op = group_ops.back(); builder.SetInsertionPointAfter(laste_input_op); - // TODO(Aurelius84): Need confirm how many YieldOps we need. std::vector output_types; std::vector outputs = AnalysisOutputs(group_ops); for (auto& value : outputs) { diff --git a/test/cpp/pir/cinn/group_op_test.cc b/test/cpp/pir/cinn/group_op_test.cc index 6e0f05a8cb244..c252c06a3cccd 100644 --- a/test/cpp/pir/cinn/group_op_test.cc +++ b/test/cpp/pir/cinn/group_op_test.cc @@ -89,3 +89,58 @@ TEST(GroupOp, TestBuild) { ++i; } } + +std::shared_ptr<::pir::Program> BuildGroupProgramByBlock() { + ::pir::IrContext* ctx = ::pir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect<::pir::ControlFlowDialect>(); + + auto program = std::make_shared<::pir::Program>(ctx); + ::pir::Builder builder = ::pir::Builder(ctx, program->block()); + + // ------- Group op 1 --------- + const float value_one = 1.0; + const std::vector shape = {64, 128}; + std::unique_ptr<::pir::Block> block1(new ::pir::Block()); + builder.SetInsertionPointToEnd(block1.get()); + auto full_op_x = builder.Build( + shape, value_one, phi::DataType::FLOAT32, phi::GPUPlace()); + builder.Build<::pir::YieldOp>(std::vector<::pir::Value>{full_op_x.out()}); + + builder.SetInsertionPointToEnd(program->block()); + auto group_op1 = builder.Build(std::move(block1)); + + // ------- Group op 2 --------- + std::unique_ptr<::pir::Block> block2(new ::pir::Block()); + builder.SetInsertionPointToEnd(block2.get()); + auto tan_op_x = builder.Build(group_op1->result(0)); + auto relu_op_x = builder.Build(tan_op_x->result(0)); + auto tan_op_y = builder.Build(relu_op_x->result(0)); + auto relu_op_y = builder.Build(tan_op_y->result(0)); + builder.Build<::pir::YieldOp>(std::vector<::pir::Value>{relu_op_y.out()}); + + builder.SetInsertionPointToEnd(program->block()); + auto group_op2 = builder.Build(std::move(block2)); + + return program; +} + +TEST(GroupOp, TestBuildByBlock) { + // Step 1: Construct pir::Program + std::shared_ptr<::pir::Program> program = BuildGroupProgramByBlock(); + std::stringstream ss; + program->Print(ss); + LOG(INFO) << ss.str(); + + EXPECT_EQ(program->block()->size(), 2u); + LOG(INFO) << program->block()->size(); + std::vector op_num = {2, 5}; + int i = 0; + for (auto* sub_op : *(program->block())) { + EXPECT_TRUE(sub_op->isa()); + EXPECT_EQ(sub_op->dyn_cast().ops().size(), + op_num[i]); + ++i; + } +}