From f6cf4fe2c34314bd4fbbbd38dcf207a5f6e7e72f Mon Sep 17 00:00:00 2001 From: winter-wang <78149749+winter-wang@users.noreply.github.com> Date: Thu, 12 Oct 2023 16:46:46 +0800 Subject: [PATCH] [PIR] adjust the definition of pd_op.while. (#58024) --- .../dialect/operator/ir/control_flow_op.cc | 7 +++--- .../pir/dialect/operator/ir/control_flow_op.h | 3 +-- .../pir/dialect/control_flow/ir/cf_dialect.cc | 2 +- paddle/pir/dialect/control_flow/ir/cf_ops.cc | 1 - paddle/pir/dialect/control_flow/ir/cf_ops.h | 25 ------------------- .../pir/control_flow_dialect/while_op_test.cc | 7 ++---- 6 files changed, 8 insertions(+), 37 deletions(-) diff --git a/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc b/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc index 557f8c7106000..c47dd600bace4 100644 --- a/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc +++ b/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc @@ -113,10 +113,11 @@ void IfOp::Verify() {} void WhileOp::Build(pir::Builder &builder, // NOLINT pir::OperationArgument &argument, // NOLINT - const std::vector &inputs, - const std::vector &output_types) { + const std::vector &inputs) { argument.AddInputs(inputs); - argument.AddOutputs(output_types); + for (auto val : inputs) { + argument.AddOutput(val.type()); + } argument.AddRegions(2u); } pir::Block *WhileOp::cond_block() { diff --git a/paddle/fluid/pir/dialect/operator/ir/control_flow_op.h b/paddle/fluid/pir/dialect/operator/ir/control_flow_op.h index 99444f78da568..48571d7e501ef 100644 --- a/paddle/fluid/pir/dialect/operator/ir/control_flow_op.h +++ b/paddle/fluid/pir/dialect/operator/ir/control_flow_op.h @@ -53,8 +53,7 @@ class WhileOp : public pir::Op { static void Build(pir::Builder &builder, // NOLINT pir::OperationArgument &argument, // NOLINT - const std::vector &inputs, - const std::vector &output_types); + const std::vector &inputs); pir::Block *cond_block(); pir::Block *body_block(); void Print(pir::IrPrinter &printer); // NOLINT diff --git a/paddle/pir/dialect/control_flow/ir/cf_dialect.cc b/paddle/pir/dialect/control_flow/ir/cf_dialect.cc index 7166af2ece636..ed36c0c81cca6 100644 --- a/paddle/pir/dialect/control_flow/ir/cf_dialect.cc +++ b/paddle/pir/dialect/control_flow/ir/cf_dialect.cc @@ -15,6 +15,6 @@ #include "paddle/pir/dialect/control_flow/ir/cf_ops.h" namespace pir { -void ControlFlowDialect::initialize() { RegisterOps(); } +void ControlFlowDialect::initialize() { RegisterOps(); } } // namespace pir IR_DEFINE_EXPLICIT_TYPE_ID(pir::ControlFlowDialect) diff --git a/paddle/pir/dialect/control_flow/ir/cf_ops.cc b/paddle/pir/dialect/control_flow/ir/cf_ops.cc index 69dce41e62bad..7981a6ab96396 100644 --- a/paddle/pir/dialect/control_flow/ir/cf_ops.cc +++ b/paddle/pir/dialect/control_flow/ir/cf_ops.cc @@ -24,4 +24,3 @@ void YieldOp::Build(Builder &builder, } // namespace pir IR_DEFINE_EXPLICIT_TYPE_ID(pir::YieldOp) -IR_DEFINE_EXPLICIT_TYPE_ID(pir::CondYieldOp) diff --git a/paddle/pir/dialect/control_flow/ir/cf_ops.h b/paddle/pir/dialect/control_flow/ir/cf_ops.h index 898f954e09d5f..fe3e965fede8f 100644 --- a/paddle/pir/dialect/control_flow/ir/cf_ops.h +++ b/paddle/pir/dialect/control_flow/ir/cf_ops.h @@ -30,31 +30,6 @@ class IR_API YieldOp : public Op { const std::vector &Value); void Verify() {} }; - -class IR_API CondYieldOp : public Op { - public: - using Op::Op; - static const char *name() { return "cf.cond_yield"; } - static constexpr uint32_t attributes_num = 0; - static constexpr const char **attributes_name = nullptr; - - template - static void Build(Builder &builder, // NOLINT - OperationArgument &argument, // NOLINT - Value cond, - const ValueContainer &inputs); - void Verify() {} -}; - -template -void CondYieldOp::Build(Builder &builder, // NOLINT - OperationArgument &argument, // NOLINT - Value cond, - const ValueContainer &inputs) { - argument.AddInput(cond); - argument.AddInputs(inputs); -} } // namespace pir IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::YieldOp); -IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::CondYieldOp); diff --git a/test/cpp/pir/control_flow_dialect/while_op_test.cc b/test/cpp/pir/control_flow_dialect/while_op_test.cc index 609f1f8eb8d2e..9b3d9a5548b8a 100644 --- a/test/cpp/pir/control_flow_dialect/while_op_test.cc +++ b/test/cpp/pir/control_flow_dialect/while_op_test.cc @@ -41,9 +41,7 @@ TEST(while_op_test, base) { builder.Build(std::vector{1}, 10, phi::DataType::INT32) .out(); - auto while_op = builder.Build( - std::vector{i, ten}, - std::vector{builder.int32_type(), builder.int32_type()}); + auto while_op = builder.Build(std::vector{i, ten}); // while(i < ten) pir::Block* cond_block = while_op.cond_block(); @@ -52,8 +50,7 @@ TEST(while_op_test, base) { builder.SetInsertionPointToStart(cond_block); auto cond_value = builder.Build(cond_i_argument, cond_ten_argument).out(); - builder.Build( - cond_value, std::vector{cond_i_argument, cond_ten_argument}); + builder.Build(std::vector{cond_value}); // { i = i + 1} pir::Block* body_block = while_op.body_block();