diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp index ff5e3a002263d3..f52c0f3e620897 100644 --- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp +++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp @@ -264,11 +264,13 @@ bool mlir::getInnermostParallelLoops(Operation *rootOp, static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend, int64_t divisor) { assert(divisor > 0 && "expected positive divisor"); - assert(dividend.getType().isIndex() && "expected index-typed value"); + assert(dividend.getType().isIntOrIndex() && + "expected integer or index-typed value"); - Value divisorMinusOneCst = - builder.create(loc, divisor - 1); - Value divisorCst = builder.create(loc, divisor); + Value divisorMinusOneCst = builder.create( + loc, builder.getIntegerAttr(dividend.getType(), divisor - 1)); + Value divisorCst = builder.create( + loc, builder.getIntegerAttr(dividend.getType(), divisor)); Value sum = builder.create(loc, dividend, divisorMinusOneCst); return builder.create(loc, sum, divisorCst); } @@ -279,9 +281,10 @@ static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend, // where divis is rounding-to-zero division. static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend, Value divisor) { - assert(dividend.getType().isIndex() && "expected index-typed value"); - - Value cstOne = builder.create(loc, 1); + assert(dividend.getType().isIntOrIndex() && + "expected integer or index-typed value"); + Value cstOne = builder.create( + loc, builder.getOneAttr(dividend.getType())); Value divisorMinusOne = builder.create(loc, divisor, cstOne); Value sum = builder.create(loc, dividend, divisorMinusOne); return builder.create(loc, sum, divisor); @@ -388,16 +391,18 @@ LogicalResult mlir::loopUnrollByFactor( // Create constant for 'upperBoundUnrolled' and set epilogue loop flag. generateEpilogueLoop = upperBoundUnrolledCst < ubCst; if (generateEpilogueLoop) - upperBoundUnrolled = boundsBuilder.create( - loc, upperBoundUnrolledCst); + upperBoundUnrolled = boundsBuilder.create( + loc, boundsBuilder.getIntegerAttr(forOp.getUpperBound().getType(), + upperBoundUnrolledCst)); else upperBoundUnrolled = forOp.getUpperBound(); // Create constant for 'stepUnrolled'. stepUnrolled = stepCst == stepUnrolledCst ? step - : boundsBuilder.create( - loc, stepUnrolledCst); + : boundsBuilder.create( + loc, boundsBuilder.getIntegerAttr( + step.getType(), stepUnrolledCst)); } else { // Dynamic loop bounds computation. // TODO: Add dynamic asserts for negative lb/ub/step, or @@ -407,8 +412,8 @@ LogicalResult mlir::loopUnrollByFactor( Value diff = boundsBuilder.create(loc, upperBound, lowerBound); Value tripCount = ceilDivPositive(boundsBuilder, loc, diff, step); - Value unrollFactorCst = - boundsBuilder.create(loc, unrollFactor); + Value unrollFactorCst = boundsBuilder.create( + loc, boundsBuilder.getIntegerAttr(tripCount.getType(), unrollFactor)); Value tripCountRem = boundsBuilder.create(loc, tripCount, unrollFactorCst); // Compute tripCountEvenMultiple = tripCount - (tripCount % unrollFactor) @@ -455,7 +460,9 @@ LogicalResult mlir::loopUnrollByFactor( [&](unsigned i, Value iv, OpBuilder b) { // iv' = iv + step * i; auto stride = b.create( - loc, step, b.create(loc, i)); + loc, step, + b.create(loc, + b.getIntegerAttr(iv.getType(), i))); return b.create(loc, iv, stride); }, annotateFn, iterArgs, yieldedValues); diff --git a/mlir/test/Dialect/SCF/loop-unroll.mlir b/mlir/test/Dialect/SCF/loop-unroll.mlir index e28efbb6ec2b91..68a11fb6a72c64 100644 --- a/mlir/test/Dialect/SCF/loop-unroll.mlir +++ b/mlir/test/Dialect/SCF/loop-unroll.mlir @@ -448,3 +448,44 @@ func.func @loop_unroll_yield_iter_arg() { // CHECK-NEXT: affine.yield %[[ITER_ARG]] : index // CHECK-NEXT: } // CHECK-NEXT: return + +// ----- + +// Test the loop unroller works with integer IV type. +func.func @static_loop_unroll_with_integer_iv() -> (f32, f32) { + %0 = arith.constant 7.0 : f32 + %lb = arith.constant 0 : i32 + %ub = arith.constant 20 : i32 + %step = arith.constant 1 : i32 + %result:2 = scf.for %i0 = %lb to %ub step %step iter_args(%arg0 = %0, %arg1 = %0) -> (f32, f32) : i32{ + %add = arith.addf %arg0, %arg1 : f32 + %mul = arith.mulf %arg0, %arg1 : f32 + scf.yield %add, %mul : f32, f32 + } + return %result#0, %result#1 : f32, f32 +} +// UNROLL-BY-3-LABEL: func @static_loop_unroll_with_integer_iv +// +// UNROLL-BY-3-DAG: %[[CST:.*]] = arith.constant {{.*}} : f32 +// UNROLL-BY-3-DAG: %[[C0:.*]] = arith.constant 0 : i32 +// UNROLL-BY-3-DAG: %[[C1:.*]] = arith.constant 1 : i32 +// UNROLL-BY-3-DAG: %[[C20:.*]] = arith.constant 20 : i32 +// UNROLL-BY-3-DAG: %[[C18:.*]] = arith.constant 18 : i32 +// UNROLL-BY-3-DAG: %[[C3:.*]] = arith.constant 3 : i32 +// UNROLL-BY-3: %[[FOR:.*]]:2 = scf.for %[[IV:.*]] = %[[C0]] to %[[C18]] step %[[C3]] +// UNROLL-BY-3-SAME: iter_args(%[[ARG0:.*]] = %[[CST]], %[[ARG1:.*]] = %[[CST]]) -> (f32, f32) : i32 { +// UNROLL-BY-3-NEXT: %[[ADD0:.*]] = arith.addf %[[ARG0]], %[[ARG1]] : f32 +// UNROLL-BY-3-NEXT: %[[MUL0:.*]] = arith.mulf %[[ARG0]], %[[ARG1]] : f32 +// UNROLL-BY-3-NEXT: %[[ADD1:.*]] = arith.addf %[[ADD0]], %[[MUL0]] : f32 +// UNROLL-BY-3-NEXT: %[[MUL1:.*]] = arith.mulf %[[ADD0]], %[[MUL0]] : f32 +// UNROLL-BY-3-NEXT: %[[ADD2:.*]] = arith.addf %[[ADD1]], %[[MUL1]] : f32 +// UNROLL-BY-3-NEXT: %[[MUL2:.*]] = arith.mulf %[[ADD1]], %[[MUL1]] : f32 +// UNROLL-BY-3-NEXT: scf.yield %[[ADD2]], %[[MUL2]] : f32, f32 +// UNROLL-BY-3-NEXT: } +// UNROLL-BY-3: %[[EFOR:.*]]:2 = scf.for %[[EIV:.*]] = %[[C18]] to %[[C20]] step %[[C1]] +// UNROLL-BY-3-SAME: iter_args(%[[EARG0:.*]] = %[[FOR]]#0, %[[EARG1:.*]] = %[[FOR]]#1) -> (f32, f32) : i32 { +// UNROLL-BY-3-NEXT: %[[EADD:.*]] = arith.addf %[[EARG0]], %[[EARG1]] : f32 +// UNROLL-BY-3-NEXT: %[[EMUL:.*]] = arith.mulf %[[EARG0]], %[[EARG1]] : f32 +// UNROLL-BY-3-NEXT: scf.yield %[[EADD]], %[[EMUL]] : f32, f32 +// UNROLL-BY-3-NEXT: } +// UNROLL-BY-3-NEXT: return %[[EFOR]]#0, %[[EFOR]]#1 : f32, f32