From dcdfdb65ef7b5c01191e9a45923e15c7af303768 Mon Sep 17 00:00:00 2001 From: hong <43953930+phlrain@users.noreply.github.com> Date: Sun, 12 Nov 2023 14:07:04 +0800 Subject: [PATCH] [PIR]Pir cinn support elementwise pow and reshape (#58908) * pir cinn support scale op * revert code * fix bug * pir_cinn_support_elementwise_pow_and_reshape * update * rever some code * update * update --- paddle/cinn/hlir/dialect/operator/ir/ops.yaml | 8 +++ .../add_broadcast_to_elementwise_pass.cc | 2 + .../transforms/op_with_group_merge_pass.cc | 5 +- .../operator/transforms/pd_to_cinn_pass.cc | 44 ++++++++++++ paddle/cinn/hlir/framework/pir/utils.cc | 14 +++- .../op_generator/op_creator_drr_gen.py | 5 +- test/cpp/pir/cinn/pir_all_path_test.cc | 68 +++++++++++++++++++ 7 files changed, 141 insertions(+), 5 deletions(-) diff --git a/paddle/cinn/hlir/dialect/operator/ir/ops.yaml b/paddle/cinn/hlir/dialect/operator/ir/ops.yaml index d5923179cf49e1..38fe8674f88fae 100644 --- a/paddle/cinn/hlir/dialect/operator/ir/ops.yaml +++ b/paddle/cinn/hlir/dialect/operator/ir/ops.yaml @@ -24,6 +24,14 @@ kernel : func : frobenius_norm +- op : reshape + args : (Tensor x, int[] shape) + output : Tensor(out) + infer_meta : + func : ReshapeInferMeta + kernel : + func : reshape + - op : scale args : (Tensor x, float scale=1.0, float bias=0.0, bool bias_after_scale=true) output : Tensor(out) diff --git a/paddle/cinn/hlir/dialect/operator/transforms/add_broadcast_to_elementwise_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/add_broadcast_to_elementwise_pass.cc index 8cbd9916715099..30dfe442d2bdcf 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/add_broadcast_to_elementwise_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/add_broadcast_to_elementwise_pass.cc @@ -151,6 +151,8 @@ bool AddBroadcastToElementwisePass::Initialize(pir::IrContext* context) { ps.Add>(context); ps.Add>(context); ps.Add>(context); + ps.Add>( + context); patterns_ = ::pir::FrozenRewritePatternSet(std::move(ps)); return true; diff --git a/paddle/cinn/hlir/dialect/operator/transforms/op_with_group_merge_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/op_with_group_merge_pass.cc index 62fa82ddb9b568..7cb6519b974acd 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/op_with_group_merge_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/op_with_group_merge_pass.cc @@ -42,9 +42,12 @@ std::unordered_map OpKindMap = { {"pd_op.exp", OpPatternKind::kElementWise}, {"pd_op.sin", OpPatternKind::kElementWise}, {"pd_op.cos", OpPatternKind::kElementWise}, + {"pd_op.pow", OpPatternKind::kElementWise}, + {"pd_op.elementwise_pow", OpPatternKind::kElementWise}, + {"pd_op.sum", OpPatternKind::kReduction}, + {"cinn_op.reshape", OpPatternKind::kElementWise}, {"pd_op.cast", OpPatternKind::kElementWise}, {"pd_op.greater_than", OpPatternKind::kElementWise}, - {"pd_op.sum", OpPatternKind::kReduction}, {"cinn_op.scale", OpPatternKind::kElementWise}, {"cinn_op.reduce_sum", OpPatternKind::kReduction}, {"cinn_op.reduce_max", OpPatternKind::kReduction}, diff --git a/paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc index f558c007d53f17..6a7362c5520b51 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc @@ -130,6 +130,49 @@ class ScaleOpPattern : public pir::OpRewritePattern { } }; +class ReshapeOpPattern + : public pir::OpRewritePattern { + public: + using pir::OpRewritePattern::OpRewritePattern; + + bool MatchAndRewrite(paddle::dialect::ReshapeOp op, + pir::PatternRewriter &rewriter) const override { + auto scale_factor_gen_op = + op->operand_source(1).dyn_cast().owner(); + + if (auto full_op = + scale_factor_gen_op->dyn_cast()) { + // sacle is generator by full op + // get attribute value from full op + + auto out_shape_attr = + full_op.attribute("value").dyn_cast().AsVector(); + + std::vector vec_out_shape; + if (out_shape_attr.size() > 0) { + PADDLE_ENFORCE_EQ( + out_shape_attr[0].isa<::pir::Int64Attribute>(), + true, + phi::errors::Unimplemented( + "the 0th elementwise MUST be ir::Int64Attribute")); + for (size_t i = 0; i < out_shape_attr.size(); ++i) { + vec_out_shape.push_back( + out_shape_attr[i].dyn_cast<::pir::Int64Attribute>().data()); + } + } + + auto cinn_reshape = rewriter.Build( + op->operand_source(0).dyn_cast(), vec_out_shape); + rewriter.ReplaceAllUsesWith(op.result(0), cinn_reshape.result(0)); + rewriter.EraseOp(op); + rewriter.EraseOp(full_op); + + return true; + } + return false; + } +}; + class UniformOpPattern : public pir::drr::DrrPatternBase { public: void operator()(pir::drr::DrrPatternContext *ctx) const override { @@ -186,6 +229,7 @@ bool PdOpToCinnOpPass::Initialize(pir::IrContext *context) { context); // NOTE, scale op pattern should before AddBroadcastTo ps.Add(SumOpPattern().Build(context)); ps.Add(MaxOpPattern().Build(context)); + ps.Add(context); // ps.Add(UniformOpPattern().Build(context)); patterns_ = ::pir::FrozenRewritePatternSet(std::move(ps)); diff --git a/paddle/cinn/hlir/framework/pir/utils.cc b/paddle/cinn/hlir/framework/pir/utils.cc index 80815204a7429b..0e3c2df6f46e40 100644 --- a/paddle/cinn/hlir/framework/pir/utils.cc +++ b/paddle/cinn/hlir/framework/pir/utils.cc @@ -38,13 +38,21 @@ const std::unordered_map CompatibleInfo::OP_NAMES = { {"pd_op.add", "elementwise_add"}, {"pd_op.subtract", "subtract"}, {"pd_op.divide", "divide"}, + {"pd_op.elementwise_pow", "pow"}, {"pd_op.multiply", "elementwise_mul"}, - {"cinn_op.broadcast", "broadcast_to"}, - {"cinn_op.scale", "scale"}}; + {"cinn_op.reshape", "reshape"}, + {"cinn_op.scale", "scale"}, + {"cinn_op.broadcast", "broadcast_to"}}; // Tagging PaddleDialect Op with REGITER_OP_MAPPER(OP) const std::unordered_set CompatibleInfo::CINN_WHITE_OPS = { - "subtract", "divide", "broadcast_to", "multiply", "scale"}; + "subtract", + "divide", + "broadcast_to", + "multiply", + "scale", + "elementwise_pow", + "reshape"}; bool CompatibleInfo::IsSupportCinn(const ::pir::Operation& op) { return CINN_WHITE_OPS.find(CompatibleInfo::OpName(op)) != diff --git a/paddle/fluid/pir/dialect/op_generator/op_creator_drr_gen.py b/paddle/fluid/pir/dialect/op_generator/op_creator_drr_gen.py index a6661a80a9a296..d4bc1d70f6cbb9 100644 --- a/paddle/fluid/pir/dialect/op_generator/op_creator_drr_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/op_creator_drr_gen.py @@ -123,7 +123,10 @@ def gen_cpp_file_code(self, cpp_file_path): if len(op_info_item.attribute_name_list) > 0: params_no_mutable_attr.append("attrs") - if len(op_info_item.mutable_attribute_name_list) == 0: + if ( + self.dialect_name != "pd_op" + or len(op_info_item.mutable_attribute_name_list) == 0 + ): body_code += NORMAL_FUNCTION_TEMPLATE.format( op_name=ir_op_name, namespace=Dialect2NameSpaceMap[self.dialect_name], diff --git a/test/cpp/pir/cinn/pir_all_path_test.cc b/test/cpp/pir/cinn/pir_all_path_test.cc index 47fa42bb25bd54..03f481c956a04e 100644 --- a/test/cpp/pir/cinn/pir_all_path_test.cc +++ b/test/cpp/pir/cinn/pir_all_path_test.cc @@ -427,3 +427,71 @@ TEST(GroupOp, TestBuildScaleTensor) { bool res0 = simple_cmp(out_tensor.data()[0], 0.5); EXPECT_EQ(res0, true); } + +std::shared_ptr<::pir::Program> BuildPowerProgram() { + ::pir::IrContext* ctx = ::pir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); + auto program = std::make_shared<::pir::Program>(ctx); + ::pir::Builder builder = ::pir::Builder(ctx, program->block()); + + auto x = builder + .Build(std::vector({16, 16}), + 2.0, + phi::DataType::FLOAT32, + phi::GPUPlace()) + .result(0); + + auto factor = + builder + .Build(std::vector({16, 16}), + 2.0, + phi::DataType::FLOAT32, + phi::GPUPlace()) + .result(0); + + auto power = + builder.Build(x, factor).result(0); + auto out = + builder + .Build(power, std::vector({-1})) + .result(0); + + builder.Build(out, "out", 0); + return program; +} + +TEST(GroupOp, TestBuildPower) { + // Step 1: Construct pir::Program + ::pir::IrContext* ctx = ::pir::IrContext::Instance(); + std::shared_ptr<::pir::Program> program = BuildPowerProgram(); + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); + + cinn::dialect::ir::PdOp2CinnOpConverter(program.get()); + + pir::PassManager pm(ctx); + pm.AddPass( + std::make_unique()); + pm.AddPass(pir::CreateBuildCinnPass()); + CHECK_EQ(pm.Run(program.get()), true); + + auto res = cinn::dialect::ir::CINNGroupLoweringPass(program.get()); + + paddle::platform::Place place = paddle::platform::CUDAPlace(0); + + auto kernel_program = + paddle::dialect::PdOpLowerToKernelPass(res.get(), place); + + paddle::framework::Scope exe_scope; + + paddle::framework::InterpreterCore executor( + place, {"out@fetch"}, kernel_program->block(), &exe_scope); + + executor.Run({}, true); + + auto out_tensor = + executor.local_scope()->FindVar("out@fetch")->Get(); + + bool res0 = simple_cmp(out_tensor.data()[0], 4.0); + EXPECT_EQ(res0, true); +}