diff --git a/lite/core/mir/generate_program_pass.h b/lite/core/mir/generate_program_pass.h index 2ef4d035710..2827e48fa1f 100644 --- a/lite/core/mir/generate_program_pass.h +++ b/lite/core/mir/generate_program_pass.h @@ -14,12 +14,16 @@ #pragma once +#include #include #include +#include #include #include #include "lite/core/kernel.h" #include "lite/core/mir/pass.h" +#include "lite/kernels/host/conditional_block_compute.h" +#include "lite/kernels/host/while_compute.h" namespace paddle { namespace lite { @@ -31,10 +35,51 @@ namespace mir { */ class GenerateProgramPass : public ProgramPass { public: - void Apply(const std::unique_ptr &graph) override; + void Apply(const std::unique_ptr& graph) override; std::unique_ptr GenProgram() { - LOG(INFO) << "insts.size " << insts_.size(); + LOG(INFO) << "insts.size: " << insts_.size(); + +#ifdef LITE_WITH_XPU + // generate RuntimeProgram for sub_block and set RuntimeProgram into + // sub_block kernel + // sub_block: while, conditional_block + std::vector sub_block_ops{"while", "conditional_block"}; + for (int i = static_cast(insts_.size()) - 2; i >= 0; i--) { + for (auto& inst : insts_[i]) { + std::string op_name = inst.op()->Type(); + if (std::find(sub_block_ops.begin(), sub_block_ops.end(), op_name) == + sub_block_ops.end()) { + continue; + } + + CHECK(inst.op()->op_info()->HasAttr("sub_block")) + << op_name << " op should have attr 'sub_block'"; + int block_idx = inst.op()->op_info()->GetAttr("sub_block"); + CHECK_LT(block_idx, static_cast(insts_.size())) + << op_name + << " op's attr 'sub_block' should be less than number of blocks."; + std::vector> sub_insts; + sub_insts.emplace_back(std::move(insts_[block_idx])); + std::unique_ptr sub_program( + new RuntimeProgram(std::move(sub_insts))); + + if (op_name == "while") { + auto* kernel = + static_cast(inst.mutable_kernel()); + kernel->SetRuntimeProgram(&sub_program); + } else if (op_name == "conditional_block") { + auto* kernel = static_cast( + inst.mutable_kernel()); + kernel->SetRuntimeProgram(&sub_program); + } else { + LOG(FATAL) << "unsupported sub_block op: " << op_name; + } + } + } +#endif + + // generate RuntimeProgram for main block std::unique_ptr program( new RuntimeProgram(std::move(insts_))); diff --git a/lite/kernels/host/conditional_block_compute.cc b/lite/kernels/host/conditional_block_compute.cc index 4124eb6bf59..01396e33b82 100644 --- a/lite/kernels/host/conditional_block_compute.cc +++ b/lite/kernels/host/conditional_block_compute.cc @@ -21,8 +21,10 @@ namespace host { void ConditionalBlockCompute::PrepareForRun() { auto& param = this->Param(); - program_.reset(new RuntimeProgram( - param.program_desc, param.exec_scope, param.block_idx)); + if (program_ == nullptr) { + program_.reset(new RuntimeProgram( + param.program_desc, param.exec_scope, param.block_idx)); + } } void ConditionalBlockCompute::Run() { diff --git a/lite/kernels/host/conditional_block_compute.h b/lite/kernels/host/conditional_block_compute.h index 8d3381ce3c4..c28120750ac 100644 --- a/lite/kernels/host/conditional_block_compute.h +++ b/lite/kernels/host/conditional_block_compute.h @@ -16,6 +16,7 @@ #include #include #include +#include #include #include "lite/core/kernel.h" #include "lite/core/op_registry.h" @@ -32,8 +33,13 @@ class ConditionalBlockCompute using param_t = operators::ConditionalBlockParam; void PrepareForRun() override; + void Run() override; + void SetRuntimeProgram(std::unique_ptr* program) { + program_ = std::move(*program); + } + private: std::unique_ptr program_; }; diff --git a/lite/kernels/host/while_compute.cc b/lite/kernels/host/while_compute.cc index ce67e25db85..a3583d29ba9 100644 --- a/lite/kernels/host/while_compute.cc +++ b/lite/kernels/host/while_compute.cc @@ -23,8 +23,10 @@ namespace host { void WhileCompute::PrepareForRun() { auto ¶m = this->Param(); - program_.reset(new RuntimeProgram( - param.program_desc, param.exec_scope, param.block_idx)); + if (program_ == nullptr) { + program_.reset(new RuntimeProgram( + param.program_desc, param.exec_scope, param.block_idx)); + } } void WhileCompute::Run() { auto ¶m = this->Param(); diff --git a/lite/kernels/host/while_compute.h b/lite/kernels/host/while_compute.h index 42065865e45..46e8f42377e 100644 --- a/lite/kernels/host/while_compute.h +++ b/lite/kernels/host/while_compute.h @@ -16,6 +16,7 @@ #include #include #include +#include #include #include "lite/core/kernel.h" #include "lite/core/op_registry.h" @@ -32,8 +33,13 @@ class WhileCompute using param_t = operators::WhileParam; void Run() override; + void PrepareForRun() override; + void SetRuntimeProgram(std::unique_ptr* program) { + program_ = std::move(*program); + } + virtual ~WhileCompute() = default; private: