diff --git a/paddle/cinn/hlir/dialect/operator/transforms/cinn_group_lowering_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/cinn_group_lowering_pass.cc index 91adf4c14fba4..31a32d7494b0f 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/cinn_group_lowering_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/cinn_group_lowering_pass.cc @@ -140,10 +140,13 @@ std::unique_ptr CINNGroupLoweringPass(::pir::Program* program) { auto group_list = cinn::dialect::ir::GeneralFusionMergePassInternal(op_fusion); - PADDLE_ENFORCE_EQ(group_list.size(), - 1u, - phi::errors::Unimplemented( - "Only support one group after group fusion")); + // using yield op to sort + std::unordered_map<::pir::Value, size_t> value2id; + auto yeild_op = group_op.ops().back(); + for (size_t i = 0; i < yeild_op->num_operands(); ++i) { + value2id[yeild_op->operand_source(i)] = i; + } + for (auto group : group_list) { auto ir_compiler = std::make_shared( *program, target, scope); @@ -162,26 +165,23 @@ std::unique_ptr CINNGroupLoweringPass(::pir::Program* program) { vec_new_ins.push_back(value_map.at(vec_ins[i])); } - // using yield op to sort - std::unordered_map<::pir::Value, size_t> value2id; - auto yeild_op = group_op.ops().back(); - for (size_t i = 0; i < yeild_op->num_operands(); ++i) { - value2id[yeild_op->operand_source(i)] = i; - } - std::unordered_map codegen2orig; std::vector vec_types; for (size_t i = 0; i < group->output_values.size(); ++i) { vec_types.push_back(group->output_values[i].type()); - codegen2orig[value2id.at(group->output_values[i])] = i; } ::pir::Operation* cinn_op = ::pir::Operation::Create(vec_new_ins, op_attrs, vec_types, op_info); - for (size_t i = 0; i < group_op.num_results(); ++i) { - value_map[group_op.result(i)] = cinn_op->result(codegen2orig.at(i)); + for (size_t i = 0; i < cinn_op->num_results(); ++i) { + auto find_it = value2id.find(group->output_values[i]); + if (find_it == value2id.end()) { + value_map[group->output_values[i]] = cinn_op->result(i); + } else { + value_map[group_op.result(find_it->second)] = cinn_op->result(i); + } } ir_program->block()->push_back(cinn_op); diff --git a/paddle/cinn/hlir/dialect/operator/transforms/group_with_group_merge_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/group_with_group_merge_pass.cc index fcf7f2242f09b..f64045046e71d 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/group_with_group_merge_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/group_with_group_merge_pass.cc @@ -17,6 +17,7 @@ #include #include "paddle/cinn/hlir/dialect/operator/transforms/op_group.h" +#include "paddle/pir/core/ir_printer.h" #include "paddle/pir/core/value.h" #include "paddle/cinn/hlir/dialect/operator/transforms/group_with_group_merge_pass_utils.h" diff --git a/test/cpp/pir/cinn/pir_all_path_test.cc b/test/cpp/pir/cinn/pir_all_path_test.cc index 03f481c956a04..6814990428461 100644 --- a/test/cpp/pir/cinn/pir_all_path_test.cc +++ b/test/cpp/pir/cinn/pir_all_path_test.cc @@ -220,7 +220,7 @@ TEST(GroupOp, TestBuildLayerNorm) { // executor.Run({}, true); // auto out_tensor = - // executor.local_scope()->FindVar("out@fetch")->Get(); + // executor.local_scope()->FindVar("out@fetch")->Get(); } std::shared_ptr<::pir::Program> BuildDropOutProgram() { @@ -495,3 +495,75 @@ TEST(GroupOp, TestBuildPower) { bool res0 = simple_cmp(out_tensor.data()[0], 4.0); EXPECT_EQ(res0, true); } + +std::shared_ptr<::pir::Program> BuildSum2GroupProgram() { + ::pir::IrContext* ctx = ::pir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); + auto program = std::make_shared<::pir::Program>(ctx); + ::pir::Builder builder = ::pir::Builder(ctx, program->block()); + + auto x = builder + .Build(std::vector({16, 16}), + 0.0, + phi::DataType::FLOAT32, + phi::GPUPlace()) + .result(0); + + auto cos = builder.Build(x).result(0); + + auto y = builder + .Build(std::vector({8, 8}), + 0.0, + phi::DataType::FLOAT32, + phi::GPUPlace()) + .result(0); + + auto sin = builder.Build(y).result(0); + + builder.Build(cos, "out", 0); + builder.Build(sin, "out2", 0); + return program; +} + +TEST(GroupOp, TestBuildSum2Group) { + // Step 1: Construct pir::Program + ::pir::IrContext* ctx = ::pir::IrContext::Instance(); + std::shared_ptr<::pir::Program> program = BuildSum2GroupProgram(); + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); + + cinn::dialect::ir::PdOp2CinnOpConverter(program.get()); + + pir::PassManager pm(ctx); + pm.AddPass( + std::make_unique()); + pm.AddPass(pir::CreateBuildCinnPass()); + CHECK_EQ(pm.Run(program.get()), true); + + auto res = cinn::dialect::ir::CINNGroupLoweringPass(program.get()); + + paddle::platform::Place place = paddle::platform::CUDAPlace(0); + + res->Print(std::cout); + auto kernel_program = + paddle::dialect::PdOpLowerToKernelPass(res.get(), place); + + paddle::framework::Scope exe_scope; + + paddle::framework::InterpreterCore executor( + place, {"out@fetch"}, kernel_program->block(), &exe_scope); + + executor.Run({}, true); + + auto out_tensor = + executor.local_scope()->FindVar("out@fetch")->Get(); + + auto out_tensor2 = + executor.local_scope()->FindVar("out2@fetch")->Get(); + + bool res0 = simple_cmp(out_tensor.data()[0], 1.0); + EXPECT_EQ(res0, true); + + bool res1 = (out_tensor2.data()[0] == 0.0); + EXPECT_EQ(res1, true); +}