From 01135059cb24e2946f7168339b304cd1fcacfdd9 Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Thu, 4 Jan 2024 12:09:13 +0000 Subject: [PATCH] [PIR] Add unittest for Operation::Clone and Group::Clone --- paddle/cinn/hlir/framework/pir/group.h | 12 +-- paddle/pir/core/operation.cc | 4 +- test/cpp/pir/cinn/group_op_test.cc | 110 +++++++++++++++++++++++++ 3 files changed, 118 insertions(+), 8 deletions(-) diff --git a/paddle/cinn/hlir/framework/pir/group.h b/paddle/cinn/hlir/framework/pir/group.h index 2cd3b9b9deddaa..535260c2ee96b9 100644 --- a/paddle/cinn/hlir/framework/pir/group.h +++ b/paddle/cinn/hlir/framework/pir/group.h @@ -68,22 +68,22 @@ struct Group { // Mapper from original to new ops. std::unordered_map<::pir::Operation*, ::pir::Operation*> ops_mapper; ::pir::CloneOptions clone_options(false, true); - for (auto* op : this->ops_set) { + for (auto* op : ops) { + VLOG(4) << "clone op :" << op->name(); auto* new_op = op->Clone(ir_mapping, clone_options); - // NOTE(dev): Must call MoveTo to deal with ownership, otherwise it + // NOTE(dev): Must call block.insert to deal with ownership, otherwise it // will lead memory-leak. - new_op->MoveTo(target_block, target_block->end()); + target_block->insert(target_block->end(), new_op); new_ops.push_back(new_op); ops_mapper[op] = new_op; } // Construct Base information for new Group auto new_group = std::make_shared(new_ops); - this->CollectOps(); for (auto& iter : this->input_ops) { - new_group->input_ops[ops_mapper[iter.first]] = iter.second; + new_group->input_ops[ops_mapper.at(iter.first)] = iter.second; } for (auto* op : this->output_ops) { - new_group->output_ops.insert(ops_mapper[op]); + new_group->output_ops.insert(ops_mapper.at(op)); } return new_group; diff --git a/paddle/pir/core/operation.cc b/paddle/pir/core/operation.cc index 0a8e26d788ca15..4d14213dd9f910 100644 --- a/paddle/pir/core/operation.cc +++ b/paddle/pir/core/operation.cc @@ -138,8 +138,8 @@ Operation *Operation::Create(const std::vector &inputs, } Operation *Operation::Clone(IrMapping &ir_mapping, CloneOptions options) { - IR_ENFORCE(options.IsCloneRegions() || num_regions_ > 0, - "Operation CloneOperands is unimplemented currently."); + IR_ENFORCE(!options.IsCloneRegions() || num_regions_ <= 0, + "Operation CloneRegions is unimplemented currently."); IR_ENFORCE(num_successors_ == 0, "Operation::Clone is not unimplemented for multiple successors."); diff --git a/test/cpp/pir/cinn/group_op_test.cc b/test/cpp/pir/cinn/group_op_test.cc index 7fcc8ae6bef317..9f9643c85f4a85 100644 --- a/test/cpp/pir/cinn/group_op_test.cc +++ b/test/cpp/pir/cinn/group_op_test.cc @@ -20,6 +20,7 @@ #include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h" #include "paddle/cinn/hlir/dialect/operator/ir/op_dialect.h" #include "paddle/cinn/hlir/dialect/operator/transforms/group_merge/cinn_group_lowering_pass.h" +#include "paddle/cinn/hlir/framework/pir/group.h" #include "paddle/fluid/framework/new_executor/interpretercore.h" #include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" @@ -243,3 +244,112 @@ TEST(GroupOp, CINNLowering) { EXPECT_EQ(res2, true); EXPECT_EQ(res3, true); } + +class GroupOpPattern : public pir::OpRewritePattern { + public: + using pir::OpRewritePattern::OpRewritePattern; + using Group = cinn::hlir::framework::pir::Group; + + bool MatchAndRewrite(cinn::dialect::GroupOp group_op, + pir::PatternRewriter& rewriter) const override { + ::pir::IrContext* ctx = ::pir::IrContext::Instance(); + auto* program = group_op->GetParentProgram(); + ::pir::Builder builder = ::pir::Builder(ctx, program->block()); + VLOG(4) << "Before GroupOpPattern: " << *program; + std::vector<::pir::Operation*> group_ops = group_op.ops(); + auto yeild_op = group_ops.back(); + std::vector<::pir::Type> output_type{yeild_op->operand_source(0).type()}; + + // construct hlir::Group + Group group({group_ops.begin(), group_ops.end() - 1}); + group.input_ops[group_ops[0]] = 0; // first tan + auto last_op_idx = group_ops.size() - 2; + group.output_ops.insert(group_ops[last_op_idx]); // last relu + + // clone group and sync their op into new GroupOp + builder.SetInsertionPointAfter(group_op.operation()); + auto new_group_op = builder.Build(output_type); + + // prepare IrMapping + ::pir::IrMapping ir_mapping; + auto depend_value = group_ops[0]->operand_source(0); + ir_mapping.Add(depend_value, depend_value); + std::shared_ptr new_group = + group.Clone(new_group_op.block(), ir_mapping); + + EXPECT_EQ(new_group->ops.size(), group.ops.size()); + EXPECT_EQ(new_group->input_ops.size(), group.input_ops.size()); + EXPECT_EQ(new_group->output_ops.size(), group.output_ops.size()); + + // Add yield op + builder.SetInsertionPointToBlockEnd(new_group_op.block()); + std::vector<::pir::Value> yield_inputs{ + new_group_op.ops().back()->result(0)}; + builder.Build<::pir::YieldOp>(yield_inputs); + EXPECT_EQ(new_group_op.ops().size(), group_ops.size()); + + // replace result UD between GroupOp + rewriter.ReplaceAllUsesWith(group_op->result(0), new_group_op->result(0)); + rewriter.EraseOp(group_op); + VLOG(4) << "After GroupOpPattern.EraseOp: " << *program; + return true; + } +}; + +class TestGroupClonePass : public pir::PatternRewritePass { + public: + TestGroupClonePass() : pir::PatternRewritePass("test_group_clone", 1) {} + + pir::RewritePatternSet InitializePatterns(pir::IrContext* context) override { + pir::RewritePatternSet ps(context); + ps.Add(context); + + return ps; + } + + bool CanApplyOn(pir::Operation* op) const override { + return op->isa() && op->num_regions() > 0; + } +}; + +std::shared_ptr<::pir::Program> BuildSingleGroupProgram() { + ::pir::IrContext* ctx = ::pir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect<::pir::ControlFlowDialect>(); + + auto program = std::make_shared<::pir::Program>(ctx); + ::pir::Builder builder = ::pir::Builder(ctx, program->block()); + const std::vector shape = {64, 128}; + // full op + auto full_x = builder.Build( + shape, 0.5, phi::DataType::FLOAT32, phi::GPUPlace()); + + // group op + auto group_op = builder.Build( + CreateDenseTensorTypes(common::make_ddim(shape))); + pir::Block* block = group_op.block(); + builder.SetInsertionPointToBlockEnd(block); + + auto tan_op_x = builder.Build(full_x->result(0)); + auto relu_op_x = builder.Build(tan_op_x->result(0)); + auto tan_op_y = builder.Build(relu_op_x->result(0)); + auto relu_op_y = builder.Build(tan_op_y->result(0)); + builder.Build<::pir::YieldOp>(std::vector<::pir::Value>{relu_op_y.out()}); + + // tan op + builder.SetInsertionPointToBlockEnd(program->block()); + auto final_op = builder.Build(group_op->result(0)); + + return program; +} + +TEST(Group, Clone) { + // Step 1: Construct pir::Program + std::shared_ptr<::pir::Program> program = BuildSingleGroupProgram(); + ::pir::IrContext* ctx = ::pir::IrContext::Instance(); + ::pir::PassManager pm(ctx); + // Step 2: Run TestGroupClonePass + pm.AddPass(std::make_unique()); + pm.Run(program.get()); +}