Skip to content

Commit

Permalink
[PIR] adjust the member fucntion name in if_op
Browse files Browse the repository at this point in the history
  • Loading branch information
winter-wang committed Dec 1, 2023
1 parent e5b5f68 commit a4f32bc
Show file tree
Hide file tree
Showing 9 changed files with 57 additions and 57 deletions.
21 changes: 11 additions & 10 deletions paddle/fluid/framework/new_executor/instruction/cond_instruction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,14 @@ CondInstruction::CondInstruction(size_t id,
// NOTE(zhangbo): IfOp sub_block's inputs include two kind of value: one is
// OpOperand of IfOp, and the other is external Values used in true_block or
// false_block.
auto true_branch_block = if_op.true_block();
auto false_branch_block = if_op.false_block();
auto& true_branch_block = if_op.true_block();
auto& false_branch_block = if_op.false_block();
std::unordered_map<pir::Value, std::vector<int>> inputs;
GetInputIds(op, *value_exec_info, &inputs);
auto true_outside_inputs =
GetExternalInputs(true_branch_block, *value_exec_info, &inputs);
GetExternalInputs(&true_branch_block, *value_exec_info, &inputs);
auto false_outside_inputs =
GetExternalInputs(false_branch_block, *value_exec_info, &inputs);
GetExternalInputs(&false_branch_block, *value_exec_info, &inputs);
SetInputs(inputs);

std::unordered_map<pir::Value, std::vector<int>> outputs;
Expand All @@ -88,21 +88,22 @@ CondInstruction::CondInstruction(size_t id,
outputs.emplace(value, GetValueIds(value, *value_exec_info));
}
}
InsertTuplePushContinerToOuts(true_branch_block, *value_exec_info, &outputs);
InsertTuplePushContinerToOuts(false_branch_block, *value_exec_info, &outputs);
InsertTuplePushContinerToOuts(&true_branch_block, *value_exec_info, &outputs);
InsertTuplePushContinerToOuts(
&false_branch_block, *value_exec_info, &outputs);
SetOutputs(outputs);
VLOG(6) << "finish process inputs outputs index";

Scope* true_scope = &(value_exec_info->GetScope()->NewScope());
true_branch_inter_ = new PirInterpreter(place,
{},
true_branch_block,
&true_branch_block,
true_scope,
value_exec_info->NewChild(true_scope),
{});

std::set<std::string> true_skip_gc_names_set;
for (auto value : GetYiedOpInputs(true_branch_block)) {
for (auto value : GetYiedOpInputs(&true_branch_block)) {
true_branch_outputs_.push_back(true_branch_inter_->GetNameByValue(value));
true_skip_gc_names_.push_back(true_branch_inter_->GetNameByValue(value));
true_skip_gc_names_set.insert(true_branch_inter_->GetNameByValue(value));
Expand All @@ -120,13 +121,13 @@ CondInstruction::CondInstruction(size_t id,
false_branch_inter_ =
new PirInterpreter(place,
{},
false_branch_block,
&false_branch_block,
false_scope,
value_exec_info->NewChild(false_scope),
{});

std::set<std::string> false_skip_gc_names_set;
for (auto value : GetYiedOpInputs(false_branch_block)) {
for (auto value : GetYiedOpInputs(&false_branch_block)) {
false_branch_outputs_.push_back(false_branch_inter_->GetNameByValue(value));
false_skip_gc_names_.push_back(false_branch_inter_->GetNameByValue(value));
false_skip_gc_names_set.insert(false_branch_inter_->GetNameByValue(value));
Expand Down
12 changes: 6 additions & 6 deletions paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,15 +90,15 @@ void IfOp::Build(pir::Builder &builder, // NOLINT
argument.AddInput(cond);
}

pir::Block *IfOp::true_block() {
pir::Block &IfOp::true_block() {
pir::Region &region = true_region();
if (region.empty()) region.emplace_back();
return &region.front();
return region.front();
}
pir::Block *IfOp::false_block() {
pir::Block &IfOp::false_block() {
pir::Region &region = false_region();
if (region.empty()) region.emplace_back();
return &region.front();
return region.front();
}

void IfOp::Print(pir::IrPrinter &printer) {
Expand All @@ -110,12 +110,12 @@ void IfOp::Print(pir::IrPrinter &printer) {
os << " -> ";
printer.PrintOpReturnType(op);
os << "{";
for (auto &item : *true_block()) {
for (auto &item : true_block()) {
os << "\n ";
printer.PrintOperation(&item);
}
os << "\n } else {";
for (auto &item : *false_block()) {
for (auto &item : false_block()) {
os << "\n ";
printer.PrintOperation(&item);
}
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/pir/dialect/operator/ir/control_flow_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ class IfOp : public pir::Op<IfOp, VjpInterface> {
std::unique_ptr<pir::Block> &&false_block);

pir::Value cond() { return operand_source(0); }
pir::Block *true_block();
pir::Block *false_block();
pir::Block &true_block();
pir::Block &false_block();
pir::Region &true_region() { return (*this)->region(0); }
pir::Region &false_region() { return (*this)->region(1); }
void Print(pir::IrPrinter &printer); // NOLINT
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/pir/transforms/inplace_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -248,9 +248,9 @@ static void GetEagerDelValueOfOp(

if (op.isa<paddle::dialect::IfOp>()) {
auto if_op = op.dyn_cast<paddle::dialect::IfOp>();
GetEagerDelValueOfOp(if_op.true_block(), skip_dels, del_value_2_op);
GetEagerDelValueOfOp(&if_op.true_block(), skip_dels, del_value_2_op);
VLOG(8) << "GetEagerDelValueOfOp for IfOp true block";
GetEagerDelValueOfOp(if_op.false_block(), skip_dels, del_value_2_op);
GetEagerDelValueOfOp(&if_op.false_block(), skip_dels, del_value_2_op);
VLOG(8) << "GetEagerDelValueOfOp for IfOp false block";
}
}
Expand Down
12 changes: 6 additions & 6 deletions paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -875,19 +875,19 @@ void HandleForIfOp(
auto new_ifop = builder.Build<IfOp>(new_cond, std::move(new_ifop_outputs));

// process true block
pir::Block* true_block = new_ifop.true_block();
auto& true_block = new_ifop.true_block();
ProcessBlock(place,
old_ifop.true_block(),
true_block,
&old_ifop.true_block(),
&true_block,
ctx,
map_op_pair,
map_value_pair);

// process false block
pir::Block* false_block = new_ifop.false_block();
auto& false_block = new_ifop.false_block();
ProcessBlock(place,
old_ifop.false_block(),
false_block,
&old_ifop.false_block(),
&false_block,
ctx,
map_op_pair,
map_value_pair);
Expand Down
8 changes: 4 additions & 4 deletions test/cpp/new_executor/standalone_executor_pir_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -234,17 +234,17 @@ TEST(StandaloneExecutor, if_op) {
auto if_op = builder.Build<paddle::dialect::IfOp>(
full_op.out(), std::vector<pir::Type>{full_op.result(0).type()});

pir::Block* true_block = if_op.true_block();
auto& true_block = if_op.true_block();

builder.SetInsertionPointToStart(true_block);
builder.SetInsertionPointToStart(&true_block);

auto full_op_1 = builder.Build<paddle::dialect::FullOp>(
std::vector<int64_t>{2}, true, phi::DataType::BOOL);
builder.Build<pir::YieldOp>(std::vector<pir::Value>{full_op_1.out()});

pir::Block* false_block = if_op.false_block();
auto& false_block = if_op.false_block();

builder.SetInsertionPointToStart(false_block);
builder.SetInsertionPointToStart(&false_block);

auto full_op_2 = builder.Build<paddle::dialect::FullOp>(
std::vector<int64_t>{3}, true, phi::DataType::BOOL);
Expand Down
20 changes: 10 additions & 10 deletions test/cpp/pir/control_flow_dialect/if_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,17 +51,17 @@ TEST(if_op_test, base) {
auto if_op = builder.Build<paddle::dialect::IfOp>(
full_op.out(), std::vector<pir::Type>{builder.bool_type()});

pir::Block* true_block = if_op.true_block();
auto& true_block = if_op.true_block();

builder.SetInsertionPointToStart(true_block);
builder.SetInsertionPointToStart(&true_block);

auto full_op_1 = builder.Build<paddle::dialect::FullOp>(
std::vector<int64_t>{2}, true, phi::DataType::BOOL);
builder.Build<pir::YieldOp>(std::vector<pir::Value>{full_op_1.out()});

pir::Block* false_block = if_op.false_block();
auto& false_block = if_op.false_block();

builder.SetInsertionPointToStart(false_block);
builder.SetInsertionPointToStart(&false_block);

auto full_op_2 = builder.Build<paddle::dialect::FullOp>(
std::vector<int64_t>{3}, true, phi::DataType::BOOL);
Expand Down Expand Up @@ -110,8 +110,8 @@ TEST(if_op_test, build_by_block) {
vec.push_back(&block);
}
EXPECT_EQ(vec.size(), 2u);
EXPECT_EQ(vec[0], if_op.true_block());
EXPECT_EQ(vec[1], if_op.false_block());
EXPECT_EQ(vec[0], &if_op.true_block());
EXPECT_EQ(vec[1], &if_op.false_block());
}

TEST(if_op_test, network_with_backward) {
Expand All @@ -133,7 +133,7 @@ TEST(if_op_test, network_with_backward) {

auto if_op = builder.Build<IfOp>(cond, std::vector<pir::Type>{x.type()});

builder.SetInsertionPointToStart(if_op.true_block());
builder.SetInsertionPointToStart(&if_op.true_block());

auto local1_z = builder.Build<AddOp>(x, y).out();
auto local1_w = builder.Build<AddOp>(local1_z, y).out();
Expand All @@ -142,7 +142,7 @@ TEST(if_op_test, network_with_backward) {

builder.Build<pir::YieldOp>(std::vector<pir::Value>{local1_w});

builder.SetInsertionPointToStart(if_op.false_block());
builder.SetInsertionPointToStart(&if_op.false_block());
auto local2_z = builder.Build<MatmulOp>(x, y).out();
auto local2_w = builder.Build<MatmulOp>(local2_z, y).out();
builder.Build<pir::TuplePushOp>(inlet_1,
Expand All @@ -158,7 +158,7 @@ TEST(if_op_test, network_with_backward) {
builder.Build<IfOp>(cond, std::vector<pir::Type>{x.type(), y.type()});

// construct the true block of if_grad
builder.SetInsertionPointToStart(if_grad.true_block());
builder.SetInsertionPointToStart(&if_grad.true_block());

auto pop_local1_z =
builder.Build<pir::TuplePopOp>(outlet_0).outlet_element(0);
Expand All @@ -174,7 +174,7 @@ TEST(if_op_test, network_with_backward) {
std::vector<pir::Value>{local1_x_grad, local1_y_grad});

// construct the false block of if_grad
builder.SetInsertionPointToStart(if_grad.false_block());
builder.SetInsertionPointToStart(&if_grad.false_block());
auto pop_local2_z =
builder.Build<pir::TuplePopOp>(outlet_1).outlet_element(0);
auto local2_matmul_grad_op =
Expand Down
25 changes: 12 additions & 13 deletions test/cpp/pir/core/program_translator_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -98,21 +98,21 @@ TEST(OperatorDialectTest, ConditionBlock) {
EXPECT_EQ(op.isa<paddle::dialect::IfOp>(), true);
EXPECT_EQ(op.num_regions(), 2u);
// true block
pir::Block *true_block =
pir::Block &true_block =
op.dyn_cast<paddle::dialect::IfOp>().true_block();
size_t true_id = 0;
for (auto &op1 : *true_block) {
for (auto &op1 : true_block) {
if (true_id == 0 || true_id == 1) {
EXPECT_EQ(op1.isa<paddle::dialect::FullOp>(), true);
}
if (true_id == 2) {
EXPECT_EQ(op1.isa<paddle::dialect::LessThanOp>(), true);
}
if (true_id == 3) {
pir::Block *true_true_block =
auto &true_true_block =
op1.dyn_cast<paddle::dialect::IfOp>().true_block();
size_t true_true_id = 0;
for (auto &op2 : *true_true_block) {
for (auto &op2 : true_true_block) {
if (true_true_id == 0) {
EXPECT_EQ(op2.isa<paddle::dialect::AddOp>(), true);
}
Expand All @@ -121,10 +121,10 @@ TEST(OperatorDialectTest, ConditionBlock) {
}
true_true_id++;
}
pir::Block *false_false_block =
auto &false_false_block =
op1.dyn_cast<paddle::dialect::IfOp>().false_block();
size_t false_false_id = 0;
for (auto &op2 : *false_false_block) {
for (auto &op2 : false_false_block) {
if (false_false_id == 0) {
EXPECT_EQ(op2.isa<paddle::dialect::MultiplyOp>(), true);
}
Expand All @@ -143,10 +143,9 @@ TEST(OperatorDialectTest, ConditionBlock) {
true_id++;
}
// false block
pir::Block *false_block =
op.dyn_cast<paddle::dialect::IfOp>().false_block();
auto &false_block = op.dyn_cast<paddle::dialect::IfOp>().false_block();
size_t false_id = 0;
for (auto &op1 : *false_block) {
for (auto &op1 : false_block) {
if (false_id == 0 || false_id == 1) {
EXPECT_EQ(op1.isa<paddle::dialect::FullOp>(), true);
}
Expand All @@ -156,10 +155,10 @@ TEST(OperatorDialectTest, ConditionBlock) {
if (false_id == 3) {
EXPECT_EQ(op1.isa<paddle::dialect::IfOp>(), true);
// true block
pir::Block *false_true_block =
auto &false_true_block =
op1.dyn_cast<paddle::dialect::IfOp>().true_block();
size_t false_true_id = 0;
for (auto &op2 : *false_true_block) {
for (auto &op2 : false_true_block) {
if (false_true_id == 0) {
EXPECT_EQ(op2.isa<paddle::dialect::AddOp>(), true);
}
Expand All @@ -169,10 +168,10 @@ TEST(OperatorDialectTest, ConditionBlock) {
false_true_id++;
}
// false block
pir::Block *false_false_block =
auto &false_false_block =
op1.dyn_cast<paddle::dialect::IfOp>().true_block();
size_t false_false_id = 0;
for (auto &op2 : *false_false_block) {
for (auto &op2 : false_false_block) {
if (false_false_id == 0) {
EXPECT_EQ(op2.isa<paddle::dialect::AddOp>(), true);
}
Expand Down
8 changes: 4 additions & 4 deletions test/cpp/pir/kernel_dialect/ir_kernel_dialect_pass_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -136,17 +136,17 @@ TEST(kernel_dialect, cond_op_test) {
auto if_op = builder.Build<paddle::dialect::IfOp>(
full_op.out(), std::vector<pir::Type>{full_op.result(0).type()});

pir::Block* true_block = if_op.true_block();
auto& true_block = if_op.true_block();

builder.SetInsertionPointToStart(true_block);
builder.SetInsertionPointToStart(&true_block);

auto full_op_1 = builder.Build<paddle::dialect::FullOp>(
std::vector<int64_t>{2}, true, phi::DataType::BOOL);
builder.Build<pir::YieldOp>(std::vector<pir::Value>{full_op_1.out()});

pir::Block* false_block = if_op.false_block();
auto& false_block = if_op.false_block();

builder.SetInsertionPointToStart(false_block);
builder.SetInsertionPointToStart(&false_block);

auto full_op_2 = builder.Build<paddle::dialect::FullOp>(
std::vector<int64_t>{3}, true, phi::DataType::BOOL);
Expand Down

0 comments on commit a4f32bc

Please sign in to comment.