Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][scf] Allow unrolling loops with integer-typed IV. #106164

Merged
merged 4 commits into from
Aug 29, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 21 additions & 14 deletions mlir/lib/Dialect/SCF/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -264,11 +264,13 @@ bool mlir::getInnermostParallelLoops(Operation *rootOp,
static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend,
int64_t divisor) {
assert(divisor > 0 && "expected positive divisor");
assert(dividend.getType().isIndex() && "expected index-typed value");
assert(dividend.getType().isIntOrIndex() &&
"expected integer or index-typed value");

Value divisorMinusOneCst =
builder.create<arith::ConstantIndexOp>(loc, divisor - 1);
Value divisorCst = builder.create<arith::ConstantIndexOp>(loc, divisor);
Value divisorMinusOneCst = builder.create<arith::ConstantOp>(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: we also have a ConstantIntOp specialization that will make this slightly less verbose.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that ConstantIntOp cannot take index type.

loc, builder.getIntegerAttr(dividend.getType(), divisor - 1));
Value divisorCst = builder.create<arith::ConstantOp>(
loc, builder.getIntegerAttr(dividend.getType(), divisor));
Value sum = builder.create<arith::AddIOp>(loc, dividend, divisorMinusOneCst);
return builder.create<arith::DivUIOp>(loc, sum, divisorCst);
}
Expand All @@ -279,9 +281,10 @@ static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend,
// where divis is rounding-to-zero division.
static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend,
Value divisor) {
assert(dividend.getType().isIndex() && "expected index-typed value");

Value cstOne = builder.create<arith::ConstantIndexOp>(loc, 1);
assert(dividend.getType().isIntOrIndex() &&
"expected integer or index-typed value");
Value cstOne = builder.create<arith::ConstantOp>(
loc, builder.getOneAttr(dividend.getType()));
Value divisorMinusOne = builder.create<arith::SubIOp>(loc, divisor, cstOne);
Value sum = builder.create<arith::AddIOp>(loc, dividend, divisorMinusOne);
return builder.create<arith::DivUIOp>(loc, sum, divisor);
Expand Down Expand Up @@ -388,16 +391,18 @@ LogicalResult mlir::loopUnrollByFactor(
// Create constant for 'upperBoundUnrolled' and set epilogue loop flag.
generateEpilogueLoop = upperBoundUnrolledCst < ubCst;
if (generateEpilogueLoop)
upperBoundUnrolled = boundsBuilder.create<arith::ConstantIndexOp>(
loc, upperBoundUnrolledCst);
upperBoundUnrolled = boundsBuilder.create<arith::ConstantOp>(
loc, boundsBuilder.getIntegerAttr(forOp.getUpperBound().getType(),
upperBoundUnrolledCst));
else
upperBoundUnrolled = forOp.getUpperBound();

// Create constant for 'stepUnrolled'.
stepUnrolled = stepCst == stepUnrolledCst
? step
: boundsBuilder.create<arith::ConstantIndexOp>(
loc, stepUnrolledCst);
: boundsBuilder.create<arith::ConstantOp>(
loc, boundsBuilder.getIntegerAttr(
step.getType(), stepUnrolledCst));
} else {
// Dynamic loop bounds computation.
// TODO: Add dynamic asserts for negative lb/ub/step, or
Expand All @@ -407,8 +412,8 @@ LogicalResult mlir::loopUnrollByFactor(
Value diff =
boundsBuilder.create<arith::SubIOp>(loc, upperBound, lowerBound);
Value tripCount = ceilDivPositive(boundsBuilder, loc, diff, step);
Value unrollFactorCst =
boundsBuilder.create<arith::ConstantIndexOp>(loc, unrollFactor);
Value unrollFactorCst = boundsBuilder.create<arith::ConstantOp>(
loc, boundsBuilder.getIntegerAttr(tripCount.getType(), unrollFactor));
Value tripCountRem =
boundsBuilder.create<arith::RemSIOp>(loc, tripCount, unrollFactorCst);
// Compute tripCountEvenMultiple = tripCount - (tripCount % unrollFactor)
Expand Down Expand Up @@ -455,7 +460,9 @@ LogicalResult mlir::loopUnrollByFactor(
[&](unsigned i, Value iv, OpBuilder b) {
// iv' = iv + step * i;
auto stride = b.create<arith::MulIOp>(
loc, step, b.create<arith::ConstantIndexOp>(loc, i));
loc, step,
b.create<arith::ConstantOp>(loc,
b.getIntegerAttr(iv.getType(), i)));
return b.create<arith::AddIOp>(loc, iv, stride);
},
annotateFn, iterArgs, yieldedValues);
Expand Down
22 changes: 11 additions & 11 deletions mlir/test/Dialect/SCF/loop-unroll.mlir
htyu marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -311,10 +311,10 @@ func.func @static_loop_unroll_up_to_factor(%arg0 : memref<?xf32>) {
// Test that epilogue's arguments are correctly renamed.
func.func @static_loop_unroll_by_3_rename_epilogue_arguments() -> (f32, f32) {
%0 = arith.constant 7.0 : f32
%lb = arith.constant 0 : index
%ub = arith.constant 20 : index
%step = arith.constant 1 : index
%result:2 = scf.for %i0 = %lb to %ub step %step iter_args(%arg0 = %0, %arg1 = %0) -> (f32, f32) {
%lb = arith.constant 0 : i32
%ub = arith.constant 20 : i32
%step = arith.constant 1 : i32
%result:2 = scf.for %i0 = %lb to %ub step %step iter_args(%arg0 = %0, %arg1 = %0) -> (f32, f32) : i32{
%add = arith.addf %arg0, %arg1 : f32
%mul = arith.mulf %arg0, %arg1 : f32
scf.yield %add, %mul : f32, f32
Expand All @@ -324,13 +324,13 @@ func.func @static_loop_unroll_by_3_rename_epilogue_arguments() -> (f32, f32) {
// UNROLL-BY-3-LABEL: func @static_loop_unroll_by_3_rename_epilogue_arguments
//
// UNROLL-BY-3-DAG: %[[CST:.*]] = arith.constant {{.*}} : f32
// UNROLL-BY-3-DAG: %[[C0:.*]] = arith.constant 0 : index
// UNROLL-BY-3-DAG: %[[C1:.*]] = arith.constant 1 : index
// UNROLL-BY-3-DAG: %[[C20:.*]] = arith.constant 20 : index
// UNROLL-BY-3-DAG: %[[C18:.*]] = arith.constant 18 : index
// UNROLL-BY-3-DAG: %[[C3:.*]] = arith.constant 3 : index
// UNROLL-BY-3-DAG: %[[C0:.*]] = arith.constant 0 : i32
// UNROLL-BY-3-DAG: %[[C1:.*]] = arith.constant 1 : i32
// UNROLL-BY-3-DAG: %[[C20:.*]] = arith.constant 20 : i32
// UNROLL-BY-3-DAG: %[[C18:.*]] = arith.constant 18 : i32
// UNROLL-BY-3-DAG: %[[C3:.*]] = arith.constant 3 : i32
// UNROLL-BY-3: %[[FOR:.*]]:2 = scf.for %[[IV:.*]] = %[[C0]] to %[[C18]] step %[[C3]]
// UNROLL-BY-3-SAME: iter_args(%[[ARG0:.*]] = %[[CST]], %[[ARG1:.*]] = %[[CST]]) -> (f32, f32) {
// UNROLL-BY-3-SAME: iter_args(%[[ARG0:.*]] = %[[CST]], %[[ARG1:.*]] = %[[CST]]) -> (f32, f32) : i32 {
// UNROLL-BY-3-NEXT: %[[ADD0:.*]] = arith.addf %[[ARG0]], %[[ARG1]] : f32
// UNROLL-BY-3-NEXT: %[[MUL0:.*]] = arith.mulf %[[ARG0]], %[[ARG1]] : f32
// UNROLL-BY-3-NEXT: %[[ADD1:.*]] = arith.addf %[[ADD0]], %[[MUL0]] : f32
Expand All @@ -340,7 +340,7 @@ func.func @static_loop_unroll_by_3_rename_epilogue_arguments() -> (f32, f32) {
// UNROLL-BY-3-NEXT: scf.yield %[[ADD2]], %[[MUL2]] : f32, f32
// UNROLL-BY-3-NEXT: }
// UNROLL-BY-3: %[[EFOR:.*]]:2 = scf.for %[[EIV:.*]] = %[[C18]] to %[[C20]] step %[[C1]]
// UNROLL-BY-3-SAME: iter_args(%[[EARG0:.*]] = %[[FOR]]#0, %[[EARG1:.*]] = %[[FOR]]#1) -> (f32, f32) {
// UNROLL-BY-3-SAME: iter_args(%[[EARG0:.*]] = %[[FOR]]#0, %[[EARG1:.*]] = %[[FOR]]#1) -> (f32, f32) : i32 {
// UNROLL-BY-3-NEXT: %[[EADD:.*]] = arith.addf %[[EARG0]], %[[EARG1]] : f32
// UNROLL-BY-3-NEXT: %[[EMUL:.*]] = arith.mulf %[[EARG0]], %[[EARG1]] : f32
// UNROLL-BY-3-NEXT: scf.yield %[[EADD]], %[[EMUL]] : f32, f32
Expand Down
Loading