Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PIR] adjust the implementation of pd_op.while. #58089

Merged
merged 1 commit into from
Oct 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 29 additions & 87 deletions paddle/fluid/framework/new_executor/instruction/while_instruction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,9 @@ WhileInstruction::WhileInstruction(size_t id,

auto while_op = op->dyn_cast<paddle::dialect::WhileOp>();

for (size_t i = 0; i < while_op.num_operands(); ++i) {
cond_var_ = inner_scope->GetVar(
parent_exe_info->GetValue2VarName().at(while_op.operand_source(0)));
for (size_t i = 1; i < while_op.num_operands(); ++i) {
while_op_inputs_.push_back(inner_scope->GetVar(
parent_exe_info->GetValue2VarName().at(while_op.operand_source(i))));
}
Expand All @@ -73,32 +75,8 @@ WhileInstruction::WhileInstruction(size_t id,
parent_exe_info->GetValue2VarName().at(while_op.result(i))));
}

cond_block_ = while_op.cond_block();
body_block_ = while_op.body_block();

auto cond_yied_inputs = GetYiedOpInputs(cond_block_);
auto body_yied_inputs = GetYiedOpInputs(body_block_);

Scope* cond_scope = &(parent_exe_info->GetScope()->NewScope());
auto cond_exe_info = parent_exe_info->NewChild(cond_scope);
for (size_t i = 0; i < cond_block_->args_size(); ++i) {
auto var_name = "block_arg_" + std::to_string(i);
cond_scope->Var(var_name);
cond_exe_info->Add(cond_block_->argument(i), var_name);
}
cond_inter_ = std::unique_ptr<NewIRInterpreter>(new NewIRInterpreter(
place, {}, cond_block_, cond_scope, cond_exe_info, {}));

std::set<std::string> cond_skip_gc_names_set;
for (auto value : cond_yied_inputs) {
cond_skip_gc_names_.push_back(cond_inter_->GetNameByValue(value));
cond_skip_gc_names_set.insert(cond_inter_->GetNameByValue(value));
}
cond_inter_->SetSkipGcVars(cond_skip_gc_names_set);

auto cond_value = cond_yied_inputs[0];
auto var_name = cond_inter_->GetNameByValue(cond_value);
cond_var = cond_inter_->local_scope()->GetVar(var_name);
auto body_block_outputs = GetYiedOpInputs(body_block_);

Scope* body_scope = &(parent_exe_info->GetScope()->NewScope());
auto body_exe_info = parent_exe_info->NewChild(body_scope);
Expand All @@ -111,31 +89,15 @@ WhileInstruction::WhileInstruction(size_t id,
place, {}, body_block_, body_scope, body_exe_info, {}));

std::set<std::string> body_skip_gc_names_set;
for (auto value : body_yied_inputs) {
for (auto value : body_block_outputs) {
body_skip_gc_names_.push_back(body_inter_->GetNameByValue(value));
body_skip_gc_names_set.insert(body_inter_->GetNameByValue(value));
}
body_inter_->SetSkipGcVars(body_skip_gc_names_set);

// the cond block and body block input also be the while_op inputs

std::unordered_map<pir::Value, std::vector<int>> inputs;
GetInputIds(op, *parent_exe_info, &inputs);

// TODO(phlrain): process cond and body block
// GetOutsideOpInputs(cond_block_,
// inner_scope,
// parent_exe_info->GetValue2VarName(),
// parent_exe_info->GetVarName2Id(),
// parent_exe_info->GetVar2VarName(),
// &inputs);

// GetOutsideOpInputs(body_block_,
// inner_scope,
// parent_exe_info->GetValue2VarName(),
// parent_exe_info->GetVarName2Id(),
// parent_exe_info->GetVar2VarName(),
// &inputs);
SetInputs(inputs);

std::unordered_map<pir::Value, std::vector<int>> outputs;
Expand All @@ -146,72 +108,52 @@ WhileInstruction::WhileInstruction(size_t id,
parent_exe_info->GetValue2VarName().find(value),
parent_exe_info->GetValue2VarName().end(),
phi::errors::PreconditionNotMet(
"input should in name map, [%d] 'th input of [%s] op",
"output should in name map, [%d] 'th output of [%s] op",
i,
"if op"));
"while op"));
std::vector<int> outputs_id = GetValueIds(value, *parent_exe_info);
outputs.emplace(value, outputs_id);
}
}
SetOutputs(outputs);
}

void WhileInstruction::CopyStepOutput() {
for (size_t i = 0; i < body_skip_gc_names_.size(); ++i) {
auto* inner_var =
body_inter_->local_scope()->GetVar(body_skip_gc_names_[i]);

void WhileInstruction::CopyInputsToOutputs() {
for (size_t i = 0; i < while_op_outputs_.size(); ++i) {
while_op_outputs_[i]->GetMutable<phi::DenseTensor>()->ShareDataWith(
inner_var->Get<phi::DenseTensor>());
while_op_inputs_[i]->Get<phi::DenseTensor>());
}
}

void WhileInstruction::CopyWhileInputToBlockArgs(const NewIRInterpreter* inter,
::pir::Block* block) {
for (size_t i = 0; i < block->args_size(); ++i) {
auto block_arg = block->argument(i);
auto var_name = inter->GetNameByValue(block_arg);
auto* inner_var = inter->local_scope()->GetVar(var_name);
void WhileInstruction::PassArgsToBodyBlock() {
for (size_t i = 0; i < body_block_->args_size(); ++i) {
auto block_arg = body_block_->argument(i);
auto var_name = body_inter_->GetNameByValue(block_arg);
auto* inner_var = body_inter_->local_scope()->GetVar(var_name);
inner_var->GetMutable<phi::DenseTensor>()->ShareDataWith(
while_op_inputs_[i]->Get<phi::DenseTensor>());
while_op_outputs_[i]->Get<phi::DenseTensor>());
}
}

void WhileInstruction::CopyStepOutputToBlockArgs(const NewIRInterpreter* inter,
::pir::Block* block) {
for (size_t i = 0; i < block->args_size(); ++i) {
auto out_var_name = body_skip_gc_names_[i];
void WhileInstruction::GetValueFromBodyBlock() {
cond_var_->GetMutable<phi::DenseTensor>()->ShareDataWith(
body_inter_->local_scope()
->GetVar(body_skip_gc_names_[0])
->Get<phi::DenseTensor>());
for (size_t i = 0; i < while_op_outputs_.size(); ++i) {
auto& out_var_name = body_skip_gc_names_[i + 1];
auto* out_var = body_inter_->local_scope()->GetVar(out_var_name);

auto block_arg = block->argument(i);
auto block_in_var_name = inter->GetNameByValue(block_arg);

auto* inner_var = inter->local_scope()->GetVar(block_in_var_name);

inner_var->GetMutable<phi::DenseTensor>()->ShareDataWith(
while_op_outputs_[i]->GetMutable<phi::DenseTensor>()->ShareDataWith(
out_var->Get<phi::DenseTensor>());
}
}

void WhileInstruction::Run() {
CopyWhileInputToBlockArgs(cond_inter_.get(), cond_block_);
CopyWhileInputToBlockArgs(body_inter_.get(), body_block_);

while (true) {
cond_inter_->Run({}, false);

if (cond_var->Get<phi::DenseTensor>().data<bool>()[0]) {
body_inter_->Run({}, false);

CopyStepOutputToBlockArgs(cond_inter_.get(), cond_block_);
CopyStepOutputToBlockArgs(body_inter_.get(), body_block_);
} else {
break;
}
CopyInputsToOutputs();
while (cond_var_->Get<phi::DenseTensor>().data<bool>()[0]) {
PassArgsToBodyBlock();
body_inter_->Run({}, false);
GetValueFromBodyBlock();
}

// copy output
CopyStepOutput();
}

} // namespace framework
Expand Down
24 changes: 13 additions & 11 deletions paddle/fluid/framework/new_executor/instruction/while_instruction.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,12 @@ class Value;
class NewIRInterpreter;
class ValueExecutionInfo;

/// The execute semantics of while op ['output' = while_op('cond', 'intput')]
/// is:
/// 'output' = 'input';
/// while('cond') {
/// 'cond', 'output' = body_block('output');
/// }
class WhileInstruction : public InstructionBase {
public:
WhileInstruction(size_t id,
Expand All @@ -43,28 +49,24 @@ class WhileInstruction : public InstructionBase {
::pir::Operation* Operation() const override { return op_; }

private:
void CopyStepOutput();
// 'output' = 'input'
void CopyInputsToOutputs();

void CopyWhileInputToBlockArgs(const NewIRInterpreter* inter,
::pir::Block* block);
// Pass argument to body_block for execution.
void PassArgsToBodyBlock();

void CopyStepOutputToBlockArgs(const NewIRInterpreter* inter,
::pir::Block* block);
// Get return value from body_block after each execution.
void GetValueFromBodyBlock();

std::string cond_name_{"while_instruction"};

Variable* cond_var;

Variable* cond_var_;
std::vector<Variable*> while_op_inputs_;
std::vector<Variable*> while_op_outputs_;

std::unique_ptr<NewIRInterpreter> cond_inter_;
std::unique_ptr<NewIRInterpreter> body_inter_;

std::vector<std::string> cond_skip_gc_names_;
std::vector<std::string> body_skip_gc_names_;

::pir::Block* cond_block_;
::pir::Block* body_block_;

::pir::Operation* op_;
Expand Down
38 changes: 21 additions & 17 deletions paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ paddle::dialect::IfOp, paddle::dialect::WhileOp
#include "paddle/pir/core/builtin_type.h"
#include "paddle/pir/core/ir_printer.h"
#include "paddle/pir/core/operation_utils.h"
#include "paddle/pir/core/utils.h"
#include "paddle/pir/dialect/control_flow/ir/cf_ops.h"

namespace paddle {
Expand Down Expand Up @@ -181,38 +182,41 @@ void IfOp::VerifyRegion() {

void WhileOp::Build(pir::Builder &builder, // NOLINT
pir::OperationArgument &argument, // NOLINT
pir::Value cond,
const std::vector<pir::Value> &inputs) {
argument.AddInput(cond);
argument.AddInputs(inputs);
for (auto val : inputs) {
argument.AddOutput(val.type());
}
argument.AddRegions(2u);
}
pir::Block *WhileOp::cond_block() {
pir::Region &cond_region = (*this)->region(0);
if (cond_region.empty()) cond_region.emplace_back();
return cond_region.front();
argument.AddRegion(nullptr);
}
pir::Block *WhileOp::body_block() {
pir::Region &body_region = (*this)->region(1);
pir::Region &body_region = (*this)->region(0);
if (body_region.empty()) body_region.emplace_back();
return body_region.front();
}
pir::Value WhileOp::cond() { return (*this)->operand_source(0); }

void WhileOp::Print(pir::IrPrinter &printer) {
auto &os = printer.os;
auto op = operation();
printer.PrintOpResult(op);
os << " \"" << name() << "\"";
printer.PrintOpOperands(op);
os << " -> ";
printer.PrintOpReturnType(op);
os << "{";
for (auto item : *cond_block()) {
os << "\n ";
printer.PrintOperation(item);
}
os << "\n } do {";
os << " = \"" << name() << "\"(";
printer.PrintValue(cond());
os << ") [";
auto operands = (*this)->operands_source();
pir::PrintInterleave(
operands.begin() + 1,
operands.end(),
[&](pir::Value v) { printer.PrintValue(v); },
[&]() { os << ", "; });
os << "] { \n ^";
pir::PrintInterleave(
body_block()->args_begin(),
body_block()->args_end(),
[&](pir::Value v) { printer.PrintValue(v); },
[&]() { os << ", "; });
for (auto item : *body_block()) {
os << "\n ";
printer.PrintOperation(item);
Expand Down
13 changes: 12 additions & 1 deletion paddle/fluid/pir/dialect/operator/ir/control_flow_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,16 @@ class IfOp : public pir::Op<IfOp> {
void VerifyRegion();
};

///
/// \brief The WhileOp is an operation that iterates over a loop body based on a
/// condition. It takes two inputs: cond_value and loop_vars. The output of the
/// WhileOp must have the same arity (length and structure) with loop_vars." The
/// semantics of WhileOp[outputs = while_op(cond, inputs)] are as below:
/// outputs = inputs
/// while(cond){
/// cond, outputs = body(outputs)
/// }
///
class WhileOp : public pir::Op<WhileOp> {
public:
using Op::Op;
Expand All @@ -54,9 +64,10 @@ class WhileOp : public pir::Op<WhileOp> {

static void Build(pir::Builder &builder, // NOLINT
pir::OperationArgument &argument, // NOLINT
pir::Value cond,
const std::vector<pir::Value> &inputs);
pir::Block *cond_block();
pir::Block *body_block();
pir::Value cond();
void Print(pir::IrPrinter &printer); // NOLINT
void VerifySig() {}
void VerifyRegion() {}
Expand Down
31 changes: 6 additions & 25 deletions paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -776,6 +776,7 @@ void HandleForWhileOp(
std::unordered_map<pir::Operation*, pir::Operation*>* map_op_pair,
std::unordered_map<pir::Value, pir::Value>* map_value_pair) {
std::vector<pir::Value> vec_in;
pir::Value cond_val;
for (size_t i = 0; i < op_item->num_operands(); ++i) {
auto cur_in = op_item->operand_source(i);

Expand All @@ -785,36 +786,16 @@ void HandleForWhileOp(
phi::errors::PreconditionNotMet(
"[%d]'s input of [%s] op MUST in map pair", 0, op_item->name()));
auto new_in = map_value_pair->at(cur_in);

vec_in.push_back(new_in);
if (i == 0)
cond_val = new_in;
else
vec_in.push_back(new_in);
}

pir::Builder builder(ctx, block);

auto base_while_op = op_item->dyn_cast<paddle::dialect::WhileOp>();
std::vector<pir::Type> op_output_types;
for (size_t i = 0; i < base_while_op.num_results(); ++i) {
op_output_types.push_back(paddle::dialect::AllocatedDenseTensorType::get(
ctx,
place,
base_while_op.result(i).type().dyn_cast<dialect::DenseTensorType>()));
}
auto new_while_op = builder.Build<paddle::dialect::WhileOp>(vec_in);

pir::Block* cond_block = new_while_op.cond_block();
for (size_t i = 0; i < vec_in.size(); ++i) {
auto block_arg = cond_block->AddArgument(vec_in[i].type());
(*map_value_pair)[base_while_op.cond_block()->argument(i)] = block_arg;
}

// process cond block
ProcessBlock(place,
base_while_op.cond_block(),
cond_block,
ctx,
map_op_pair,
map_value_pair);

auto new_while_op = builder.Build<paddle::dialect::WhileOp>(cond_val, vec_in);
pir::Block* body_block = new_while_op.body_block();
for (size_t i = 0; i < vec_in.size(); ++i) {
auto block_arg = body_block->AddArgument(vec_in[i].type());
Expand Down
1 change: 1 addition & 0 deletions paddle/pir/core/block.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ class IR_API Block {
ArgsIterator args_end() { return arguments_.end(); }
bool args_empty() const { return arguments_.empty(); }
uint32_t args_size() const { return arguments_.size(); }
const BlockArgListType &args() const { return arguments_; }
BlockArgument argument(uint32_t index) { return arguments_[index]; }
Type argument_type(uint32_t index) const { return arguments_[index].type(); }
void ClearArguments();
Expand Down
Loading