Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 44 additions & 3 deletions clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRLoopToSCF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,12 @@ class SCFWhileLoop {
SCFWhileLoop(cir::WhileOp op, cir::WhileOp::Adaptor adaptor,
mlir::ConversionPatternRewriter *rewriter)
: whileOp(op), adaptor(adaptor), rewriter(rewriter) {}
void transferToSCFWhileOp();
mlir::scf::WhileOp transferToSCFWhileOp();

private:
cir::WhileOp whileOp;
cir::WhileOp::Adaptor adaptor;
mlir::scf::WhileOp scfWhileOp;
mlir::ConversionPatternRewriter *rewriter;
};

Expand Down Expand Up @@ -337,7 +338,7 @@ void SCFLoop::transformToSCFWhileOp() {
scfWhileOp.getAfterBody()->end());
}

void SCFWhileLoop::transferToSCFWhileOp() {
mlir::scf::WhileOp SCFWhileLoop::transferToSCFWhileOp() {
auto scfWhileOp = rewriter->create<mlir::scf::WhileOp>(
whileOp->getLoc(), whileOp->getResultTypes(), adaptor.getOperands());
rewriter->createBlock(&scfWhileOp.getBefore());
Expand All @@ -348,6 +349,7 @@ void SCFWhileLoop::transferToSCFWhileOp() {
rewriter->inlineBlockBefore(&whileOp.getBody().front(),
scfWhileOp.getAfterBody(),
scfWhileOp.getAfterBody()->end());
return scfWhileOp;
}

void SCFDoLoop::transferToSCFWhileOp() {
Expand Down Expand Up @@ -393,14 +395,53 @@ class CIRForOpLowering : public mlir::OpConversionPattern<cir::ForOp> {
};

class CIRWhileOpLowering : public mlir::OpConversionPattern<cir::WhileOp> {
void rewriteContinue(mlir::scf::WhileOp whileOp,
mlir::ConversionPatternRewriter &rewriter) const {
// Collect all ContinueOp inside this while.
llvm::SmallVector<cir::ContinueOp> continues;
whileOp->walk([&](mlir::Operation *op) {
if (auto continueOp = dyn_cast<ContinueOp>(op))
continues.push_back(continueOp);
});

if (continues.empty())
return;

for (auto continueOp : continues) {
// When the break is under an IfOp, a direct replacement of `scf.yield`
// won't work: the yield would jump out of that IfOp instead. We might
// need to change the whileOp itself to achieve the same effect.
for (mlir::Operation *parent = continueOp->getParentOp();
parent != whileOp; parent = parent->getParentOp()) {
if (isa<mlir::scf::IfOp>(parent) || isa<cir::IfOp>(parent))
llvm_unreachable("NYI");
}

// Operations after this break has to be removed.
for (mlir::Operation *runner = continueOp->getNextNode(); runner;) {
mlir::Operation *next = runner->getNextNode();
runner->erase();
runner = next;
}

// Blocks after this break also has to be removed.
for (mlir::Block *block = continueOp->getBlock()->getNextNode(); block;) {
mlir::Block *next = block->getNextNode();
block->erase();
block = next;
}
}
}

public:
using OpConversionPattern<cir::WhileOp>::OpConversionPattern;

mlir::LogicalResult
matchAndRewrite(cir::WhileOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
SCFWhileLoop loop(op, adaptor, &rewriter);
loop.transferToSCFWhileOp();
auto whileOp = loop.transferToSCFWhileOp();
rewriteContinue(whileOp, rewriter);
rewriter.eraseOp(op);
return mlir::success();
}
Expand Down
27 changes: 27 additions & 0 deletions clang/test/CIR/Lowering/ThroughMLIR/while-with-continue.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -fno-clangir-direct-lowering -emit-mlir=core %s -o %t.mlir
// RUN: FileCheck --input-file=%t.mlir %s

void for_with_break() {
int i = 0;
while (i < 100) {
i++;
continue;
i++;
}
// Only the first `i++` will be emitted.

// CHECK: scf.while : () -> () {
// CHECK: %[[TMP0:.+]] = memref.load %alloca[]
// CHECK: %[[HUNDRED:.+]] = arith.constant 100
// CHECK: %[[TMP1:.+]] = arith.cmpi slt, %[[TMP0]], %[[HUNDRED]]
// CHECK: scf.condition(%[[TMP1]])
// CHECK: } do {
// CHECK: memref.alloca_scope {
// CHECK: %[[TMP2:.+]] = memref.load %alloca[]
// CHECK: %[[ONE:.+]] = arith.constant 1
// CHECK: %[[TMP3:.+]] = arith.addi %[[TMP2]], %[[ONE]]
// CHECK: memref.store %[[TMP3]], %alloca[]
// CHECK: }
// CHECK: scf.yield
// CHECK: }
}
Loading