Skip to content

Commit

Permalink
[PIR+CINN]Support construct block first for GroupOp (PaddlePaddle#58216)
Browse files Browse the repository at this point in the history
  • Loading branch information
Aurelius84 authored and jiahy0825 committed Oct 26, 2023
1 parent 84ef5bc commit 37cab3a
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 3 deletions.
19 changes: 17 additions & 2 deletions paddle/cinn/hlir/dialect/operator/ir/manual_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,38 @@
#include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h"

#include <vector>
#include "glog/logging.h"
#include "paddle/pir/core/builtin_type.h"
#include "paddle/pir/core/enforce.h"
#include "paddle/pir/core/op_base.h"
#include "paddle/pir/dialect/control_flow/ir/cf_ops.h"

namespace cinn {
namespace dialect {

const char *GroupOp::attributes_name[GroupOp::attributes_num] = {"group_info"};

// TODO(Aurlius84): Need to figure out how to rebuild relation info of ops outer
// GroupOp
void GroupOp::Build(pir::Builder &builder,
pir::OperationArgument &argument,
const std::vector<pir::Type> &output_types) {
argument.AddRegion(nullptr);
argument.output_types = output_types;
}

void GroupOp::Build(pir::Builder &builder, // NOLINT
pir::OperationArgument &argument, // NOLINT
std::unique_ptr<pir::Block> &&block) {
VLOG(4) << "Start build GroupOp";
if (block && !block->empty()) {
IR_ENFORCE(block->back()->isa<pir::YieldOp>());
auto *op = block->back();
for (size_t i = 0; i < op->num_operands(); ++i) {
argument.AddOutput(op->operand(i).type());
}
}
argument.AddRegion()->push_back(block.release());
}

pir::Block *GroupOp::block() {
pir::Region &region = (*this)->region(0);
if (region.empty()) region.emplace_back();
Expand Down
4 changes: 4 additions & 0 deletions paddle/cinn/hlir/dialect/operator/ir/manual_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ class GroupOp : public pir::Op<GroupOp> {
pir::OperationArgument &argument, // NOLINT
const std::vector<pir::Type> &output_types);

static void Build(pir::Builder &builder, // NOLINT
pir::OperationArgument &argument, // NOLINT
std::unique_ptr<pir::Block> &&block);

pir::Block *block();
std::vector<pir::Operation *> ops();

Expand Down
1 change: 0 additions & 1 deletion paddle/fluid/pir/transforms/build_cinn_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -551,7 +551,6 @@ void ReplaceWithGroupOp(pir::Block* block,
// step 1: Ensure the insert point and create GroupOp here.
auto* laste_input_op = group_ops.back();
builder.SetInsertionPointAfter(laste_input_op);
// TODO(Aurelius84): Need confirm how many YieldOps we need.
std::vector<pir::Type> output_types;
std::vector<pir::Value> outputs = AnalysisOutputs(group_ops);
for (auto& value : outputs) {
Expand Down
55 changes: 55 additions & 0 deletions test/cpp/pir/cinn/group_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,58 @@ TEST(GroupOp, TestBuild) {
++i;
}
}

std::shared_ptr<::pir::Program> BuildGroupProgramByBlock() {
::pir::IrContext* ctx = ::pir::IrContext::Instance();
ctx->GetOrRegisterDialect<paddle::dialect::OperatorDialect>();
ctx->GetOrRegisterDialect<cinn::dialect::OperatorDialect>();
ctx->GetOrRegisterDialect<::pir::ControlFlowDialect>();

auto program = std::make_shared<::pir::Program>(ctx);
::pir::Builder builder = ::pir::Builder(ctx, program->block());

// ------- Group op 1 ---------
const float value_one = 1.0;
const std::vector<int64_t> shape = {64, 128};
std::unique_ptr<::pir::Block> block1(new ::pir::Block());
builder.SetInsertionPointToEnd(block1.get());
auto full_op_x = builder.Build<paddle::dialect::FullOp>(
shape, value_one, phi::DataType::FLOAT32, phi::GPUPlace());
builder.Build<::pir::YieldOp>(std::vector<::pir::Value>{full_op_x.out()});

builder.SetInsertionPointToEnd(program->block());
auto group_op1 = builder.Build<cinn::dialect::GroupOp>(std::move(block1));

// ------- Group op 2 ---------
std::unique_ptr<::pir::Block> block2(new ::pir::Block());
builder.SetInsertionPointToEnd(block2.get());
auto tan_op_x = builder.Build<paddle::dialect::TanOp>(group_op1->result(0));
auto relu_op_x = builder.Build<paddle::dialect::ReluOp>(tan_op_x->result(0));
auto tan_op_y = builder.Build<paddle::dialect::TanOp>(relu_op_x->result(0));
auto relu_op_y = builder.Build<paddle::dialect::ReluOp>(tan_op_y->result(0));
builder.Build<::pir::YieldOp>(std::vector<::pir::Value>{relu_op_y.out()});

builder.SetInsertionPointToEnd(program->block());
auto group_op2 = builder.Build<cinn::dialect::GroupOp>(std::move(block2));

return program;
}

TEST(GroupOp, TestBuildByBlock) {
// Step 1: Construct pir::Program
std::shared_ptr<::pir::Program> program = BuildGroupProgramByBlock();
std::stringstream ss;
program->Print(ss);
LOG(INFO) << ss.str();

EXPECT_EQ(program->block()->size(), 2u);
LOG(INFO) << program->block()->size();
std::vector<uint32_t> op_num = {2, 5};
int i = 0;
for (auto* sub_op : *(program->block())) {
EXPECT_TRUE(sub_op->isa<cinn::dialect::GroupOp>());
EXPECT_EQ(sub_op->dyn_cast<cinn::dialect::GroupOp>().ops().size(),
op_num[i]);
++i;
}
}

0 comments on commit 37cab3a

Please sign in to comment.