Skip to content

Commit

Permalink
[PIR] Support if op exe (PaddlePaddle#57801)
Browse files Browse the repository at this point in the history
* add

* add

* fix

* fix

* refine

* delete sub_blocks

* refine

* refien

* add ut

* fix
  • Loading branch information
zhangbo9674 authored and jiahy0825 committed Oct 16, 2023
1 parent 570a637 commit 18c2fb2
Show file tree
Hide file tree
Showing 19 changed files with 575 additions and 628 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ cc_library(
instruction_base
SRCS instruction_base.cc phi_kernel_instruction.cc
legacy_kernel_instruction.cc cond_instruction.cc instruction_util.cc
DEPS phi framework_proto)
DEPS pir_adaptor phi framework_proto)

if(WITH_CINN AND NOT CINN_ONLY)
cc_library(
Expand Down
173 changes: 63 additions & 110 deletions paddle/fluid/framework/new_executor/instruction/cond_instruction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,45 +48,31 @@ std::vector<pir::Value> GetYiedOpInputs(pir::Block* block) {
}
}
}

return vec_res;
}

void GetInputIds(
pir::Operation* op,
Scope* inner_scope,
const std::unordered_map<::pir::Value, std::string>& value_2_var_name,
const std::map<std::string, int>& var_name_2_id,
const std::unordered_map<const paddle::framework::Variable*, std::string>&
variable_2_var_name,
std::unordered_map<pir::Value, std::vector<int>>* input_ids) {
void GetInputIds(pir::Operation* op,
const ValueExecutionInfo& value_exec_info,
std::unordered_map<pir::Value, std::vector<int>>* input_ids) {
for (size_t i = 0; i < op->num_operands(); i++) {
pir::Value value = op->operand_source(i);
if (value) {
PADDLE_ENFORCE_NE(
value_2_var_name.find(value),
value_2_var_name.end(),
if (value && value.type()) {
PADDLE_ENFORCE_EQ(
value_exec_info.HasValue(value),
true,
phi::errors::PreconditionNotMet(
"input should in name map, [%d] 'th input of [%s] op",
i,
"if op"));
std::vector<int> inputs_id = GetValueIds(value,
inner_scope,
value_2_var_name,
var_name_2_id,
variable_2_var_name);
std::vector<int> inputs_id = GetValueIds(value, value_exec_info);
input_ids->emplace(value, inputs_id);
}
}
}

void GetOutsideOpInputs(
pir::Block* block,
Scope* inner_scope,
const std::unordered_map<::pir::Value, std::string>& value_2_var_name,
const std::map<std::string, int>& var_name_2_id,
const std::unordered_map<const paddle::framework::Variable*, std::string>&
variable_2_var_name,
const ValueExecutionInfo& value_exec_info,
std::unordered_map<pir::Value, std::vector<int>>* input_ids) {
std::unordered_set<pir::Value> inner_outputs;
for (auto op : (*block)) {
Expand All @@ -99,161 +85,128 @@ void GetOutsideOpInputs(
for (size_t i = 0; i < op->num_operands(); ++i) {
pir::Value value = op->operand_source(i);
if (value && (!inner_outputs.count(value))) {
PADDLE_ENFORCE_NE(
value_2_var_name.find(value),
value_2_var_name.end(),
PADDLE_ENFORCE_EQ(
value_exec_info.HasValue(value),
true,
phi::errors::PreconditionNotMet(
"input should in name map, [%d] 'th input of [%s] op",
i,
"if op"));
std::vector<int> inputs_id = GetValueIds(value,
inner_scope,
value_2_var_name,
var_name_2_id,
variable_2_var_name);
op->name()));
std::vector<int> inputs_id = GetValueIds(value, value_exec_info);

input_ids->emplace(value, inputs_id);
}
}
}
}

CondInstruction::CondInstruction(
size_t id,
const platform::Place& place,
pir::Operation* op,
Scope* scope,
Scope* local_scope,
ValueExecutionInfo* parent_exe_info,
const std::map<pir::Block*, paddle::framework::Scope*>& sub_blocks)
CondInstruction::CondInstruction(size_t id,
const platform::Place& place,
pir::Operation* op,
ValueExecutionInfo* value_exec_info)
: InstructionBase(id, place) {
op_ = op;
VLOG(6) << "finish process dist attributes";

SetKernelType(AnalyseOpFuncType(op, place));
VLOG(6) << "finish process analyse kernel type";

Scope* inner_scope = local_scope == nullptr ? scope : local_scope;

VLOG(6) << "finish process inputs outputs index";

PADDLE_ENFORCE(
op->isa<paddle::dialect::IfOp>(),
phi::errors::PreconditionNotMet("Cond instruction only support if op"));

auto if_op = op->dyn_cast<paddle::dialect::IfOp>();
op_ = op;

for (size_t i = 0; i < if_op.num_results(); ++i) {
if_op_outputs_.push_back(inner_scope->GetVar(
parent_exe_info->GetValue2VarName().at(if_op.result(i))));
}
SetKernelType(AnalyseOpFuncType(op, place));
VLOG(6) << "finish process analyse kernel type";

auto cond_value = if_op.operand_source(0);
auto var_name = parent_exe_info->GetValue2VarName().at(cond_value);
cond_var = inner_scope->FindVar(var_name);
cond_var_ = value_exec_info->GetScope()->FindVar(
value_exec_info->GetValue2VarName().at(cond_value));
for (size_t i = 0; i < if_op.num_results(); ++i) {
output_vars_.push_back(value_exec_info->GetScope()->GetVar(
value_exec_info->GetValue2VarName().at(if_op.result(i))));
}
VLOG(6) << "finish process cond_var and output_vars";

auto true_branch_block = if_op.true_block();
auto false_branch_block = if_op.false_block();

auto true_branch_yied_inputs = GetYiedOpInputs(true_branch_block);
auto false_branch_yied_inputs = GetYiedOpInputs(false_branch_block);

auto true_scope = sub_blocks.at(true_branch_block);
true_branch_inter =
Scope* true_scope = &(value_exec_info->GetScope()->NewScope());
true_branch_inter_ =
new NewIRInterpreter(place,
{},
true_branch_block,
true_scope,
parent_exe_info->NewChild(true_scope),
value_exec_info->NewChild(true_scope),
{});

std::set<std::string> true_skip_gc_names_set;
for (auto value : true_branch_yied_inputs) {
true_skip_gc_names_.push_back(true_branch_inter->GetNameByValue(value));
true_skip_gc_names_set.insert(true_branch_inter->GetNameByValue(value));
true_skip_gc_names_.push_back(true_branch_inter_->GetNameByValue(value));
true_skip_gc_names_set.insert(true_branch_inter_->GetNameByValue(value));
}
true_branch_inter->SetSkipGcVars(true_skip_gc_names_set);
true_branch_inter_->SetSkipGcVars(true_skip_gc_names_set);
VLOG(6) << "finish process true branch interpreter";

auto false_scope = sub_blocks.at(false_branch_block);
false_branch_inter =
auto false_branch_block = if_op.false_block();
auto false_branch_yied_inputs = GetYiedOpInputs(false_branch_block);
Scope* false_scope = &(value_exec_info->GetScope()->NewScope());
false_branch_inter_ =
new NewIRInterpreter(place,
{},
false_branch_block,
false_scope,
parent_exe_info->NewChild(false_scope),
value_exec_info->NewChild(false_scope),
{});

std::set<std::string> false_skip_gc_names_set;
for (auto value : false_branch_yied_inputs) {
false_skip_gc_names_.push_back(false_branch_inter->GetNameByValue(value));
false_skip_gc_names_set.insert(false_branch_inter->GetNameByValue(value));
false_skip_gc_names_.push_back(false_branch_inter_->GetNameByValue(value));
false_skip_gc_names_set.insert(false_branch_inter_->GetNameByValue(value));
}
false_branch_inter->SetSkipGcVars(false_skip_gc_names_set);

// the true branch and false branch input will be the if_op inputs
false_branch_inter_->SetSkipGcVars(false_skip_gc_names_set);
VLOG(6) << "finish process false branch interpreter";

// NOTE(zhangbo): IfOp sub_block's inputs include two kind of value: one is
// OpOperand of IfOp, and the other is external Values used in true_block or
// false_block.
std::unordered_map<pir::Value, std::vector<int>> inputs;
GetInputIds(op,
inner_scope,
parent_exe_info->GetValue2VarName(),
parent_exe_info->GetVarName2Id(),
parent_exe_info->GetVar2VarName(),
&inputs);
GetOutsideOpInputs(true_branch_block,
inner_scope,
parent_exe_info->GetValue2VarName(),
parent_exe_info->GetVarName2Id(),
parent_exe_info->GetVar2VarName(),
&inputs);

GetOutsideOpInputs(false_branch_block,
inner_scope,
parent_exe_info->GetValue2VarName(),
parent_exe_info->GetVarName2Id(),
parent_exe_info->GetVar2VarName(),
&inputs);
GetInputIds(op, *value_exec_info, &inputs);
GetOutsideOpInputs(true_branch_block, *value_exec_info, &inputs);
GetOutsideOpInputs(false_branch_block, *value_exec_info, &inputs);
SetInputs(inputs);

std::unordered_map<pir::Value, std::vector<int>> outputs;
for (size_t i = 0; i < op->num_results(); i++) {
pir::Value value = op->result(i);
if (value && value.type()) {
PADDLE_ENFORCE_NE(
parent_exe_info->GetValue2VarName().find(value),
parent_exe_info->GetValue2VarName().end(),
PADDLE_ENFORCE_EQ(
value_exec_info->HasValue(value),
true,
phi::errors::PreconditionNotMet(
"input should in name map, [%d] 'th input of [%s] op",
i,
"if op"));
std::vector<int> outputs_id =
GetValueIds(value,
inner_scope,
parent_exe_info->GetValue2VarName(),
parent_exe_info->GetVarName2Id(),
parent_exe_info->GetVar2VarName());
std::vector<int> outputs_id = GetValueIds(value, *value_exec_info);
outputs.emplace(value, outputs_id);
}
}
SetOutputs(outputs);
VLOG(6) << "finish process inputs outputs index";
}

void CondInstruction::CopyBranchOutput(
const std::vector<std::string>& var_names, const NewIRInterpreter* inter) {
for (size_t i = 0; i < var_names.size(); ++i) {
auto* inner_var = inter->local_scope()->GetVar(var_names[i]);
auto* inner_var = inter->InnerScope()->GetVar(var_names[i]);

if_op_outputs_[i]->GetMutable<phi::DenseTensor>()->ShareDataWith(
output_vars_[i]->GetMutable<phi::DenseTensor>()->ShareDataWith(
inner_var->Get<phi::DenseTensor>());
}
}

void CondInstruction::Run() {
if (cond_var->Get<phi::DenseTensor>().data<bool>()[0]) {
true_branch_inter->Run({}, false);
CopyBranchOutput(true_skip_gc_names_, true_branch_inter);
DeviceContext().Wait();
if (cond_var_->Get<phi::DenseTensor>().data<bool>()[0]) {
true_branch_inter_->Run({}, false);
CopyBranchOutput(true_skip_gc_names_, true_branch_inter_);
} else {
false_branch_inter->Run({}, false);
CopyBranchOutput(false_skip_gc_names_, false_branch_inter);
false_branch_inter_->Run({}, false);
CopyBranchOutput(false_skip_gc_names_, false_branch_inter_);
}

// copy ouptut
Expand Down
28 changes: 14 additions & 14 deletions paddle/fluid/framework/new_executor/instruction/cond_instruction.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,10 @@ class ValueExecutionInfo;

class CondInstruction : public InstructionBase {
public:
CondInstruction(
size_t id,
const platform::Place& place,
::pir::Operation* op,
Scope* scope,
Scope* local_scope,
ValueExecutionInfo* parent_exe_info,
const std::map<pir::Block*, paddle::framework::Scope*>& sub_blocks);
CondInstruction(size_t id,
const platform::Place& place,
::pir::Operation* op,
ValueExecutionInfo* value_exe_info);

void Run() override;

Expand All @@ -48,19 +44,23 @@ class CondInstruction : public InstructionBase {
void CopyBranchOutput(const std::vector<std::string>& var_names,
const NewIRInterpreter* inter);

::pir::Operation* op_;

std::string cond_name_{"cond_instruction"};

Variable* cond_var;
Variable* cond_var_;

std::vector<Variable*> output_vars_;

std::vector<Variable*> if_op_outputs_;
NewIRInterpreter* true_branch_inter_;

NewIRInterpreter* true_branch_inter;
NewIRInterpreter* false_branch_inter;
NewIRInterpreter* false_branch_inter_;

// TODO(zhangbo): Currently, only the output of IfOp is included. In the
// future, need to consider how to support IfGradOp using IfOp value.
std::vector<std::string> true_skip_gc_names_;
std::vector<std::string> false_skip_gc_names_;

::pir::Operation* op_;
std::vector<std::string> false_skip_gc_names_;
};

} // namespace framework
Expand Down
32 changes: 10 additions & 22 deletions paddle/fluid/framework/new_executor/instruction/instruction_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include "paddle/fluid/framework/new_executor/instruction/instruction_util.h"
#include "paddle/fluid/framework/new_executor/interpreter/interpreter_util.h"
#include "paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.h"
#include "paddle/fluid/platform/profiler/event_tracing.h"

#include "paddle/fluid/framework/new_executor/interpreter/stream_analyzer.h"
Expand Down Expand Up @@ -214,31 +215,22 @@ void InstructionBase::SetOutputs(
}

void InstructionBase::InitInputsOutputsIds(
::pir::Operation* op,
Scope* inner_scope,
const std::unordered_map<pir::Value, std::string>& value_2_var_name,
const std::map<std::string, int>& var_name_2_id,
const std::unordered_map<const paddle::framework::Variable*, std::string>&
variable_2_var_name) {
::pir::Operation* op, const ValueExecutionInfo& value_exec_info) {
auto op_attributes = op->attributes();
auto op_name =
op_attributes.at("op_name").dyn_cast<pir::StrAttribute>().AsString();
std::unordered_map<pir::Value, std::vector<int>> inputs;
for (size_t i = 0; i < op->num_operands(); i++) {
pir::Value value = op->operand_source(i);
if (value) {
PADDLE_ENFORCE_NE(
value_2_var_name.find(value),
value_2_var_name.end(),
PADDLE_ENFORCE_EQ(
value_exec_info.HasValue(value),
true,
phi::errors::PreconditionNotMet(
"input should in name map, [%d] 'th input of [%s] op",
i,
op_name));
std::vector<int> inputs_id = GetValueIds(value,
inner_scope,
value_2_var_name,
var_name_2_id,
variable_2_var_name);
std::vector<int> inputs_id = GetValueIds(value, value_exec_info);
inputs.emplace(value, inputs_id);
}
}
Expand All @@ -248,18 +240,14 @@ void InstructionBase::InitInputsOutputsIds(
for (size_t i = 0; i < op->num_results(); i++) {
pir::Value value = op->result(i);
if (value && value.type()) {
PADDLE_ENFORCE_NE(
value_2_var_name.find(value),
value_2_var_name.end(),
PADDLE_ENFORCE_EQ(
value_exec_info.HasValue(value),
true,
phi::errors::PreconditionNotMet(
"input should in name map, [%d] 'th input of [%s] op",
i,
op_name));
std::vector<int> outputs_id = GetValueIds(value,
inner_scope,
value_2_var_name,
var_name_2_id,
variable_2_var_name);
std::vector<int> outputs_id = GetValueIds(value, value_exec_info);
outputs.emplace(value, outputs_id);
}
}
Expand Down
Loading

0 comments on commit 18c2fb2

Please sign in to comment.