From 488b3b62c7607dbb6805b0118b7dc7437d434757 Mon Sep 17 00:00:00 2001 From: Hongtao Yu Date: Mon, 26 Aug 2024 16:28:20 -0700 Subject: [PATCH 1/4] [mlir][scf] Allow unrolling loops with integer-typed IV. --- mlir/include/mlir/Dialect/Arith/Utils/Utils.h | 2 ++ mlir/lib/Dialect/Arith/Utils/Utils.cpp | 10 ++++++ mlir/lib/Dialect/SCF/Utils/Utils.cpp | 33 ++++++++++--------- mlir/test/Dialect/SCF/loop-unroll.mlir | 22 ++++++------- 4 files changed, 41 insertions(+), 26 deletions(-) diff --git a/mlir/include/mlir/Dialect/Arith/Utils/Utils.h b/mlir/include/mlir/Dialect/Arith/Utils/Utils.h index 76f5825025739b..2202a2d62ebd92 100644 --- a/mlir/include/mlir/Dialect/Arith/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Arith/Utils/Utils.h @@ -93,6 +93,8 @@ Value createScalarOrSplatConstant(OpBuilder &builder, Location loc, Type type, int64_t value); Value createScalarOrSplatConstant(OpBuilder &builder, Location loc, Type type, const APFloat &value); +Value createIntOrIndexConstant(OpBuilder &builder, Location loc, Type type, + int64_t value); /// Returns the int type of the integer in ofr. /// Other attribute types are not supported. diff --git a/mlir/lib/Dialect/Arith/Utils/Utils.cpp b/mlir/lib/Dialect/Arith/Utils/Utils.cpp index e75db84b75e280..5fb1cda3211828 100644 --- a/mlir/lib/Dialect/Arith/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Arith/Utils/Utils.cpp @@ -302,6 +302,16 @@ Value mlir::createScalarOrSplatConstant(OpBuilder &builder, Location loc, return builder.createOrFold(loc, type, splat); } +Value mlir::createIntOrIndexConstant(OpBuilder &b, Location loc, Type type, + int64_t value) { + assert(type.isIntOrIndex() && + "unexpected type other than integers and index"); + if (type.isIndex()) + return b.create(loc, value); + else + return b.create(loc, b.getIntegerAttr(type, value)); +} + Type mlir::getType(OpFoldResult ofr) { if (auto value = dyn_cast_if_present(ofr)) return value.getType(); diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp index ff5e3a002263d3..49bda65bad2df2 100644 --- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp +++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp @@ -264,11 +264,13 @@ bool mlir::getInnermostParallelLoops(Operation *rootOp, static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend, int64_t divisor) { assert(divisor > 0 && "expected positive divisor"); - assert(dividend.getType().isIndex() && "expected index-typed value"); + assert(dividend.getType().isIntOrIndex() && + "expected integer or index-typed value"); Value divisorMinusOneCst = - builder.create(loc, divisor - 1); - Value divisorCst = builder.create(loc, divisor); + createIntOrIndexConstant(builder, loc, dividend.getType(), divisor - 1); + Value divisorCst = + createIntOrIndexConstant(builder, loc, dividend.getType(), divisor); Value sum = builder.create(loc, dividend, divisorMinusOneCst); return builder.create(loc, sum, divisorCst); } @@ -279,9 +281,9 @@ static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend, // where divis is rounding-to-zero division. static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend, Value divisor) { - assert(dividend.getType().isIndex() && "expected index-typed value"); - - Value cstOne = builder.create(loc, 1); + assert(dividend.getType().isIntOrIndex() && + "expected integer or index-typed value"); + Value cstOne = createIntOrIndexConstant(builder, loc, dividend.getType(), 1); Value divisorMinusOne = builder.create(loc, divisor, cstOne); Value sum = builder.create(loc, dividend, divisorMinusOne); return builder.create(loc, sum, divisor); @@ -388,16 +390,17 @@ LogicalResult mlir::loopUnrollByFactor( // Create constant for 'upperBoundUnrolled' and set epilogue loop flag. generateEpilogueLoop = upperBoundUnrolledCst < ubCst; if (generateEpilogueLoop) - upperBoundUnrolled = boundsBuilder.create( - loc, upperBoundUnrolledCst); + upperBoundUnrolled = createIntOrIndexConstant( + boundsBuilder, loc, forOp.getUpperBound().getType(), upperBoundUnrolledCst); else upperBoundUnrolled = forOp.getUpperBound(); // Create constant for 'stepUnrolled'. - stepUnrolled = stepCst == stepUnrolledCst - ? step - : boundsBuilder.create( - loc, stepUnrolledCst); + stepUnrolled = + stepCst == stepUnrolledCst + ? step + : createIntOrIndexConstant(boundsBuilder, loc, step.getType(), + stepUnrolledCst); } else { // Dynamic loop bounds computation. // TODO: Add dynamic asserts for negative lb/ub/step, or @@ -407,8 +410,8 @@ LogicalResult mlir::loopUnrollByFactor( Value diff = boundsBuilder.create(loc, upperBound, lowerBound); Value tripCount = ceilDivPositive(boundsBuilder, loc, diff, step); - Value unrollFactorCst = - boundsBuilder.create(loc, unrollFactor); + Value unrollFactorCst = createIntOrIndexConstant( + boundsBuilder, loc, tripCount.getType(), unrollFactor); Value tripCountRem = boundsBuilder.create(loc, tripCount, unrollFactorCst); // Compute tripCountEvenMultiple = tripCount - (tripCount % unrollFactor) @@ -455,7 +458,7 @@ LogicalResult mlir::loopUnrollByFactor( [&](unsigned i, Value iv, OpBuilder b) { // iv' = iv + step * i; auto stride = b.create( - loc, step, b.create(loc, i)); + loc, step, createIntOrIndexConstant(b, loc, iv.getType(), i)); return b.create(loc, iv, stride); }, annotateFn, iterArgs, yieldedValues); diff --git a/mlir/test/Dialect/SCF/loop-unroll.mlir b/mlir/test/Dialect/SCF/loop-unroll.mlir index e28efbb6ec2b91..ae994a8e6c5d7b 100644 --- a/mlir/test/Dialect/SCF/loop-unroll.mlir +++ b/mlir/test/Dialect/SCF/loop-unroll.mlir @@ -311,10 +311,10 @@ func.func @static_loop_unroll_up_to_factor(%arg0 : memref) { // Test that epilogue's arguments are correctly renamed. func.func @static_loop_unroll_by_3_rename_epilogue_arguments() -> (f32, f32) { %0 = arith.constant 7.0 : f32 - %lb = arith.constant 0 : index - %ub = arith.constant 20 : index - %step = arith.constant 1 : index - %result:2 = scf.for %i0 = %lb to %ub step %step iter_args(%arg0 = %0, %arg1 = %0) -> (f32, f32) { + %lb = arith.constant 0 : i32 + %ub = arith.constant 20 : i32 + %step = arith.constant 1 : i32 + %result:2 = scf.for %i0 = %lb to %ub step %step iter_args(%arg0 = %0, %arg1 = %0) -> (f32, f32) : i32{ %add = arith.addf %arg0, %arg1 : f32 %mul = arith.mulf %arg0, %arg1 : f32 scf.yield %add, %mul : f32, f32 @@ -324,13 +324,13 @@ func.func @static_loop_unroll_by_3_rename_epilogue_arguments() -> (f32, f32) { // UNROLL-BY-3-LABEL: func @static_loop_unroll_by_3_rename_epilogue_arguments // // UNROLL-BY-3-DAG: %[[CST:.*]] = arith.constant {{.*}} : f32 -// UNROLL-BY-3-DAG: %[[C0:.*]] = arith.constant 0 : index -// UNROLL-BY-3-DAG: %[[C1:.*]] = arith.constant 1 : index -// UNROLL-BY-3-DAG: %[[C20:.*]] = arith.constant 20 : index -// UNROLL-BY-3-DAG: %[[C18:.*]] = arith.constant 18 : index -// UNROLL-BY-3-DAG: %[[C3:.*]] = arith.constant 3 : index +// UNROLL-BY-3-DAG: %[[C0:.*]] = arith.constant 0 : i32 +// UNROLL-BY-3-DAG: %[[C1:.*]] = arith.constant 1 : i32 +// UNROLL-BY-3-DAG: %[[C20:.*]] = arith.constant 20 : i32 +// UNROLL-BY-3-DAG: %[[C18:.*]] = arith.constant 18 : i32 +// UNROLL-BY-3-DAG: %[[C3:.*]] = arith.constant 3 : i32 // UNROLL-BY-3: %[[FOR:.*]]:2 = scf.for %[[IV:.*]] = %[[C0]] to %[[C18]] step %[[C3]] -// UNROLL-BY-3-SAME: iter_args(%[[ARG0:.*]] = %[[CST]], %[[ARG1:.*]] = %[[CST]]) -> (f32, f32) { +// UNROLL-BY-3-SAME: iter_args(%[[ARG0:.*]] = %[[CST]], %[[ARG1:.*]] = %[[CST]]) -> (f32, f32) : i32 { // UNROLL-BY-3-NEXT: %[[ADD0:.*]] = arith.addf %[[ARG0]], %[[ARG1]] : f32 // UNROLL-BY-3-NEXT: %[[MUL0:.*]] = arith.mulf %[[ARG0]], %[[ARG1]] : f32 // UNROLL-BY-3-NEXT: %[[ADD1:.*]] = arith.addf %[[ADD0]], %[[MUL0]] : f32 @@ -340,7 +340,7 @@ func.func @static_loop_unroll_by_3_rename_epilogue_arguments() -> (f32, f32) { // UNROLL-BY-3-NEXT: scf.yield %[[ADD2]], %[[MUL2]] : f32, f32 // UNROLL-BY-3-NEXT: } // UNROLL-BY-3: %[[EFOR:.*]]:2 = scf.for %[[EIV:.*]] = %[[C18]] to %[[C20]] step %[[C1]] -// UNROLL-BY-3-SAME: iter_args(%[[EARG0:.*]] = %[[FOR]]#0, %[[EARG1:.*]] = %[[FOR]]#1) -> (f32, f32) { +// UNROLL-BY-3-SAME: iter_args(%[[EARG0:.*]] = %[[FOR]]#0, %[[EARG1:.*]] = %[[FOR]]#1) -> (f32, f32) : i32 { // UNROLL-BY-3-NEXT: %[[EADD:.*]] = arith.addf %[[EARG0]], %[[EARG1]] : f32 // UNROLL-BY-3-NEXT: %[[EMUL:.*]] = arith.mulf %[[EARG0]], %[[EARG1]] : f32 // UNROLL-BY-3-NEXT: scf.yield %[[EADD]], %[[EMUL]] : f32, f32 From 92cf042759136889bd457855130d18cae6011d01 Mon Sep 17 00:00:00 2001 From: Hongtao Yu Date: Mon, 26 Aug 2024 17:50:51 -0700 Subject: [PATCH 2/4] Simplying, formatting and adding doc. --- mlir/include/mlir/Dialect/Arith/Utils/Utils.h | 3 +++ mlir/lib/Dialect/Arith/Utils/Utils.cpp | 5 +---- mlir/lib/Dialect/SCF/Utils/Utils.cpp | 3 ++- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/mlir/include/mlir/Dialect/Arith/Utils/Utils.h b/mlir/include/mlir/Dialect/Arith/Utils/Utils.h index 2202a2d62ebd92..34abf215aceb62 100644 --- a/mlir/include/mlir/Dialect/Arith/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Arith/Utils/Utils.h @@ -93,6 +93,9 @@ Value createScalarOrSplatConstant(OpBuilder &builder, Location loc, Type type, int64_t value); Value createScalarOrSplatConstant(OpBuilder &builder, Location loc, Type type, const APFloat &value); + +/// Create a constant of type `type` at location `loc` whose value is `value`. +/// This works for integer type or the index type only. Value createIntOrIndexConstant(OpBuilder &builder, Location loc, Type type, int64_t value); diff --git a/mlir/lib/Dialect/Arith/Utils/Utils.cpp b/mlir/lib/Dialect/Arith/Utils/Utils.cpp index 5fb1cda3211828..371d325c40d2da 100644 --- a/mlir/lib/Dialect/Arith/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Arith/Utils/Utils.cpp @@ -306,10 +306,7 @@ Value mlir::createIntOrIndexConstant(OpBuilder &b, Location loc, Type type, int64_t value) { assert(type.isIntOrIndex() && "unexpected type other than integers and index"); - if (type.isIndex()) - return b.create(loc, value); - else - return b.create(loc, b.getIntegerAttr(type, value)); + return b.create(loc, b.getIntegerAttr(type, value)); } Type mlir::getType(OpFoldResult ofr) { diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp index 49bda65bad2df2..7e1feced24aa88 100644 --- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp +++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp @@ -391,7 +391,8 @@ LogicalResult mlir::loopUnrollByFactor( generateEpilogueLoop = upperBoundUnrolledCst < ubCst; if (generateEpilogueLoop) upperBoundUnrolled = createIntOrIndexConstant( - boundsBuilder, loc, forOp.getUpperBound().getType(), upperBoundUnrolledCst); + boundsBuilder, loc, forOp.getUpperBound().getType(), + upperBoundUnrolledCst); else upperBoundUnrolled = forOp.getUpperBound(); From e71cb8a7d2398da2f1b718c0778b0bf1fa41d38a Mon Sep 17 00:00:00 2001 From: Hongtao Yu Date: Mon, 26 Aug 2024 21:00:29 -0700 Subject: [PATCH 3/4] Inline createIntOrIndexConstant --- mlir/include/mlir/Dialect/Arith/Utils/Utils.h | 5 --- mlir/lib/Dialect/Arith/Utils/Utils.cpp | 7 ---- mlir/lib/Dialect/SCF/Utils/Utils.cpp | 35 ++++++++++--------- 3 files changed, 19 insertions(+), 28 deletions(-) diff --git a/mlir/include/mlir/Dialect/Arith/Utils/Utils.h b/mlir/include/mlir/Dialect/Arith/Utils/Utils.h index 34abf215aceb62..76f5825025739b 100644 --- a/mlir/include/mlir/Dialect/Arith/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Arith/Utils/Utils.h @@ -94,11 +94,6 @@ Value createScalarOrSplatConstant(OpBuilder &builder, Location loc, Type type, Value createScalarOrSplatConstant(OpBuilder &builder, Location loc, Type type, const APFloat &value); -/// Create a constant of type `type` at location `loc` whose value is `value`. -/// This works for integer type or the index type only. -Value createIntOrIndexConstant(OpBuilder &builder, Location loc, Type type, - int64_t value); - /// Returns the int type of the integer in ofr. /// Other attribute types are not supported. Type getType(OpFoldResult ofr); diff --git a/mlir/lib/Dialect/Arith/Utils/Utils.cpp b/mlir/lib/Dialect/Arith/Utils/Utils.cpp index 371d325c40d2da..e75db84b75e280 100644 --- a/mlir/lib/Dialect/Arith/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Arith/Utils/Utils.cpp @@ -302,13 +302,6 @@ Value mlir::createScalarOrSplatConstant(OpBuilder &builder, Location loc, return builder.createOrFold(loc, type, splat); } -Value mlir::createIntOrIndexConstant(OpBuilder &b, Location loc, Type type, - int64_t value) { - assert(type.isIntOrIndex() && - "unexpected type other than integers and index"); - return b.create(loc, b.getIntegerAttr(type, value)); -} - Type mlir::getType(OpFoldResult ofr) { if (auto value = dyn_cast_if_present(ofr)) return value.getType(); diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp index 7e1feced24aa88..f52c0f3e620897 100644 --- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp +++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp @@ -267,10 +267,10 @@ static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend, assert(dividend.getType().isIntOrIndex() && "expected integer or index-typed value"); - Value divisorMinusOneCst = - createIntOrIndexConstant(builder, loc, dividend.getType(), divisor - 1); - Value divisorCst = - createIntOrIndexConstant(builder, loc, dividend.getType(), divisor); + Value divisorMinusOneCst = builder.create( + loc, builder.getIntegerAttr(dividend.getType(), divisor - 1)); + Value divisorCst = builder.create( + loc, builder.getIntegerAttr(dividend.getType(), divisor)); Value sum = builder.create(loc, dividend, divisorMinusOneCst); return builder.create(loc, sum, divisorCst); } @@ -283,7 +283,8 @@ static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend, Value divisor) { assert(dividend.getType().isIntOrIndex() && "expected integer or index-typed value"); - Value cstOne = createIntOrIndexConstant(builder, loc, dividend.getType(), 1); + Value cstOne = builder.create( + loc, builder.getOneAttr(dividend.getType())); Value divisorMinusOne = builder.create(loc, divisor, cstOne); Value sum = builder.create(loc, dividend, divisorMinusOne); return builder.create(loc, sum, divisor); @@ -390,18 +391,18 @@ LogicalResult mlir::loopUnrollByFactor( // Create constant for 'upperBoundUnrolled' and set epilogue loop flag. generateEpilogueLoop = upperBoundUnrolledCst < ubCst; if (generateEpilogueLoop) - upperBoundUnrolled = createIntOrIndexConstant( - boundsBuilder, loc, forOp.getUpperBound().getType(), - upperBoundUnrolledCst); + upperBoundUnrolled = boundsBuilder.create( + loc, boundsBuilder.getIntegerAttr(forOp.getUpperBound().getType(), + upperBoundUnrolledCst)); else upperBoundUnrolled = forOp.getUpperBound(); // Create constant for 'stepUnrolled'. - stepUnrolled = - stepCst == stepUnrolledCst - ? step - : createIntOrIndexConstant(boundsBuilder, loc, step.getType(), - stepUnrolledCst); + stepUnrolled = stepCst == stepUnrolledCst + ? step + : boundsBuilder.create( + loc, boundsBuilder.getIntegerAttr( + step.getType(), stepUnrolledCst)); } else { // Dynamic loop bounds computation. // TODO: Add dynamic asserts for negative lb/ub/step, or @@ -411,8 +412,8 @@ LogicalResult mlir::loopUnrollByFactor( Value diff = boundsBuilder.create(loc, upperBound, lowerBound); Value tripCount = ceilDivPositive(boundsBuilder, loc, diff, step); - Value unrollFactorCst = createIntOrIndexConstant( - boundsBuilder, loc, tripCount.getType(), unrollFactor); + Value unrollFactorCst = boundsBuilder.create( + loc, boundsBuilder.getIntegerAttr(tripCount.getType(), unrollFactor)); Value tripCountRem = boundsBuilder.create(loc, tripCount, unrollFactorCst); // Compute tripCountEvenMultiple = tripCount - (tripCount % unrollFactor) @@ -459,7 +460,9 @@ LogicalResult mlir::loopUnrollByFactor( [&](unsigned i, Value iv, OpBuilder b) { // iv' = iv + step * i; auto stride = b.create( - loc, step, createIntOrIndexConstant(b, loc, iv.getType(), i)); + loc, step, + b.create(loc, + b.getIntegerAttr(iv.getType(), i))); return b.create(loc, iv, stride); }, annotateFn, iterArgs, yieldedValues); From e42687b5e140b3f62cf2f9206a00f58d0b009cef Mon Sep 17 00:00:00 2001 From: Hongtao Yu Date: Tue, 27 Aug 2024 16:18:34 -0700 Subject: [PATCH 4/4] improve test --- mlir/test/Dialect/SCF/loop-unroll.mlir | 63 +++++++++++++++++++++----- 1 file changed, 52 insertions(+), 11 deletions(-) diff --git a/mlir/test/Dialect/SCF/loop-unroll.mlir b/mlir/test/Dialect/SCF/loop-unroll.mlir index ae994a8e6c5d7b..68a11fb6a72c64 100644 --- a/mlir/test/Dialect/SCF/loop-unroll.mlir +++ b/mlir/test/Dialect/SCF/loop-unroll.mlir @@ -311,10 +311,10 @@ func.func @static_loop_unroll_up_to_factor(%arg0 : memref) { // Test that epilogue's arguments are correctly renamed. func.func @static_loop_unroll_by_3_rename_epilogue_arguments() -> (f32, f32) { %0 = arith.constant 7.0 : f32 - %lb = arith.constant 0 : i32 - %ub = arith.constant 20 : i32 - %step = arith.constant 1 : i32 - %result:2 = scf.for %i0 = %lb to %ub step %step iter_args(%arg0 = %0, %arg1 = %0) -> (f32, f32) : i32{ + %lb = arith.constant 0 : index + %ub = arith.constant 20 : index + %step = arith.constant 1 : index + %result:2 = scf.for %i0 = %lb to %ub step %step iter_args(%arg0 = %0, %arg1 = %0) -> (f32, f32) { %add = arith.addf %arg0, %arg1 : f32 %mul = arith.mulf %arg0, %arg1 : f32 scf.yield %add, %mul : f32, f32 @@ -324,13 +324,13 @@ func.func @static_loop_unroll_by_3_rename_epilogue_arguments() -> (f32, f32) { // UNROLL-BY-3-LABEL: func @static_loop_unroll_by_3_rename_epilogue_arguments // // UNROLL-BY-3-DAG: %[[CST:.*]] = arith.constant {{.*}} : f32 -// UNROLL-BY-3-DAG: %[[C0:.*]] = arith.constant 0 : i32 -// UNROLL-BY-3-DAG: %[[C1:.*]] = arith.constant 1 : i32 -// UNROLL-BY-3-DAG: %[[C20:.*]] = arith.constant 20 : i32 -// UNROLL-BY-3-DAG: %[[C18:.*]] = arith.constant 18 : i32 -// UNROLL-BY-3-DAG: %[[C3:.*]] = arith.constant 3 : i32 +// UNROLL-BY-3-DAG: %[[C0:.*]] = arith.constant 0 : index +// UNROLL-BY-3-DAG: %[[C1:.*]] = arith.constant 1 : index +// UNROLL-BY-3-DAG: %[[C20:.*]] = arith.constant 20 : index +// UNROLL-BY-3-DAG: %[[C18:.*]] = arith.constant 18 : index +// UNROLL-BY-3-DAG: %[[C3:.*]] = arith.constant 3 : index // UNROLL-BY-3: %[[FOR:.*]]:2 = scf.for %[[IV:.*]] = %[[C0]] to %[[C18]] step %[[C3]] -// UNROLL-BY-3-SAME: iter_args(%[[ARG0:.*]] = %[[CST]], %[[ARG1:.*]] = %[[CST]]) -> (f32, f32) : i32 { +// UNROLL-BY-3-SAME: iter_args(%[[ARG0:.*]] = %[[CST]], %[[ARG1:.*]] = %[[CST]]) -> (f32, f32) { // UNROLL-BY-3-NEXT: %[[ADD0:.*]] = arith.addf %[[ARG0]], %[[ARG1]] : f32 // UNROLL-BY-3-NEXT: %[[MUL0:.*]] = arith.mulf %[[ARG0]], %[[ARG1]] : f32 // UNROLL-BY-3-NEXT: %[[ADD1:.*]] = arith.addf %[[ADD0]], %[[MUL0]] : f32 @@ -340,7 +340,7 @@ func.func @static_loop_unroll_by_3_rename_epilogue_arguments() -> (f32, f32) { // UNROLL-BY-3-NEXT: scf.yield %[[ADD2]], %[[MUL2]] : f32, f32 // UNROLL-BY-3-NEXT: } // UNROLL-BY-3: %[[EFOR:.*]]:2 = scf.for %[[EIV:.*]] = %[[C18]] to %[[C20]] step %[[C1]] -// UNROLL-BY-3-SAME: iter_args(%[[EARG0:.*]] = %[[FOR]]#0, %[[EARG1:.*]] = %[[FOR]]#1) -> (f32, f32) : i32 { +// UNROLL-BY-3-SAME: iter_args(%[[EARG0:.*]] = %[[FOR]]#0, %[[EARG1:.*]] = %[[FOR]]#1) -> (f32, f32) { // UNROLL-BY-3-NEXT: %[[EADD:.*]] = arith.addf %[[EARG0]], %[[EARG1]] : f32 // UNROLL-BY-3-NEXT: %[[EMUL:.*]] = arith.mulf %[[EARG0]], %[[EARG1]] : f32 // UNROLL-BY-3-NEXT: scf.yield %[[EADD]], %[[EMUL]] : f32, f32 @@ -448,3 +448,44 @@ func.func @loop_unroll_yield_iter_arg() { // CHECK-NEXT: affine.yield %[[ITER_ARG]] : index // CHECK-NEXT: } // CHECK-NEXT: return + +// ----- + +// Test the loop unroller works with integer IV type. +func.func @static_loop_unroll_with_integer_iv() -> (f32, f32) { + %0 = arith.constant 7.0 : f32 + %lb = arith.constant 0 : i32 + %ub = arith.constant 20 : i32 + %step = arith.constant 1 : i32 + %result:2 = scf.for %i0 = %lb to %ub step %step iter_args(%arg0 = %0, %arg1 = %0) -> (f32, f32) : i32{ + %add = arith.addf %arg0, %arg1 : f32 + %mul = arith.mulf %arg0, %arg1 : f32 + scf.yield %add, %mul : f32, f32 + } + return %result#0, %result#1 : f32, f32 +} +// UNROLL-BY-3-LABEL: func @static_loop_unroll_with_integer_iv +// +// UNROLL-BY-3-DAG: %[[CST:.*]] = arith.constant {{.*}} : f32 +// UNROLL-BY-3-DAG: %[[C0:.*]] = arith.constant 0 : i32 +// UNROLL-BY-3-DAG: %[[C1:.*]] = arith.constant 1 : i32 +// UNROLL-BY-3-DAG: %[[C20:.*]] = arith.constant 20 : i32 +// UNROLL-BY-3-DAG: %[[C18:.*]] = arith.constant 18 : i32 +// UNROLL-BY-3-DAG: %[[C3:.*]] = arith.constant 3 : i32 +// UNROLL-BY-3: %[[FOR:.*]]:2 = scf.for %[[IV:.*]] = %[[C0]] to %[[C18]] step %[[C3]] +// UNROLL-BY-3-SAME: iter_args(%[[ARG0:.*]] = %[[CST]], %[[ARG1:.*]] = %[[CST]]) -> (f32, f32) : i32 { +// UNROLL-BY-3-NEXT: %[[ADD0:.*]] = arith.addf %[[ARG0]], %[[ARG1]] : f32 +// UNROLL-BY-3-NEXT: %[[MUL0:.*]] = arith.mulf %[[ARG0]], %[[ARG1]] : f32 +// UNROLL-BY-3-NEXT: %[[ADD1:.*]] = arith.addf %[[ADD0]], %[[MUL0]] : f32 +// UNROLL-BY-3-NEXT: %[[MUL1:.*]] = arith.mulf %[[ADD0]], %[[MUL0]] : f32 +// UNROLL-BY-3-NEXT: %[[ADD2:.*]] = arith.addf %[[ADD1]], %[[MUL1]] : f32 +// UNROLL-BY-3-NEXT: %[[MUL2:.*]] = arith.mulf %[[ADD1]], %[[MUL1]] : f32 +// UNROLL-BY-3-NEXT: scf.yield %[[ADD2]], %[[MUL2]] : f32, f32 +// UNROLL-BY-3-NEXT: } +// UNROLL-BY-3: %[[EFOR:.*]]:2 = scf.for %[[EIV:.*]] = %[[C18]] to %[[C20]] step %[[C1]] +// UNROLL-BY-3-SAME: iter_args(%[[EARG0:.*]] = %[[FOR]]#0, %[[EARG1:.*]] = %[[FOR]]#1) -> (f32, f32) : i32 { +// UNROLL-BY-3-NEXT: %[[EADD:.*]] = arith.addf %[[EARG0]], %[[EARG1]] : f32 +// UNROLL-BY-3-NEXT: %[[EMUL:.*]] = arith.mulf %[[EARG0]], %[[EARG1]] : f32 +// UNROLL-BY-3-NEXT: scf.yield %[[EADD]], %[[EMUL]] : f32, f32 +// UNROLL-BY-3-NEXT: } +// UNROLL-BY-3-NEXT: return %[[EFOR]]#0, %[[EFOR]]#1 : f32, f32