From c3b007854f87320bfd522e42e76f7f96945f861f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=82=85=E5=89=91=E5=AF=92?= Date: Fri, 15 Sep 2023 17:09:48 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90CINN=E3=80=91Unify=20for,ifthenelse=20?= =?UTF-8?q?expression=20(#57312)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * unify for,ifthenelse expression * delete logic about simplify block in ifthenelse * fix test case * delete comment --- paddle/cinn/backends/codegen_c.cc | 22 ++-------------------- paddle/cinn/backends/ir_schedule_test.cc | 4 ---- paddle/cinn/ir/ir.cc | 8 ++++++-- paddle/cinn/ir/utils/ir_printer.cc | 18 ++---------------- paddle/cinn/optim/ir_simplify.cc | 17 ----------------- 5 files changed, 10 insertions(+), 59 deletions(-) diff --git a/paddle/cinn/backends/codegen_c.cc b/paddle/cinn/backends/codegen_c.cc index 3352a458ceceb..2345bf53d36cd 100644 --- a/paddle/cinn/backends/codegen_c.cc +++ b/paddle/cinn/backends/codegen_c.cc @@ -285,31 +285,13 @@ void CodeGenC::Visit(const ir::Select *op) { void CodeGenC::Visit(const ir::IfThenElse *op) { str_ += "if ("; IrPrinter::Visit(op->condition); - str_ += ") {\n"; + str_ += ") "; - if (!op->true_case.As()) IncIndent(); - DoIndent(); IrPrinter::Visit(op->true_case); - if (!op->true_case.As()) str_ += ";"; - str_ += "\n"; - - if (!op->true_case.As()) DecIndent(); - - DoIndent(); - str_ += "}"; if (op->false_case.defined()) { - str_ += " else {\n"; - - if (!op->true_case.As()) IncIndent(); - DoIndent(); + str_ += " else "; IrPrinter::Visit(op->false_case); - if (!op->false_case.As()) str_ += ";"; - str_ += "\n"; - if (!op->true_case.As()) DecIndent(); - - DoIndent(); - str_ += "}"; } } void CodeGenC::Visit(const ir::Block *op) { diff --git a/paddle/cinn/backends/ir_schedule_test.cc b/paddle/cinn/backends/ir_schedule_test.cc index fa2b7b7299891..a8126993c5eae 100644 --- a/paddle/cinn/backends/ir_schedule_test.cc +++ b/paddle/cinn/backends/ir_schedule_test.cc @@ -794,10 +794,8 @@ void test_simple_compute_at(void* _args, int32_t num_args) for (int32_t i_j_fused_1 = 0; i_j_fused_1 < 2; i_j_fused_1 += 1) { for (int32_t i_j_fused_2 = 0; i_j_fused_2 < 1024; i_j_fused_2 += 1) { if ((((1024 * i_j_fused_1) + i_j_fused_2) < 1280)) { - { B[((1024 * i_j_fused_1) + i_j_fused_2)] = A[((1024 * i_j_fused_1) + i_j_fused_2)]; C[((1024 * i_j_fused_1) + i_j_fused_2)] = B[((1024 * i_j_fused_1) + i_j_fused_2)]; - } }; }; }; @@ -869,10 +867,8 @@ void test_compute_at0(void* _args, int32_t num_args) for (int32_t i_j_fused_1 = 0; i_j_fused_1 < 2; i_j_fused_1 += 1) { for (int32_t i_j_fused_2 = 0; i_j_fused_2 < 1024; i_j_fused_2 += 1) { if ((((1024 * i_j_fused_1) + i_j_fused_2) < 1280)) { - { B[((1024 * i_j_fused_1) + i_j_fused_2)] = A[((1024 * i_j_fused_1) + i_j_fused_2)]; C[((1024 * i_j_fused_1) + i_j_fused_2)] = B[((1024 * i_j_fused_1) + i_j_fused_2)]; - } }; }; }; diff --git a/paddle/cinn/ir/ir.cc b/paddle/cinn/ir/ir.cc index a1e92c75c290c..5427a14afa5ba 100644 --- a/paddle/cinn/ir/ir.cc +++ b/paddle/cinn/ir/ir.cc @@ -257,7 +257,7 @@ Expr For::Make(Var loop_var, node->min = min; node->extent = extent; node->device_api = device_api; - node->body = body; + node->body = body.As() ? body : ir::Block::Make({body}); node->set_for_type(for_type); node->set_vectorize_info(vector_info); node->set_bind_info(bind_info); @@ -346,6 +346,10 @@ std::vector ScheduleBlockRealize::expr_fields() const { } Expr IfThenElse::Make(Expr condition, Expr true_case, Expr false_case) { + if (true_case.defined() && (!true_case.As())) + true_case = ir::Block::Make({true_case}); + if (false_case.defined() && (!false_case.As())) + false_case = ir::Block::Make({false_case}); auto node = make_shared(condition, true_case, false_case); return Expr(node); } @@ -513,7 +517,7 @@ Expr PolyFor::Make(Var iterator, n->condition = condition; n->inc = inc; n->device_api = device_api; - n->body = body; + n->body = body.As() ? body : ir::Block::Make({body}); n->set_for_type(for_type); n->set_vectorize_info(vectorize_info); n->set_bind_info(bind_info); diff --git a/paddle/cinn/ir/utils/ir_printer.cc b/paddle/cinn/ir/utils/ir_printer.cc index 985214ebdb88a..5b3eb6c20f1cb 100644 --- a/paddle/cinn/ir/utils/ir_printer.cc +++ b/paddle/cinn/ir/utils/ir_printer.cc @@ -229,26 +229,12 @@ void IrPrinter::Visit(const PolyFor *x) { void IrPrinter::Visit(const IfThenElse *x) { str_ += "if ("; Visit(x->condition); - str_ += ") {\n"; - IncIndent(); - DoIndent(); + str_ += ") "; Visit(x->true_case); - DecIndent(); - str_ += "\n"; - DoIndent(); - str_ += "}"; if (x->false_case.defined()) { - str_ += " else {\n"; - IncIndent(); - - DoIndent(); + str_ += " else "; Visit(x->false_case); - str_ += "\n"; - - DecIndent(); - DoIndent(); - str_ += "}"; } } void IrPrinter::Visit(const Block *x) { diff --git a/paddle/cinn/optim/ir_simplify.cc b/paddle/cinn/optim/ir_simplify.cc index 6cf3fcf4b7be8..7166f91ac63fd 100644 --- a/paddle/cinn/optim/ir_simplify.cc +++ b/paddle/cinn/optim/ir_simplify.cc @@ -306,23 +306,6 @@ struct SimplifyBlocksMutator : public ir::IRMutator<> { } } - void Visit(const IfThenElse* op, Expr* expr) override { - auto* node = expr->As(); - Visit(&node->condition, &node->condition); - if (node->true_case.As() && - (node->true_case.As()->stmts.size() == 1)) { - node->true_case = node->true_case.As()->stmts[0]; - } - Visit(&node->true_case, &node->true_case); - if (node->false_case.defined()) { - if (node->false_case.As() && - (node->false_case.As()->stmts.size() == 1)) { - node->false_case = node->false_case.As()->stmts[0]; - } - Visit(&node->false_case, &node->false_case); - } - } - void Visit(const ScheduleBlock* op, Expr* expr) override { auto* node = expr->As(); CHECK(node);