Skip to content

Commit

Permalink
[PIR] adjust the definition of pd_op.while. (PaddlePaddle#58024)
Browse files Browse the repository at this point in the history
  • Loading branch information
winter-wang authored and Frida-a committed Oct 14, 2023
1 parent b2c36e5 commit f6cf4fe
Show file tree
Hide file tree
Showing 6 changed files with 8 additions and 37 deletions.
7 changes: 4 additions & 3 deletions paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -113,10 +113,11 @@ void IfOp::Verify() {}

void WhileOp::Build(pir::Builder &builder, // NOLINT
pir::OperationArgument &argument, // NOLINT
const std::vector<pir::Value> &inputs,
const std::vector<pir::Type> &output_types) {
const std::vector<pir::Value> &inputs) {
argument.AddInputs(inputs);
argument.AddOutputs(output_types);
for (auto val : inputs) {
argument.AddOutput(val.type());
}
argument.AddRegions(2u);
}
pir::Block *WhileOp::cond_block() {
Expand Down
3 changes: 1 addition & 2 deletions paddle/fluid/pir/dialect/operator/ir/control_flow_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,7 @@ class WhileOp : public pir::Op<WhileOp> {

static void Build(pir::Builder &builder, // NOLINT
pir::OperationArgument &argument, // NOLINT
const std::vector<pir::Value> &inputs,
const std::vector<pir::Type> &output_types);
const std::vector<pir::Value> &inputs);
pir::Block *cond_block();
pir::Block *body_block();
void Print(pir::IrPrinter &printer); // NOLINT
Expand Down
2 changes: 1 addition & 1 deletion paddle/pir/dialect/control_flow/ir/cf_dialect.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,6 @@
#include "paddle/pir/dialect/control_flow/ir/cf_ops.h"

namespace pir {
void ControlFlowDialect::initialize() { RegisterOps<YieldOp, CondYieldOp>(); }
void ControlFlowDialect::initialize() { RegisterOps<YieldOp>(); }
} // namespace pir
IR_DEFINE_EXPLICIT_TYPE_ID(pir::ControlFlowDialect)
1 change: 0 additions & 1 deletion paddle/pir/dialect/control_flow/ir/cf_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,3 @@ void YieldOp::Build(Builder &builder,
} // namespace pir

IR_DEFINE_EXPLICIT_TYPE_ID(pir::YieldOp)
IR_DEFINE_EXPLICIT_TYPE_ID(pir::CondYieldOp)
25 changes: 0 additions & 25 deletions paddle/pir/dialect/control_flow/ir/cf_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,31 +30,6 @@ class IR_API YieldOp : public Op<YieldOp> {
const std::vector<Value> &Value);
void Verify() {}
};

class IR_API CondYieldOp : public Op<CondYieldOp> {
public:
using Op::Op;
static const char *name() { return "cf.cond_yield"; }
static constexpr uint32_t attributes_num = 0;
static constexpr const char **attributes_name = nullptr;

template <class ValueContainer>
static void Build(Builder &builder, // NOLINT
OperationArgument &argument, // NOLINT
Value cond,
const ValueContainer &inputs);
void Verify() {}
};

template <class ValueContainer>
void CondYieldOp::Build(Builder &builder, // NOLINT
OperationArgument &argument, // NOLINT
Value cond,
const ValueContainer &inputs) {
argument.AddInput(cond);
argument.AddInputs(inputs);
}
} // namespace pir

IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::YieldOp);
IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::CondYieldOp);
7 changes: 2 additions & 5 deletions test/cpp/pir/control_flow_dialect/while_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,7 @@ TEST(while_op_test, base) {
builder.Build<FullOp>(std::vector<int64_t>{1}, 10, phi::DataType::INT32)
.out();

auto while_op = builder.Build<WhileOp>(
std::vector<pir::Value>{i, ten},
std::vector<pir::Type>{builder.int32_type(), builder.int32_type()});
auto while_op = builder.Build<WhileOp>(std::vector<pir::Value>{i, ten});

// while(i < ten)
pir::Block* cond_block = while_op.cond_block();
Expand All @@ -52,8 +50,7 @@ TEST(while_op_test, base) {
builder.SetInsertionPointToStart(cond_block);
auto cond_value =
builder.Build<LessThanOp>(cond_i_argument, cond_ten_argument).out();
builder.Build<pir::CondYieldOp>(
cond_value, std::vector<pir::Value>{cond_i_argument, cond_ten_argument});
builder.Build<pir::YieldOp>(std::vector<pir::Value>{cond_value});

// { i = i + 1}
pir::Block* body_block = while_op.body_block();
Expand Down

0 comments on commit f6cf4fe

Please sign in to comment.