Skip to content

Commit

Permalink
[PIR] New ir supprt cond part 1 (#57265)
Browse files Browse the repository at this point in the history
* [IR] add control flow dialect

* fix error

* add error

* test

* fix

* fix

* fix

* fix

* fix

* fix

* update

* fix

* update

* update

* refactor pd op to kernel pass

* fix bug

* update

* add code

* update

* fix merge bug

* fix bug

* update

* update

* update

* fix bug

* update

* update

* fix bug

* update

* update

* fix bug

* remove useless code

* update

* update

* update

* new ir support cond part 1

* fix compile bug

* fix compile bug

* fix compile bug

* fix bug

* open test case

---------

Co-authored-by: winter-wang <1030748926@qq.com>
  • Loading branch information
phlrain and winter-wang authored Sep 15, 2023
1 parent 70c4634 commit ae188d1
Show file tree
Hide file tree
Showing 20 changed files with 271 additions and 96 deletions.
3 changes: 2 additions & 1 deletion paddle/fluid/framework/executor_cache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -344,11 +344,12 @@ std::shared_ptr<InterpreterCore> CreateNewIRInterpreterCoreInfoToCache(
std::shared_ptr<InterpreterCore> core = nullptr;

core.reset(new InterpreterCore(
place, {}, std::move(ir_program), scope, execution_config));
place, {}, ir_program->block(), scope, execution_config));

auto &cached_value =
interpretercore_info_cache.GetMutable(program_id, scope, is_grad);
cached_value.core_ = core;
cached_value.ir_prog_ = std::move(ir_program);
return core;
}

Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/framework/executor_cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ class InterpreterCoreInfo {
struct CacheValue {
std::shared_ptr<InterpreterCore> core_{nullptr};
std::set<std::string> skip_eager_delete_vars_;
std::unique_ptr<::pir::Program> ir_prog_{nullptr};
};

bool IsAvailable(bool is_grad) {
Expand Down
5 changes: 5 additions & 0 deletions paddle/fluid/framework/new_executor/interpreter/plan.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,11 @@ std::shared_ptr<::pir::Program> Plan::IrProgram(
return type_to_ir_program_.at(job_type);
}

void Plan::UpdateIrProgram(const std::string& job_type,
std::shared_ptr<::pir::Program> ir_prog) {
type_to_ir_program_[job_type] = ir_prog;
}

int64_t Plan::MicroBatchNum() const { return micro_batch_num_; }

} // namespace interpreter
Expand Down
5 changes: 4 additions & 1 deletion paddle/fluid/framework/new_executor/interpreter/plan.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,15 @@ class Plan final {
const ProgramDesc* Program(const std::string& job_type) const;
std::shared_ptr<::pir::Program> IrProgram(const std::string& job_type) const;

void UpdateIrProgram(const std::string& job_type,
std::shared_ptr<::pir::Program> ir_prog);

int64_t MicroBatchNum() const;

private:
const std::vector<std::shared_ptr<Job>> job_list_;
const std::unordered_map<std::string, ProgramDesc*> type_to_program_;
const std::unordered_map<std::string, std::shared_ptr<::pir::Program>>
std::unordered_map<std::string, std::shared_ptr<::pir::Program>>
type_to_ir_program_;
int64_t micro_batch_num_;
};
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/framework/new_executor/interpretercore.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,12 @@ InterpreterCore::InterpreterCore(const platform::Place& place,
InterpreterCore::InterpreterCore(
const platform::Place& place,
const std::vector<std::string>& fetch_var_names,
std::unique_ptr<::pir::Program> ir_prog,
const ::pir::Block* ir_block,
framework::Scope* scope,
const ExecutionConfig& execution_config) {
VLOG(4) << "InterpreterCore(): " << this << " on " << place;
impl_ = std::make_unique<NewIRInterpreter>(
place, fetch_var_names, std::move(ir_prog), scope, execution_config);
place, fetch_var_names, ir_block, scope, execution_config);
}

InterpreterCore::~InterpreterCore() {
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/framework/new_executor/interpretercore.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
PD_DECLARE_bool(new_executor_use_local_scope);

namespace pir {
class Program;
class Block;
} // namespace pir

namespace paddle {
Expand All @@ -38,7 +38,7 @@ class InterpreterCore {
// This constructor is for New IR.
InterpreterCore(const platform::Place& place,
const std::vector<std::string>& fetch_var_names,
std::unique_ptr<::pir::Program> ir_prog,
const ::pir::Block* ir_prog,
Scope* scope,
const ExecutionConfig& execution_config = ExecutionConfig());
~InterpreterCore();
Expand Down
21 changes: 13 additions & 8 deletions paddle/fluid/framework/new_executor/new_ir_interpreter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,14 @@ namespace framework {
NewIRInterpreter::NewIRInterpreter(
const platform::Place& place,
const std::vector<std::string>& fetch_var_names,
std::unique_ptr<::pir::Program> ir_prog,
const ::pir::Block* ir_block,
framework::Scope* scope,
const ExecutionConfig& execution_config)
: place_(place),
execution_config_(execution_config),
var_scope_(scope),
scope_(scope),
ir_program_(std::move(ir_prog)),
ir_block_(ir_block),
ir_stream_analyzer_(place),
fetch_var_names_(fetch_var_names) {
VLOG(4) << "NewIRInterpreter(): " << this << " on " << place_;
Expand All @@ -92,8 +92,7 @@ NewIRInterpreter::NewIRInterpreter(
// TODO(zhangbo): delete var_scope
var_scope_.SetLocalScope(local_scope_);

execution_config_.AnalyzeThreadPoolConfig(place,
ir_program_->block()->size());
execution_config_.AnalyzeThreadPoolConfig(place, 1);
execution_config_.Log(/*log_level=*/8);

ir_instruction_scheduling_priority_less = [this](size_t lhs, size_t rhs) {
Expand Down Expand Up @@ -334,6 +333,10 @@ Scope* NewIRInterpreter::InnerScope() {
return local_scope_ != nullptr ? local_scope_ : scope_;
}

std::string NewIRInterpreter::GetNameByValue(::pir::Value value) const {
return value_2_var_name_.at(value);
}

void NewIRInterpreter::UpdateSyncOpNum() {
int64_t sync_op_num = 0;
for (auto& ins : vec_instruction_base_) {
Expand Down Expand Up @@ -501,13 +504,15 @@ void NewIRInterpreter::BuildInstruction() {
VLOG(6) << "Build Instructions for new ir ... ";
vec_instruction_base_.clear();
size_t op_idx = 0;
for (auto& op : *ir_program_->block()) {
for (auto& op : *ir_block_) {
VLOG(6) << "Build Instruction for op: " << op_idx;
if (op->dialect()->name() == "builtin") {
if (interpreter::GetSpecialOpNames().count(op->name())) {
VLOG(6) << "skip process " << op->name();
continue;
}
} else if (op->dialect()->name() == "cf") {
continue;
} else if (op->dialect()->name() == "pd_kernel") {
auto op_name = op->attributes()
.at("op_name")
Expand Down Expand Up @@ -923,13 +928,14 @@ FetchList NewIRInterpreter::Run(const std::vector<std::string>& feed_names,
// Build
std::stringstream ss;
ss << this;
::pir::BuildScope(*ir_program_->block(),
::pir::BuildScope(*ir_block_,
InnerScope(),
ss.str(),
&value_2_var_name_,
&variable_2_var_name_,
&var_name_2_id_,
&variable_list_);
&variable_list_,
&sub_blocks_);

interpreter::BuildId2VarName(var_name_2_id_, &id_2_var_name_);

Expand Down Expand Up @@ -977,7 +983,6 @@ FetchList NewIRInterpreter::Run(const std::vector<std::string>& feed_names,
if (HasLocalScope()) {
ClearLoDTensorArrayInLocalScope();
}

// return Fetch Tensors
Scope* inner_scope = InnerScope();
if (FLAGS_enable_new_ir_in_executor) {
Expand Down
9 changes: 6 additions & 3 deletions paddle/fluid/framework/new_executor/new_ir_interpreter.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
#include "paddle/pir/core/value.h"

namespace ir {
class Program;
class Block;
} // namespace ir

namespace paddle {
Expand All @@ -36,7 +36,7 @@ class NewIRInterpreter : public InterpreterBaseImpl {
public:
NewIRInterpreter(const platform::Place& place,
const std::vector<std::string>& fetch_var_names,
std::unique_ptr<::pir::Program> ir_prog,
const ::pir::Block* ir_block,
Scope* scope,
const ExecutionConfig& execution_config = ExecutionConfig());

Expand Down Expand Up @@ -81,6 +81,8 @@ class NewIRInterpreter : public InterpreterBaseImpl {

int GetIdByName(const std::string& name) const;

std::string GetNameByValue(::pir::Value value) const;

private:
// build graph
void UpdateSyncOpNum();
Expand Down Expand Up @@ -198,7 +200,7 @@ class NewIRInterpreter : public InterpreterBaseImpl {

InstructionSchedulingPriorityLess ir_instruction_scheduling_priority_less;

std::unique_ptr<::pir::Program> ir_program_{nullptr};
const ::pir::Block* ir_block_{nullptr};

std::vector<std::unique_ptr<InstructionBase>> vec_instruction_base_;

Expand All @@ -211,6 +213,7 @@ class NewIRInterpreter : public InterpreterBaseImpl {
std::unordered_map<int, std::string> id_2_var_name_;

std::vector<Variable*> variable_list_;
std::map<pir::Block*, paddle::framework::Scope*> sub_blocks_;

std::vector<int> var_ref_count_;

Expand Down
6 changes: 4 additions & 2 deletions paddle/fluid/framework/new_executor/standalone_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -107,17 +107,19 @@ StandaloneExecutor::StandaloneExecutor(const platform::Place& place,
}
auto kernel_program =
paddle::dialect::PdOpLowerToKernelPass(base_program.get(), place);
std::shared_ptr<pir::Program> shared_program = std::move(kernel_program);
plan_.UpdateIrProgram("base", shared_program);

if (FLAGS_new_ir_apply_inplace_pass) {
pir::PassManager pm(pir::IrContext::Instance(), 3);
pm.AddPass(pir::CreateInplacePass());
pm.Run(kernel_program.get());
pm.Run(shared_program.get());
}

interpretercores_.emplace_back(
std::make_shared<InterpreterCore>(place_,
fetch_var_names_,
std::move(kernel_program),
shared_program->block(),
scope_,
execution_config));
} else {
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/framework/new_executor/standalone_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class StandaloneExecutor {
private:
bool is_interpretercore_build_result_shared_{false};
const platform::Place place_;
const interpreter::Plan plan_;
interpreter::Plan plan_;

std::vector<framework::Scope*> micro_batch_scopes_;
std::vector<std::shared_ptr<InterpreterCore>> interpretercores_;
Expand Down
4 changes: 3 additions & 1 deletion paddle/fluid/pir/phi_kernel_adaptor/phi_kernel_adaptor.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ class PhiKernelAdaptor {
variable_2_var_name;
std::map<std::string, int> var_name_2_id;
std::vector<paddle::framework::Variable*> variable_list;
std::map<pir::Block*, paddle::framework::Scope*> sub_blocks;
std::stringstream ss;
ss << this;

Expand All @@ -69,7 +70,8 @@ class PhiKernelAdaptor {
&value_2_var_name,
&variable_2_var_name,
&var_name_2_id,
&variable_list);
&variable_list,
&sub_blocks);
pir::IrContext* ctx = pir::IrContext::Instance();

ctx->GetOrRegisterDialect<paddle::dialect::OperatorDialect>();
Expand Down
59 changes: 48 additions & 11 deletions paddle/fluid/pir/phi_kernel_adaptor/phi_kernel_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,21 @@
#include "glog/logging.h"
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/pir/dialect/operator/ir/manual_op.h"

namespace pir {

const std::unordered_set<std::string> SpecialOps = {"pd_op.feed",
"pd_op.fetch",
"builtin.combine",
"builtin.set_parameter",
"builtin.get_parameter",
"builtin.slice",
"builtin.split",
"pd_op.data",
"pd_op.shadow_output",
"pd_op.if"};

void AddNewData(pir::Value value,
std::string name,
paddle::framework::Variable* var,
Expand Down Expand Up @@ -226,7 +238,8 @@ void HandleForSpecialOp(
std::unordered_map<const paddle::framework::Variable*, std::string>*
variable_2_var_name,
std::map<std::string, int>* var_name_2_id,
std::vector<paddle::framework::Variable*>* variable_list) {
std::vector<paddle::framework::Variable*>* variable_list,
std::map<pir::Block*, paddle::framework::Scope*>* sub_blocks) {
std::string op_name = op->name();
if (op->attributes().count("op_name")) {
op_name =
Expand Down Expand Up @@ -425,6 +438,33 @@ void HandleForSpecialOp(
value_2_var_name->emplace(out_value, var_name);
}
}

if (op_name == "pd_op.if") {
auto if_op = op->dyn_cast<paddle::dialect::IfOp>();

auto true_block = if_op.true_block();

auto false_block = if_op.false_block();

auto& true_branch_scope = inner_scope->NewScope();
sub_blocks->emplace(true_block, &true_branch_scope);

auto& false_branch_scope = inner_scope->NewScope();
sub_blocks->emplace(false_block, &false_branch_scope);

for (size_t i = 0; i < if_op->num_results(); ++i) {
// auto true_value = true_yeid_op->operand_source(i);

auto if_op_out_value = if_op->result(i);
BuildValue(if_op_out_value,
inner_scope,
var_name_prefix,
value_2_var_name,
variable_2_var_name,
var_name_2_id,
variable_list);
}
}
}

void HandleForInplaceOp(
Expand Down Expand Up @@ -483,16 +523,17 @@ void HandleForInplaceOp(
}
}

// NOTE(zhiqiu): the persistable is created in inner_scope's root, and other is
// created in inner_scope.
// NOTE(zhiqiu): the persistable is created in inner_scope's root, and other
// is created in inner_scope.
void BuildScope(const pir::Block& block,
paddle::framework::Scope* inner_scope,
const std::string& var_name_prefix,
std::unordered_map<pir::Value, std::string>* value_2_var_name,
std::unordered_map<const paddle::framework::Variable*,
std::string>* variable_2_var_name,
std::map<std::string, int>* var_name_2_id,
std::vector<paddle::framework::Variable*>* variable_list) {
std::vector<paddle::framework::Variable*>* variable_list,
std::map<pir::Block*, paddle::framework::Scope*>* sub_blocks) {
VLOG(4) << "***** [before build] scope"
<< "(" << inner_scope << ") ******\n"
<< paddle::framework::GenScopeTreeDebugInfo(
Expand All @@ -507,19 +548,15 @@ void BuildScope(const pir::Block& block,
.AsString();
}
VLOG(4) << "build op:" << op_name;

if (op_name == "pd_op.feed" || op_name == "pd_op.fetch" ||
op_name == "builtin.combine" || op_name == "builtin.set_parameter" ||
op_name == "builtin.get_parameter" || op_name == "builtin.slice" ||
op_name == "builtin.split" || op_name == "pd_op.data" ||
op_name == "pd_op.shadow_output") {
if (SpecialOps.count(op_name)) {
HandleForSpecialOp(op,
inner_scope,
var_name_prefix,
value_2_var_name,
variable_2_var_name,
var_name_2_id,
variable_list);
variable_list,
sub_blocks);
continue;
}

Expand Down
3 changes: 2 additions & 1 deletion paddle/fluid/pir/phi_kernel_adaptor/phi_kernel_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ void BuildScope(const pir::Block& block,
std::unordered_map<const paddle::framework::Variable*,
std::string>* variable_2_var_name,
std::map<std::string, int>* var_name_2_id,
std::vector<paddle::framework::Variable*>* variable_list);
std::vector<paddle::framework::Variable*>* variable_list,
std::map<pir::Block*, paddle::framework::Scope*>* sub_blocks);

void BuildRuntimeContext(
pir::Operation* op,
Expand Down
13 changes: 7 additions & 6 deletions paddle/fluid/pir/transforms/constant_folding_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,12 +96,13 @@ class ConstantFoldingPattern : public pir::RewritePattern {

// Execute program
exe_config_.create_local_scope = false;
paddle::framework::InterpreterCore core(
phi::CPUPlace{},
fetch_var_names,
paddle::dialect::PdOpLowerToKernelPass(temp_program.get()),
&scope_,
exe_config_);
auto kernel_program =
paddle::dialect::PdOpLowerToKernelPass(temp_program.get());
paddle::framework::InterpreterCore core(phi::CPUPlace{},
fetch_var_names,
kernel_program->block(),
&scope_,
exe_config_);

paddle::framework::FetchList fetch_list = core.Run({});

Expand Down
Loading

0 comments on commit ae188d1

Please sign in to comment.