Skip to content

Commit

Permalink
Use upstream dataflow tooling to build an arithmetic opt pass. (iree-…
Browse files Browse the repository at this point in the history
…org#18702)

This combines several things into one fixpoint iteration:

* Upstream IntRangeOptimizations for taking care of things like constant
replacement for unit ranges.
* Arith canonicalizations.
* Local adaptation of signed->unsigned conversion (upstream's version
can't compose since it is based on dialect conversion for some reason).
It also has 32bit bugs that have been corrected locally.
* Int64/unsigned index conversion.
* Common factor elision for integer division.
* Making the util.assume ops implement InferIntRangeInterface.

I have some additional advanced patterns to the side which simplify a
lot of torch cases, but they need some more baking/testing, so I'm just
landing the basic pass for now to start.

---------

Signed-off-by: Stella Laurenzo <stellaraccident@gmail.com>
  • Loading branch information
stellaraccident authored Oct 8, 2024
1 parent 0889d13 commit 0f28d44
Show file tree
Hide file tree
Showing 13 changed files with 762 additions and 1 deletion.
1 change: 1 addition & 0 deletions build_tools/bazel_to_cmake/bazel_to_cmake_targets.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def __init__(self, repo_map: Dict[str, str]):
],
# MLIR
"@llvm-project//mlir:AllPassesAndDialects": ["MLIRAllDialects"],
"@llvm-project//mlir:ArithOpsIncGen": ["MLIRArithDialect"],
"@llvm-project//mlir:BufferizationInterfaces": [""],
"@llvm-project//mlir:CommonFolders": [""],
"@llvm-project//mlir:ConversionPasses": [""],
Expand Down
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Dialect/Util/IR/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ iree_td_library(
"@llvm-project//mlir:CallInterfacesTdFiles",
"@llvm-project//mlir:ControlFlowInterfacesTdFiles",
"@llvm-project//mlir:FunctionInterfacesTdFiles",
"@llvm-project//mlir:InferIntRangeInterfaceTdFiles",
"@llvm-project//mlir:InferTypeOpInterfaceTdFiles",
"@llvm-project//mlir:OpBaseTdFiles",
"@llvm-project//mlir:SideEffectInterfacesTdFiles",
Expand Down
74 changes: 74 additions & 0 deletions compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/SMLoc.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/IR/Attributes.h"
Expand All @@ -22,6 +23,8 @@
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"

#include <numeric>

namespace mlir::iree_compiler {

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1102,6 +1105,77 @@ namespace mlir::iree_compiler::IREE::Util {
// util.assume.int
//===----------------------------------------------------------------------===//

SmallVector<IntAssumptionAttr>
AssumeIntOp::getOperandAssumptions(unsigned operandIndex) {
assert(operandIndex < getNumOperands() &&
"getUnionedUnsignedRange operand out of range");
auto assumptions = cast<ArrayAttr>(getAssumptions()[operandIndex]);
SmallVector<IntAssumptionAttr> results;
for (auto assumption : assumptions) {
results.push_back(cast<IntAssumptionAttr>(assumption));
}
return results;
}

std::pair<std::optional<uint64_t>, std::optional<uint64_t>>
AssumeIntOp::getUnionedUnsignedRange(unsigned operandIndex) {
auto assumptions = getOperandAssumptions(operandIndex);
std::optional<uint64_t> uminUnion;
std::optional<uint64_t> umaxUnion;

for (auto assumption : assumptions) {
auto umin = assumption.getUmin();
auto umax = assumption.getUmax();
if (umin) {
uminUnion = std::min(
*umin, uminUnion ? *uminUnion : std::numeric_limits<uint64_t>::max());
}
if (umax) {
umaxUnion = std::max(
*umax, umaxUnion ? *umaxUnion : std::numeric_limits<uint64_t>::min());
}
}
return std::make_pair(uminUnion, umaxUnion);
}

// Gets the unioned divisor for an operand. If there are multiple divisor
// assumptions, the gcd of all of them is returned. If there are no
// divisor assumptions, std::nullopt is returned.
std::optional<uint64_t> AssumeIntOp::getUnionedDivisor(unsigned operandIndex) {
auto assumptions = getOperandAssumptions(operandIndex);
std::optional<uint64_t> divisorUnion;
for (auto assumption : assumptions) {
auto divisor = assumption.getDivisor();
if (divisor) {
if (divisorUnion)
divisorUnion = std::gcd(*divisor, *divisorUnion);
else
divisorUnion = *divisor;
}
}
return divisorUnion;
}

void AssumeIntOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
for (auto [index, result] : llvm::enumerate(getResults())) {
Type type = result.getType();
unsigned bitWidth;
if (isa<IndexType>(type))
bitWidth = 64;
else if (auto intType = dyn_cast<IntegerType>(type))
bitWidth = intType.getWidth();
else
continue;
auto [umin, umax] = getUnionedUnsignedRange(index);
if (umin && umax) {
APInt uminAp(bitWidth, *umin);
APInt umaxAp(bitWidth, *umax);
setResultRange(result, ConstantIntRanges::fromUnsigned(uminAp, umaxAp));
}
}
}

void AssumeIntOp::build(OpBuilder &builder, OperationState &state,
Value singleOperand,
IntAssumptionAttr singleAssumption) {
Expand Down
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "mlir/Interfaces/CastInterfaces.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Interfaces/InferIntRangeInterface.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Interfaces/ViewLikeInterface.h"
Expand Down
22 changes: 21 additions & 1 deletion compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ include "mlir/Interfaces/CallInterfaces.td"
include "mlir/Interfaces/CastInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/FunctionInterfaces.td"
include "mlir/Interfaces/InferIntRangeInterface.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/ViewLikeInterface.td"
Expand Down Expand Up @@ -458,7 +459,9 @@ def OpGroupCompilerHintOps : OpDocGroup {

let opDocGroup = OpGroupCompilerHintOps in {

def Util_AssumeIntOp : Util_PureOp<"assume.int", []> {
def Util_AssumeIntOp : Util_PureOp<"assume.int", [
DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>
]> {
let summary = "memorializes assumptions about index/integer values.";
let description = [{
This op is used to memorialize the result of some integer analysis or
Expand Down Expand Up @@ -490,6 +493,23 @@ def Util_AssumeIntOp : Util_PureOp<"assume.int", []> {
)>,
];

let extraClassDeclaration = [{
// Gets the list of assumptions for an operand.
SmallVector<IntAssumptionAttr> getOperandAssumptions(unsigned operandIndex);

// Gets the unioned unsigned range for an operand. If there are multiple
// assumptions for the operand, this will return the bounding range for
// them all. If there is no umin/umax, then std::nullopt will be returned
// for that position.
std::pair<std::optional<uint64_t>, std::optional<uint64_t>>
getUnionedUnsignedRange(unsigned operandIndex);

// Gets the unioned divisor for an operand. If there are multiple divisor
// assumptions, the gcd of all of them is returned. If there are no
// divisor assumptions, std::nullopt is returned.
std::optional<uint64_t> getUnionedDivisor(unsigned operandIndex);
}];

let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ iree_compiler_cc_library(
"HoistIntoGlobals.cpp",
"IPO.cpp",
"ImportResources.cpp",
"OptimizeIntArithmetic.cpp",
"PassDetail.h",
"Passes.cpp",
"Patterns.cpp",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ iree_cc_library(
"HoistIntoGlobals.cpp"
"IPO.cpp"
"ImportResources.cpp"
"OptimizeIntArithmetic.cpp"
"PassDetail.h"
"Passes.cpp"
"Patterns.cpp"
Expand Down
Loading

0 comments on commit 0f28d44

Please sign in to comment.