diff --git a/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc b/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc index d6f131cb9d06c4..cbb7e878e04a0b 100644 --- a/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc +++ b/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc @@ -77,6 +77,9 @@ const std::unordered_set SpecialLowerOps = { pir::YieldOp::name(), IfOp::name(), WhileOp::name(), + pir::CreateStackOp::name(), + pir::PushBackOp::name(), + pir::PopBackOp::name(), "cinn_runtime.jit_kernel"}; static bool NeedFallBackCpu(const pir::Operation* op, @@ -1017,6 +1020,45 @@ void HandleForSpecialOp( } } + if (op_item->isa<::pir::CreateStackOp>() || + op_item->isa<::pir::PushBackOp>()) { + for (size_t i = 0; i < op_item->num_operands(); ++i) { + auto cur_in = op_item->operand_source(i); + if (!cur_in) { + vec_inputs.emplace_back(); + continue; + } + auto new_in = GetNewInput( + cur_in, *map_value_pair, static_cast(i), op_item->name()); + vec_inputs.push_back(new_in); + } + for (size_t i = 0; i < op_item->num_results(); ++i) { + op_output_types.push_back(op_item->result(i).type()); + } + } + + if (op_item->isa<::pir::PopBackOp>()) { + for (size_t i = 0; i < op_item->num_operands(); ++i) { + auto cur_in = op_item->operand_source(i); + auto new_in = GetNewInput( + cur_in, *map_value_pair, static_cast(i), op_item->name()); + vec_inputs.push_back(new_in); + } + + auto pop_back_op = op_item->dyn_cast<::pir::PopBackOp>(); + for (size_t i = 0; i < op_item->num_results(); ++i) { + auto cur_inlet_element = pop_back_op.inlet_element(i); + PADDLE_ENFORCE_EQ(map_value_pair->count(cur_inlet_element), + true, + phi::errors::PreconditionNotMet( + "[%d]'s output of [%s] op MUST be in map pair", + i, + op_item->name())); + auto new_inlet_element = map_value_pair->at(cur_inlet_element); + + op_output_types.push_back(new_inlet_element.type()); + } + } if (op_item->name() == "cinn_runtime.jit_kernel") { if (op_item->num_operands() > 0) { for (size_t i = 0; i < op_item->num_operands(); ++i) { @@ -1049,7 +1091,7 @@ void HandleForSpecialOp( (*map_value_pair)[op_item->result(i)] = op->result(i); } } - VLOG(6) << "Deep copy a new builtin op: " << op_item->name(); + VLOG(6) << "Deep copy a new special op: " << op_item->name(); } std::vector BuildOutputs(pir::Operation* op_item, diff --git a/paddle/pir/dialect/control_flow/ir/cf_op.cc b/paddle/pir/dialect/control_flow/ir/cf_op.cc index ee34872f54799b..6828701d5961d9 100644 --- a/paddle/pir/dialect/control_flow/ir/cf_op.cc +++ b/paddle/pir/dialect/control_flow/ir/cf_op.cc @@ -169,8 +169,10 @@ void PopBackOp::VerifySig() { "The pop elements size must equal to push elements size."); for (size_t index = 0; index < inlet_size; ++index) { IR_ENFORCE(outlet_element(index).type() == inlet_element(index).type(), - "The %d element's push type isn't equal to pop type", - index); + "The %d element's push type (%s) isn't equal to pop type (%s)", + index, + outlet_element(index).type(), + inlet_element(index).type()); } VLOG(4) << "End Verifying for PopBackOp."; } diff --git a/test/cpp/pir/control_flow_dialect/CMakeLists.txt b/test/cpp/pir/control_flow_dialect/CMakeLists.txt index d295fa37b6b730..9309a8a774b297 100644 --- a/test/cpp/pir/control_flow_dialect/CMakeLists.txt +++ b/test/cpp/pir/control_flow_dialect/CMakeLists.txt @@ -5,6 +5,8 @@ cc_test_old( DEPS pir op_dialect_vjp + pir_transforms + op_dialect gtest) cc_test_old( @@ -14,4 +16,6 @@ cc_test_old( DEPS pir op_dialect_vjp + pir_transforms + op_dialect gtest) diff --git a/test/cpp/pir/control_flow_dialect/if_op_test.cc b/test/cpp/pir/control_flow_dialect/if_op_test.cc index b1b6b3abe005dd..64ccbcdf270224 100644 --- a/test/cpp/pir/control_flow_dialect/if_op_test.cc +++ b/test/cpp/pir/control_flow_dialect/if_op_test.cc @@ -14,15 +14,25 @@ #include #include +#include "paddle/fluid/pir/dialect/kernel/ir/kernel_dialect.h" #include "paddle/fluid/pir/dialect/operator/ir/control_flow_op.h" #include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" +#include "paddle/fluid/pir/transforms/pd_op_to_kernel_pass.h" +#include "paddle/phi/core/kernel_registry.h" #include "paddle/pir/core/builder.h" #include "paddle/pir/core/builtin_op.h" #include "paddle/pir/core/program.h" #include "paddle/pir/dialect/control_flow/ir/cf_dialect.h" #include "paddle/pir/dialect/control_flow/ir/cf_op.h" +PD_DECLARE_KERNEL(full, CPU, ALL_LAYOUT); +PD_DECLARE_KERNEL(add, CPU, ALL_LAYOUT); +PD_DECLARE_KERNEL(matmul_grad, CPU, ALL_LAYOUT); +PD_DECLARE_KERNEL(add_grad, CPU, ALL_LAYOUT); +PD_DECLARE_KERNEL(matmul, CPU, ALL_LAYOUT); +PD_DECLARE_KERNEL(less_than, CPU, ALL_LAYOUT); + using namespace paddle::dialect; // NOLINT TEST(if_op_test, base) { @@ -99,6 +109,7 @@ TEST(if_op_test, network_with_backward) { pir::IrContext* ctx = pir::IrContext::Instance(); ctx->GetOrRegisterDialect(); ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); pir::Program program(ctx); pir::Block* block = program.block(); @@ -169,4 +180,6 @@ TEST(if_op_test, network_with_backward) { builder.SetInsertionPointToEnd(block); LOG(INFO) << program; + + auto kernel_program = paddle::dialect::PdOpLowerToKernelPass(&program); }