Skip to content

Commit

Permalink
set RuntimeProgram for sub_block when use Cxx Config (#5776)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhupengyang authored Mar 30, 2021
1 parent 4fd3fe0 commit e348a8a
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 6 deletions.
49 changes: 47 additions & 2 deletions lite/core/mir/generate_program_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,16 @@

#pragma once

#include <algorithm>
#include <list>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "lite/core/kernel.h"
#include "lite/core/mir/pass.h"
#include "lite/kernels/host/conditional_block_compute.h"
#include "lite/kernels/host/while_compute.h"

namespace paddle {
namespace lite {
Expand All @@ -31,10 +35,51 @@ namespace mir {
*/
class GenerateProgramPass : public ProgramPass {
public:
void Apply(const std::unique_ptr<SSAGraph> &graph) override;
void Apply(const std::unique_ptr<SSAGraph>& graph) override;

std::unique_ptr<RuntimeProgram> GenProgram() {
LOG(INFO) << "insts.size " << insts_.size();
LOG(INFO) << "insts.size: " << insts_.size();

#ifdef LITE_WITH_XPU
// generate RuntimeProgram for sub_block and set RuntimeProgram into
// sub_block kernel
// sub_block: while, conditional_block
std::vector<std::string> sub_block_ops{"while", "conditional_block"};
for (int i = static_cast<int>(insts_.size()) - 2; i >= 0; i--) {
for (auto& inst : insts_[i]) {
std::string op_name = inst.op()->Type();
if (std::find(sub_block_ops.begin(), sub_block_ops.end(), op_name) ==
sub_block_ops.end()) {
continue;
}

CHECK(inst.op()->op_info()->HasAttr("sub_block"))
<< op_name << " op should have attr 'sub_block'";
int block_idx = inst.op()->op_info()->GetAttr<int>("sub_block");
CHECK_LT(block_idx, static_cast<int>(insts_.size()))
<< op_name
<< " op's attr 'sub_block' should be less than number of blocks.";
std::vector<std::vector<Instruction>> sub_insts;
sub_insts.emplace_back(std::move(insts_[block_idx]));
std::unique_ptr<RuntimeProgram> sub_program(
new RuntimeProgram(std::move(sub_insts)));

if (op_name == "while") {
auto* kernel =
static_cast<kernels::host::WhileCompute*>(inst.mutable_kernel());
kernel->SetRuntimeProgram(&sub_program);
} else if (op_name == "conditional_block") {
auto* kernel = static_cast<kernels::host::ConditionalBlockCompute*>(
inst.mutable_kernel());
kernel->SetRuntimeProgram(&sub_program);
} else {
LOG(FATAL) << "unsupported sub_block op: " << op_name;
}
}
}
#endif

// generate RuntimeProgram for main block
std::unique_ptr<RuntimeProgram> program(
new RuntimeProgram(std::move(insts_)));

Expand Down
6 changes: 4 additions & 2 deletions lite/kernels/host/conditional_block_compute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@ namespace host {

void ConditionalBlockCompute::PrepareForRun() {
auto& param = this->Param<param_t>();
program_.reset(new RuntimeProgram(
param.program_desc, param.exec_scope, param.block_idx));
if (program_ == nullptr) {
program_.reset(new RuntimeProgram(
param.program_desc, param.exec_scope, param.block_idx));
}
}

void ConditionalBlockCompute::Run() {
Expand Down
6 changes: 6 additions & 0 deletions lite/kernels/host/conditional_block_compute.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include <algorithm>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"
Expand All @@ -32,8 +33,13 @@ class ConditionalBlockCompute
using param_t = operators::ConditionalBlockParam;

void PrepareForRun() override;

void Run() override;

void SetRuntimeProgram(std::unique_ptr<RuntimeProgram>* program) {
program_ = std::move(*program);
}

private:
std::unique_ptr<RuntimeProgram> program_;
};
Expand Down
6 changes: 4 additions & 2 deletions lite/kernels/host/while_compute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,10 @@ namespace host {

void WhileCompute::PrepareForRun() {
auto &param = this->Param<param_t>();
program_.reset(new RuntimeProgram(
param.program_desc, param.exec_scope, param.block_idx));
if (program_ == nullptr) {
program_.reset(new RuntimeProgram(
param.program_desc, param.exec_scope, param.block_idx));
}
}
void WhileCompute::Run() {
auto &param = this->Param<param_t>();
Expand Down
6 changes: 6 additions & 0 deletions lite/kernels/host/while_compute.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include <algorithm>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"
Expand All @@ -32,8 +33,13 @@ class WhileCompute
using param_t = operators::WhileParam;

void Run() override;

void PrepareForRun() override;

void SetRuntimeProgram(std::unique_ptr<RuntimeProgram>* program) {
program_ = std::move(*program);
}

virtual ~WhileCompute() = default;

private:
Expand Down

0 comments on commit e348a8a

Please sign in to comment.