From ae188d1885ff8c733fde5578fb09ab7c9ffab725 Mon Sep 17 00:00:00 2001 From: hong <43953930+phlrain@users.noreply.github.com> Date: Fri, 15 Sep 2023 16:37:16 +0800 Subject: [PATCH] [PIR] New ir supprt cond part 1 (#57265) * [IR] add control flow dialect * fix error * add error * test * fix * fix * fix * fix * fix * fix * update * fix * update * update * refactor pd op to kernel pass * fix bug * update * add code * update * fix merge bug * fix bug * update * update * update * fix bug * update * update * fix bug * update * update * fix bug * remove useless code * update * update * update * new ir support cond part 1 * fix compile bug * fix compile bug * fix compile bug * fix bug * open test case --------- Co-authored-by: winter-wang <1030748926@qq.com> --- paddle/fluid/framework/executor_cache.cc | 3 +- paddle/fluid/framework/executor_cache.h | 1 + .../new_executor/interpreter/plan.cc | 5 + .../framework/new_executor/interpreter/plan.h | 5 +- .../framework/new_executor/interpretercore.cc | 4 +- .../framework/new_executor/interpretercore.h | 4 +- .../new_executor/new_ir_interpreter.cc | 21 ++- .../new_executor/new_ir_interpreter.h | 9 +- .../new_executor/standalone_executor.cc | 6 +- .../new_executor/standalone_executor.h | 2 +- .../phi_kernel_adaptor/phi_kernel_adaptor.h | 4 +- .../pir/phi_kernel_adaptor/phi_kernel_util.cc | 59 +++++-- .../pir/phi_kernel_adaptor/phi_kernel_util.h | 3 +- .../pir/transforms/constant_folding_pass.cc | 13 +- .../pir/transforms/pd_op_to_kernel_pass.cc | 162 +++++++++++++----- .../pir/transforms/pd_op_to_kernel_pass.h | 7 + .../standalone_executor_new_ir_test.cc | 4 +- test/cpp/pir/cinn/jit_instruction_test.cc | 3 +- .../ir_kernel_dialect_pass_test.cc | 38 ++++ test/cpp/prim/test_vjp.cc | 14 +- 20 files changed, 271 insertions(+), 96 deletions(-) diff --git a/paddle/fluid/framework/executor_cache.cc b/paddle/fluid/framework/executor_cache.cc index c03c8542b49c6..3f750d961bfbb 100644 --- a/paddle/fluid/framework/executor_cache.cc +++ b/paddle/fluid/framework/executor_cache.cc @@ -344,11 +344,12 @@ std::shared_ptr CreateNewIRInterpreterCoreInfoToCache( std::shared_ptr core = nullptr; core.reset(new InterpreterCore( - place, {}, std::move(ir_program), scope, execution_config)); + place, {}, ir_program->block(), scope, execution_config)); auto &cached_value = interpretercore_info_cache.GetMutable(program_id, scope, is_grad); cached_value.core_ = core; + cached_value.ir_prog_ = std::move(ir_program); return core; } diff --git a/paddle/fluid/framework/executor_cache.h b/paddle/fluid/framework/executor_cache.h index 1c5602a31f872..9fadcab929fd4 100644 --- a/paddle/fluid/framework/executor_cache.h +++ b/paddle/fluid/framework/executor_cache.h @@ -168,6 +168,7 @@ class InterpreterCoreInfo { struct CacheValue { std::shared_ptr core_{nullptr}; std::set skip_eager_delete_vars_; + std::unique_ptr<::pir::Program> ir_prog_{nullptr}; }; bool IsAvailable(bool is_grad) { diff --git a/paddle/fluid/framework/new_executor/interpreter/plan.cc b/paddle/fluid/framework/new_executor/interpreter/plan.cc index ce2f8b2718ff3..ab05c6216426a 100644 --- a/paddle/fluid/framework/new_executor/interpreter/plan.cc +++ b/paddle/fluid/framework/new_executor/interpreter/plan.cc @@ -74,6 +74,11 @@ std::shared_ptr<::pir::Program> Plan::IrProgram( return type_to_ir_program_.at(job_type); } +void Plan::UpdateIrProgram(const std::string& job_type, + std::shared_ptr<::pir::Program> ir_prog) { + type_to_ir_program_[job_type] = ir_prog; +} + int64_t Plan::MicroBatchNum() const { return micro_batch_num_; } } // namespace interpreter diff --git a/paddle/fluid/framework/new_executor/interpreter/plan.h b/paddle/fluid/framework/new_executor/interpreter/plan.h index 8ce66db821305..389eb5c9df84e 100644 --- a/paddle/fluid/framework/new_executor/interpreter/plan.h +++ b/paddle/fluid/framework/new_executor/interpreter/plan.h @@ -43,12 +43,15 @@ class Plan final { const ProgramDesc* Program(const std::string& job_type) const; std::shared_ptr<::pir::Program> IrProgram(const std::string& job_type) const; + void UpdateIrProgram(const std::string& job_type, + std::shared_ptr<::pir::Program> ir_prog); + int64_t MicroBatchNum() const; private: const std::vector> job_list_; const std::unordered_map type_to_program_; - const std::unordered_map> + std::unordered_map> type_to_ir_program_; int64_t micro_batch_num_; }; diff --git a/paddle/fluid/framework/new_executor/interpretercore.cc b/paddle/fluid/framework/new_executor/interpretercore.cc index a2c3c49e1c634..dc8110331a176 100644 --- a/paddle/fluid/framework/new_executor/interpretercore.cc +++ b/paddle/fluid/framework/new_executor/interpretercore.cc @@ -50,12 +50,12 @@ InterpreterCore::InterpreterCore(const platform::Place& place, InterpreterCore::InterpreterCore( const platform::Place& place, const std::vector& fetch_var_names, - std::unique_ptr<::pir::Program> ir_prog, + const ::pir::Block* ir_block, framework::Scope* scope, const ExecutionConfig& execution_config) { VLOG(4) << "InterpreterCore(): " << this << " on " << place; impl_ = std::make_unique( - place, fetch_var_names, std::move(ir_prog), scope, execution_config); + place, fetch_var_names, ir_block, scope, execution_config); } InterpreterCore::~InterpreterCore() { diff --git a/paddle/fluid/framework/new_executor/interpretercore.h b/paddle/fluid/framework/new_executor/interpretercore.h index 52df30cbfd976..47f2d9c6a3378 100644 --- a/paddle/fluid/framework/new_executor/interpretercore.h +++ b/paddle/fluid/framework/new_executor/interpretercore.h @@ -18,7 +18,7 @@ PD_DECLARE_bool(new_executor_use_local_scope); namespace pir { -class Program; +class Block; } // namespace pir namespace paddle { @@ -38,7 +38,7 @@ class InterpreterCore { // This constructor is for New IR. InterpreterCore(const platform::Place& place, const std::vector& fetch_var_names, - std::unique_ptr<::pir::Program> ir_prog, + const ::pir::Block* ir_prog, Scope* scope, const ExecutionConfig& execution_config = ExecutionConfig()); ~InterpreterCore(); diff --git a/paddle/fluid/framework/new_executor/new_ir_interpreter.cc b/paddle/fluid/framework/new_executor/new_ir_interpreter.cc index 193dc4bdf1a52..aea21f5dc17aa 100644 --- a/paddle/fluid/framework/new_executor/new_ir_interpreter.cc +++ b/paddle/fluid/framework/new_executor/new_ir_interpreter.cc @@ -58,14 +58,14 @@ namespace framework { NewIRInterpreter::NewIRInterpreter( const platform::Place& place, const std::vector& fetch_var_names, - std::unique_ptr<::pir::Program> ir_prog, + const ::pir::Block* ir_block, framework::Scope* scope, const ExecutionConfig& execution_config) : place_(place), execution_config_(execution_config), var_scope_(scope), scope_(scope), - ir_program_(std::move(ir_prog)), + ir_block_(ir_block), ir_stream_analyzer_(place), fetch_var_names_(fetch_var_names) { VLOG(4) << "NewIRInterpreter(): " << this << " on " << place_; @@ -92,8 +92,7 @@ NewIRInterpreter::NewIRInterpreter( // TODO(zhangbo): delete var_scope var_scope_.SetLocalScope(local_scope_); - execution_config_.AnalyzeThreadPoolConfig(place, - ir_program_->block()->size()); + execution_config_.AnalyzeThreadPoolConfig(place, 1); execution_config_.Log(/*log_level=*/8); ir_instruction_scheduling_priority_less = [this](size_t lhs, size_t rhs) { @@ -334,6 +333,10 @@ Scope* NewIRInterpreter::InnerScope() { return local_scope_ != nullptr ? local_scope_ : scope_; } +std::string NewIRInterpreter::GetNameByValue(::pir::Value value) const { + return value_2_var_name_.at(value); +} + void NewIRInterpreter::UpdateSyncOpNum() { int64_t sync_op_num = 0; for (auto& ins : vec_instruction_base_) { @@ -501,13 +504,15 @@ void NewIRInterpreter::BuildInstruction() { VLOG(6) << "Build Instructions for new ir ... "; vec_instruction_base_.clear(); size_t op_idx = 0; - for (auto& op : *ir_program_->block()) { + for (auto& op : *ir_block_) { VLOG(6) << "Build Instruction for op: " << op_idx; if (op->dialect()->name() == "builtin") { if (interpreter::GetSpecialOpNames().count(op->name())) { VLOG(6) << "skip process " << op->name(); continue; } + } else if (op->dialect()->name() == "cf") { + continue; } else if (op->dialect()->name() == "pd_kernel") { auto op_name = op->attributes() .at("op_name") @@ -923,13 +928,14 @@ FetchList NewIRInterpreter::Run(const std::vector& feed_names, // Build std::stringstream ss; ss << this; - ::pir::BuildScope(*ir_program_->block(), + ::pir::BuildScope(*ir_block_, InnerScope(), ss.str(), &value_2_var_name_, &variable_2_var_name_, &var_name_2_id_, - &variable_list_); + &variable_list_, + &sub_blocks_); interpreter::BuildId2VarName(var_name_2_id_, &id_2_var_name_); @@ -977,7 +983,6 @@ FetchList NewIRInterpreter::Run(const std::vector& feed_names, if (HasLocalScope()) { ClearLoDTensorArrayInLocalScope(); } - // return Fetch Tensors Scope* inner_scope = InnerScope(); if (FLAGS_enable_new_ir_in_executor) { diff --git a/paddle/fluid/framework/new_executor/new_ir_interpreter.h b/paddle/fluid/framework/new_executor/new_ir_interpreter.h index c0681a277d5f7..cf5cb21ce81aa 100644 --- a/paddle/fluid/framework/new_executor/new_ir_interpreter.h +++ b/paddle/fluid/framework/new_executor/new_ir_interpreter.h @@ -19,7 +19,7 @@ #include "paddle/pir/core/value.h" namespace ir { -class Program; +class Block; } // namespace ir namespace paddle { @@ -36,7 +36,7 @@ class NewIRInterpreter : public InterpreterBaseImpl { public: NewIRInterpreter(const platform::Place& place, const std::vector& fetch_var_names, - std::unique_ptr<::pir::Program> ir_prog, + const ::pir::Block* ir_block, Scope* scope, const ExecutionConfig& execution_config = ExecutionConfig()); @@ -81,6 +81,8 @@ class NewIRInterpreter : public InterpreterBaseImpl { int GetIdByName(const std::string& name) const; + std::string GetNameByValue(::pir::Value value) const; + private: // build graph void UpdateSyncOpNum(); @@ -198,7 +200,7 @@ class NewIRInterpreter : public InterpreterBaseImpl { InstructionSchedulingPriorityLess ir_instruction_scheduling_priority_less; - std::unique_ptr<::pir::Program> ir_program_{nullptr}; + const ::pir::Block* ir_block_{nullptr}; std::vector> vec_instruction_base_; @@ -211,6 +213,7 @@ class NewIRInterpreter : public InterpreterBaseImpl { std::unordered_map id_2_var_name_; std::vector variable_list_; + std::map sub_blocks_; std::vector var_ref_count_; diff --git a/paddle/fluid/framework/new_executor/standalone_executor.cc b/paddle/fluid/framework/new_executor/standalone_executor.cc index a2ae422b814a3..44dc782e64b27 100644 --- a/paddle/fluid/framework/new_executor/standalone_executor.cc +++ b/paddle/fluid/framework/new_executor/standalone_executor.cc @@ -107,17 +107,19 @@ StandaloneExecutor::StandaloneExecutor(const platform::Place& place, } auto kernel_program = paddle::dialect::PdOpLowerToKernelPass(base_program.get(), place); + std::shared_ptr shared_program = std::move(kernel_program); + plan_.UpdateIrProgram("base", shared_program); if (FLAGS_new_ir_apply_inplace_pass) { pir::PassManager pm(pir::IrContext::Instance(), 3); pm.AddPass(pir::CreateInplacePass()); - pm.Run(kernel_program.get()); + pm.Run(shared_program.get()); } interpretercores_.emplace_back( std::make_shared(place_, fetch_var_names_, - std::move(kernel_program), + shared_program->block(), scope_, execution_config)); } else { diff --git a/paddle/fluid/framework/new_executor/standalone_executor.h b/paddle/fluid/framework/new_executor/standalone_executor.h index e9ee5509d20be..cb10648855181 100644 --- a/paddle/fluid/framework/new_executor/standalone_executor.h +++ b/paddle/fluid/framework/new_executor/standalone_executor.h @@ -44,7 +44,7 @@ class StandaloneExecutor { private: bool is_interpretercore_build_result_shared_{false}; const platform::Place place_; - const interpreter::Plan plan_; + interpreter::Plan plan_; std::vector micro_batch_scopes_; std::vector> interpretercores_; diff --git a/paddle/fluid/pir/phi_kernel_adaptor/phi_kernel_adaptor.h b/paddle/fluid/pir/phi_kernel_adaptor/phi_kernel_adaptor.h index 47c0d39856d2f..c8a72797318c7 100644 --- a/paddle/fluid/pir/phi_kernel_adaptor/phi_kernel_adaptor.h +++ b/paddle/fluid/pir/phi_kernel_adaptor/phi_kernel_adaptor.h @@ -60,6 +60,7 @@ class PhiKernelAdaptor { variable_2_var_name; std::map var_name_2_id; std::vector variable_list; + std::map sub_blocks; std::stringstream ss; ss << this; @@ -69,7 +70,8 @@ class PhiKernelAdaptor { &value_2_var_name, &variable_2_var_name, &var_name_2_id, - &variable_list); + &variable_list, + &sub_blocks); pir::IrContext* ctx = pir::IrContext::Instance(); ctx->GetOrRegisterDialect(); diff --git a/paddle/fluid/pir/phi_kernel_adaptor/phi_kernel_util.cc b/paddle/fluid/pir/phi_kernel_adaptor/phi_kernel_util.cc index 609037c3ecd0f..a6cd7b553492e 100644 --- a/paddle/fluid/pir/phi_kernel_adaptor/phi_kernel_util.cc +++ b/paddle/fluid/pir/phi_kernel_adaptor/phi_kernel_util.cc @@ -43,9 +43,21 @@ #include "glog/logging.h" #include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/pir/dialect/operator/ir/manual_op.h" namespace pir { +const std::unordered_set SpecialOps = {"pd_op.feed", + "pd_op.fetch", + "builtin.combine", + "builtin.set_parameter", + "builtin.get_parameter", + "builtin.slice", + "builtin.split", + "pd_op.data", + "pd_op.shadow_output", + "pd_op.if"}; + void AddNewData(pir::Value value, std::string name, paddle::framework::Variable* var, @@ -226,7 +238,8 @@ void HandleForSpecialOp( std::unordered_map* variable_2_var_name, std::map* var_name_2_id, - std::vector* variable_list) { + std::vector* variable_list, + std::map* sub_blocks) { std::string op_name = op->name(); if (op->attributes().count("op_name")) { op_name = @@ -425,6 +438,33 @@ void HandleForSpecialOp( value_2_var_name->emplace(out_value, var_name); } } + + if (op_name == "pd_op.if") { + auto if_op = op->dyn_cast(); + + auto true_block = if_op.true_block(); + + auto false_block = if_op.false_block(); + + auto& true_branch_scope = inner_scope->NewScope(); + sub_blocks->emplace(true_block, &true_branch_scope); + + auto& false_branch_scope = inner_scope->NewScope(); + sub_blocks->emplace(false_block, &false_branch_scope); + + for (size_t i = 0; i < if_op->num_results(); ++i) { + // auto true_value = true_yeid_op->operand_source(i); + + auto if_op_out_value = if_op->result(i); + BuildValue(if_op_out_value, + inner_scope, + var_name_prefix, + value_2_var_name, + variable_2_var_name, + var_name_2_id, + variable_list); + } + } } void HandleForInplaceOp( @@ -483,8 +523,8 @@ void HandleForInplaceOp( } } -// NOTE(zhiqiu): the persistable is created in inner_scope's root, and other is -// created in inner_scope. +// NOTE(zhiqiu): the persistable is created in inner_scope's root, and other +// is created in inner_scope. void BuildScope(const pir::Block& block, paddle::framework::Scope* inner_scope, const std::string& var_name_prefix, @@ -492,7 +532,8 @@ void BuildScope(const pir::Block& block, std::unordered_map* variable_2_var_name, std::map* var_name_2_id, - std::vector* variable_list) { + std::vector* variable_list, + std::map* sub_blocks) { VLOG(4) << "***** [before build] scope" << "(" << inner_scope << ") ******\n" << paddle::framework::GenScopeTreeDebugInfo( @@ -507,19 +548,15 @@ void BuildScope(const pir::Block& block, .AsString(); } VLOG(4) << "build op:" << op_name; - - if (op_name == "pd_op.feed" || op_name == "pd_op.fetch" || - op_name == "builtin.combine" || op_name == "builtin.set_parameter" || - op_name == "builtin.get_parameter" || op_name == "builtin.slice" || - op_name == "builtin.split" || op_name == "pd_op.data" || - op_name == "pd_op.shadow_output") { + if (SpecialOps.count(op_name)) { HandleForSpecialOp(op, inner_scope, var_name_prefix, value_2_var_name, variable_2_var_name, var_name_2_id, - variable_list); + variable_list, + sub_blocks); continue; } diff --git a/paddle/fluid/pir/phi_kernel_adaptor/phi_kernel_util.h b/paddle/fluid/pir/phi_kernel_adaptor/phi_kernel_util.h index 037674467bc67..c3a409b8027b6 100644 --- a/paddle/fluid/pir/phi_kernel_adaptor/phi_kernel_util.h +++ b/paddle/fluid/pir/phi_kernel_adaptor/phi_kernel_util.h @@ -50,7 +50,8 @@ void BuildScope(const pir::Block& block, std::unordered_map* variable_2_var_name, std::map* var_name_2_id, - std::vector* variable_list); + std::vector* variable_list, + std::map* sub_blocks); void BuildRuntimeContext( pir::Operation* op, diff --git a/paddle/fluid/pir/transforms/constant_folding_pass.cc b/paddle/fluid/pir/transforms/constant_folding_pass.cc index d3f78787841f0..fe393344eb967 100644 --- a/paddle/fluid/pir/transforms/constant_folding_pass.cc +++ b/paddle/fluid/pir/transforms/constant_folding_pass.cc @@ -96,12 +96,13 @@ class ConstantFoldingPattern : public pir::RewritePattern { // Execute program exe_config_.create_local_scope = false; - paddle::framework::InterpreterCore core( - phi::CPUPlace{}, - fetch_var_names, - paddle::dialect::PdOpLowerToKernelPass(temp_program.get()), - &scope_, - exe_config_); + auto kernel_program = + paddle::dialect::PdOpLowerToKernelPass(temp_program.get()); + paddle::framework::InterpreterCore core(phi::CPUPlace{}, + fetch_var_names, + kernel_program->block(), + &scope_, + exe_config_); paddle::framework::FetchList fetch_list = core.Run({}); diff --git a/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc b/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc index 6c3460910e241..31d8be857204c 100644 --- a/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc +++ b/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc @@ -21,6 +21,7 @@ #include "paddle/fluid/pir/dialect/kernel/ir/kernel_op.h" #include "paddle/fluid/pir/dialect/kernel/ir/kernel_type.h" #include "paddle/fluid/pir/dialect/operator/interface/op_yaml_info.h" +#include "paddle/fluid/pir/dialect/operator/ir/manual_op.h" #include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h" #include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" #include "paddle/fluid/pir/dialect/operator/trait/inplace.h" @@ -33,6 +34,7 @@ #include "paddle/phi/common/place.h" #include "paddle/phi/core/compat/convert_utils.h" #include "paddle/phi/core/kernel_factory.h" + namespace paddle { namespace dialect { @@ -61,11 +63,11 @@ const std::unordered_set UnchangeOutputOps = { "builtin.get_parameter", "pd_op.shadow_output"}; -const std::unordered_set SpecialLowerOps = { - "builtin.combine", - "builtin.slice", - "builtin.split", -}; +const std::unordered_set SpecialLowerOps = {"builtin.combine", + "builtin.slice", + "builtin.split", + "pd_op.if", + "cf.yield"}; bool NeedFallBackCpu(const pir::Operation* op, const std::string& kernel_fn_name, @@ -225,7 +227,7 @@ pir::OpResult AddPlaceTransferOp(pir::OpResult in, const phi::Place& src_place, const phi::Place& dst_place, const phi::KernelKey& kernel_key, - pir::Program* program) { + pir::Block* block) { pir::IrContext* ctx = pir::IrContext::Instance(); std::string op_name = paddle::dialect::PhiKernelOp::name(); @@ -248,7 +250,7 @@ pir::OpResult AddPlaceTransferOp(pir::OpResult in, op->set_attribute(kAttrIsPersisable, in.GetDefiningOp()->attribute(kAttrIsPersisable)); } - program->block()->push_back(op); + block->push_back(op); auto new_in = op->result(0); @@ -266,7 +268,7 @@ pir::OpResult AddPlaceTransferOp(pir::OpResult in, pir::Operation* op = pir::Operation::Create({in}, op_attribute, {out_type}, op_info); - program->block()->push_back(op); + block->push_back(op); auto new_in = op->result(0); return new_in; @@ -647,6 +649,50 @@ phi::KernelKey GetKernelKey( return res; } +void HandleForIfOp( + const phi::Place& place, + pir::Operation* op_item, + pir::Block* block, + pir::IrContext* ctx, + std::unordered_map* map_op_pair, + std::unordered_map* map_value_pair) { + auto cur_in = op_item->operand_source(0); + PADDLE_ENFORCE_EQ( + map_value_pair->count(cur_in), + true, + phi::errors::PreconditionNotMet( + "[%d]'s input of [%s] op MUST in map pair", 0, op_item->name())); + auto new_in = map_value_pair->at(cur_in); + + pir::Builder builder(ctx, block); + + auto base_if_op = op_item->dyn_cast(); + auto allocated_dense_tensor_dtype = + paddle::dialect::AllocatedDenseTensorType::get( + ctx, + place, + base_if_op.result(0).type().dyn_cast()); + auto new_if_op = builder.Build( + new_in, std::vector{allocated_dense_tensor_dtype}); + // process true block + pir::Block* true_block = new_if_op.true_block(); + ProcessBlock(place, + base_if_op.true_block(), + true_block, + ctx, + map_op_pair, + map_value_pair); + + // process false block + pir::Block* false_block = new_if_op.false_block(); + ProcessBlock(place, + base_if_op.false_block(), + false_block, + ctx, + map_op_pair, + map_value_pair); +} + pir::OpResult GetNewInput( const pir::Value cur_in, const std::unordered_map& map_value_pair, @@ -662,11 +708,16 @@ pir::OpResult GetNewInput( } void HandleForSpecialOp( + const phi::Place& place, pir::Operation* op_item, - pir::Program* program, + pir::Block* block, pir::IrContext* ctx, std::unordered_map* map_op_pair, std::unordered_map* map_value_pair) { + if (op_item->name() == "pd_op.if") { + HandleForIfOp(place, op_item, block, ctx, map_op_pair, map_value_pair); + return; + } std::vector vec_inputs; std::vector op_output_types; if (op_item->name() == "builtin.combine") { @@ -740,11 +791,25 @@ void HandleForSpecialOp( } } + if (op_item->name() == "cf.yield") { + if (op_item->num_operands() > 0) { + for (size_t i = 0; i < op_item->num_operands(); ++i) { + auto cur_in = op_item->operand_source(i); + if (!cur_in) { + vec_inputs.emplace_back(); + continue; + } + auto new_in = GetNewInput(cur_in, *map_value_pair, i, op_item->name()); + vec_inputs.push_back(new_in); + } + } + } + pir::OpInfo op_info = ctx->GetRegisteredOpInfo(op_item->name()); // Generate new op pir::Operation* op = pir::Operation::Create( vec_inputs, op_item->attributes(), op_output_types, op_info); - program->block()->push_back(op); + block->push_back(op); (*map_op_pair)[op_item] = op; // only deal with single output if (op_item->num_results() > 0) { @@ -843,7 +908,7 @@ std::vector BuildOpInputList( pir::IrContext* ctx, std::unordered_map* map_op_pair, std::unordered_map* map_value_pair, - pir::Program* program) { + pir::Block* block) { if (op_item->num_operands() == 0) { return {}; } @@ -908,7 +973,7 @@ std::vector BuildOpInputList( new_in_alloc_type.lod(), new_in_alloc_type.offset()); new_in = AddPlaceTransferOp( - new_in, out_type, in_place, out_place, kernel_key, program); + new_in, out_type, in_place, out_place, kernel_key, block); } } else if (new_in_type.isa()) { // [ todo need update here, support combine data transfomer] @@ -973,7 +1038,7 @@ std::vector BuildOpInputList( "VectorType")); } in_i = AddPlaceTransferOp( - in_i, out_type, place, out_place, kernel_key, program); + in_i, out_type, place, out_place, kernel_key, block); is_trans = true; } @@ -991,7 +1056,7 @@ std::vector BuildOpInputList( inner_inputs, {}, {target_vec_type}, op_info); new_in = operation->result(0); - program->block()->push_back(operation); + block->push_back(operation); } } @@ -1012,7 +1077,7 @@ void AddShadowFeed( const phi::Place& place, pir::Operation* op_item, pir::Operation* kernel_op, - pir::Program* program, + pir::Block* block, pir::IrContext* ctx, std::unordered_map* map_op_pair, std::unordered_map* map_value_pair) { @@ -1049,7 +1114,7 @@ void AddShadowFeed( {kernel_op->result(0)}, attr_map, {out_type}, phi_kernel_op_info); (*map_op_pair)[op_item] = shadow_op; - program->block()->push_back(shadow_op); + block->push_back(shadow_op); if (op_item->num_results() > 0) { for (size_t i = 0; i < shadow_op->num_results(); ++i) { (*map_value_pair)[op_item->result(i)] = shadow_op->result(i); @@ -1093,7 +1158,7 @@ pir::Operation* BuildPhiKernelOp( const std::vector& vec_inputs, const std::vector& op_output_types, pir::Operation* op_item, - pir::Program* program, + pir::Block* block, pir::IrContext* ctx, std::unordered_map* map_op_pair, std::unordered_map* map_value_pair) { @@ -1133,30 +1198,18 @@ pir::Operation* BuildPhiKernelOp( (*map_value_pair)[op_item->result(i)] = op->result(i); } } - program->block()->push_back(op); + block->push_back(op); return op; } -std::unique_ptr PdOpLowerToKernelPass(pir::Program* prog, - phi::Place place) { - if (VLOG_IS_ON(2)) { - std::stringstream ss; - prog->Print(ss); - VLOG(2) << "Program after lowering to kernel pass : " << ss.str(); - } - - auto program = std::make_unique(pir::IrContext::Instance()); - - auto block = prog->block(); - - pir::IrContext* ctx = pir::IrContext::Instance(); - ctx->GetOrRegisterDialect(); - ctx->GetOrRegisterDialect(); - - std::unordered_map map_op_pair; - std::unordered_map map_value_pair; - +void ProcessBlock( + const phi::Place& place, + pir::Block* block, + pir::Block* new_block, + pir::IrContext* ctx, + std::unordered_map* map_op_pair, + std::unordered_map* map_value_pair) { auto skip_feed_names = GetSkipFeedNames(block); for (auto op_item : *block) { @@ -1169,7 +1222,7 @@ std::unique_ptr PdOpLowerToKernelPass(pir::Program* prog, // HandleSpecialOp if (SpecialLowerOps.count(op_item->name())) { HandleForSpecialOp( - op_item, program.get(), ctx, &map_op_pair, &map_value_pair); + place, op_item, new_block, ctx, map_op_pair, map_value_pair); continue; } @@ -1180,7 +1233,7 @@ std::unique_ptr PdOpLowerToKernelPass(pir::Program* prog, auto kernel_fn_str = GetKernelFnStr(op_info_parser.get(), op_item); auto kernel_key = GetKernelKey( - op_item, place, kernel_fn_str, map_value_pair, op_info_parser.get()); + op_item, place, kernel_fn_str, *map_value_pair, op_info_parser.get()); VLOG(6) << "kernel type " << kernel_key; // build output type @@ -1194,9 +1247,9 @@ std::unique_ptr PdOpLowerToKernelPass(pir::Program* prog, place, op_info_parser.get(), ctx, - &map_op_pair, - &map_value_pair, - program.get()); + map_op_pair, + map_value_pair, + new_block); // build op pir::Operation* op = BuildPhiKernelOp(kernel_fn_str, @@ -1204,14 +1257,31 @@ std::unique_ptr PdOpLowerToKernelPass(pir::Program* prog, vec_inputs, op_output_types, op_item, - program.get(), + new_block, ctx, - &map_op_pair, - &map_value_pair); + map_op_pair, + map_value_pair); AddShadowFeed( - place, op_item, op, program.get(), ctx, &map_op_pair, &map_value_pair); + place, op_item, op, new_block, ctx, map_op_pair, map_value_pair); } +} + +std::unique_ptr PdOpLowerToKernelPass(pir::Program* prog, + phi::Place place) { + auto program = std::make_unique(pir::IrContext::Instance()); + + auto block = prog->block(); + + pir::IrContext* ctx = pir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); + + std::unordered_map map_op_pair; + std::unordered_map map_value_pair; + + ProcessBlock( + place, block, program->block(), ctx, &map_op_pair, &map_value_pair); if (VLOG_IS_ON(2)) { std::stringstream ss1; diff --git a/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.h b/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.h index acf839391b8c5..35b5484508a6f 100644 --- a/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.h +++ b/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.h @@ -22,5 +22,12 @@ namespace dialect { std::unique_ptr PdOpLowerToKernelPass( pir::Program* prog, phi::Place place = phi::CPUPlace()); +void ProcessBlock( + const phi::Place& place, + pir::Block* block, + pir::Block* new_block, + pir::IrContext* ctx, + std::unordered_map* map_op_pair, + std::unordered_map* map_value_pair); } // namespace dialect } // namespace paddle diff --git a/test/cpp/new_executor/standalone_executor_new_ir_test.cc b/test/cpp/new_executor/standalone_executor_new_ir_test.cc index b865dc61d1c4a..47f47cd625c7c 100644 --- a/test/cpp/new_executor/standalone_executor_new_ir_test.cc +++ b/test/cpp/new_executor/standalone_executor_new_ir_test.cc @@ -69,7 +69,7 @@ TEST(StandaloneExecutor, run) { Scope scope; ProgramDesc prog_desc; - InterpreterCore test_core(place, {}, std::move(kernel_program), &scope); + InterpreterCore test_core(place, {}, kernel_program->block(), &scope); std::stringstream os; os << reinterpret_cast( @@ -110,7 +110,7 @@ TEST(StandaloneExecutor, run_inplace_sqrt) { auto place = platform::CPUPlace(); Scope scope; - InterpreterCore test_core(place, {}, std::move(kernel_program), &scope); + InterpreterCore test_core(place, {}, kernel_program->block(), &scope); std::stringstream os; os << reinterpret_cast( diff --git a/test/cpp/pir/cinn/jit_instruction_test.cc b/test/cpp/pir/cinn/jit_instruction_test.cc index 456b1719b2f65..2996bf17c962a 100644 --- a/test/cpp/pir/cinn/jit_instruction_test.cc +++ b/test/cpp/pir/cinn/jit_instruction_test.cc @@ -83,8 +83,7 @@ TEST(CinnJitInstruction, Run) { platform::Place place = platform::CUDAPlace(0); Scope exe_scope; - InterpreterCore executor( - place, {}, std::move(ir_runtime_program), &exe_scope); + InterpreterCore executor(place, {}, ir_runtime_program->block(), &exe_scope); executor.SetSkipGcVars(out_names); executor.Run({}); diff --git a/test/cpp/pir/kernel_dialect/ir_kernel_dialect_pass_test.cc b/test/cpp/pir/kernel_dialect/ir_kernel_dialect_pass_test.cc index 97aad7062292c..9efb3f2329e88 100644 --- a/test/cpp/pir/kernel_dialect/ir_kernel_dialect_pass_test.cc +++ b/test/cpp/pir/kernel_dialect/ir_kernel_dialect_pass_test.cc @@ -43,6 +43,8 @@ #include "paddle/pir/core/ir_context.h" #include "paddle/pir/core/program.h" #include "paddle/pir/core/utils.h" +#include "paddle/pir/dialect/control_flow/ir/cf_dialect.h" +#include "paddle/pir/dialect/control_flow/ir/cf_ops.h" PD_DECLARE_KERNEL(full, CPU, ALL_LAYOUT); PD_DECLARE_KERNEL(full_int_array, CPU, ALL_LAYOUT); @@ -170,3 +172,39 @@ TEST(kernel_dialect, legacy_op_test) { EXPECT_EQ(kernel_key, op->dyn_cast().kernel_key()); } + +TEST(kernel_dialect, cond_op_test) { + // (1) Init environment. + pir::IrContext* ctx = pir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); + + pir::Program program(ctx); + pir::Block* block = program.block(); + pir::Builder builder(ctx, block); + + auto full_op = builder.Build( + std::vector{1}, true, phi::DataType::BOOL); + + auto if_op = builder.Build( + full_op.out(), std::vector{full_op.result(0).type()}); + + pir::Block* true_block = if_op.true_block(); + + builder.SetInsertionPointToStart(true_block); + + auto full_op_1 = builder.Build( + std::vector{2}, true, phi::DataType::BOOL); + builder.Build(std::vector{full_op_1.out()}); + + pir::Block* false_block = if_op.false_block(); + + builder.SetInsertionPointToStart(false_block); + + auto full_op_2 = builder.Build( + std::vector{3}, true, phi::DataType::BOOL); + builder.Build(std::vector{full_op_2.out()}); + + program.Print(std::cout); + auto kernel_program = paddle::dialect::PdOpLowerToKernelPass(&program); +} diff --git a/test/cpp/prim/test_vjp.cc b/test/cpp/prim/test_vjp.cc index 39ed39e3e7f7b..74aadb868ff64 100644 --- a/test/cpp/prim/test_vjp.cc +++ b/test/cpp/prim/test_vjp.cc @@ -72,7 +72,7 @@ TEST(VJP, TanhBackwardTest) { Scope scope; ProgramDesc prog_desc; - InterpreterCore test_core(place, {}, std::move(kernel_program), &scope); + InterpreterCore test_core(place, {}, kernel_program->block(), &scope); std::stringstream os; os << reinterpret_cast( const_cast(test_core.Impl())); @@ -127,7 +127,7 @@ TEST(VJP, Tanh_BackwardTest) { Scope scope; ProgramDesc prog_desc; - InterpreterCore test_core(place, {}, std::move(kernel_program), &scope); + InterpreterCore test_core(place, {}, kernel_program->block(), &scope); std::stringstream os; os << reinterpret_cast( const_cast(test_core.Impl())); @@ -182,7 +182,7 @@ TEST(VJP, MeanBackwardTest) { Scope scope; ProgramDesc prog_desc; - InterpreterCore test_core(place, {}, std::move(kernel_program), &scope); + InterpreterCore test_core(place, {}, kernel_program->block(), &scope); std::stringstream os; os << reinterpret_cast( const_cast(test_core.Impl())); @@ -238,7 +238,7 @@ TEST(VJP, ConcatBackwardTest) { Scope scope; ProgramDesc prog_desc; - InterpreterCore test_core(place, {}, std::move(kernel_program), &scope); + InterpreterCore test_core(place, {}, kernel_program->block(), &scope); std::stringstream os; os << reinterpret_cast( const_cast(test_core.Impl())); @@ -304,7 +304,7 @@ TEST(VJP, AddBackwardTest) { Scope scope; ProgramDesc prog_desc; - InterpreterCore test_core(place, {}, std::move(kernel_program), &scope); + InterpreterCore test_core(place, {}, kernel_program->block(), &scope); std::stringstream os; os << reinterpret_cast( const_cast(test_core.Impl())); @@ -370,7 +370,7 @@ TEST(VJP, Add_BackwardTest) { Scope scope; ProgramDesc prog_desc; - InterpreterCore test_core(place, {}, std::move(kernel_program), &scope); + InterpreterCore test_core(place, {}, kernel_program->block(), &scope); std::stringstream os; os << reinterpret_cast( const_cast(test_core.Impl())); @@ -434,7 +434,7 @@ TEST(VJP, SplitBackwardTest) { auto place = platform::CPUPlace(); Scope scope; ProgramDesc prog_desc; - InterpreterCore test_core(place, {}, std::move(kernel_program), &scope); + InterpreterCore test_core(place, {}, kernel_program->block(), &scope); std::stringstream os; os << reinterpret_cast( const_cast(test_core.Impl()));