diff --git a/clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRLoopToSCF.cpp b/clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRLoopToSCF.cpp index 148038877649..7976ebfe2e5a 100644 --- a/clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRLoopToSCF.cpp +++ b/clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRLoopToSCF.cpp @@ -63,11 +63,12 @@ class SCFWhileLoop { SCFWhileLoop(cir::WhileOp op, cir::WhileOp::Adaptor adaptor, mlir::ConversionPatternRewriter *rewriter) : whileOp(op), adaptor(adaptor), rewriter(rewriter) {} - void transferToSCFWhileOp(); + mlir::scf::WhileOp transferToSCFWhileOp(); private: cir::WhileOp whileOp; cir::WhileOp::Adaptor adaptor; + mlir::scf::WhileOp scfWhileOp; mlir::ConversionPatternRewriter *rewriter; }; @@ -337,7 +338,7 @@ void SCFLoop::transformToSCFWhileOp() { scfWhileOp.getAfterBody()->end()); } -void SCFWhileLoop::transferToSCFWhileOp() { +mlir::scf::WhileOp SCFWhileLoop::transferToSCFWhileOp() { auto scfWhileOp = rewriter->create( whileOp->getLoc(), whileOp->getResultTypes(), adaptor.getOperands()); rewriter->createBlock(&scfWhileOp.getBefore()); @@ -348,6 +349,7 @@ void SCFWhileLoop::transferToSCFWhileOp() { rewriter->inlineBlockBefore(&whileOp.getBody().front(), scfWhileOp.getAfterBody(), scfWhileOp.getAfterBody()->end()); + return scfWhileOp; } void SCFDoLoop::transferToSCFWhileOp() { @@ -393,6 +395,44 @@ class CIRForOpLowering : public mlir::OpConversionPattern { }; class CIRWhileOpLowering : public mlir::OpConversionPattern { + void rewriteContinue(mlir::scf::WhileOp whileOp, + mlir::ConversionPatternRewriter &rewriter) const { + // Collect all ContinueOp inside this while. + llvm::SmallVector continues; + whileOp->walk([&](mlir::Operation *op) { + if (auto continueOp = dyn_cast(op)) + continues.push_back(continueOp); + }); + + if (continues.empty()) + return; + + for (auto continueOp : continues) { + // When the break is under an IfOp, a direct replacement of `scf.yield` + // won't work: the yield would jump out of that IfOp instead. We might + // need to change the whileOp itself to achieve the same effect. + for (mlir::Operation *parent = continueOp->getParentOp(); + parent != whileOp; parent = parent->getParentOp()) { + if (isa(parent) || isa(parent)) + llvm_unreachable("NYI"); + } + + // Operations after this break has to be removed. + for (mlir::Operation *runner = continueOp->getNextNode(); runner;) { + mlir::Operation *next = runner->getNextNode(); + runner->erase(); + runner = next; + } + + // Blocks after this break also has to be removed. + for (mlir::Block *block = continueOp->getBlock()->getNextNode(); block;) { + mlir::Block *next = block->getNextNode(); + block->erase(); + block = next; + } + } + } + public: using OpConversionPattern::OpConversionPattern; @@ -400,7 +440,8 @@ class CIRWhileOpLowering : public mlir::OpConversionPattern { matchAndRewrite(cir::WhileOp op, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override { SCFWhileLoop loop(op, adaptor, &rewriter); - loop.transferToSCFWhileOp(); + auto whileOp = loop.transferToSCFWhileOp(); + rewriteContinue(whileOp, rewriter); rewriter.eraseOp(op); return mlir::success(); } diff --git a/clang/test/CIR/Lowering/ThroughMLIR/while-with-continue.cpp b/clang/test/CIR/Lowering/ThroughMLIR/while-with-continue.cpp new file mode 100644 index 000000000000..e016a2650c97 --- /dev/null +++ b/clang/test/CIR/Lowering/ThroughMLIR/while-with-continue.cpp @@ -0,0 +1,27 @@ +// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -fno-clangir-direct-lowering -emit-mlir=core %s -o %t.mlir +// RUN: FileCheck --input-file=%t.mlir %s + +void for_with_break() { + int i = 0; + while (i < 100) { + i++; + continue; + i++; + } + // Only the first `i++` will be emitted. + + // CHECK: scf.while : () -> () { + // CHECK: %[[TMP0:.+]] = memref.load %alloca[] + // CHECK: %[[HUNDRED:.+]] = arith.constant 100 + // CHECK: %[[TMP1:.+]] = arith.cmpi slt, %[[TMP0]], %[[HUNDRED]] + // CHECK: scf.condition(%[[TMP1]]) + // CHECK: } do { + // CHECK: memref.alloca_scope { + // CHECK: %[[TMP2:.+]] = memref.load %alloca[] + // CHECK: %[[ONE:.+]] = arith.constant 1 + // CHECK: %[[TMP3:.+]] = arith.addi %[[TMP2]], %[[ONE]] + // CHECK: memref.store %[[TMP3]], %alloca[] + // CHECK: } + // CHECK: scf.yield + // CHECK: } +}