From e276cf0831bfc71ef42eab78368aa487037ac27f Mon Sep 17 00:00:00 2001 From: Peiming Liu Date: Mon, 10 Jun 2024 10:20:24 -0700 Subject: [PATCH] [mlir][sparse] introduce `sparse_tensor.iterate` operation (#88955) A `sparse_tensor.iterate` iterates over a sparse iteration space extracted from `sparse_tensor.extract_iteration_space` operation introduced in https://github.com/llvm/llvm-project/pull/88554. --- .../Dialect/SparseTensor/IR/SparseTensor.h | 40 +++ .../SparseTensor/IR/SparseTensorAttrDefs.td | 15 + .../SparseTensor/IR/SparseTensorOps.td | 113 +++++++- .../SparseTensor/IR/SparseTensorDialect.cpp | 257 +++++++++++++++++- mlir/test/Dialect/SparseTensor/invalid.mlir | 63 ++++- mlir/test/Dialect/SparseTensor/roundtrip.mlir | 31 ++- .../SparseTensor/sparse_itertion_licm.mlir | 27 ++ 7 files changed, 538 insertions(+), 8 deletions(-) create mode 100644 mlir/test/Dialect/SparseTensor/sparse_itertion_licm.mlir diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h index 3cf81d2e58f21c..04a6386a199de4 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h @@ -17,9 +17,13 @@ #include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/TensorEncoding.h" +#include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Interfaces/LoopLikeInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" +#include "llvm/ADT/bit.h" + //===----------------------------------------------------------------------===// // // Type aliases to help code be more self-documenting. Unfortunately @@ -54,6 +58,42 @@ struct COOSegment { } }; +/// A simple wrapper to encode a bitset of (at most 64) levels, currently used +/// by `sparse_tensor.iterate` operation for the set of levels on which the +/// coordinates should be loaded. +class LevelSet { + uint64_t bits = 0; + +public: + LevelSet() = default; + explicit LevelSet(uint64_t bits) : bits(bits) {} + operator uint64_t() const { return bits; } + + LevelSet &set(unsigned i) { + assert(i < 64); + bits |= static_cast(0x01u) << i; + return *this; + } + + LevelSet &operator|=(LevelSet lhs) { + bits |= static_cast(lhs); + return *this; + } + + LevelSet &lshift(unsigned offset) { + bits = bits << offset; + return *this; + } + + bool operator[](unsigned i) const { + assert(i < 64); + return (bits & (1 << i)) != 0; + } + + unsigned count() const { return llvm::popcount(bits); } + bool empty() const { return bits == 0; } +}; + } // namespace sparse_tensor } // namespace mlir diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td index 53dd8e39438cc6..69b212cce4ceba 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td @@ -19,6 +19,21 @@ class SparseTensor_Attr traits = []> : AttrDef; +//===----------------------------------------------------------------------===// +// A simple bitset attribute wrapped around a single int64_t to encode a set of +// sparse tensor levels. +//===----------------------------------------------------------------------===// + +def LevelSetAttr : + TypedAttrBase< + I64, "IntegerAttr", + And<[CPred<"::llvm::isa<::mlir::IntegerAttr>($_self)">, + CPred<"::llvm::cast<::mlir::IntegerAttr>($_self).getType().isInteger(64)">]>, + "LevelSet attribute"> { + let returnType = [{::mlir::sparse_tensor::LevelSet}]; + let convertFromStorage = [{::mlir::sparse_tensor::LevelSet($_self.getValue().getZExtValue())}]; +} + //===----------------------------------------------------------------------===// // These attributes are just like `IndexAttr` except that they clarify whether // the index refers to a dimension (an axis of the semantic tensor) or a level diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td index 4e4441c640ed95..5ae6f9f3443f8c 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td @@ -15,6 +15,8 @@ include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td" include "mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td" include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/Interfaces/ControlFlowInterfaces.td" +include "mlir/Interfaces/LoopLikeInterface.td" //===----------------------------------------------------------------------===// // Base class. @@ -1304,7 +1306,7 @@ def SparseTensor_SelectOp : SparseTensor_Op<"select", [Pure, SameOperandsAndResu def SparseTensor_YieldOp : SparseTensor_Op<"yield", [Pure, Terminator, ParentOneOf<["BinaryOp", "UnaryOp", "ReduceOp", "SelectOp", - "ForeachOp"]>]> { + "ForeachOp", "IterateOp"]>]> { let summary = "Yield from sparse_tensor set-like operations"; let description = [{ Yields a value from within a `binary`, `unary`, `reduce`, @@ -1476,7 +1478,7 @@ def ExtractIterSpaceOp : SparseTensor_Op<"extract_iteration_space", the returned iteration space covers. `hiLvl - loLvl` defines the dimension of the iteration space. - The type of returned the value is automatically inferred to + The type of returned the value is must be `!sparse_tensor.iter_space<#INPUT_ENCODING, lvls = $loLvl to $hiLvl>`. The returned iteration space can then be iterated over by `sparse_tensor.iterate` operations to visit every stored element @@ -1487,6 +1489,7 @@ def ExtractIterSpaceOp : SparseTensor_Op<"extract_iteration_space", // Extracts a 1-D iteration space from a COO tensor at level 1. %space = sparse_tensor.iteration.extract_space %sp at %it1 lvls = 1 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0> + ->!sparse_tensor.iter_space<#COO, lvls = 1> ``` }]; @@ -1499,20 +1502,120 @@ def ExtractIterSpaceOp : SparseTensor_Op<"extract_iteration_space", return getHiLvl() - getLoLvl(); } ArrayRef<::mlir::sparse_tensor::LevelType> getSpaceLvlTypes() { - return getResultSpace().getType().getLvlTypes(); + return getExtractedSpace().getType().getLvlTypes(); } }]; let arguments = (ins AnySparseTensor:$tensor, Optional:$parentIter, LevelAttr:$loLvl, LevelAttr:$hiLvl); - let results = (outs AnySparseIterSpace:$resultSpace); + let results = (outs AnySparseIterSpace:$extractedSpace); let assemblyFormat = "$tensor (`at` $parentIter^)? `lvls` `=` custom($loLvl, $hiLvl) " - " attr-dict `:` type($tensor) (`,` type($parentIter)^)?"; + " attr-dict `:` type($tensor) (`,` type($parentIter)^)? " + "`->` qualified(type($extractedSpace))"; let hasVerifier = 1; } +def IterateOp : SparseTensor_Op<"iterate", + [RecursiveMemoryEffects, RecursivelySpeculatable, + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, + SingleBlockImplicitTerminator<"sparse_tensor::YieldOp">]> { + + let summary = "Iterates over a sparse iteration space"; + let description = [{ + The `sparse_tensor.iterate` operation represents a loop (nest) over + the provided iteration space extracted from a specific sparse tensor. + The operation defines an SSA value for a sparse iterator that points + to the current stored element in the sparse tensor and SSA values + for coordinates of the stored element. The coordinates are always + converted to `index` type despite of the underlying sparse tensor + storage. When coordinates are not used, the SSA values can be skipped + by `_` symbols, which usually leads to simpler generated code after + sparsification. For example: + + ```mlir + // The coordinate for level 0 is not used when iterating over a 2-D + // iteration space. + %sparse_tensor.iterate %iterator in %space at(_, %crd_1) + : !sparse_tensor.iter_space<#CSR, lvls = 0 to 2> + ``` + + `sparse_tensor.iterate` can also operate on loop-carried variables. + It returns the final values after loop termination. + The initial values of the variables are passed as additional SSA operands + to the iterator SSA value and used coordinate SSA values mentioned + above. The operation region has an argument for the iterator, variadic + arguments for specified (used) coordiates and followed by one argument + for each loop-carried variable, representing the value of the variable + at the current iteration. + The body region must contain exactly one block that terminates with + `sparse_tensor.yield`. + + The results of an `sparse_tensor.iterate` hold the final values after + the last iteration. If the `sparse_tensor.iterate` defines any values, + a yield must be explicitly present. + The number and types of the `sparse_tensor.iterate` results must match + the initial values in the iter_args binding and the yield operands. + + + A nested `sparse_tensor.iterate` example that prints all the coordinates + stored in the sparse input: + + ```mlir + func.func @nested_iterate(%sp : tensor<4x8xf32, #COO>) { + // Iterates over the first level of %sp + %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 + : tensor<4x8xf32, #COO> -> !sparse_tensor.iter_space<#COO, lvls = 0 to 1> + %r1 = sparse_tensor.iterate %it1 in %l1 at (%coord0) + : !sparse_tensor.iter_space<#COO, lvls = 0 to 1> { + // Iterates over the second level of %sp + %l2 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 1 + : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0 to 1> + -> !sparse_tensor.iter_space<#COO, lvls = 1 to 2> + %r2 = sparse_tensor.iterate %it2 in %l2 at (coord1) + : !sparse_tensor.iter_space<#COO, lvls = 1 to 2> { + vector.print %coord0 : index + vector.print %coord1 : index + } + } + } + + ``` + }]; + + let arguments = (ins AnySparseIterSpace:$iterSpace, + Variadic:$initArgs, + LevelSetAttr:$crdUsedLvls); + let results = (outs Variadic:$results); + let regions = (region SizedRegion<1>:$region); + + let extraClassDeclaration = [{ + unsigned getSpaceDim() { + return getIterSpace().getType().getSpaceDim(); + } + BlockArgument getIterator() { + return getRegion().getArguments().front(); + } + Block::BlockArgListType getCrds() { + // The first block argument is iterator, the remaining arguments are + // referenced coordinates. + return getRegion().getArguments().slice(1, getCrdUsedLvls().count()); + } + unsigned getNumRegionIterArgs() { + return getRegion().getArguments().size() - 1 - getCrdUsedLvls().count(); + } + }]; + + let hasVerifier = 1; + let hasRegionVerifier = 1; + let hasCustomAssemblyFormat = 1; +} + //===----------------------------------------------------------------------===// // Sparse Tensor Debugging and Test-Only Operations. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp index 4adb1c19096a24..232d25d718c652 100644 --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -2130,6 +2130,106 @@ static void printLevelRange(OpAsmPrinter &p, Operation *, IntegerAttr lvlLo, printLevelRange(p, lo, hi); } +static ParseResult +parseSparseSpaceLoop(OpAsmParser &parser, OperationState &state, + SmallVectorImpl &iterators, + SmallVectorImpl &iterArgs) { + SmallVector spaces; + SmallVector initArgs; + + // Parse "%iters, ... in %spaces, ..." + if (parser.parseArgumentList(iterators) || parser.parseKeyword("in") || + parser.parseOperandList(spaces)) + return failure(); + + if (iterators.size() != spaces.size()) + return parser.emitError( + parser.getNameLoc(), + "mismatch in number of sparse iterators and sparse spaces"); + + // Parse "at(%crd0, _, ...)" + LevelSet crdUsedLvlSet; + bool hasUsedCrds = succeeded(parser.parseOptionalKeyword("at")); + unsigned lvlCrdCnt = 0; + if (hasUsedCrds) { + ParseResult crdList = parser.parseCommaSeparatedList( + OpAsmParser::Delimiter::Paren, [&]() -> ParseResult { + if (parser.parseOptionalKeyword("_")) { + if (parser.parseArgument(iterArgs.emplace_back())) + return failure(); + // Always use IndexType for the coordinate. + crdUsedLvlSet.set(lvlCrdCnt); + iterArgs.back().type = parser.getBuilder().getIndexType(); + } + lvlCrdCnt += 1; + return success(); + }); + if (failed(crdList)) { + return parser.emitError( + parser.getNameLoc(), + "expecting SSA value or \"_\" for level coordinates"); + } + } + // Set the CrdUsedLvl bitset. + state.addAttribute("crdUsedLvls", + parser.getBuilder().getI64IntegerAttr(crdUsedLvlSet)); + + // Parse "iter_args(%arg = %init, ...)" + bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args")); + if (hasIterArgs) + if (parser.parseAssignmentList(iterArgs, initArgs)) + return failure(); + + SmallVector iterSpaceTps; + // parse ": sparse_tensor.iter_space -> ret" + if (parser.parseColon() || parser.parseTypeList(iterSpaceTps)) + return failure(); + if (iterSpaceTps.size() != spaces.size()) + return parser.emitError(parser.getNameLoc(), + "mismatch in number of iteration space operands " + "and iteration space types"); + + for (auto [it, tp] : llvm::zip_equal(iterators, iterSpaceTps)) { + IterSpaceType spaceTp = llvm::dyn_cast(tp); + if (!spaceTp) + return parser.emitError(parser.getNameLoc(), + "expected sparse_tensor.iter_space type for " + "iteration space operands"); + if (hasUsedCrds && spaceTp.getSpaceDim() != lvlCrdCnt) + return parser.emitError(parser.getNameLoc(), + "mismatch in number of iteration space dimension " + "and specified coordinates"); + it.type = spaceTp.getIteratorType(); + } + + if (hasIterArgs) + if (parser.parseArrowTypeList(state.types)) + return failure(); + + // Resolves input operands. + if (parser.resolveOperands(spaces, iterSpaceTps, parser.getNameLoc(), + state.operands)) + return failure(); + + if (hasIterArgs) { + unsigned numCrds = crdUsedLvlSet.count(); + // Strip off leading args that used for coordinates. + MutableArrayRef args = MutableArrayRef(iterArgs).drop_front(numCrds); + if (args.size() != initArgs.size() || args.size() != state.types.size()) { + return parser.emitError( + parser.getNameLoc(), + "mismatch in number of iteration arguments and return values"); + } + + for (auto [it, init, tp] : llvm::zip_equal(args, initArgs, state.types)) { + it.type = tp; + if (parser.resolveOperand(init, tp, state.operands)) + return failure(); + } + } + return success(); +} + LogicalResult ExtractIterSpaceOp::inferReturnTypes( MLIRContext *ctx, std::optional loc, ValueRange ops, DictionaryAttr attr, OpaqueProperties prop, RegionRange region, @@ -2153,7 +2253,7 @@ LogicalResult ExtractIterSpaceOp::verify() { } if (pIter) { - IterSpaceType spaceTp = getResultSpace().getType(); + IterSpaceType spaceTp = getExtractedSpace().getType(); if (pIter.getType().getEncoding() != spaceTp.getEncoding()) return emitOpError( "mismatch in parent iterator encoding and iteration space encoding."); @@ -2166,6 +2266,161 @@ LogicalResult ExtractIterSpaceOp::verify() { return success(); } +ParseResult IterateOp::parse(OpAsmParser &parser, OperationState &result) { + OpAsmParser::Argument iterator; + OpAsmParser::UnresolvedOperand iterSpace; + + SmallVector iters, iterArgs; + if (parseSparseSpaceLoop(parser, result, iters, iterArgs)) + return failure(); + if (iters.size() != 1) + return parser.emitError(parser.getNameLoc(), + "expected only one iterator/iteration space"); + + iters.append(iterArgs); + Region *body = result.addRegion(); + if (parser.parseRegion(*body, iters)) + return failure(); + + IterateOp::ensureTerminator(*body, parser.getBuilder(), result.location); + + // Parse the optional attribute list. + if (parser.parseOptionalAttrDict(result.attributes)) + return failure(); + + return success(); +} + +/// Prints the initialization list in the form of +/// (%inner = %outer, %inner2 = %outer2, <...>) +/// where 'inner' values are assumed to be region arguments and 'outer' values +/// are regular SSA values. +static void printInitializationList(OpAsmPrinter &p, + Block::BlockArgListType blocksArgs, + ValueRange initializers, + StringRef prefix = "") { + assert(blocksArgs.size() == initializers.size() && + "expected same length of arguments and initializers"); + if (initializers.empty()) + return; + + p << prefix << '('; + llvm::interleaveComma(llvm::zip(blocksArgs, initializers), p, [&](auto it) { + p << std::get<0>(it) << " = " << std::get<1>(it); + }); + p << ")"; +} + +static void printUsedCrdsList(OpAsmPrinter &p, unsigned spaceDim, + Block::BlockArgListType blocksArgs, + LevelSet crdUsedLvls) { + if (crdUsedLvls.empty()) + return; + + p << " at("; + for (unsigned i = 0; i < spaceDim; i++) { + if (crdUsedLvls[i]) { + p << blocksArgs.front(); + blocksArgs = blocksArgs.drop_front(); + } else { + p << "_"; + } + if (i != spaceDim - 1) + p << ", "; + } + assert(blocksArgs.empty()); + p << ")"; +} + +void IterateOp::print(OpAsmPrinter &p) { + p << " " << getIterator() << " in " << getIterSpace(); + printUsedCrdsList(p, getSpaceDim(), getCrds(), getCrdUsedLvls()); + printInitializationList(p, getRegionIterArgs(), getInitArgs(), " iter_args"); + + p << " : " << getIterSpace().getType() << " "; + if (!getInitArgs().empty()) + p << "-> (" << getInitArgs().getTypes() << ") "; + + p.printRegion(getRegion(), /*printEntryBlockArgs=*/false, + /*printBlockTerminators=*/!getInitArgs().empty()); +} + +LogicalResult IterateOp::verify() { + if (getInitArgs().size() != getNumResults()) { + return emitOpError( + "mismatch in number of loop-carried values and defined values"); + } + return success(); +} + +LogicalResult IterateOp::verifyRegions() { + if (getIterator().getType() != getIterSpace().getType().getIteratorType()) + return emitOpError("mismatch in iterator and iteration space type"); + if (getNumRegionIterArgs() != getNumResults()) + return emitOpError( + "mismatch in number of basic block args and defined values"); + + auto initArgs = getInitArgs(); + auto iterArgs = getRegionIterArgs(); + auto yieldVals = getYieldedValues(); + auto opResults = getResults(); + if (!llvm::all_equal({initArgs.size(), iterArgs.size(), yieldVals.size(), + opResults.size()})) { + return emitOpError() << "number mismatch between iter args and results."; + } + + for (auto [i, init, iter, yield, ret] : + llvm::enumerate(initArgs, iterArgs, yieldVals, opResults)) { + if (init.getType() != ret.getType()) + return emitOpError() << "types mismatch between " << i + << "th iter operand and defined value"; + if (iter.getType() != ret.getType()) + return emitOpError() << "types mismatch between " << i + << "th iter region arg and defined value"; + if (yield.getType() != ret.getType()) + return emitOpError() << "types mismatch between " << i + << "th yield value and defined value"; + } + + return success(); +} + +/// OpInterfaces' methods implemented by IterateOp. +SmallVector IterateOp::getLoopRegions() { return {&getRegion()}; } + +MutableArrayRef IterateOp::getInitsMutable() { + return getInitArgsMutable(); +} + +Block::BlockArgListType IterateOp::getRegionIterArgs() { + return getRegion().getArguments().take_back(getNumRegionIterArgs()); +} + +std::optional> IterateOp::getYieldedValuesMutable() { + return cast( + getRegion().getBlocks().front().getTerminator()) + .getResultsMutable(); +} + +std::optional IterateOp::getLoopResults() { return getResults(); } + +OperandRange IterateOp::getEntrySuccessorOperands(RegionBranchPoint point) { + return getInitArgs(); +} + +void IterateOp::getSuccessorRegions(RegionBranchPoint point, + SmallVectorImpl ®ions) { + // Both the operation itself and the region may be branching into the body or + // back into the operation itself. + regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs())); + // It is possible for loop not to enter the body. + regions.push_back(RegionSuccessor(getResults())); +} + +//===----------------------------------------------------------------------===// +// Sparse Tensor Dialect Setups. +//===----------------------------------------------------------------------===// + /// Materialize a single constant operation from a given attribute value with /// the desired resultant type. Operation *SparseTensorDialect::materializeConstant(OpBuilder &builder, diff --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir index 3fa696e1600a93..eb0dc01be25b93 100644 --- a/mlir/test/Dialect/SparseTensor/invalid.mlir +++ b/mlir/test/Dialect/SparseTensor/invalid.mlir @@ -1025,6 +1025,7 @@ func.func @sparse_print(%arg0: tensor<10x10xf64>) { func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse_tensor.iterator<#COO, lvls = 2>) { // expected-error@+1 {{'sparse_tensor.extract_iteration_space' expect larger level upper bound than lower bound}} %l1 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 2 to 0 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 2> + -> !sparse_tensor.iter_space<#COO, lvls = 0 to 2> return } @@ -1040,6 +1041,7 @@ func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse_tensor.iterator<#COO, lvls = 0>) { // expected-error@+1 {{'sparse_tensor.extract_iteration_space' op parent iterator should be specified iff level lower bound equals 0}} %l1 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 0 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0> + -> !sparse_tensor.iter_space<#COO, lvls = 1> return } @@ -1054,7 +1056,7 @@ func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>) { // expected-error@+1 {{'sparse_tensor.extract_iteration_space' op parent iterator should be specified iff level lower bound equals 0}} - %l1 = sparse_tensor.extract_iteration_space %sp lvls = 1 : tensor<4x8xf32, #COO> + %l1 = sparse_tensor.extract_iteration_space %sp lvls = 1 : tensor<4x8xf32, #COO> -> !sparse_tensor.iter_space<#COO, lvls = 1> return } @@ -1077,6 +1079,7 @@ func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>) { func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse_tensor.iterator<#CSR, lvls = 0>) { // expected-error@+1 {{'sparse_tensor.extract_iteration_space' op mismatch in parent iterator encoding and iteration space encoding.}} %l1 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 1 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#CSR, lvls = 0> + -> !sparse_tensor.iter_space<#COO, lvls = 1> return } @@ -1092,5 +1095,63 @@ func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse_tensor.iterator<#COO, lvls = 0>) { // expected-error@+1 {{'sparse_tensor.extract_iteration_space' op parent iterator should be used to extract an iteration space from a consecutive level.}} %l1 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 2 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0> + -> !sparse_tensor.iter_space<#COO, lvls = 2> return } + + +// ----- + +#COO = #sparse_tensor.encoding<{ + map = (i, j) -> ( + i : compressed(nonunique), + j : singleton(soa) + ) +}> + +func.func @sparse_iterate(%sp : tensor<4x8xf32, #COO>, %i : index, %j : index) -> index { + %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO> -> !sparse_tensor.iter_space<#COO, lvls = 0> + // expected-error @+1 {{'sparse_tensor.iterate' op different number of region iter_args and yielded values: 2 != 1}} + %r1, %r2 = sparse_tensor.iterate %it1 in %l1 at (%crd) iter_args(%si = %i, %sj = %j): !sparse_tensor.iter_space<#COO, lvls = 0> -> (index, index) { + sparse_tensor.yield %si : index + } + return %r1 : index +} + +// ----- + +#COO = #sparse_tensor.encoding<{ + map = (i, j) -> ( + i : compressed(nonunique), + j : singleton(soa) + ) +}> + +// expected-note@+1 {{prior use here}} +func.func @sparse_iterate(%sp : tensor<4x8xf32, #COO>, %i : index) -> f32 { + %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO> -> !sparse_tensor.iter_space<#COO, lvls = 0> + // expected-error @+1 {{use of value '%i' expects different type than prior uses: 'f32' vs 'index'}} + %r1 = sparse_tensor.iterate %it1 in %l1 at (%crd) iter_args(%outer = %i): !sparse_tensor.iter_space<#COO, lvls = 0> -> f32 { + sparse_tensor.yield %outer : f32 + } + return %r1 : f32 +} + +// ----- + +#COO = #sparse_tensor.encoding<{ + map = (i, j) -> ( + i : compressed(nonunique), + j : singleton(soa) + ) +}> + +func.func @sparse_iterate(%sp : tensor<4x8xf32, #COO>, %i : index, %j : index) -> index { + %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO> -> !sparse_tensor.iter_space<#COO, lvls = 0> + // expected-error @+1 {{'sparse_tensor.iterate' op 0-th region iter_arg and 0-th yielded value have different type: 'index' != 'f32'}} + %r1 = sparse_tensor.iterate %it1 in %l1 at (%crd) iter_args(%si = %i): !sparse_tensor.iter_space<#COO, lvls = 0> -> index { + %y = arith.constant 1.0 : f32 + sparse_tensor.yield %y : f32 + } + return %r1 : index +} diff --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir index d34071279e5129..bce0b41a99828a 100644 --- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir +++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir @@ -758,8 +758,37 @@ func.func @sparse_has_runtime() -> i1 { func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse_tensor.iterator<#COO, lvls = 0>) -> (!sparse_tensor.iter_space<#COO, lvls = 0>, !sparse_tensor.iter_space<#COO, lvls = 1>) { // Extracting the iteration space for the first level needs no parent iterator. - %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO> + %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO> -> !sparse_tensor.iter_space<#COO, lvls = 0> // Extracting the iteration space for the second level needs a parent iterator. %l2 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 1 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0> + -> !sparse_tensor.iter_space<#COO, lvls = 1> return %l1, %l2 : !sparse_tensor.iter_space<#COO, lvls = 0>, !sparse_tensor.iter_space<#COO, lvls = 1> } + + +// ----- + +#COO = #sparse_tensor.encoding<{ + map = (i, j) -> ( + i : compressed(nonunique), + j : singleton(soa) + ) +}> + +// CHECK-LABEL: func.func @sparse_iterate( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<4x8xf32, #sparse{{[0-9]*}}>, +// CHECK-SAME: %[[VAL_1:.*]]: index, +// CHECK-SAME: %[[VAL_2:.*]]: index) -> index { +// CHECK: %[[VAL_3:.*]] = sparse_tensor.extract_iteration_space %[[VAL_0]] lvls = 0 : tensor<4x8xf32, #sparse{{[0-9]*}}> +// CHECK: %[[VAL_4:.*]] = sparse_tensor.iterate %[[VAL_5:.*]] in %[[VAL_3]] at(%[[VAL_6:.*]]) iter_args(%[[VAL_7:.*]] = %[[VAL_1]]) : !sparse_tensor.iter_space<#sparse{{[0-9]*}}, lvls = 0> -> (index) { +// CHECK: sparse_tensor.yield %[[VAL_7]] : index +// CHECK: } +// CHECK: return %[[VAL_4]] : index +// CHECK: } +func.func @sparse_iterate(%sp : tensor<4x8xf32, #COO>, %i : index, %j : index) -> index { + %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO> -> !sparse_tensor.iter_space<#COO, lvls = 0> + %r1 = sparse_tensor.iterate %it1 in %l1 at (%crd) iter_args(%outer = %i): !sparse_tensor.iter_space<#COO, lvls = 0 to 1> -> index { + sparse_tensor.yield %outer : index + } + return %r1 : index +} diff --git a/mlir/test/Dialect/SparseTensor/sparse_itertion_licm.mlir b/mlir/test/Dialect/SparseTensor/sparse_itertion_licm.mlir new file mode 100644 index 00000000000000..f70fab3b7251df --- /dev/null +++ b/mlir/test/Dialect/SparseTensor/sparse_itertion_licm.mlir @@ -0,0 +1,27 @@ +// RUN: mlir-opt %s --loop-invariant-code-motion | FileCheck %s + +#CSR = #sparse_tensor.encoding<{ + map = (i, j) -> ( + i : dense, + j : compressed + ) +}> + +// Make sure that pure instructions are hoisted outside the loop. +// +// CHECK: sparse_tensor.values +// CHECK: sparse_tensor.positions +// CHECK: sparse_tensor.coordinate +// CHECK: sparse_tensor.iterate +func.func @sparse_iterate(%sp : tensor) { + %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor + -> !sparse_tensor.iter_space<#CSR, lvls = 0> + sparse_tensor.iterate %it1 in %l1 at (%crd) : !sparse_tensor.iter_space<#CSR, lvls = 0> { + %0 = sparse_tensor.values %sp : tensor to memref + %1 = sparse_tensor.positions %sp { level = 1 : index } : tensor to memref + %2 = sparse_tensor.coordinates %sp { level = 1 : index } : tensor to memref + "test.op"(%0, %1, %2) : (memref, memref, memref) -> () + } + + return +}